/*
 * Decompiled with CFR 0.152.
 */
package mill.perk;

import java.io.BufferedReader;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import mill.common.Classifier;
import mill.common.CommandLineParameters;
import mill.common.Kernel;
import mill.common.LinearKernel;
import mill.common.Nodes;
import mill.common.NodesFactory;
import mill.common.Outcome;
import mill.common.Sample;
import mill.common.Score;
import mill.common.SimpleTokenize;
import mill.common.TrainingParameters;
import mill.perk.MulticlassPerceptronClassifierTrainingParameters;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;

public class MulticlassPerceptronClassifier
extends Classifier {
    private boolean mDoNormalize;
    private boolean mIsDual;
    private Integer mNilCategory;
    private int mFeatureCount;
    private Row[] mRows;
    double mSoftmaxGamma;
    public static final double DEFAULT_SOFTMAX_GAMMA = 1.0E-5;
    static Logger mLog = Logger.getLogger((String)MulticlassPerceptronClassifier.class.getName());

    public MulticlassPerceptronClassifier() {
        this.mIsDual = false;
        this.mRows = null;
        this.mDoNormalize = true;
        this.mSoftmaxGamma = 1.0E-5;
    }

    public MulticlassPerceptronClassifier(Boolean dual, Boolean doNormalize, Double gamma) {
        this.mIsDual = dual != null ? dual : false;
        this.mRows = null;
        if (doNormalize != null) {
            this.mDoNormalize = doNormalize;
        } else {
            doNormalize = true;
        }
        this.mSoftmaxGamma = gamma != null ? gamma : 1.0E-5;
    }

    @Override
    public void train(List<Sample> trainSamples, List<Sample> testSamples, TrainingParameters pars) {
        MulticlassPerceptronClassifierTrainingParameters parameters = (MulticlassPerceptronClassifierTrainingParameters)pars;
        this.mNilCategory = parameters.mNilCategory;
        int maxFeatIndex = -1;
        ListIterator<Sample> it = trainSamples.listIterator();
        while (it.hasNext()) {
            int idx;
            Sample s = it.next();
            if (this.mDoNormalize) {
                s.getNodes().normalize();
            }
            if ((idx = s.getNodes().getMaxIndex()) <= maxFeatIndex) continue;
            maxFeatIndex = idx;
        }
        this.mFeatureCount = maxFeatIndex + 1;
        mLog.error((Object)("Feature count: " + this.mFeatureCount));
        int maxLabel = -1;
        for (Sample sample : trainSamples) {
            if (sample.getLabel() <= maxLabel) continue;
            maxLabel = sample.getLabel();
        }
        int modelCount = maxLabel + 1;
        mLog.warn((Object)("Loaded " + trainSamples.size() + " training examples with " + modelCount + " classes."));
        double[] labelWeights = new double[modelCount];
        for (int i = 0; i < labelWeights.length; ++i) {
            labelWeights[i] = 1.0;
        }
        for (MulticlassPerceptronClassifierTrainingParameters.LabelWeight lw : parameters.mLabelWeights) {
            labelWeights[lw.mLabel] = lw.mWeight;
        }
        this.mRows = new Row[modelCount];
        for (int i = 0; i < modelCount; ++i) {
            if (this.mIsDual) {
                throw new RuntimeException("Dual mode not supported yet!");
            }
            this.mRows[i] = new PrimalRow(this.mFeatureCount);
        }
        for (int epoch = 0; epoch < parameters.mNumberOfEpochs; ++epoch) {
            mLog.error((Object)("Starting epoch " + (epoch + 1) + "..."));
            this.trainEpoch(trainSamples, parameters.mKernel, epoch, parameters.mMinScore, labelWeights);
            mLog.error((Object)("Ended epoch " + (epoch + 1) + "."));
            for (int i = 0; i < this.mRows.length; ++i) {
                mLog.debug((Object)("\tUpdate count for label " + i + ": " + this.mRows[i].getUpdateCount()));
            }
            if (parameters.mModelFile != null) {
                String epochModel = parameters.mModelFile + "." + (epoch + 1);
                try {
                    this.saveModel(epochModel);
                    mLog.error((Object)("Epoch model saved in file: " + epochModel));
                }
                catch (IOException e) {
                    mLog.error((Object)("Failed to save model in file: " + epochModel));
                }
            }
            if (testSamples == null) continue;
            Score scr = this.test(testSamples, this.mNilCategory, parameters.mKernel);
            mLog.error((Object)scr);
            mLog.error((Object)("\tCorrect: " + scr.mCorrect));
            mLog.error((Object)("\tTotal: " + scr.mTotal));
            mLog.error((Object)("\tCorrectNonNil: " + scr.mCorrectNonNil));
            mLog.error((Object)("\tPredictedNonNil: " + scr.mPredictedNonNil));
            mLog.error((Object)("\tTotalNonNil: " + scr.mTotalNonNil));
        }
    }

    private void trainEpoch(List<Sample> trainSamples, Kernel kernel, int epoch, double minScore, double[] labelWeights) {
        for (Sample sample : trainSamples) {
            int i;
            int goldLabel = sample.getLabel();
            double offset = minScore * labelWeights[goldLabel];
            double thisScore = this.mRows[goldLabel].multiply(sample.getNodes(), kernel) - offset;
            LinkedList<Integer> betterLabels = new LinkedList<Integer>();
            for (i = 0; i < this.mRows.length; ++i) {
                double score;
                if (i == goldLabel || !((score = this.mRows[i].multiply(sample.getNodes(), kernel)) >= thisScore)) continue;
                betterLabels.add(i);
            }
            if (!betterLabels.isEmpty()) {
                this.mRows[goldLabel].addVector(sample.getNodes(), 1.0);
                this.mRows[goldLabel].addToWeight(-1);
                double weight = -1.0 / (double)betterLabels.size();
                Iterator i$ = betterLabels.iterator();
                while (i$.hasNext()) {
                    int label = (Integer)i$.next();
                    this.mRows[label].addVector(sample.getNodes(), weight);
                    this.mRows[label].addToWeight(-1);
                }
            }
            for (i = 0; i < this.mRows.length; ++i) {
                this.mRows[i].addToWeight(1);
            }
        }
        for (int i = 0; i < this.mRows.length; ++i) {
            this.mRows[i].makeAverage();
        }
    }

    @Override
    public void predict(Nodes sample, List<Outcome> outcomes, Kernel customKernel) {
        int i;
        if (customKernel == null) {
            throw new RuntimeException("Attempted prediction without kernel!");
        }
        if (this.mDoNormalize) {
            sample.normalize();
        }
        Outcome[] sorted = new Outcome[this.mRows.length];
        for (i = 0; i < this.mRows.length; ++i) {
            double conf = this.mRows[i].predict(sample, customKernel);
            sorted[i] = new Outcome(i, conf);
        }
        MulticlassPerceptronClassifier.sort(sorted);
        this.convertToProbabilities(sorted);
        outcomes.clear();
        for (i = 0; i < sorted.length; ++i) {
            outcomes.add(sorted[i]);
        }
    }

    private void convertToProbabilities(Outcome[] outcomes) {
        int i;
        double sum = 0.0;
        double gamma = this.mSoftmaxGamma;
        for (i = 0; i < outcomes.length; ++i) {
            sum += Math.exp(outcomes[i].getConfidence() * gamma);
        }
        if (sum == 0.0) {
            System.err.println("Invalid normalization factor: " + sum);
            for (i = 0; i < outcomes.length; ++i) {
                System.err.println(outcomes[i] + " ==> " + Math.exp(outcomes[i].getConfidence() * gamma));
            }
            throw new RuntimeException("Invalid normalization factor: " + sum);
        }
        for (i = 0; i < outcomes.length; ++i) {
            double den = Math.exp(outcomes[i].getConfidence() * gamma);
            outcomes[i].setConfidence(den / sum);
        }
    }

    private static void sort(Outcome[] outcomes) {
        for (int i = 0; i < outcomes.length; ++i) {
            for (int j = i + 1; j < outcomes.length; ++j) {
                if (!(outcomes[i].getConfidence() < outcomes[j].getConfidence())) continue;
                Outcome tmp = outcomes[i];
                outcomes[i] = outcomes[j];
                outcomes[j] = tmp;
            }
        }
    }

    @Override
    public void saveModel(String fileName) throws IOException {
        PrintStream os = new PrintStream(new FileOutputStream(fileName));
        os.print(this.mRows.length + " " + this.mIsDual + " ");
        if (this.mNilCategory != null) {
            os.print(this.mNilCategory);
        } else {
            os.print(-1);
        }
        os.println(" " + this.mFeatureCount);
        for (int i = 0; i < this.mRows.length; ++i) {
            this.mRows[i].save(os);
        }
        os.close();
    }

    @Override
    public void loadModel(String fileName, NodesFactory factory) throws IOException {
        BufferedReader is = new BufferedReader(new FileReader(fileName));
        String line = is.readLine();
        ArrayList<String> tokens = SimpleTokenize.tokenize(line);
        if (tokens.size() != 4) {
            throw new RuntimeException("Invalid Perceptron model meta info: " + line);
        }
        int modelCount = Integer.parseInt(tokens.get(0));
        this.mIsDual = Boolean.parseBoolean(tokens.get(1));
        this.mNilCategory = new Integer(tokens.get(2));
        if (this.mNilCategory == -1) {
            this.mNilCategory = null;
        }
        this.mFeatureCount = Integer.parseInt(tokens.get(3));
        this.mRows = new Row[modelCount];
        for (int i = 0; i < this.mRows.length; ++i) {
            this.mRows[i] = new PrimalRow(is, this.mFeatureCount);
        }
        is.close();
    }

    public static void usage() {
        System.err.println("java mill.perk.MulticlassPerceptron \\\n\t--train=<training samples> \\\n\t--test=<testing samples> \\\n\t--model=<model file> \\\n\t--epochs=<number of epochs> \\\n\t--nil=<number of the NIL category>");
        System.exit(1);
    }

    public static void main(String[] args) throws Exception {
        String modelFile;
        CommandLineParameters.read(args);
        String log4j = CommandLineParameters.getString("log4j");
        PropertyConfigurator.configure((String)log4j);
        String trainFile = CommandLineParameters.getString("train");
        String testFile = CommandLineParameters.getString("test");
        if (trainFile == null && testFile == null) {
            MulticlassPerceptronClassifier.usage();
        }
        if ((modelFile = CommandLineParameters.getString("model")) == null) {
            MulticlassPerceptronClassifier.usage();
        }
        Integer nil = CommandLineParameters.getInteger("nil");
        Integer epochs = CommandLineParameters.getInteger("epochs");
        Double mins = CommandLineParameters.getDouble("min-score");
        Double nilp = CommandLineParameters.getDouble("nil-penalty");
        NodesFactory factory = new NodesFactory();
        List<Sample> trainSamples = null;
        if (trainFile != null) {
            trainSamples = Sample.readSamples(trainFile, factory);
        }
        List<Sample> testSamples = null;
        if (testFile != null) {
            testSamples = Sample.readSamples(testFile, factory);
        }
        MulticlassPerceptronClassifier cls = new MulticlassPerceptronClassifier();
        LinearKernel kernel = new LinearKernel();
        if (trainSamples != null) {
            MulticlassPerceptronClassifierTrainingParameters pars = new MulticlassPerceptronClassifierTrainingParameters(nil, kernel, epochs, mins, modelFile);
            if (nil != null && nilp != null) {
                pars.addLabelWeight(nil, 1.0 / nilp);
            }
            cls.train(trainSamples, testSamples, pars);
            cls.saveModel(modelFile);
        } else if (testSamples != null) {
            cls.loadModel(modelFile, null);
            mLog.error((Object)"Started testing...");
            Score score = cls.test(testSamples, nil, kernel);
            mLog.error((Object)score);
            mLog.error((Object)"Done testing.");
        }
    }

    class PrimalRow
    implements Row {
        private Nodes mLastVector;
        private int mLastWeight;
        private Nodes mAvgVector;
        int mFeatCount;
        int mUpdateCount;

        public PrimalRow(int featCount) {
            this.mFeatCount = featCount;
            this.reset();
        }

        public PrimalRow(BufferedReader is, int featCount) throws IOException {
            this.mFeatCount = featCount;
            this.mUpdateCount = -1;
            this.mLastWeight = -1;
            String line = is.readLine();
            this.mAvgVector = new Nodes(line);
            this.mAvgVector.expand(this.mFeatCount);
        }

        @Override
        public void reset() {
            this.mLastVector = new Nodes(this.mFeatCount);
            this.mAvgVector = new Nodes(this.mFeatCount);
            this.mUpdateCount = 0;
            this.mLastWeight = 0;
        }

        @Override
        public void makeAverage() {
            this.mAvgVector.sumExpanded(this.mLastVector, this.mLastWeight);
        }

        @Override
        public void addVector(Nodes x, double sign) {
            this.makeAverage();
            this.mLastVector.sumNonExpanded(x, sign);
            this.mLastWeight = 0;
            ++this.mUpdateCount;
        }

        @Override
        public int getUpdateCount() {
            return this.mUpdateCount;
        }

        @Override
        public void addToWeight(int i) {
            this.mLastWeight += i;
        }

        @Override
        public double multiply(Nodes vector, Kernel kernel) {
            return kernel.multiply(this.mLastVector, vector);
        }

        @Override
        public double predict(Nodes vector, Kernel kernel) {
            return kernel.multiply(this.mAvgVector, vector);
        }

        @Override
        public void save(PrintStream os) {
            os.println(this.mAvgVector);
        }
    }

    static interface Row {
        public void reset();

        public void makeAverage();

        public void addVector(Nodes var1, double var2);

        public int getUpdateCount();

        public void addToWeight(int var1);

        public double multiply(Nodes var1, Kernel var2);

        public double predict(Nodes var1, Kernel var2);

        public void save(PrintStream var1);
    }
}

