/*
 * Decompiled with CFR 0.152.
 */
package mill.libsvm;

import java.io.IOException;
import java.util.List;
import java.util.ListIterator;
import mill.common.Classifier;
import mill.common.CommandLineParameters;
import mill.common.Kernel;
import mill.common.Node;
import mill.common.Nodes;
import mill.common.NodesFactory;
import mill.common.Outcome;
import mill.common.Sample;
import mill.common.Score;
import mill.common.TrainingParameters;
import mill.libsvm.svm;
import mill.libsvm.svm_model;
import mill.libsvm.svm_parameter;
import mill.libsvm.svm_problem;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;

public class svm_classifier
extends Classifier {
    svm_model mModel;
    private Integer mNilLabel;
    private Double mMinProb;
    static Logger mLog = Logger.getLogger((String)svm_classifier.class.getName());

    public svm_classifier() {
        this.mModel = null;
    }

    public svm_classifier(Integer nilLabel) {
        this.mNilLabel = nilLabel;
        this.mMinProb = null;
    }

    public svm_classifier(Integer nilLabel, Double minProb) {
        this.mNilLabel = nilLabel;
        this.mMinProb = minProb;
    }

    @Override
    public void predict(Nodes sample, List<Outcome> outcomes, Kernel customKernel) {
        this.predictWithSvmPredictProbability(sample, outcomes, customKernel);
    }

    public void predictWithSvmPredict(Nodes sample, List<Outcome> outcomes, Kernel customKernel) {
        Node[] x = sample.getAll();
        double v = svm.svm_predict(this.mModel, x);
        outcomes.add(new Outcome((int)v));
    }

    public void predictWithSvmPredictValues(Nodes sample, List<Outcome> outcomes, Kernel customKernel) {
        int i;
        Node[] x = sample.getAll();
        int nrClasses = svm.svm_get_nr_class(this.mModel);
        double[] values = new double[nrClasses * (nrClasses - 1) / 2];
        svm.svm_predict_values(this.mModel, x, values);
        int[] vote = new int[nrClasses];
        double[] confs = new double[nrClasses];
        for (int i2 = 0; i2 < nrClasses; ++i2) {
            vote[i2] = 0;
            confs[i2] = 0.0;
        }
        int pos = 0;
        for (int i3 = 0; i3 < nrClasses; ++i3) {
            for (int j = i3 + 1; j < nrClasses; ++j) {
                if (values[pos] > 0.0) {
                    int n = i3;
                    vote[n] = vote[n] + 1;
                    int n2 = i3;
                    confs[n2] = confs[n2] + values[pos];
                } else {
                    int n = j;
                    vote[n] = vote[n] + 1;
                    int n3 = j;
                    confs[n3] = confs[n3] + values[pos];
                }
                ++pos;
            }
        }
        Outcome[] sorted = new Outcome[nrClasses];
        for (i = 0; i < nrClasses; ++i) {
            sorted[i] = new Outcome(this.mModel.getLabels()[i], vote[i]);
        }
        svm_classifier.sort(sorted);
        outcomes.clear();
        for (i = 0; i < sorted.length; ++i) {
            outcomes.add(sorted[i]);
        }
    }

    public void predictWithSvmPredictProbability(Nodes sample, List<Outcome> outcomes, Kernel customKernel) {
        int i;
        Node[] x = sample.getAll();
        boolean debug = false;
        if (svm.svm_check_probability_model(this.mModel) == 0) {
            throw new RuntimeException("svm_check_probability_model returns 0!");
        }
        int nrClass = svm.svm_get_nr_class(this.mModel);
        int[] labels = new int[nrClass];
        svm.svm_get_labels(this.mModel, labels);
        double[] probs = new double[nrClass];
        int bestClass = (int)svm.svm_predict_probability(this.mModel, x, probs);
        if (debug) {
            System.out.print("best class " + bestClass + ":");
            for (int i2 = 0; i2 < nrClass; ++i2) {
                System.out.print(" (" + labels[i2] + ", " + probs[i2] + ")");
            }
            System.out.println();
        }
        Outcome[] sorted = new Outcome[nrClass];
        for (i = 0; i < nrClass; ++i) {
            sorted[i] = new Outcome(labels[i], probs[i]);
        }
        svm_classifier.sort(sorted);
        outcomes.clear();
        for (i = 0; i < sorted.length; ++i) {
            outcomes.add(sorted[i]);
        }
        if (this.mMinProb != null && this.mNilLabel != null && outcomes.size() > 0 && outcomes.get(0).getLabel() == this.mNilLabel.intValue() && outcomes.size() > 1) {
            int nilPosition = 0;
            for (int i3 = 0; i3 < outcomes.size() - 1 && outcomes.get(i3 + 1).getConfidence() > this.mMinProb; ++i3) {
                Outcome crt = outcomes.get(i3);
                Outcome next = outcomes.get(i3 + 1);
                double tmp = crt.getConfidence();
                crt.setConfidence(next.getConfidence());
                next.setConfidence(tmp);
                crt.setLabel(next.getLabel());
                next.setLabel(this.mNilLabel);
                nilPosition = i3 + 1;
            }
            if (nilPosition > 0) {
                int i4;
                double prev = outcomes.get(nilPosition - 1).getConfidence();
                double next = 0.0;
                if (nilPosition < outcomes.size() - 1) {
                    next = outcomes.get(nilPosition + 1).getConfidence();
                }
                Outcome nilOutcome = outcomes.get(nilPosition);
                nilOutcome.setConfidence((prev + next) * nilOutcome.getConfidence());
                double sum = 0.0;
                for (i4 = 0; i4 < outcomes.size(); ++i4) {
                    sum += outcomes.get(i4).getConfidence();
                }
                for (i4 = 0; i4 < outcomes.size(); ++i4) {
                    Outcome crt = outcomes.get(i4);
                    crt.setConfidence(crt.getConfidence() / sum);
                }
            }
        }
    }

    private static void sort(Outcome[] outcomes) {
        for (int i = 0; i < outcomes.length; ++i) {
            for (int j = i + 1; j < outcomes.length; ++j) {
                if (!(outcomes[i].getConfidence() < outcomes[j].getConfidence())) continue;
                Outcome tmp = outcomes[i];
                outcomes[i] = outcomes[j];
                outcomes[j] = tmp;
            }
        }
    }

    @Override
    public void train(List<Sample> trainSamples, List<Sample> testSamples, TrainingParameters pars) {
        svm_problem problem = new svm_problem();
        problem.l = trainSamples.size();
        problem.y = new double[problem.l];
        problem.x = new Node[problem.l][];
        int index = 0;
        ListIterator<Sample> it = trainSamples.listIterator();
        while (it.hasNext()) {
            Sample sample = it.next();
            problem.y[index] = sample.getLabel();
            problem.x[index] = sample.getNodes().getAll();
            ++index;
        }
        svm_parameter parameters = (svm_parameter)pars;
        this.mModel = svm.svm_train(problem, parameters);
        if (testSamples != null) {
            Score score = this.test(testSamples, pars.mNilCategory, null);
            mLog.error((Object)score);
        }
    }

    @Override
    public void saveModel(String fileName) throws IOException {
        svm.svm_save_model(fileName, this.mModel);
    }

    @Override
    public void loadModel(String fileName, NodesFactory factory) throws IOException {
        this.mModel = svm.svm_load_model(fileName);
    }

    public static void usage() {
        System.err.println("java mill.svm.OneVsOneSvmClassifier \\\n\t--train=<training samples> \\\n\t--test=<testing samples> \\\n\t--model=<model file> \\\n\t--gamma=<kernel gamma> \\\n\t--cost=<penalty cost> \\\n\t--kernel=<kernel type> (same as libsvm)\\\n\t--degree=<kernel degree> \\\n\t--nil=<number of the NIL category>");
        System.exit(1);
    }

    public static void main(String[] args) throws Exception {
        String modelFile;
        CommandLineParameters.read(args);
        String log4j = CommandLineParameters.getString("log4j");
        if (log4j == null) {
            svm_classifier.usage();
        }
        PropertyConfigurator.configure((String)log4j);
        String trainFile = CommandLineParameters.getString("train");
        String testFile = CommandLineParameters.getString("test");
        if (trainFile == null && testFile == null) {
            svm_classifier.usage();
        }
        if ((modelFile = CommandLineParameters.getString("model")) == null) {
            svm_classifier.usage();
        }
        Integer nil = CommandLineParameters.getInteger("nil");
        Integer kernel = CommandLineParameters.getInteger("kernel");
        Integer degree = CommandLineParameters.getInteger("degree");
        Double gamma = CommandLineParameters.getDouble("gamma");
        Double cost = CommandLineParameters.getDouble("cost");
        NodesFactory factory = new NodesFactory();
        List<Sample> trainSamples = null;
        if (trainFile != null) {
            trainSamples = Sample.readSamples(trainFile, factory);
        }
        List<Sample> testSamples = null;
        if (testFile != null) {
            testSamples = Sample.readSamples(testFile, factory);
        }
        svm_classifier cls = new svm_classifier();
        if (trainSamples != null) {
            svm_parameter pars = new svm_parameter(nil, kernel, degree, gamma, cost);
            cls.train(trainSamples, testSamples, pars);
            cls.saveModel(modelFile);
        } else if (testSamples != null) {
            cls.loadModel(modelFile, null);
            mLog.error((Object)"Started testing...");
            Score score = cls.test(testSamples, nil, null);
            mLog.error((Object)score);
            mLog.error((Object)"Done testing.");
        }
    }
}

