/*
 * Decompiled with CFR 0.152.
 */
package mill.common;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.ListIterator;
import mill.common.Kernel;
import mill.common.Nodes;
import mill.common.NodesFactory;
import mill.common.Outcome;
import mill.common.Sample;
import mill.common.Score;
import mill.common.TrainingParameters;
import org.apache.log4j.Logger;

public abstract class Classifier {
    static Logger mLog = Logger.getLogger((String)Classifier.class.getName());

    public abstract void train(List<Sample> var1, List<Sample> var2, TrainingParameters var3);

    public void predict(Nodes sample, List<Outcome> outcomes) {
        this.predict(sample, outcomes, null);
    }

    public abstract void predict(Nodes var1, List<Outcome> var2, Kernel var3);

    public abstract void saveModel(String var1) throws IOException;

    public abstract void loadModel(String var1, NodesFactory var2) throws IOException;

    public Score test(List<Sample> testSamples, Integer nilCategory, Kernel kernel) {
        Score score = new Score();
        boolean displayOutcome = false;
        boolean logErrors = false;
        int maxLabel = 0;
        ListIterator<Sample> it = testSamples.listIterator();
        while (it.hasNext()) {
            Sample sample = it.next();
            if (sample.getLabel() <= maxLabel) continue;
            maxLabel = sample.getLabel();
        }
        int labelCount = maxLabel + 1;
        Score[] labelScores = new Score[labelCount];
        for (int i = 0; i < labelCount; ++i) {
            if (nilCategory != null && i == nilCategory) continue;
            labelScores[i] = new Score();
        }
        ListIterator<Sample> it2 = testSamples.listIterator();
        while (it2.hasNext()) {
            Sample sample = it2.next();
            ArrayList<Outcome> outcomes = new ArrayList<Outcome>();
            this.predict(sample.getNodes(), outcomes, kernel);
            int goldLabel = sample.getLabel();
            this.updateScore(score, goldLabel, outcomes, nilCategory);
            if (nilCategory == null || goldLabel != nilCategory) {
                this.updateScore(labelScores[goldLabel], goldLabel, outcomes, nilCategory);
            }
            if (outcomes.size() > 0) {
                int predictedLabel = ((Outcome)outcomes.get(0)).getLabel();
                if (displayOutcome) {
                    mLog.warn((Object)(goldLabel + " " + predictedLabel + " " + ((Outcome)outcomes.get(0)).getConfidence()));
                }
                if (!logErrors || predictedLabel == goldLabel) continue;
                mLog.warn((Object)(goldLabel + "\t" + predictedLabel + "\t" + sample.getNodes()));
                continue;
            }
            mLog.error((Object)"Warning: found 0 outcomes!");
        }
        StringBuffer buf = new StringBuffer();
        for (int i = 0; i < labelCount; ++i) {
            if (nilCategory != null && i == nilCategory) continue;
            labelScores[i].compute(nilCategory != null);
            buf.append(i + "(" + labelScores[i] + ") ");
        }
        mLog.warn((Object)("Label scores: " + buf));
        score.compute(nilCategory != null);
        return score;
    }

    private void updateScore(Score score, int goldLabel, List<Outcome> outcomes, Integer nilCategory) {
        ++score.mTotal;
        if (outcomes.size() == 0) {
            return;
        }
        int predictedLabel = outcomes.get(0).getLabel();
        if (goldLabel == predictedLabel) {
            ++score.mCorrect;
        }
        if (nilCategory != null) {
            if (goldLabel == predictedLabel && goldLabel != nilCategory) {
                ++score.mCorrectNonNil;
            }
            if (predictedLabel != nilCategory) {
                ++score.mPredictedNonNil;
            }
            if (goldLabel != nilCategory) {
                ++score.mTotalNonNil;
            }
        }
    }
}

