/*
 * Decompiled with CFR 0.152.
 */
package mill.perk;

import java.io.BufferedReader;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.ListIterator;
import java.util.Vector;
import mill.common.Kernel;
import mill.common.Nodes;
import mill.common.Sample;
import mill.common.SimpleTokenize;
import mill.perk.BinaryModel;
import mill.perk.Cache;
import mill.perk.DualBinaryModel;
import mill.perk.PrimalExpandedBinaryModel;
import org.apache.log4j.Logger;

class BinaryPerceptronClassifier {
    private double POSITIVE_MARGIN_RATIO;
    private double NEGATIVE_MARGIN_RATIO;
    private Integer mPositiveLabel;
    private Integer mNegativeLabel;
    private boolean mIsDual;
    int mFeatureCount;
    private BinaryModel mModel;
    double mPositiveMargin;
    double mNegativeMargin;
    Cache[] mCache;
    static Logger mLog = Logger.getLogger((String)BinaryPerceptronClassifier.class.getName());
    private int mPosMistakes;
    private int mPosVisited;
    private int mNegMistakes;
    private int mNegVisited;

    public BinaryPerceptronClassifier(Integer posLabel, Integer negLabel, boolean isDual, int featCount) {
        assert (posLabel != null);
        this.mPositiveLabel = posLabel;
        this.mNegativeLabel = negLabel;
        this.mIsDual = isDual;
        this.mFeatureCount = featCount;
        this.mModel = this.mIsDual ? new DualBinaryModel() : new PrimalExpandedBinaryModel(featCount);
        this.mPositiveMargin = 0.0;
        this.mNegativeMargin = 0.0;
        this.mCache = null;
        this.mPosMistakes = 0;
        this.mPosVisited = 0;
        this.mNegMistakes = 0;
        this.mNegVisited = 0;
        this.POSITIVE_MARGIN_RATIO = 1.0;
        this.NEGATIVE_MARGIN_RATIO = 0.01;
    }

    public int getPositiveLabel() {
        return this.mPositiveLabel;
    }

    public int getNegativeLabel() {
        return this.mNegativeLabel;
    }

    public BinaryModel getModel() {
        return this.mModel;
    }

    void makeCache(int sampleCount) {
        this.mCache = new Cache[sampleCount];
        for (int i = 0; i < this.mCache.length; ++i) {
            this.mCache[i] = new Cache();
        }
    }

    void setMargins(double pos, double neg) {
        this.mPositiveMargin = pos;
        this.mNegativeMargin = neg;
    }

    void setMarginStartValues(double pos, double neg) {
        this.POSITIVE_MARGIN_RATIO = pos;
        this.NEGATIVE_MARGIN_RATIO = neg;
    }

    void calculateMargins(List<Sample> samples, Kernel kernel, double posThresh, double negThresh) {
        mLog.debug((Object)("Calculating PAUM margins for label: " + this.mPositiveLabel + "..."));
        this.mPositiveMargin = 0.0;
        this.mNegativeMargin = 0.0;
        int sampleCount = 0;
        Vector<Double> negMargins = new Vector<Double>();
        Vector<Double> posMargins = new Vector<Double>();
        int weight = 0;
        int sampleIndex = 0;
        ListIterator<Sample> it = samples.listIterator();
        while (it.hasNext()) {
            Sample sample = it.next();
            if (sample.getLabel() == this.mPositiveLabel.intValue() || this.mNegativeLabel == null || sample.getLabel() == this.mNegativeLabel.intValue()) {
                double prediction = this.mModel.multiply(sample.getNodes(), kernel, null);
                int goldSign = -1;
                if (sample.getLabel() == this.mPositiveLabel.intValue()) {
                    goldSign = 1;
                }
                double margin = (double)goldSign * prediction;
                if (goldSign < 0) {
                    if (negThresh < 0.0 && margin < 0.0) {
                        negMargins.add(margin);
                    } else if (negThresh > 0.0 && margin > 0.0) {
                        negMargins.add(margin);
                    }
                } else if (posThresh < 0.0 && margin < 0.0) {
                    posMargins.add(margin);
                } else if (posThresh > 0.0 && margin > 0.0) {
                    posMargins.add(margin);
                }
                weight = this.update(kernel, sample, prediction, weight, false);
                ++sampleCount;
            }
            ++sampleIndex;
        }
        this.mModel.addWeight(weight);
        this.mModel.reset();
        if (posMargins.size() > 0) {
            double[] posSortedMargins = this.sortMargins(posMargins, posThresh > 0.0);
            int posPosition = (int)Math.abs((double)posSortedMargins.length * posThresh);
            this.mPositiveMargin = posPosition > 0 ? posSortedMargins[posPosition - 1] : 0.0;
            this.POSITIVE_MARGIN_RATIO = this.mPositiveMargin;
            mLog.debug((Object)("  + tau for label " + this.mPositiveLabel + ": [" + this.printableDouble(posSortedMargins[0]) + " ==> " + this.printableDouble(this.mPositiveMargin) + " <== " + this.printableDouble(posSortedMargins[posSortedMargins.length - 1]) + "]"));
        }
        if (negMargins.size() > 0) {
            double[] negSortedMargins = this.sortMargins(negMargins, negThresh > 0.0);
            int negPosition = (int)Math.abs((double)negSortedMargins.length * negThresh);
            this.mNegativeMargin = negPosition > 0 ? negSortedMargins[negPosition - 1] : 0.0;
            this.NEGATIVE_MARGIN_RATIO = this.mNegativeMargin;
            mLog.debug((Object)("  - tau for label " + this.mPositiveLabel + ": [" + this.printableDouble(negSortedMargins[0]) + " ==> " + this.printableDouble(this.mNegativeMargin) + " <== " + this.printableDouble(negSortedMargins[negSortedMargins.length - 1]) + "]"));
        }
    }

    void calculateDistributions(List<Sample> samples, Kernel kernel, double posThresh, double negThresh) throws IOException {
        mLog.warn((Object)("Calculating margin distributions for label: " + this.mPositiveLabel + "..."));
        this.mPositiveMargin = 0.0;
        this.mNegativeMargin = 0.0;
        int sampleCount = 0;
        Vector<Double> margins = new Vector<Double>();
        int weight = 0;
        int sampleIndex = 0;
        ListIterator<Sample> it = samples.listIterator();
        while (it.hasNext()) {
            Sample sample = it.next();
            if (sample.getLabel() == this.mPositiveLabel.intValue() || this.mNegativeLabel == null || sample.getLabel() == this.mNegativeLabel.intValue()) {
                double prediction = this.mModel.multiply(sample.getNodes(), kernel, null);
                int goldSign = -1;
                if (sample.getLabel() == this.mPositiveLabel.intValue()) {
                    goldSign = 1;
                }
                double margin = (double)goldSign * prediction;
                if (goldSign > 0) {
                    margins.add(margin);
                }
                weight = this.update(kernel, sample, prediction, weight, false);
                if (++sampleCount % 1000 == 0) {
                    mLog.debug((Object)("Processed " + sampleCount + " samples."));
                    mLog.debug((Object)("Currently storing " + this.mModel.getSupportVectorCount() + " SVs."));
                }
            }
            ++sampleIndex;
        }
        this.mModel.addWeight(weight);
        this.mModel.reset();
        if (margins.size() > 0) {
            double[] sortedMargins = this.sortMargins(margins, true);
            Vector<DoubleInteger> points = new Vector<DoubleInteger>();
            double bucket = 0.05;
            points.add(new DoubleInteger(sortedMargins[0], 1));
            for (int i = 1; i < sortedMargins.length; ++i) {
                if (sortedMargins[i] > ((DoubleInteger)points.lastElement()).mDouble + bucket) {
                    points.add(new DoubleInteger(sortedMargins[i], 1));
                    continue;
                }
                ++((DoubleInteger)points.lastElement()).mInt;
            }
            PrintStream os = new PrintStream(new FileOutputStream(this.mPositiveLabel.toString() + ".margins"));
            for (int i = 0; i < points.size(); ++i) {
                os.println(((DoubleInteger)points.get((int)i)).mDouble + " " + ((DoubleInteger)points.get((int)i)).mInt);
            }
            os.close();
        }
    }

    double[] sortMargins(Vector<Double> margins, boolean ascending) {
        int i;
        double[] sorted = new double[margins.size()];
        for (i = 0; i < margins.size(); ++i) {
            sorted[i] = margins.get(i);
        }
        Arrays.sort(sorted);
        if (!ascending) {
            for (i = 0; i < sorted.length / 2; ++i) {
                double tmp = sorted[i];
                sorted[i] = sorted[sorted.length - i - 1];
                sorted[sorted.length - i - 1] = tmp;
            }
        }
        return sorted;
    }

    double printableDouble(double d) {
        int t100 = (int)(d * 100.0);
        double s = (double)t100 / 100.0;
        return s;
    }

    private int update(Kernel kernel, Sample sample, double prediction, int weight, boolean dynamicMarginThresholds) {
        int goldSign = -1;
        if (sample.getLabel() == this.mPositiveLabel.intValue()) {
            goldSign = 1;
        }
        int predictedSign = -1;
        if (prediction > 0.0) {
            predictedSign = 1;
        } else if (prediction == 0.0) {
            predictedSign = 0;
        }
        double margin = (double)goldSign * prediction;
        boolean doUpdate = false;
        if (goldSign > 0 && margin <= this.mPositiveMargin) {
            doUpdate = true;
        } else if (goldSign < 0 && margin <= this.mNegativeMargin) {
            doUpdate = true;
        }
        if (dynamicMarginThresholds) {
            if (goldSign > 0) {
                double newMargin;
                ++this.mPosVisited;
                if (margin <= 0.0) {
                    ++this.mPosMistakes;
                }
                if ((newMargin = this.POSITIVE_MARGIN_RATIO * ((double)this.mPosMistakes / (double)this.mPosVisited)) != this.mPositiveMargin) {
                    this.mPositiveMargin = newMargin;
                }
            } else if (goldSign < 0) {
                double newMargin;
                ++this.mNegVisited;
                if (margin <= 0.0) {
                    ++this.mNegMistakes;
                }
                if ((newMargin = this.NEGATIVE_MARGIN_RATIO * ((double)this.mNegMistakes / (double)this.mNegVisited)) != this.mNegativeMargin) {
                    this.mNegativeMargin = newMargin;
                }
            }
        }
        if (!doUpdate) {
            ++weight;
        } else {
            this.mModel.addWeight(weight);
            if (this.mModel.getSupportVectorCount() == 500) {
                mLog.debug((Object)"Reached 500 SVs. Compiling AVG vector...");
                this.mModel.compile(this.mFeatureCount);
            }
            this.mModel.addVector(sample.getNodes(), goldSign);
            weight = 1;
        }
        return weight;
    }

    void trainEpoch(List<Sample> samples, Kernel kernel, int epoch, boolean dynamicMarginThresholds) {
        if (dynamicMarginThresholds) {
            this.mPositiveMargin = this.POSITIVE_MARGIN_RATIO;
            this.mNegativeMargin = this.NEGATIVE_MARGIN_RATIO;
        }
        String type = this.mIsDual ? "dual" : "primal";
        mLog.debug((Object)("Epoch " + (epoch + 1) + ": training binary model (" + type + ") for label " + this.mPositiveLabel + (this.mNegativeLabel != null ? " vs " + this.mNegativeLabel : "") + "..."));
        if (dynamicMarginThresholds) {
            if (this.mPositiveMargin != 0.0) {
                mLog.debug((Object)("Start of epoch " + (epoch + 1) + " for label " + this.mPositiveLabel + (this.mNegativeLabel != null ? " vs " + this.mNegativeLabel : "") + ": + tau is " + this.printableDouble(this.mPositiveMargin)));
            }
            if (this.mNegativeMargin != 0.0) {
                mLog.debug((Object)("Start of epoch " + (epoch + 1) + " for label " + this.mPositiveLabel + (this.mNegativeLabel != null ? " vs " + this.mNegativeLabel : "") + ": - tau is " + this.printableDouble(this.mNegativeMargin)));
            }
        }
        int sampleCount = 0;
        int weight = 0;
        int sampleIndex = 0;
        ListIterator<Sample> it = samples.listIterator();
        while (it.hasNext()) {
            Sample sample = it.next();
            if (sample.getLabel() == this.mPositiveLabel.intValue() || this.mNegativeLabel == null || sample.getLabel() == this.mNegativeLabel.intValue()) {
                Cache cache = null;
                if (this.mCache != null && this.mIsDual) {
                    cache = this.mCache[sampleIndex];
                }
                double prediction = this.mModel.multiply(sample.getNodes(), kernel, cache);
                if (this.mCache != null && this.mIsDual) {
                    this.mCache[sampleIndex].mOffset = this.mModel.getSupportVectorCount();
                    this.mCache[sampleIndex].mValue = prediction;
                }
                weight = this.update(kernel, sample, prediction, weight, dynamicMarginThresholds);
                ++sampleCount;
            }
            ++sampleIndex;
        }
        this.mModel.addWeight(weight);
        mLog.debug((Object)("Epoch " + (epoch + 1) + ": have " + this.mModel.getSupportVectorCount() + " SVs for label " + this.mPositiveLabel + (this.mNegativeLabel != null ? " vs " + this.mNegativeLabel : "") + "."));
        if (dynamicMarginThresholds) {
            if (this.mPositiveMargin != 0.0) {
                mLog.debug((Object)("End of epoch " + (epoch + 1) + " for label " + this.mPositiveLabel + (this.mNegativeLabel != null ? " vs " + this.mNegativeLabel : "") + ": + tau is " + this.printableDouble(this.mPositiveMargin)));
            }
            if (this.mNegativeMargin != 0.0) {
                mLog.debug((Object)("End of epoch " + (epoch + 1) + " for label " + this.mPositiveLabel + (this.mNegativeLabel != null ? " vs " + this.mNegativeLabel : "") + ": - tau is " + this.printableDouble(this.mNegativeMargin)));
            }
        }
        mLog.debug((Object)"Compiling average vector...");
        this.mModel.compile(this.mFeatureCount);
    }

    double predict(Nodes vector, Kernel kernel) {
        return this.mModel.predictAverage(vector, kernel);
    }

    void save(PrintStream os) throws IOException {
        os.println(this.mPositiveLabel + (this.mNegativeLabel != null ? " " + this.mNegativeLabel : ""));
        this.mModel.save(os);
    }

    static BinaryPerceptronClassifier load(BufferedReader is, boolean isDual, int featCount) throws IOException {
        String line = is.readLine();
        ArrayList<String> tokens = SimpleTokenize.tokenize(line);
        Integer posLabel = new Integer(tokens.get(0));
        Integer negLabel = null;
        if (tokens.size() > 1) {
            negLabel = new Integer(tokens.get(1));
        }
        BinaryPerceptronClassifier bc = new BinaryPerceptronClassifier(posLabel, negLabel, isDual, featCount);
        bc.mModel.load(is, featCount);
        return bc;
    }

    class DoubleInteger {
        public double mDouble;
        public int mInt;

        DoubleInteger(double d, int i) {
            this.mDouble = d;
            this.mInt = i;
        }
    }
}

