package mill.perk;

import java.io.BufferedReader;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.ListIterator;
import java.util.Vector;
import mill.common.Kernel;
import mill.common.Nodes;
import mill.common.Sample;
import mill.common.SimpleTokenize;
import mill.me.MaxEntClassifierTrainingParameters;
import org.apache.log4j.Logger;

/* loaded from: input_file:mill/perk/BinaryPerceptronClassifier.class */
class BinaryPerceptronClassifier {
    private double POSITIVE_MARGIN_RATIO;
    private double NEGATIVE_MARGIN_RATIO;
    private Integer mPositiveLabel;
    private Integer mNegativeLabel;
    private boolean mIsDual;
    int mFeatureCount;
    private BinaryModel mModel;
    double mPositiveMargin;
    double mNegativeMargin;
    Cache[] mCache;
    static Logger mLog;
    private int mPosMistakes;
    private int mPosVisited;
    private int mNegMistakes;
    private int mNegVisited;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:mill/perk/BinaryPerceptronClassifier$DoubleInteger.class */
    class DoubleInteger {
        public double mDouble;
        public int mInt;

        DoubleInteger(double d, int i) {
            this.mDouble = d;
            this.mInt = i;
        }
    }

    public BinaryPerceptronClassifier(Integer num, Integer num2, boolean z, int i) {
        if (!$assertionsDisabled && num == null) {
            throw new AssertionError();
        }
        this.mPositiveLabel = num;
        this.mNegativeLabel = num2;
        this.mIsDual = z;
        this.mFeatureCount = i;
        if (this.mIsDual) {
            this.mModel = new DualBinaryModel();
        } else {
            this.mModel = new PrimalExpandedBinaryModel(i);
        }
        this.mPositiveMargin = 0.0d;
        this.mNegativeMargin = 0.0d;
        this.mCache = null;
        this.mPosMistakes = 0;
        this.mPosVisited = 0;
        this.mNegMistakes = 0;
        this.mNegVisited = 0;
        this.POSITIVE_MARGIN_RATIO = 1.0d;
        this.NEGATIVE_MARGIN_RATIO = 0.01d;
    }

    public int getPositiveLabel() {
        return this.mPositiveLabel.intValue();
    }

    public int getNegativeLabel() {
        return this.mNegativeLabel.intValue();
    }

    public BinaryModel getModel() {
        return this.mModel;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void makeCache(int i) {
        this.mCache = new Cache[i];
        for (int i2 = 0; i2 < this.mCache.length; i2++) {
            this.mCache[i2] = new Cache();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setMargins(double d, double d2) {
        this.mPositiveMargin = d;
        this.mNegativeMargin = d2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setMarginStartValues(double d, double d2) {
        this.POSITIVE_MARGIN_RATIO = d;
        this.NEGATIVE_MARGIN_RATIO = d2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void calculateMargins(List<Sample> list, Kernel kernel, double d, double d2) {
        mLog.debug("Calculating PAUM margins for label: " + this.mPositiveLabel + "...");
        this.mPositiveMargin = 0.0d;
        this.mNegativeMargin = 0.0d;
        int i = 0;
        Vector<Double> vector = new Vector<>();
        Vector<Double> vector2 = new Vector<>();
        int i2 = 0;
        int i3 = 0;
        ListIterator<Sample> listIterator = list.listIterator();
        while (listIterator.hasNext()) {
            Sample next = listIterator.next();
            if (next.getLabel() == this.mPositiveLabel.intValue() || this.mNegativeLabel == null || next.getLabel() == this.mNegativeLabel.intValue()) {
                double multiply = this.mModel.multiply(next.getNodes(), kernel, null);
                int i4 = -1;
                if (next.getLabel() == this.mPositiveLabel.intValue()) {
                    i4 = 1;
                }
                double d3 = i4 * multiply;
                if (i4 < 0) {
                    if (d2 < 0.0d && d3 < 0.0d) {
                        vector.add(Double.valueOf(d3));
                    } else if (d2 > 0.0d && d3 > 0.0d) {
                        vector.add(Double.valueOf(d3));
                    }
                } else if (d < 0.0d && d3 < 0.0d) {
                    vector2.add(Double.valueOf(d3));
                } else if (d > 0.0d && d3 > 0.0d) {
                    vector2.add(Double.valueOf(d3));
                }
                i2 = update(kernel, next, multiply, i2, false);
                i++;
            }
            i3++;
        }
        this.mModel.addWeight(i2);
        this.mModel.reset();
        if (vector2.size() > 0) {
            double[] sortMargins = sortMargins(vector2, d > 0.0d);
            int abs = (int) Math.abs(sortMargins.length * d);
            if (abs > 0) {
                this.mPositiveMargin = sortMargins[abs - 1];
            } else {
                this.mPositiveMargin = 0.0d;
            }
            this.POSITIVE_MARGIN_RATIO = this.mPositiveMargin;
            mLog.debug("  + tau for label " + this.mPositiveLabel + ": [" + printableDouble(sortMargins[0]) + " ==> " + printableDouble(this.mPositiveMargin) + " <== " + printableDouble(sortMargins[sortMargins.length - 1]) + "]");
        }
        if (vector.size() > 0) {
            double[] sortMargins2 = sortMargins(vector, d2 > 0.0d);
            int abs2 = (int) Math.abs(sortMargins2.length * d2);
            if (abs2 > 0) {
                this.mNegativeMargin = sortMargins2[abs2 - 1];
            } else {
                this.mNegativeMargin = 0.0d;
            }
            this.NEGATIVE_MARGIN_RATIO = this.mNegativeMargin;
            mLog.debug("  - tau for label " + this.mPositiveLabel + ": [" + printableDouble(sortMargins2[0]) + " ==> " + printableDouble(this.mNegativeMargin) + " <== " + printableDouble(sortMargins2[sortMargins2.length - 1]) + "]");
        }
    }

    void calculateDistributions(List<Sample> list, Kernel kernel, double d, double d2) throws IOException {
        mLog.warn("Calculating margin distributions for label: " + this.mPositiveLabel + "...");
        this.mPositiveMargin = 0.0d;
        this.mNegativeMargin = 0.0d;
        int i = 0;
        Vector<Double> vector = new Vector<>();
        int i2 = 0;
        int i3 = 0;
        ListIterator<Sample> listIterator = list.listIterator();
        while (listIterator.hasNext()) {
            Sample next = listIterator.next();
            if (next.getLabel() == this.mPositiveLabel.intValue() || this.mNegativeLabel == null || next.getLabel() == this.mNegativeLabel.intValue()) {
                double multiply = this.mModel.multiply(next.getNodes(), kernel, null);
                int i4 = next.getLabel() == this.mPositiveLabel.intValue() ? 1 : -1;
                double d3 = i4 * multiply;
                if (i4 > 0) {
                    vector.add(Double.valueOf(d3));
                }
                i2 = update(kernel, next, multiply, i2, false);
                i++;
                if (i % MaxEntClassifierTrainingParameters.DEFAULT_ITERATIONS == 0) {
                    mLog.debug("Processed " + i + " samples.");
                    mLog.debug("Currently storing " + this.mModel.getSupportVectorCount() + " SVs.");
                }
            }
            i3++;
        }
        this.mModel.addWeight(i2);
        this.mModel.reset();
        if (vector.size() > 0) {
            double[] sortMargins = sortMargins(vector, true);
            Vector vector2 = new Vector();
            vector2.add(new DoubleInteger(sortMargins[0], 1));
            for (int i5 = 1; i5 < sortMargins.length; i5++) {
                if (sortMargins[i5] > ((DoubleInteger) vector2.lastElement()).mDouble + 0.05d) {
                    vector2.add(new DoubleInteger(sortMargins[i5], 1));
                } else {
                    ((DoubleInteger) vector2.lastElement()).mInt++;
                }
            }
            PrintStream printStream = new PrintStream(new FileOutputStream(this.mPositiveLabel.toString() + ".margins"));
            for (int i6 = 0; i6 < vector2.size(); i6++) {
                printStream.println(((DoubleInteger) vector2.get(i6)).mDouble + " " + ((DoubleInteger) vector2.get(i6)).mInt);
            }
            printStream.close();
        }
    }

    double[] sortMargins(Vector<Double> vector, boolean z) {
        double[] dArr = new double[vector.size()];
        for (int i = 0; i < vector.size(); i++) {
            dArr[i] = vector.get(i).doubleValue();
        }
        Arrays.sort(dArr);
        if (!z) {
            for (int i2 = 0; i2 < dArr.length / 2; i2++) {
                double d = dArr[i2];
                dArr[i2] = dArr[(dArr.length - i2) - 1];
                dArr[(dArr.length - i2) - 1] = d;
            }
        }
        return dArr;
    }

    double printableDouble(double d) {
        return ((int) (d * 100.0d)) / 100.0d;
    }

    private int update(Kernel kernel, Sample sample, double d, int i, boolean z) {
        int i2;
        int i3 = -1;
        if (sample.getLabel() == this.mPositiveLabel.intValue()) {
            i3 = 1;
        }
        if (d <= 0.0d && d == 0.0d) {
        }
        double d2 = i3 * d;
        boolean z2 = false;
        if (i3 > 0 && d2 <= this.mPositiveMargin) {
            z2 = true;
        } else if (i3 < 0 && d2 <= this.mNegativeMargin) {
            z2 = true;
        }
        if (z) {
            if (i3 > 0) {
                this.mPosVisited++;
                if (d2 <= 0.0d) {
                    this.mPosMistakes++;
                }
                double d3 = this.POSITIVE_MARGIN_RATIO * (this.mPosMistakes / this.mPosVisited);
                if (d3 != this.mPositiveMargin) {
                    this.mPositiveMargin = d3;
                }
            } else if (i3 < 0) {
                this.mNegVisited++;
                if (d2 <= 0.0d) {
                    this.mNegMistakes++;
                }
                double d4 = this.NEGATIVE_MARGIN_RATIO * (this.mNegMistakes / this.mNegVisited);
                if (d4 != this.mNegativeMargin) {
                    this.mNegativeMargin = d4;
                }
            }
        }
        if (z2) {
            this.mModel.addWeight(i);
            if (this.mModel.getSupportVectorCount() == 500) {
                mLog.debug("Reached 500 SVs. Compiling AVG vector...");
                this.mModel.compile(this.mFeatureCount);
            }
            this.mModel.addVector(sample.getNodes(), i3);
            i2 = 1;
        } else {
            i2 = i + 1;
        }
        return i2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void trainEpoch(List<Sample> list, Kernel kernel, int i, boolean z) {
        if (z) {
            this.mPositiveMargin = this.POSITIVE_MARGIN_RATIO;
            this.mNegativeMargin = this.NEGATIVE_MARGIN_RATIO;
        }
        mLog.debug("Epoch " + (i + 1) + ": training binary model (" + (this.mIsDual ? "dual" : "primal") + ") for label " + this.mPositiveLabel + (this.mNegativeLabel != null ? " vs " + this.mNegativeLabel : "") + "...");
        if (z) {
            if (this.mPositiveMargin != 0.0d) {
                mLog.debug("Start of epoch " + (i + 1) + " for label " + this.mPositiveLabel + (this.mNegativeLabel != null ? " vs " + this.mNegativeLabel : "") + ": + tau is " + printableDouble(this.mPositiveMargin));
            }
            if (this.mNegativeMargin != 0.0d) {
                mLog.debug("Start of epoch " + (i + 1) + " for label " + this.mPositiveLabel + (this.mNegativeLabel != null ? " vs " + this.mNegativeLabel : "") + ": - tau is " + printableDouble(this.mNegativeMargin));
            }
        }
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        ListIterator<Sample> listIterator = list.listIterator();
        while (listIterator.hasNext()) {
            Sample next = listIterator.next();
            if (next.getLabel() == this.mPositiveLabel.intValue() || this.mNegativeLabel == null || next.getLabel() == this.mNegativeLabel.intValue()) {
                Cache cache = null;
                if (this.mCache != null && this.mIsDual) {
                    cache = this.mCache[i4];
                }
                double multiply = this.mModel.multiply(next.getNodes(), kernel, cache);
                if (this.mCache != null && this.mIsDual) {
                    this.mCache[i4].mOffset = this.mModel.getSupportVectorCount();
                    this.mCache[i4].mValue = multiply;
                }
                i3 = update(kernel, next, multiply, i3, z);
                i2++;
            }
            i4++;
        }
        this.mModel.addWeight(i3);
        mLog.debug("Epoch " + (i + 1) + ": have " + this.mModel.getSupportVectorCount() + " SVs for label " + this.mPositiveLabel + (this.mNegativeLabel != null ? " vs " + this.mNegativeLabel : "") + ".");
        if (z) {
            if (this.mPositiveMargin != 0.0d) {
                mLog.debug("End of epoch " + (i + 1) + " for label " + this.mPositiveLabel + (this.mNegativeLabel != null ? " vs " + this.mNegativeLabel : "") + ": + tau is " + printableDouble(this.mPositiveMargin));
            }
            if (this.mNegativeMargin != 0.0d) {
                mLog.debug("End of epoch " + (i + 1) + " for label " + this.mPositiveLabel + (this.mNegativeLabel != null ? " vs " + this.mNegativeLabel : "") + ": - tau is " + printableDouble(this.mNegativeMargin));
            }
        }
        mLog.debug("Compiling average vector...");
        this.mModel.compile(this.mFeatureCount);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double predict(Nodes nodes, Kernel kernel) {
        return this.mModel.predictAverage(nodes, kernel);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void save(PrintStream printStream) throws IOException {
        printStream.println(this.mPositiveLabel + (this.mNegativeLabel != null ? " " + this.mNegativeLabel : ""));
        this.mModel.save(printStream);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static BinaryPerceptronClassifier load(BufferedReader bufferedReader, boolean z, int i) throws IOException {
        ArrayList<String> arrayList = SimpleTokenize.tokenize(bufferedReader.readLine());
        Integer num = new Integer(arrayList.get(0));
        Integer num2 = null;
        if (arrayList.size() > 1) {
            num2 = new Integer(arrayList.get(1));
        }
        BinaryPerceptronClassifier binaryPerceptronClassifier = new BinaryPerceptronClassifier(num, num2, z, i);
        binaryPerceptronClassifier.mModel.load(bufferedReader, i);
        return binaryPerceptronClassifier;
    }

    static {
        $assertionsDisabled = !BinaryPerceptronClassifier.class.desiredAssertionStatus();
        mLog = Logger.getLogger(BinaryPerceptronClassifier.class.getName());
    }
}
