package mill.perk;

import java.io.BufferedReader;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import mill.common.Classifier;
import mill.common.CommandLineParameters;
import mill.common.Kernel;
import mill.common.LinearKernel;
import mill.common.Nodes;
import mill.common.NodesFactory;
import mill.common.Outcome;
import mill.common.Sample;
import mill.common.Score;
import mill.common.SimpleTokenize;
import mill.common.StringDictionary;
import mill.common.TrainingParameters;
import mill.perk.MulticlassPerceptronClassifierTrainingParameters;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;

/* loaded from: input_file:mill/perk/MulticlassPerceptronClassifier.class */
public class MulticlassPerceptronClassifier extends Classifier {
    private boolean mDoNormalize;
    private boolean mIsDual;
    private Integer mNilCategory;
    private int mFeatureCount;
    private Row[] mRows;
    double mSoftmaxGamma;
    public static final double DEFAULT_SOFTMAX_GAMMA = 1.0E-5d;
    static Logger mLog = Logger.getLogger(MulticlassPerceptronClassifier.class.getName());

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:mill/perk/MulticlassPerceptronClassifier$PrimalRow.class */
    public class PrimalRow implements Row {
        private Nodes mLastVector;
        private int mLastWeight;
        private Nodes mAvgVector;
        int mFeatCount;
        int mUpdateCount;

        public PrimalRow(int i) {
            this.mFeatCount = i;
            reset();
        }

        public PrimalRow(BufferedReader bufferedReader, int i) throws IOException {
            this.mFeatCount = i;
            this.mUpdateCount = -1;
            this.mLastWeight = -1;
            this.mAvgVector = new Nodes(bufferedReader.readLine());
            this.mAvgVector.expand(this.mFeatCount);
        }

        @Override // mill.perk.MulticlassPerceptronClassifier.Row
        public void reset() {
            this.mLastVector = new Nodes(this.mFeatCount);
            this.mAvgVector = new Nodes(this.mFeatCount);
            this.mUpdateCount = 0;
            this.mLastWeight = 0;
        }

        @Override // mill.perk.MulticlassPerceptronClassifier.Row
        public void makeAverage() {
            this.mAvgVector.sumExpanded(this.mLastVector, this.mLastWeight);
        }

        @Override // mill.perk.MulticlassPerceptronClassifier.Row
        public void addVector(Nodes nodes, double d) {
            makeAverage();
            this.mLastVector.sumNonExpanded(nodes, d);
            this.mLastWeight = 0;
            this.mUpdateCount++;
        }

        @Override // mill.perk.MulticlassPerceptronClassifier.Row
        public int getUpdateCount() {
            return this.mUpdateCount;
        }

        @Override // mill.perk.MulticlassPerceptronClassifier.Row
        public void addToWeight(int i) {
            this.mLastWeight += i;
        }

        @Override // mill.perk.MulticlassPerceptronClassifier.Row
        public double multiply(Nodes nodes, Kernel kernel) {
            return kernel.multiply(this.mLastVector, nodes);
        }

        @Override // mill.perk.MulticlassPerceptronClassifier.Row
        public double predict(Nodes nodes, Kernel kernel) {
            return kernel.multiply(this.mAvgVector, nodes);
        }

        @Override // mill.perk.MulticlassPerceptronClassifier.Row
        public void save(PrintStream printStream) {
            printStream.println(this.mAvgVector);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:mill/perk/MulticlassPerceptronClassifier$Row.class */
    public interface Row {
        void reset();

        void makeAverage();

        void addVector(Nodes nodes, double d);

        int getUpdateCount();

        void addToWeight(int i);

        double multiply(Nodes nodes, Kernel kernel);

        double predict(Nodes nodes, Kernel kernel);

        void save(PrintStream printStream);
    }

    public MulticlassPerceptronClassifier() {
        this.mIsDual = false;
        this.mRows = null;
        this.mDoNormalize = true;
        this.mSoftmaxGamma = 1.0E-5d;
    }

    public MulticlassPerceptronClassifier(Boolean bool, Boolean bool2, Double d) {
        if (bool != null) {
            this.mIsDual = bool.booleanValue();
        } else {
            this.mIsDual = false;
        }
        this.mRows = null;
        if (bool2 != null) {
            this.mDoNormalize = bool2.booleanValue();
        }
        if (d != null) {
            this.mSoftmaxGamma = d.doubleValue();
        } else {
            this.mSoftmaxGamma = 1.0E-5d;
        }
    }

    @Override // mill.common.Classifier
    public void train(List<Sample> list, List<Sample> list2, TrainingParameters trainingParameters) {
        MulticlassPerceptronClassifierTrainingParameters multiclassPerceptronClassifierTrainingParameters = (MulticlassPerceptronClassifierTrainingParameters) trainingParameters;
        this.mNilCategory = multiclassPerceptronClassifierTrainingParameters.mNilCategory;
        int i = -1;
        ListIterator<Sample> listIterator = list.listIterator();
        while (listIterator.hasNext()) {
            Sample next = listIterator.next();
            if (this.mDoNormalize) {
                next.getNodes().normalize();
            }
            int maxIndex = next.getNodes().getMaxIndex();
            if (maxIndex > i) {
                i = maxIndex;
            }
        }
        this.mFeatureCount = i + 1;
        mLog.error("Feature count: " + this.mFeatureCount);
        int i2 = -1;
        for (Sample sample : list) {
            if (sample.getLabel() > i2) {
                i2 = sample.getLabel();
            }
        }
        int i3 = i2 + 1;
        mLog.warn("Loaded " + list.size() + " training examples with " + i3 + " classes.");
        double[] dArr = new double[i3];
        for (int i4 = 0; i4 < dArr.length; i4++) {
            dArr[i4] = 1.0d;
        }
        for (MulticlassPerceptronClassifierTrainingParameters.LabelWeight labelWeight : multiclassPerceptronClassifierTrainingParameters.mLabelWeights) {
            dArr[labelWeight.mLabel] = labelWeight.mWeight;
        }
        this.mRows = new Row[i3];
        for (int i5 = 0; i5 < i3; i5++) {
            if (this.mIsDual) {
                throw new RuntimeException("Dual mode not supported yet!");
            }
            this.mRows[i5] = new PrimalRow(this.mFeatureCount);
        }
        for (int i6 = 0; i6 < multiclassPerceptronClassifierTrainingParameters.mNumberOfEpochs; i6++) {
            mLog.error("Starting epoch " + (i6 + 1) + "...");
            trainEpoch(list, multiclassPerceptronClassifierTrainingParameters.mKernel, i6, multiclassPerceptronClassifierTrainingParameters.mMinScore, dArr);
            mLog.error("Ended epoch " + (i6 + 1) + ".");
            for (int i7 = 0; i7 < this.mRows.length; i7++) {
                mLog.debug("\tUpdate count for label " + i7 + ": " + this.mRows[i7].getUpdateCount());
            }
            if (multiclassPerceptronClassifierTrainingParameters.mModelFile != null) {
                String str = multiclassPerceptronClassifierTrainingParameters.mModelFile + "." + (i6 + 1);
                try {
                    saveModel(str);
                    mLog.error("Epoch model saved in file: " + str);
                } catch (IOException e) {
                    mLog.error("Failed to save model in file: " + str);
                }
            }
            if (list2 != null) {
                Score test = test(list2, this.mNilCategory, multiclassPerceptronClassifierTrainingParameters.mKernel);
                mLog.error(test);
                mLog.error("\tCorrect: " + test.mCorrect);
                mLog.error("\tTotal: " + test.mTotal);
                mLog.error("\tCorrectNonNil: " + test.mCorrectNonNil);
                mLog.error("\tPredictedNonNil: " + test.mPredictedNonNil);
                mLog.error("\tTotalNonNil: " + test.mTotalNonNil);
            }
        }
    }

    private void trainEpoch(List<Sample> list, Kernel kernel, int i, double d, double[] dArr) {
        for (Sample sample : list) {
            int label = sample.getLabel();
            double multiply = this.mRows[label].multiply(sample.getNodes(), kernel) - (d * dArr[label]);
            LinkedList linkedList = new LinkedList();
            for (int i2 = 0; i2 < this.mRows.length; i2++) {
                if (i2 != label && this.mRows[i2].multiply(sample.getNodes(), kernel) >= multiply) {
                    linkedList.add(Integer.valueOf(i2));
                }
            }
            if (!linkedList.isEmpty()) {
                this.mRows[label].addVector(sample.getNodes(), 1.0d);
                this.mRows[label].addToWeight(-1);
                double size = (-1.0d) / linkedList.size();
                Iterator it = linkedList.iterator();
                while (it.hasNext()) {
                    int intValue = ((Integer) it.next()).intValue();
                    this.mRows[intValue].addVector(sample.getNodes(), size);
                    this.mRows[intValue].addToWeight(-1);
                }
            }
            for (int i3 = 0; i3 < this.mRows.length; i3++) {
                this.mRows[i3].addToWeight(1);
            }
        }
        for (int i4 = 0; i4 < this.mRows.length; i4++) {
            this.mRows[i4].makeAverage();
        }
    }

    @Override // mill.common.Classifier
    public void predict(Nodes nodes, List<Outcome> list, Kernel kernel) {
        if (kernel == null) {
            throw new RuntimeException("Attempted prediction without kernel!");
        }
        if (this.mDoNormalize) {
            nodes.normalize();
        }
        Outcome[] outcomeArr = new Outcome[this.mRows.length];
        for (int i = 0; i < this.mRows.length; i++) {
            outcomeArr[i] = new Outcome(i, this.mRows[i].predict(nodes, kernel));
        }
        sort(outcomeArr);
        convertToProbabilities(outcomeArr);
        list.clear();
        for (Outcome outcome : outcomeArr) {
            list.add(outcome);
        }
    }

    private void convertToProbabilities(Outcome[] outcomeArr) {
        double d = 0.0d;
        double d2 = this.mSoftmaxGamma;
        for (Outcome outcome : outcomeArr) {
            d += Math.exp(outcome.getConfidence() * d2);
        }
        if (d == 0.0d) {
            System.err.println("Invalid normalization factor: " + d);
            for (int i = 0; i < outcomeArr.length; i++) {
                System.err.println(outcomeArr[i] + " ==> " + Math.exp(outcomeArr[i].getConfidence() * d2));
            }
            throw new RuntimeException("Invalid normalization factor: " + d);
        }
        for (int i2 = 0; i2 < outcomeArr.length; i2++) {
            outcomeArr[i2].setConfidence(Math.exp(outcomeArr[i2].getConfidence() * d2) / d);
        }
    }

    private static void sort(Outcome[] outcomeArr) {
        for (int i = 0; i < outcomeArr.length; i++) {
            for (int i2 = i + 1; i2 < outcomeArr.length; i2++) {
                if (outcomeArr[i].getConfidence() < outcomeArr[i2].getConfidence()) {
                    Outcome outcome = outcomeArr[i];
                    outcomeArr[i] = outcomeArr[i2];
                    outcomeArr[i2] = outcome;
                }
            }
        }
    }

    @Override // mill.common.Classifier
    public void saveModel(String str) throws IOException {
        PrintStream printStream = new PrintStream(new FileOutputStream(str));
        printStream.print(this.mRows.length + " " + this.mIsDual + " ");
        if (this.mNilCategory != null) {
            printStream.print(this.mNilCategory);
        } else {
            printStream.print(-1);
        }
        printStream.println(" " + this.mFeatureCount);
        for (int i = 0; i < this.mRows.length; i++) {
            this.mRows[i].save(printStream);
        }
        printStream.close();
    }

    @Override // mill.common.Classifier
    public void loadModel(String str, NodesFactory nodesFactory) throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
        String readLine = bufferedReader.readLine();
        ArrayList<String> arrayList = SimpleTokenize.tokenize(readLine);
        if (arrayList.size() != 4) {
            throw new RuntimeException("Invalid Perceptron model meta info: " + readLine);
        }
        int parseInt = Integer.parseInt(arrayList.get(0));
        this.mIsDual = Boolean.parseBoolean(arrayList.get(1));
        this.mNilCategory = new Integer(arrayList.get(2));
        if (this.mNilCategory.intValue() == -1) {
            this.mNilCategory = null;
        }
        this.mFeatureCount = Integer.parseInt(arrayList.get(3));
        this.mRows = new Row[parseInt];
        for (int i = 0; i < this.mRows.length; i++) {
            this.mRows[i] = new PrimalRow(bufferedReader, this.mFeatureCount);
        }
        bufferedReader.close();
    }

    public static void usage() {
        System.err.println("java mill.perk.MulticlassPerceptron \\\n\t--train=<training samples> \\\n\t--test=<testing samples> \\\n\t--model=<model file> \\\n\t--epochs=<number of epochs> \\\n\t--nil=<number of the NIL category>");
        System.exit(1);
    }

    public static void main(String[] strArr) throws Exception {
        CommandLineParameters.read(strArr);
        PropertyConfigurator.configure(CommandLineParameters.getString("log4j"));
        String string = CommandLineParameters.getString("train");
        String string2 = CommandLineParameters.getString("test");
        if (string == null && string2 == null) {
            usage();
        }
        String string3 = CommandLineParameters.getString("model");
        if (string3 == null) {
            usage();
        }
        Integer integer = CommandLineParameters.getInteger(StringDictionary.NIL_VALUE);
        Integer integer2 = CommandLineParameters.getInteger("epochs");
        Double d = CommandLineParameters.getDouble("min-score");
        Double d2 = CommandLineParameters.getDouble("nil-penalty");
        NodesFactory nodesFactory = new NodesFactory();
        List<Sample> list = null;
        if (string != null) {
            list = Sample.readSamples(string, nodesFactory);
        }
        List<Sample> list2 = null;
        if (string2 != null) {
            list2 = Sample.readSamples(string2, nodesFactory);
        }
        MulticlassPerceptronClassifier multiclassPerceptronClassifier = new MulticlassPerceptronClassifier();
        LinearKernel linearKernel = new LinearKernel();
        if (list != null) {
            MulticlassPerceptronClassifierTrainingParameters multiclassPerceptronClassifierTrainingParameters = new MulticlassPerceptronClassifierTrainingParameters(integer, linearKernel, integer2, d, string3);
            if (integer != null && d2 != null) {
                multiclassPerceptronClassifierTrainingParameters.addLabelWeight(integer.intValue(), 1.0d / d2.doubleValue());
            }
            multiclassPerceptronClassifier.train(list, list2, multiclassPerceptronClassifierTrainingParameters);
            multiclassPerceptronClassifier.saveModel(string3);
            return;
        }
        if (list2 != null) {
            multiclassPerceptronClassifier.loadModel(string3, null);
            mLog.error("Started testing...");
            mLog.error(multiclassPerceptronClassifier.test(list2, integer, linearKernel));
            mLog.error("Done testing.");
        }
    }
}
