package mill.me;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.List;
import java.util.ListIterator;
import mill.common.Classifier;
import mill.common.CommandLineParameters;
import mill.common.Kernel;
import mill.common.Nodes;
import mill.common.NodesFactory;
import mill.common.Outcome;
import mill.common.Sample;
import mill.common.StringDictionary;
import mill.common.TrainingParameters;
import opennlp.maxent.GIS;
import opennlp.maxent.GISModel;
import opennlp.maxent.io.SuffixSensitiveGISModelReader;
import opennlp.maxent.io.SuffixSensitiveGISModelWriter;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;

/* loaded from: input_file:mill/me/MaxEntClassifier.class */
public class MaxEntClassifier extends Classifier {
    private GISModel mModel = null;
    static Logger mLog = Logger.getLogger(MaxEntClassifier.class.getName());

    @Override // mill.common.Classifier
    public void train(List<Sample> list, List<Sample> list2, TrainingParameters trainingParameters) {
        try {
            File createTempFile = File.createTempFile("maxent", null);
            PrintStream printStream = new PrintStream(new FileOutputStream(createTempFile));
            ListIterator<Sample> listIterator = list.listIterator();
            while (listIterator.hasNext()) {
                Sample next = listIterator.next();
                if (next.getNodes().size() > 0) {
                    for (int i = 0; i < next.getNodes().size(); i++) {
                        printStream.print(next.getNodes().get(i).getIndex() + " ");
                    }
                    printStream.println(next.getLabel());
                }
            }
            printStream.close();
            this.mModel = GIS.trainModel(new MaxEntEventStream(createTempFile.getAbsolutePath()), ((MaxEntClassifierTrainingParameters) trainingParameters).mIterations, 1, true);
            createTempFile.delete();
            if (list2 != null) {
                System.err.println("Testing...");
                System.err.println(test(list2, trainingParameters.mNilCategory, null));
            }
        } catch (IOException e) {
            e.printStackTrace();
            throw new RuntimeException("MaxEntClassifier.train() failed!");
        }
    }

    private static void sort(Outcome[] outcomeArr) {
        for (int i = 0; i < outcomeArr.length; i++) {
            for (int i2 = i + 1; i2 < outcomeArr.length; i2++) {
                if (outcomeArr[i].getConfidence() < outcomeArr[i2].getConfidence()) {
                    Outcome outcome = outcomeArr[i];
                    outcomeArr[i] = outcomeArr[i2];
                    outcomeArr[i2] = outcome;
                }
            }
        }
    }

    @Override // mill.common.Classifier
    public void predict(Nodes nodes, List<Outcome> list, Kernel kernel) {
        String[] strArr = new String[nodes.size()];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = Integer.toString(nodes.get(i).getIndex());
        }
        double[] eval = this.mModel.eval(strArr);
        Outcome[] outcomeArr = new Outcome[eval.length];
        for (int i2 = 0; i2 < eval.length; i2++) {
            outcomeArr[i2] = new Outcome(Integer.parseInt(this.mModel.getOutcome(i2)), eval[i2]);
        }
        sort(outcomeArr);
        list.clear();
        for (Outcome outcome : outcomeArr) {
            list.add(outcome);
        }
    }

    @Override // mill.common.Classifier
    public void saveModel(String str) throws IOException {
        new SuffixSensitiveGISModelWriter(this.mModel, new File(str)).persist();
    }

    @Override // mill.common.Classifier
    public void loadModel(String str, NodesFactory nodesFactory) throws IOException {
        this.mModel = new SuffixSensitiveGISModelReader(new File(str)).getModel();
    }

    public static void usage() {
        System.err.println("java mill.me.MaxEntClassifier \\\n\t--train=<training samples> \\\n\t--test=<testing samples> \\\n\t--model=<model file> \\\n\t--iterations=<number of training iterations> \\\n\t--nil=<number of the NIL category>");
        System.exit(1);
    }

    public static void main(String[] strArr) throws Exception {
        CommandLineParameters.read(strArr);
        PropertyConfigurator.configure(CommandLineParameters.getString("log4j"));
        String string = CommandLineParameters.getString("train");
        String string2 = CommandLineParameters.getString("test");
        if (string == null && string2 == null) {
            usage();
        }
        String string3 = CommandLineParameters.getString("model");
        if (string3 == null) {
            usage();
        }
        Integer integer = CommandLineParameters.getInteger("iterations");
        Integer integer2 = CommandLineParameters.getInteger(StringDictionary.NIL_VALUE);
        NodesFactory nodesFactory = new NodesFactory();
        List<Sample> list = null;
        if (string != null) {
            list = Sample.readSamples(string, nodesFactory);
        }
        List<Sample> list2 = null;
        if (string2 != null) {
            list2 = Sample.readSamples(string2, nodesFactory);
        }
        MaxEntClassifier maxEntClassifier = new MaxEntClassifier();
        if (list != null) {
            maxEntClassifier.train(list, list2, new MaxEntClassifierTrainingParameters(integer2, integer));
            maxEntClassifier.saveModel(string3);
            return;
        }
        if (list2 != null) {
            maxEntClassifier.loadModel(string3, null);
            mLog.error("Started testing...");
            mLog.error(maxEntClassifier.test(list2, integer2, null));
            mLog.error("Done testing.");
        }
    }
}
