/*
 * Decompiled with CFR 0.152.
 */
package mill.me;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.List;
import java.util.ListIterator;
import mill.common.Classifier;
import mill.common.CommandLineParameters;
import mill.common.Kernel;
import mill.common.Node;
import mill.common.Nodes;
import mill.common.NodesFactory;
import mill.common.Outcome;
import mill.common.Sample;
import mill.common.Score;
import mill.common.TrainingParameters;
import mill.me.MaxEntClassifierTrainingParameters;
import mill.me.MaxEntEventStream;
import opennlp.maxent.EventStream;
import opennlp.maxent.GIS;
import opennlp.maxent.GISModel;
import opennlp.maxent.io.SuffixSensitiveGISModelReader;
import opennlp.maxent.io.SuffixSensitiveGISModelWriter;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;

public class MaxEntClassifier
extends Classifier {
    private GISModel mModel = null;
    static Logger mLog = Logger.getLogger((String)MaxEntClassifier.class.getName());

    @Override
    public void train(List<Sample> trainSamples, List<Sample> testSamples, TrainingParameters pars) {
        try {
            File tmp = File.createTempFile("maxent", null);
            PrintStream os = new PrintStream(new FileOutputStream(tmp));
            ListIterator<Sample> it = trainSamples.listIterator();
            while (it.hasNext()) {
                Sample sample = it.next();
                if (sample.getNodes().size() <= 0) continue;
                for (int i = 0; i < sample.getNodes().size(); ++i) {
                    Node node = sample.getNodes().get(i);
                    os.print(node.getIndex() + " ");
                }
                os.println(sample.getLabel());
            }
            os.close();
            MaxEntEventStream es = new MaxEntEventStream(tmp.getAbsolutePath());
            MaxEntClassifierTrainingParameters mePars = (MaxEntClassifierTrainingParameters)pars;
            this.mModel = GIS.trainModel((EventStream)es, (int)mePars.mIterations, (int)1, (boolean)true);
            tmp.delete();
            if (testSamples != null) {
                System.err.println("Testing...");
                Score score = this.test(testSamples, pars.mNilCategory, null);
                System.err.println(score);
            }
        }
        catch (IOException e) {
            e.printStackTrace();
            throw new RuntimeException("MaxEntClassifier.train() failed!");
        }
    }

    private static void sort(Outcome[] outcomes) {
        for (int i = 0; i < outcomes.length; ++i) {
            for (int j = i + 1; j < outcomes.length; ++j) {
                if (!(outcomes[i].getConfidence() < outcomes[j].getConfidence())) continue;
                Outcome tmp = outcomes[i];
                outcomes[i] = outcomes[j];
                outcomes[j] = tmp;
            }
        }
    }

    @Override
    public void predict(Nodes sample, List<Outcome> outcomes, Kernel customKernel) {
        int i;
        String[] features = new String[sample.size()];
        for (int k = 0; k < features.length; ++k) {
            features[k] = Integer.toString(sample.get(k).getIndex());
        }
        double[] confidences = this.mModel.eval(features);
        Outcome[] sorted = new Outcome[confidences.length];
        for (i = 0; i < confidences.length; ++i) {
            sorted[i] = new Outcome(Integer.parseInt(this.mModel.getOutcome(i)), confidences[i]);
        }
        MaxEntClassifier.sort(sorted);
        outcomes.clear();
        for (i = 0; i < sorted.length; ++i) {
            outcomes.add(sorted[i]);
        }
    }

    @Override
    public void saveModel(String fileName) throws IOException {
        SuffixSensitiveGISModelWriter mw = new SuffixSensitiveGISModelWriter(this.mModel, new File(fileName));
        mw.persist();
    }

    @Override
    public void loadModel(String fileName, NodesFactory factory) throws IOException {
        SuffixSensitiveGISModelReader mr = new SuffixSensitiveGISModelReader(new File(fileName));
        this.mModel = mr.getModel();
    }

    public static void usage() {
        System.err.println("java mill.me.MaxEntClassifier \\\n\t--train=<training samples> \\\n\t--test=<testing samples> \\\n\t--model=<model file> \\\n\t--iterations=<number of training iterations> \\\n\t--nil=<number of the NIL category>");
        System.exit(1);
    }

    public static void main(String[] args) throws Exception {
        String modelFile;
        CommandLineParameters.read(args);
        String log4j = CommandLineParameters.getString("log4j");
        PropertyConfigurator.configure((String)log4j);
        String trainFile = CommandLineParameters.getString("train");
        String testFile = CommandLineParameters.getString("test");
        if (trainFile == null && testFile == null) {
            MaxEntClassifier.usage();
        }
        if ((modelFile = CommandLineParameters.getString("model")) == null) {
            MaxEntClassifier.usage();
        }
        Integer iterations = CommandLineParameters.getInteger("iterations");
        Integer nil = CommandLineParameters.getInteger("nil");
        NodesFactory factory = new NodesFactory();
        List<Sample> trainSamples = null;
        if (trainFile != null) {
            trainSamples = Sample.readSamples(trainFile, factory);
        }
        List<Sample> testSamples = null;
        if (testFile != null) {
            testSamples = Sample.readSamples(testFile, factory);
        }
        MaxEntClassifier cls = new MaxEntClassifier();
        if (trainSamples != null) {
            MaxEntClassifierTrainingParameters pars = new MaxEntClassifierTrainingParameters(nil, iterations);
            cls.train(trainSamples, testSamples, pars);
            cls.saveModel(modelFile);
        } else if (testSamples != null) {
            cls.loadModel(modelFile, null);
            mLog.error((Object)"Started testing...");
            Score score = cls.test(testSamples, nil, null);
            mLog.error((Object)score);
            mLog.error((Object)"Done testing.");
        }
    }
}

