package mill.perk;

import java.io.BufferedReader;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Set;
import mill.common.Classifier;
import mill.common.CommandLineParameters;
import mill.common.Kernel;
import mill.common.LinearKernel;
import mill.common.Nodes;
import mill.common.NodesFactory;
import mill.common.Outcome;
import mill.common.PolyKernel;
import mill.common.Sample;
import mill.common.Score;
import mill.common.SimpleTokenize;
import mill.common.StringDictionary;
import mill.common.TrainingParameters;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;

/* loaded from: input_file:mill/perk/OneVsRestPerceptronClassifier.class */
public class OneVsRestPerceptronClassifier extends Classifier {
    private boolean mDoNormalize;
    private boolean mIsDual;
    private BinaryPerceptronClassifier[] mModels;
    private Integer mNilCategory;
    private int mFeatureCount;
    private Double mMinProbThresh;
    private double mSoftmaxGamma;
    public static final double DEFAULT_SOFTMAX_GAMMA = 1.0E-5d;
    static Logger mLog = Logger.getLogger(OneVsRestPerceptronClassifier.class.getName());

    public OneVsRestPerceptronClassifier() {
        this.mIsDual = false;
        this.mMinProbThresh = null;
        this.mDoNormalize = true;
        this.mSoftmaxGamma = 1.0E-5d;
    }

    public OneVsRestPerceptronClassifier(Boolean bool, Double d, Boolean bool2, Double d2) {
        if (bool != null) {
            this.mIsDual = bool.booleanValue();
        } else {
            this.mIsDual = false;
        }
        this.mMinProbThresh = d;
        if (bool2 != null) {
            this.mDoNormalize = bool2.booleanValue();
        }
        if (d2 != null) {
            this.mSoftmaxGamma = d2.doubleValue();
        } else {
            this.mSoftmaxGamma = 1.0E-5d;
        }
    }

    @Override // mill.common.Classifier
    public void train(List<Sample> list, List<Sample> list2, TrainingParameters trainingParameters) {
        OneVsRestPerceptronClassifierTrainingParameters oneVsRestPerceptronClassifierTrainingParameters = (OneVsRestPerceptronClassifierTrainingParameters) trainingParameters;
        this.mNilCategory = oneVsRestPerceptronClassifierTrainingParameters.mNilCategory;
        int i = -1;
        ListIterator<Sample> listIterator = list.listIterator();
        while (listIterator.hasNext()) {
            Sample next = listIterator.next();
            if (this.mDoNormalize) {
                next.getNodes().normalize();
            }
            int maxIndex = next.getNodes().getMaxIndex();
            if (maxIndex > i) {
                i = maxIndex;
            }
        }
        this.mFeatureCount = i + 1;
        mLog.error("Feature count: " + this.mFeatureCount);
        if (this.mDoNormalize && list2 != null) {
            ListIterator<Sample> listIterator2 = list2.listIterator();
            while (listIterator2.hasNext()) {
                listIterator2.next().getNodes().normalize();
            }
        }
        HashMap hashMap = new HashMap();
        ListIterator<Sample> listIterator3 = list.listIterator();
        while (listIterator3.hasNext()) {
            Sample next2 = listIterator3.next();
            if (!hashMap.containsKey(Integer.valueOf(next2.getLabel()))) {
                hashMap.put(Integer.valueOf(next2.getLabel()), Integer.valueOf(hashMap.size()));
            }
        }
        mLog.warn("Loaded " + list.size() + " training examples with " + hashMap.size() + " labels.");
        Set keySet = hashMap.keySet();
        this.mModels = new BinaryPerceptronClassifier[keySet.size()];
        int i2 = 0;
        Iterator it = keySet.iterator();
        while (it.hasNext()) {
            this.mModels[i2] = new BinaryPerceptronClassifier(Integer.valueOf(((Integer) it.next()).intValue()), null, this.mIsDual, this.mFeatureCount);
            this.mModels[i2].makeCache(list.size());
            i2++;
        }
        for (int i3 = 0; i3 < this.mModels.length; i3++) {
            if (this.mNilCategory != null && this.mModels[i3].getPositiveLabel() == this.mNilCategory.intValue()) {
                this.mModels[i3].setMargins(0.0d, 0.0d);
            } else if (oneVsRestPerceptronClassifierTrainingParameters.mType == 0) {
                this.mModels[i3].setMargins(oneVsRestPerceptronClassifierTrainingParameters.mPosCoverage, oneVsRestPerceptronClassifierTrainingParameters.mNegCoverage);
            } else if (oneVsRestPerceptronClassifierTrainingParameters.mType == 2) {
                this.mModels[i3].setMarginStartValues(oneVsRestPerceptronClassifierTrainingParameters.mPosCoverage, oneVsRestPerceptronClassifierTrainingParameters.mNegCoverage);
            } else if (oneVsRestPerceptronClassifierTrainingParameters.mType == 1 || oneVsRestPerceptronClassifierTrainingParameters.mType == 3) {
                this.mModels[i3].calculateMargins(list, oneVsRestPerceptronClassifierTrainingParameters.mKernel, oneVsRestPerceptronClassifierTrainingParameters.mPosCoverage, oneVsRestPerceptronClassifierTrainingParameters.mNegCoverage);
            }
        }
        ArrayList arrayList = new ArrayList();
        boolean z = oneVsRestPerceptronClassifierTrainingParameters.mType == 2 || oneVsRestPerceptronClassifierTrainingParameters.mType == 3;
        for (int i4 = 0; i4 < oneVsRestPerceptronClassifierTrainingParameters.mNumberOfEpochs; i4++) {
            mLog.error("Starting epoch " + (i4 + 1) + "...");
            int i5 = 0;
            for (int i6 = 0; i6 < this.mModels.length; i6++) {
                boolean z2 = z;
                if (this.mNilCategory != null && this.mModels[i6].getPositiveLabel() == this.mNilCategory.intValue()) {
                    z2 = false;
                }
                this.mModels[i6].trainEpoch(list, oneVsRestPerceptronClassifierTrainingParameters.mKernel, i4, z2);
                i5 += this.mModels[i6].getModel().getSupportVectorCount();
            }
            mLog.error("Ended epoch " + (i4 + 1) + ". Total number of SVs: " + i5);
            if (list2 != null) {
                Score test = test(list2, this.mNilCategory, oneVsRestPerceptronClassifierTrainingParameters.mKernel);
                mLog.error(test);
                mLog.error("\tCorrect: " + test.mCorrect);
                mLog.error("\tTotal: " + test.mTotal);
                mLog.error("\tCorrectNonNil: " + test.mCorrectNonNil);
                mLog.error("\tPredictedNonNil: " + test.mPredictedNonNil);
                mLog.error("\tTotalNonNil: " + test.mTotalNonNil);
                arrayList.add(test);
            }
            if (oneVsRestPerceptronClassifierTrainingParameters.mModelPrefix != null) {
                String str = oneVsRestPerceptronClassifierTrainingParameters.mModelPrefix + "." + (i4 + 1);
                try {
                    saveModel(str);
                    mLog.error("Epoch model saved in file: " + str);
                } catch (IOException e) {
                    mLog.error("Failed to save model in file: " + str);
                }
            }
        }
        if (arrayList.isEmpty() || oneVsRestPerceptronClassifierTrainingParameters.mScoreFile == null) {
            return;
        }
        try {
            PrintStream printStream = new PrintStream(new FileOutputStream(oneVsRestPerceptronClassifierTrainingParameters.mScoreFile));
            for (int i7 = 0; i7 < arrayList.size(); i7++) {
                printStream.println((i7 + 1) + " " + arrayList.get(i7));
            }
            printStream.close();
        } catch (IOException e2) {
            mLog.error("Failed to save scores in file: " + oneVsRestPerceptronClassifierTrainingParameters.mScoreFile);
        }
    }

    @Override // mill.common.Classifier
    public void predict(Nodes nodes, List<Outcome> list, Kernel kernel) {
        if (kernel == null) {
            throw new RuntimeException("Attempted prediction without kernel!");
        }
        if (this.mDoNormalize) {
            nodes.normalize();
        }
        Outcome[] outcomeArr = new Outcome[this.mModels.length];
        for (int i = 0; i < outcomeArr.length; i++) {
            outcomeArr[i] = new Outcome(this.mModels[i - 0].getPositiveLabel(), this.mModels[i - 0].predict(nodes, kernel));
        }
        sort(outcomeArr);
        convertToProbabilities(outcomeArr);
        list.clear();
        for (Outcome outcome : outcomeArr) {
            list.add(outcome);
        }
        if (this.mMinProbThresh == null || this.mNilCategory == null || list.size() <= 0 || list.get(0).getLabel() != this.mNilCategory.intValue() || list.size() <= 1) {
            return;
        }
        int i2 = 0;
        for (int i3 = 0; i3 < list.size() - 1 && list.get(i3 + 1).getConfidence() > this.mMinProbThresh.doubleValue(); i3++) {
            Outcome outcome2 = list.get(i3);
            Outcome outcome3 = list.get(i3 + 1);
            double confidence = outcome2.getConfidence();
            outcome2.setConfidence(outcome3.getConfidence());
            outcome3.setConfidence(confidence);
            outcome2.setLabel(outcome3.getLabel());
            outcome3.setLabel(this.mNilCategory.intValue());
            i2 = i3 + 1;
        }
        if (i2 > 0) {
            double confidence2 = list.get(i2 - 1).getConfidence();
            double confidence3 = i2 < list.size() - 1 ? list.get(i2 + 1).getConfidence() : 0.0d;
            Outcome outcome4 = list.get(i2);
            outcome4.setConfidence((confidence2 + confidence3) * outcome4.getConfidence());
            double d = 0.0d;
            for (int i4 = 0; i4 < list.size(); i4++) {
                d += list.get(i4).getConfidence();
            }
            for (int i5 = 0; i5 < list.size(); i5++) {
                Outcome outcome5 = list.get(i5);
                outcome5.setConfidence(outcome5.getConfidence() / d);
            }
        }
    }

    private void convertToProbabilities(Outcome[] outcomeArr) {
        double d = 0.0d;
        double d2 = this.mSoftmaxGamma;
        for (Outcome outcome : outcomeArr) {
            d += Math.exp(outcome.getConfidence() * d2);
        }
        if (d == 0.0d) {
            System.err.println("Invalid normalization factor: " + d);
            for (int i = 0; i < outcomeArr.length; i++) {
                System.err.println(outcomeArr[i] + " ==> " + Math.exp(outcomeArr[i].getConfidence() * d2));
            }
            throw new RuntimeException("Invalid normalization factor: " + d);
        }
        for (int i2 = 0; i2 < outcomeArr.length; i2++) {
            outcomeArr[i2].setConfidence(Math.exp(outcomeArr[i2].getConfidence() * d2) / d);
        }
    }

    private static void sort(Outcome[] outcomeArr) {
        for (int i = 0; i < outcomeArr.length; i++) {
            for (int i2 = i + 1; i2 < outcomeArr.length; i2++) {
                if (outcomeArr[i].getConfidence() < outcomeArr[i2].getConfidence()) {
                    Outcome outcome = outcomeArr[i];
                    outcomeArr[i] = outcomeArr[i2];
                    outcomeArr[i2] = outcome;
                }
            }
        }
    }

    @Override // mill.common.Classifier
    public void saveModel(String str) throws IOException {
        PrintStream printStream = new PrintStream(new FileOutputStream(str));
        printStream.print(this.mModels.length + " " + this.mIsDual + " ");
        if (this.mNilCategory != null) {
            printStream.print("true " + this.mNilCategory + " ");
        } else {
            printStream.print("false 0 ");
        }
        printStream.print(this.mFeatureCount + " ");
        printStream.println();
        for (int i = 0; i < this.mModels.length; i++) {
            this.mModels[i].save(printStream);
        }
        printStream.close();
    }

    @Override // mill.common.Classifier
    public void loadModel(String str, NodesFactory nodesFactory) throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
        String readLine = bufferedReader.readLine();
        ArrayList<String> arrayList = SimpleTokenize.tokenize(readLine);
        if (arrayList.size() != 5) {
            throw new RuntimeException("Invalid Perceptron model meta info: " + readLine);
        }
        int parseInt = Integer.parseInt(arrayList.get(0));
        this.mIsDual = Boolean.parseBoolean(arrayList.get(1));
        if (Boolean.parseBoolean(arrayList.get(2))) {
            this.mNilCategory = new Integer(arrayList.get(3));
        } else {
            this.mNilCategory = null;
        }
        this.mFeatureCount = Integer.parseInt(arrayList.get(4));
        this.mModels = new BinaryPerceptronClassifier[parseInt];
        for (int i = 0; i < parseInt; i++) {
            this.mModels[i] = BinaryPerceptronClassifier.load(bufferedReader, this.mIsDual, this.mFeatureCount);
        }
        bufferedReader.close();
    }

    public static void usage() {
        System.err.println("java mill.svm.OneVsOneSvmClassifier \\\n\t--train=<training samples> \\\n\t--test=<testing samples> \\\n\t--model=<model file> \\\n\t--kernel=<kernel type> (same as libsvm)\\\n\t--degree=<kernel degree> \\\n\t--epochs=<number of epochs> \\\n\t--tpos=<threshold for positive examples> \\\n\t--tneg=<threshold for negative examples> \\\n\t--type=<PAUM type> \\\n\t--nil=<number of the NIL category>");
        System.exit(1);
    }

    public static void main(String[] strArr) throws Exception {
        Kernel polyKernel;
        CommandLineParameters.read(strArr);
        PropertyConfigurator.configure(CommandLineParameters.getString("log4j"));
        String string = CommandLineParameters.getString("train");
        String string2 = CommandLineParameters.getString("test");
        if (string == null && string2 == null) {
            usage();
        }
        String string3 = CommandLineParameters.getString("model");
        if (string3 == null) {
            usage();
        }
        Integer integer = CommandLineParameters.getInteger(StringDictionary.NIL_VALUE);
        Integer integer2 = CommandLineParameters.getInteger("kernel");
        Integer integer3 = CommandLineParameters.getInteger("degree");
        Integer integer4 = CommandLineParameters.getInteger("epochs");
        Double d = CommandLineParameters.getDouble("tpos");
        Double d2 = CommandLineParameters.getDouble("tneg");
        Boolean bool = CommandLineParameters.getBoolean("dual");
        String string4 = CommandLineParameters.getString("type");
        int i = 0;
        if (string4 != null) {
            if (string4.equalsIgnoreCase("static-const")) {
                i = 0;
            } else if (string4.equalsIgnoreCase("static-distr")) {
                i = 1;
            } else if (string4.equalsIgnoreCase("dynamic-const")) {
                i = 2;
            } else {
                if (!string4.equalsIgnoreCase("dynamic-distr")) {
                    throw new RuntimeException("Invalid PAUM type: " + string4);
                }
                i = 3;
            }
        }
        NodesFactory nodesFactory = new NodesFactory();
        List<Sample> list = null;
        if (string != null) {
            list = Sample.readSamples(string, nodesFactory);
        }
        List<Sample> list2 = null;
        if (string2 != null) {
            list2 = Sample.readSamples(string2, nodesFactory);
        }
        OneVsRestPerceptronClassifier oneVsRestPerceptronClassifier = new OneVsRestPerceptronClassifier(bool, null, null, null);
        if (integer2 == null) {
            polyKernel = new LinearKernel();
        } else if (integer2.intValue() == 0) {
            polyKernel = new LinearKernel();
        } else {
            if (integer2.intValue() != 1) {
                throw new RuntimeException("Unsupported kernel type: " + integer2);
            }
            polyKernel = new PolyKernel(integer3);
        }
        if (list != null) {
            oneVsRestPerceptronClassifier.train(list, list2, new OneVsRestPerceptronClassifierTrainingParameters(integer, polyKernel, integer4, d, d2, i, string3, string3 + ".scores"));
            oneVsRestPerceptronClassifier.saveModel(string3);
            return;
        }
        if (list2 != null) {
            oneVsRestPerceptronClassifier.loadModel(string3, null);
            mLog.error("Started testing...");
            mLog.error(oneVsRestPerceptronClassifier.test(list2, integer, polyKernel));
            mLog.error("Done testing.");
        }
    }
}
