/*
 * Decompiled with CFR 0.152.
 */
package mill.perk;

import java.io.BufferedReader;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import mill.common.Classifier;
import mill.common.CommandLineParameters;
import mill.common.DotKernel;
import mill.common.Kernel;
import mill.common.LinearKernel;
import mill.common.Nodes;
import mill.common.NodesFactory;
import mill.common.Outcome;
import mill.common.PolyKernel;
import mill.common.Sample;
import mill.common.Score;
import mill.common.SimpleTokenize;
import mill.common.TrainingParameters;
import mill.perk.BinaryPerceptronClassifier;
import mill.perk.OneVsRestPerceptronClassifierTrainingParameters;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;

public class OneVsRestPerceptronClassifier
extends Classifier {
    private boolean mDoNormalize;
    private boolean mIsDual;
    private BinaryPerceptronClassifier[] mModels;
    private Integer mNilCategory;
    private int mFeatureCount;
    private Double mMinProbThresh;
    private double mSoftmaxGamma;
    public static final double DEFAULT_SOFTMAX_GAMMA = 1.0E-5;
    static Logger mLog = Logger.getLogger((String)OneVsRestPerceptronClassifier.class.getName());

    public OneVsRestPerceptronClassifier() {
        this.mIsDual = false;
        this.mMinProbThresh = null;
        this.mDoNormalize = true;
        this.mSoftmaxGamma = 1.0E-5;
    }

    public OneVsRestPerceptronClassifier(Boolean dual, Double minThresh, Boolean doNormalize, Double gamma) {
        this.mIsDual = dual != null ? dual : false;
        this.mMinProbThresh = minThresh;
        if (doNormalize != null) {
            this.mDoNormalize = doNormalize;
        } else {
            doNormalize = true;
        }
        this.mSoftmaxGamma = gamma != null ? gamma : 1.0E-5;
    }

    /*
     * Unable to fully structure code
     */
    @Override
    public void train(List<Sample> trainSamples, List<Sample> testSamples, TrainingParameters pars) {
        parameters = (OneVsRestPerceptronClassifierTrainingParameters)pars;
        this.mNilCategory = parameters.mNilCategory;
        maxFeatIndex = -1;
        it = trainSamples.listIterator();
        while (it.hasNext()) {
            s = it.next();
            if (this.mDoNormalize) {
                s.getNodes().normalize();
            }
            if ((idx = s.getNodes().getMaxIndex()) <= maxFeatIndex) continue;
            maxFeatIndex = idx;
        }
        this.mFeatureCount = maxFeatIndex + 1;
        OneVsRestPerceptronClassifier.mLog.error((Object)("Feature count: " + this.mFeatureCount));
        if (this.mDoNormalize && testSamples != null) {
            it = testSamples.listIterator();
            while (it.hasNext()) {
                s = it.next();
                s.getNodes().normalize();
            }
        }
        categoriesToIndeces = new HashMap<Integer, Integer>();
        it = trainSamples.listIterator();
        while (it.hasNext()) {
            sample = it.next();
            if (categoriesToIndeces.containsKey(sample.getLabel())) continue;
            categoriesToIndeces.put(sample.getLabel(), categoriesToIndeces.size());
        }
        OneVsRestPerceptronClassifier.mLog.warn((Object)("Loaded " + trainSamples.size() + " training examples with " + categoriesToIndeces.size() + " labels."));
        keys = categoriesToIndeces.keySet();
        modelCount = 0;
        modelCount = keys.size();
        this.mModels = new BinaryPerceptronClassifier[modelCount];
        current = 0;
        for (Integer key : keys) {
            this.mModels[current] = new BinaryPerceptronClassifier((int)key, null, this.mIsDual, this.mFeatureCount);
            this.mModels[current].makeCache(trainSamples.size());
            ++current;
        }
        for (model = 0; model < this.mModels.length; ++model) {
            if (this.mNilCategory != null && this.mModels[model].getPositiveLabel() == this.mNilCategory.intValue()) {
                this.mModels[model].setMargins(0.0, 0.0);
                continue;
            }
            if (parameters.mType == 0) {
                this.mModels[model].setMargins(parameters.mPosCoverage, parameters.mNegCoverage);
                continue;
            }
            if (parameters.mType == 2) {
                this.mModels[model].setMarginStartValues(parameters.mPosCoverage, parameters.mNegCoverage);
                continue;
            }
            if (parameters.mType != 1 && parameters.mType != 3) continue;
            this.mModels[model].calculateMargins(trainSamples, parameters.mKernel, parameters.mPosCoverage, parameters.mNegCoverage);
        }
        scores = new ArrayList<Score>();
        if (parameters.mType == 2) ** GOTO lbl-1000
        if (parameters.mType == 3) lbl-1000:
        // 2 sources

        {
            v0 = true;
        } else {
            v0 = false;
        }
        isDynamic = v0;
        for (epoch = 0; epoch < parameters.mNumberOfEpochs; ++epoch) {
            OneVsRestPerceptronClassifier.mLog.error((Object)("Starting epoch " + (epoch + 1) + "..."));
            svCount = 0;
            for (model = 0; model < this.mModels.length; ++model) {
                modelIsDynamic = isDynamic;
                if (this.mNilCategory != null && this.mModels[model].getPositiveLabel() == this.mNilCategory.intValue()) {
                    modelIsDynamic = false;
                }
                this.mModels[model].trainEpoch(trainSamples, parameters.mKernel, epoch, modelIsDynamic);
                svCount += this.mModels[model].getModel().getSupportVectorCount();
            }
            OneVsRestPerceptronClassifier.mLog.error((Object)("Ended epoch " + (epoch + 1) + ". Total number of SVs: " + svCount));
            if (testSamples != null) {
                scr = this.test(testSamples, this.mNilCategory, parameters.mKernel);
                OneVsRestPerceptronClassifier.mLog.error((Object)scr);
                OneVsRestPerceptronClassifier.mLog.error((Object)("\tCorrect: " + scr.mCorrect));
                OneVsRestPerceptronClassifier.mLog.error((Object)("\tTotal: " + scr.mTotal));
                OneVsRestPerceptronClassifier.mLog.error((Object)("\tCorrectNonNil: " + scr.mCorrectNonNil));
                OneVsRestPerceptronClassifier.mLog.error((Object)("\tPredictedNonNil: " + scr.mPredictedNonNil));
                OneVsRestPerceptronClassifier.mLog.error((Object)("\tTotalNonNil: " + scr.mTotalNonNil));
                scores.add(scr);
            }
            if (parameters.mModelPrefix == null) continue;
            epochModel = parameters.mModelPrefix + "." + (epoch + 1);
            try {
                this.saveModel(epochModel);
                OneVsRestPerceptronClassifier.mLog.error((Object)("Epoch model saved in file: " + epochModel));
                continue;
            }
            catch (IOException e) {
                OneVsRestPerceptronClassifier.mLog.error((Object)("Failed to save model in file: " + epochModel));
            }
        }
        if (!scores.isEmpty() && parameters.mScoreFile != null) {
            try {
                os = new PrintStream(new FileOutputStream(parameters.mScoreFile));
                for (i = 0; i < scores.size(); ++i) {
                    os.println(i + 1 + " " + scores.get(i));
                }
                os.close();
            }
            catch (IOException e) {
                OneVsRestPerceptronClassifier.mLog.error((Object)("Failed to save scores in file: " + parameters.mScoreFile));
            }
        }
    }

    @Override
    public void predict(Nodes sample, List<Outcome> outcomes, Kernel customKernel) {
        int offset;
        int i;
        if (customKernel == null) {
            throw new RuntimeException("Attempted prediction without kernel!");
        }
        if (this.mDoNormalize) {
            sample.normalize();
        }
        int modelCount = this.mModels.length;
        Outcome[] preds = new Outcome[modelCount];
        for (i = offset = 0; i < preds.length; ++i) {
            double conf = this.mModels[i - offset].predict(sample, customKernel);
            preds[i] = new Outcome(this.mModels[i - offset].getPositiveLabel(), conf);
        }
        OneVsRestPerceptronClassifier.sort(preds);
        this.convertToProbabilities(preds);
        outcomes.clear();
        for (i = 0; i < preds.length; ++i) {
            outcomes.add(preds[i]);
        }
        if (this.mMinProbThresh != null && this.mNilCategory != null && outcomes.size() > 0 && outcomes.get(0).getLabel() == this.mNilCategory.intValue() && outcomes.size() > 1) {
            int nilPosition = 0;
            for (int i2 = 0; i2 < outcomes.size() - 1 && outcomes.get(i2 + 1).getConfidence() > this.mMinProbThresh; ++i2) {
                Outcome crt = outcomes.get(i2);
                Outcome next = outcomes.get(i2 + 1);
                double tmp = crt.getConfidence();
                crt.setConfidence(next.getConfidence());
                next.setConfidence(tmp);
                crt.setLabel(next.getLabel());
                next.setLabel(this.mNilCategory);
                nilPosition = i2 + 1;
            }
            if (nilPosition > 0) {
                int i3;
                double prev = outcomes.get(nilPosition - 1).getConfidence();
                double next = 0.0;
                if (nilPosition < outcomes.size() - 1) {
                    next = outcomes.get(nilPosition + 1).getConfidence();
                }
                Outcome nilOutcome = outcomes.get(nilPosition);
                nilOutcome.setConfidence((prev + next) * nilOutcome.getConfidence());
                double sum = 0.0;
                for (i3 = 0; i3 < outcomes.size(); ++i3) {
                    sum += outcomes.get(i3).getConfidence();
                }
                for (i3 = 0; i3 < outcomes.size(); ++i3) {
                    Outcome crt = outcomes.get(i3);
                    crt.setConfidence(crt.getConfidence() / sum);
                }
            }
        }
    }

    private void convertToProbabilities(Outcome[] outcomes) {
        int i;
        double sum = 0.0;
        double gamma = this.mSoftmaxGamma;
        for (i = 0; i < outcomes.length; ++i) {
            sum += Math.exp(outcomes[i].getConfidence() * gamma);
        }
        if (sum == 0.0) {
            System.err.println("Invalid normalization factor: " + sum);
            for (i = 0; i < outcomes.length; ++i) {
                System.err.println(outcomes[i] + " ==> " + Math.exp(outcomes[i].getConfidence() * gamma));
            }
            throw new RuntimeException("Invalid normalization factor: " + sum);
        }
        for (i = 0; i < outcomes.length; ++i) {
            double den = Math.exp(outcomes[i].getConfidence() * gamma);
            outcomes[i].setConfidence(den / sum);
        }
    }

    private static void sort(Outcome[] outcomes) {
        for (int i = 0; i < outcomes.length; ++i) {
            for (int j = i + 1; j < outcomes.length; ++j) {
                if (!(outcomes[i].getConfidence() < outcomes[j].getConfidence())) continue;
                Outcome tmp = outcomes[i];
                outcomes[i] = outcomes[j];
                outcomes[j] = tmp;
            }
        }
    }

    @Override
    public void saveModel(String fileName) throws IOException {
        PrintStream os = new PrintStream(new FileOutputStream(fileName));
        os.print(this.mModels.length + " " + this.mIsDual + " ");
        if (this.mNilCategory != null) {
            os.print("true " + this.mNilCategory + " ");
        } else {
            os.print("false 0 ");
        }
        os.print(this.mFeatureCount + " ");
        os.println();
        for (int i = 0; i < this.mModels.length; ++i) {
            this.mModels[i].save(os);
        }
        os.close();
    }

    @Override
    public void loadModel(String fileName, NodesFactory factory) throws IOException {
        BufferedReader is = new BufferedReader(new FileReader(fileName));
        String line = is.readLine();
        ArrayList<String> tokens = SimpleTokenize.tokenize(line);
        if (tokens.size() != 5) {
            throw new RuntimeException("Invalid Perceptron model meta info: " + line);
        }
        int modelCount = Integer.parseInt(tokens.get(0));
        this.mIsDual = Boolean.parseBoolean(tokens.get(1));
        boolean hasNil = Boolean.parseBoolean(tokens.get(2));
        this.mNilCategory = hasNil ? new Integer(tokens.get(3)) : null;
        this.mFeatureCount = Integer.parseInt(tokens.get(4));
        this.mModels = new BinaryPerceptronClassifier[modelCount];
        for (int i = 0; i < modelCount; ++i) {
            this.mModels[i] = BinaryPerceptronClassifier.load(is, this.mIsDual, this.mFeatureCount);
        }
        is.close();
    }

    public static void usage() {
        System.err.println("java mill.svm.OneVsOneSvmClassifier \\\n\t--train=<training samples> \\\n\t--test=<testing samples> \\\n\t--model=<model file> \\\n\t--kernel=<kernel type> (same as libsvm)\\\n\t--degree=<kernel degree> \\\n\t--epochs=<number of epochs> \\\n\t--tpos=<threshold for positive examples> \\\n\t--tneg=<threshold for negative examples> \\\n\t--type=<PAUM type> \\\n\t--nil=<number of the NIL category>");
        System.exit(1);
    }

    public static void main(String[] args) throws Exception {
        String modelFile;
        CommandLineParameters.read(args);
        String log4j = CommandLineParameters.getString("log4j");
        PropertyConfigurator.configure((String)log4j);
        String trainFile = CommandLineParameters.getString("train");
        String testFile = CommandLineParameters.getString("test");
        if (trainFile == null && testFile == null) {
            OneVsRestPerceptronClassifier.usage();
        }
        if ((modelFile = CommandLineParameters.getString("model")) == null) {
            OneVsRestPerceptronClassifier.usage();
        }
        Integer nil = CommandLineParameters.getInteger("nil");
        Integer kernelType = CommandLineParameters.getInteger("kernel");
        Integer degree = CommandLineParameters.getInteger("degree");
        Integer epochs = CommandLineParameters.getInteger("epochs");
        Double tpos = CommandLineParameters.getDouble("tpos");
        Double tneg = CommandLineParameters.getDouble("tneg");
        Boolean dual = CommandLineParameters.getBoolean("dual");
        String paramType = CommandLineParameters.getString("type");
        int paumType = 0;
        if (paramType != null) {
            if (paramType.equalsIgnoreCase("static-const")) {
                paumType = 0;
            } else if (paramType.equalsIgnoreCase("static-distr")) {
                paumType = 1;
            } else if (paramType.equalsIgnoreCase("dynamic-const")) {
                paumType = 2;
            } else if (paramType.equalsIgnoreCase("dynamic-distr")) {
                paumType = 3;
            } else {
                throw new RuntimeException("Invalid PAUM type: " + paramType);
            }
        }
        NodesFactory factory = new NodesFactory();
        List<Sample> trainSamples = null;
        if (trainFile != null) {
            trainSamples = Sample.readSamples(trainFile, factory);
        }
        List<Sample> testSamples = null;
        if (testFile != null) {
            testSamples = Sample.readSamples(testFile, factory);
        }
        OneVsRestPerceptronClassifier cls = new OneVsRestPerceptronClassifier(dual, null, null, null);
        DotKernel kernel = null;
        if (kernelType == null) {
            kernel = new LinearKernel();
        } else if (kernelType == 0) {
            kernel = new LinearKernel();
        } else if (kernelType == 1) {
            kernel = new PolyKernel(degree);
        } else {
            throw new RuntimeException("Unsupported kernel type: " + kernelType);
        }
        if (trainSamples != null) {
            OneVsRestPerceptronClassifierTrainingParameters pars = new OneVsRestPerceptronClassifierTrainingParameters(nil, kernel, epochs, tpos, tneg, paumType, modelFile, modelFile + ".scores");
            cls.train(trainSamples, testSamples, pars);
            cls.saveModel(modelFile);
        } else if (testSamples != null) {
            cls.loadModel(modelFile, null);
            mLog.error((Object)"Started testing...");
            Score score = cls.test(testSamples, nil, kernel);
            mLog.error((Object)score);
            mLog.error((Object)"Done testing.");
        }
    }
}

