package mill.libsvm;

import java.io.IOException;
import java.util.List;
import java.util.ListIterator;
import mill.common.Classifier;
import mill.common.CommandLineParameters;
import mill.common.Node;
import mill.common.Nodes;
import mill.common.NodesFactory;
import mill.common.Outcome;
import mill.common.Sample;
import mill.common.StringDictionary;
import mill.common.TrainingParameters;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;

/* loaded from: input_file:mill/libsvm/svm_classifier.class */
public class svm_classifier extends Classifier {
    svm_model mModel;
    private Integer mNilLabel;
    private Double mMinProb;
    static Logger mLog = Logger.getLogger(svm_classifier.class.getName());

    public svm_classifier() {
        this.mModel = null;
    }

    public svm_classifier(Integer num) {
        this.mNilLabel = num;
        this.mMinProb = null;
    }

    public svm_classifier(Integer num, Double d) {
        this.mNilLabel = num;
        this.mMinProb = d;
    }

    @Override // mill.common.Classifier
    public void predict(Nodes nodes, List<Outcome> list, mill.common.Kernel kernel) {
        predictWithSvmPredictProbability(nodes, list, kernel);
    }

    public void predictWithSvmPredict(Nodes nodes, List<Outcome> list, mill.common.Kernel kernel) {
        list.add(new Outcome((int) svm.svm_predict(this.mModel, nodes.getAll())));
    }

    public void predictWithSvmPredictValues(Nodes nodes, List<Outcome> list, mill.common.Kernel kernel) {
        Node[] all = nodes.getAll();
        int svm_get_nr_class = svm.svm_get_nr_class(this.mModel);
        double[] dArr = new double[(svm_get_nr_class * (svm_get_nr_class - 1)) / 2];
        svm.svm_predict_values(this.mModel, all, dArr);
        int[] iArr = new int[svm_get_nr_class];
        double[] dArr2 = new double[svm_get_nr_class];
        for (int i = 0; i < svm_get_nr_class; i++) {
            iArr[i] = 0;
            dArr2[i] = 0.0d;
        }
        int i2 = 0;
        for (int i3 = 0; i3 < svm_get_nr_class; i3++) {
            for (int i4 = i3 + 1; i4 < svm_get_nr_class; i4++) {
                if (dArr[i2] > 0.0d) {
                    int i5 = i3;
                    iArr[i5] = iArr[i5] + 1;
                    int i6 = i3;
                    dArr2[i6] = dArr2[i6] + dArr[i2];
                } else {
                    int i7 = i4;
                    iArr[i7] = iArr[i7] + 1;
                    int i8 = i4;
                    dArr2[i8] = dArr2[i8] + dArr[i2];
                }
                i2++;
            }
        }
        Outcome[] outcomeArr = new Outcome[svm_get_nr_class];
        for (int i9 = 0; i9 < svm_get_nr_class; i9++) {
            outcomeArr[i9] = new Outcome(this.mModel.getLabels()[i9], iArr[i9]);
        }
        sort(outcomeArr);
        list.clear();
        for (Outcome outcome : outcomeArr) {
            list.add(outcome);
        }
    }

    public void predictWithSvmPredictProbability(Nodes nodes, List<Outcome> list, mill.common.Kernel kernel) {
        Node[] all = nodes.getAll();
        if (svm.svm_check_probability_model(this.mModel) == 0) {
            throw new RuntimeException("svm_check_probability_model returns 0!");
        }
        int svm_get_nr_class = svm.svm_get_nr_class(this.mModel);
        int[] iArr = new int[svm_get_nr_class];
        svm.svm_get_labels(this.mModel, iArr);
        double[] dArr = new double[svm_get_nr_class];
        int svm_predict_probability = (int) svm.svm_predict_probability(this.mModel, all, dArr);
        if (0 != 0) {
            System.out.print("best class " + svm_predict_probability + ":");
            for (int i = 0; i < svm_get_nr_class; i++) {
                System.out.print(" (" + iArr[i] + ", " + dArr[i] + ")");
            }
            System.out.println();
        }
        Outcome[] outcomeArr = new Outcome[svm_get_nr_class];
        for (int i2 = 0; i2 < svm_get_nr_class; i2++) {
            outcomeArr[i2] = new Outcome(iArr[i2], dArr[i2]);
        }
        sort(outcomeArr);
        list.clear();
        for (Outcome outcome : outcomeArr) {
            list.add(outcome);
        }
        if (this.mMinProb == null || this.mNilLabel == null || list.size() <= 0 || list.get(0).getLabel() != this.mNilLabel.intValue() || list.size() <= 1) {
            return;
        }
        int i3 = 0;
        for (int i4 = 0; i4 < list.size() - 1 && list.get(i4 + 1).getConfidence() > this.mMinProb.doubleValue(); i4++) {
            Outcome outcome2 = list.get(i4);
            Outcome outcome3 = list.get(i4 + 1);
            double confidence = outcome2.getConfidence();
            outcome2.setConfidence(outcome3.getConfidence());
            outcome3.setConfidence(confidence);
            outcome2.setLabel(outcome3.getLabel());
            outcome3.setLabel(this.mNilLabel.intValue());
            i3 = i4 + 1;
        }
        if (i3 > 0) {
            double confidence2 = list.get(i3 - 1).getConfidence();
            double confidence3 = i3 < list.size() - 1 ? list.get(i3 + 1).getConfidence() : 0.0d;
            Outcome outcome4 = list.get(i3);
            outcome4.setConfidence((confidence2 + confidence3) * outcome4.getConfidence());
            double d = 0.0d;
            for (int i5 = 0; i5 < list.size(); i5++) {
                d += list.get(i5).getConfidence();
            }
            for (int i6 = 0; i6 < list.size(); i6++) {
                Outcome outcome5 = list.get(i6);
                outcome5.setConfidence(outcome5.getConfidence() / d);
            }
        }
    }

    private static void sort(Outcome[] outcomeArr) {
        for (int i = 0; i < outcomeArr.length; i++) {
            for (int i2 = i + 1; i2 < outcomeArr.length; i2++) {
                if (outcomeArr[i].getConfidence() < outcomeArr[i2].getConfidence()) {
                    Outcome outcome = outcomeArr[i];
                    outcomeArr[i] = outcomeArr[i2];
                    outcomeArr[i2] = outcome;
                }
            }
        }
    }

    /* JADX WARN: Type inference failed for: r1v8, types: [mill.common.Node[], mill.common.Node[][]] */
    @Override // mill.common.Classifier
    public void train(List<Sample> list, List<Sample> list2, TrainingParameters trainingParameters) {
        svm_problem svm_problemVar = new svm_problem();
        svm_problemVar.l = list.size();
        svm_problemVar.y = new double[svm_problemVar.l];
        svm_problemVar.x = new Node[svm_problemVar.l];
        int i = 0;
        ListIterator<Sample> listIterator = list.listIterator();
        while (listIterator.hasNext()) {
            Sample next = listIterator.next();
            svm_problemVar.y[i] = next.getLabel();
            svm_problemVar.x[i] = next.getNodes().getAll();
            i++;
        }
        this.mModel = svm.svm_train(svm_problemVar, (svm_parameter) trainingParameters);
        if (list2 != null) {
            mLog.error(test(list2, trainingParameters.mNilCategory, null));
        }
    }

    @Override // mill.common.Classifier
    public void saveModel(String str) throws IOException {
        svm.svm_save_model(str, this.mModel);
    }

    @Override // mill.common.Classifier
    public void loadModel(String str, NodesFactory nodesFactory) throws IOException {
        this.mModel = svm.svm_load_model(str);
    }

    public static void usage() {
        System.err.println("java mill.svm.OneVsOneSvmClassifier \\\n\t--train=<training samples> \\\n\t--test=<testing samples> \\\n\t--model=<model file> \\\n\t--gamma=<kernel gamma> \\\n\t--cost=<penalty cost> \\\n\t--kernel=<kernel type> (same as libsvm)\\\n\t--degree=<kernel degree> \\\n\t--nil=<number of the NIL category>");
        System.exit(1);
    }

    public static void main(String[] strArr) throws Exception {
        CommandLineParameters.read(strArr);
        String string = CommandLineParameters.getString("log4j");
        if (string == null) {
            usage();
        }
        PropertyConfigurator.configure(string);
        String string2 = CommandLineParameters.getString("train");
        String string3 = CommandLineParameters.getString("test");
        if (string2 == null && string3 == null) {
            usage();
        }
        String string4 = CommandLineParameters.getString("model");
        if (string4 == null) {
            usage();
        }
        Integer integer = CommandLineParameters.getInteger(StringDictionary.NIL_VALUE);
        Integer integer2 = CommandLineParameters.getInteger("kernel");
        Integer integer3 = CommandLineParameters.getInteger("degree");
        Double d = CommandLineParameters.getDouble("gamma");
        Double d2 = CommandLineParameters.getDouble("cost");
        NodesFactory nodesFactory = new NodesFactory();
        List<Sample> list = null;
        if (string2 != null) {
            list = Sample.readSamples(string2, nodesFactory);
        }
        List<Sample> list2 = null;
        if (string3 != null) {
            list2 = Sample.readSamples(string3, nodesFactory);
        }
        svm_classifier svm_classifierVar = new svm_classifier();
        if (list != null) {
            svm_classifierVar.train(list, list2, new svm_parameter(integer, integer2, integer3, d, d2));
            svm_classifierVar.saveModel(string4);
            return;
        }
        if (list2 != null) {
            svm_classifierVar.loadModel(string4, null);
            mLog.error("Started testing...");
            mLog.error(svm_classifierVar.test(list2, integer, null));
            mLog.error("Done testing.");
        }
    }
}
