/*
 * Decompiled with CFR 0.152.
 */
package opennlp.maxent;

import gnu.trove.TDoubleFunction;
import gnu.trove.TIntDoubleProcedure;
import opennlp.maxent.DataIndexer;
import opennlp.maxent.EventStream;
import opennlp.maxent.GISModel;
import opennlp.maxent.OnePassDataIndexer;
import opennlp.maxent.TIntParamHashMap;

class GISTrainer {
    private boolean _simpleSmoothing = false;
    private boolean _useSlackParameter = false;
    private double sigma = 2.0;
    private double _smoothingObservation = 0.1;
    private boolean printMessages = false;
    private int numTokens;
    private int numPreds;
    private int numOutcomes;
    private int TID;
    private int PID;
    private int OID;
    private int[][] contexts;
    private int[] outcomes;
    private int[] outcomeList;
    private int[] numTimesEventsSeen;
    private String[] outcomeLabels;
    private String[] predLabels;
    private TIntParamHashMap[] observedExpects;
    private TIntParamHashMap[] params;
    private TIntParamHashMap[] modelExpects;
    private int constant;
    private double constantInverse;
    private double correctionParam;
    private double cfObservedExpect;
    private double CFMOD;
    private final double NEAR_ZERO = 0.01;
    private final double LLThreshold = 1.0E-4;
    double[] modelDistribution;
    int[] numfeats;
    double iprob;
    private TDoubleFunction backToZeros = new TDoubleFunction(){

        public double execute(double d) {
            return 0.0;
        }
    };
    private TIntDoubleProcedure updateModelExpect = new TIntDoubleProcedure(){

        public boolean execute(int n, double d) {
            GISTrainer.this.modelExpects[GISTrainer.this.PID].put(n, d + GISTrainer.this.modelDistribution[n] * (double)GISTrainer.this.numTimesEventsSeen[GISTrainer.this.TID]);
            return true;
        }
    };
    private TIntDoubleProcedure updateParams = new TIntDoubleProcedure(){

        public boolean execute(int n, double d) {
            GISTrainer.this.params[GISTrainer.this.PID].put(n, d + (Math.log(GISTrainer.this.observedExpects[GISTrainer.this.PID].get(n)) - Math.log(GISTrainer.this.modelExpects[GISTrainer.this.PID].get(n))));
            return true;
        }
    };
    private TIntDoubleProcedure updateParamsWithSmoothing = new TIntDoubleProcedure(){

        public boolean execute(int n, double d) {
            double d2 = 0.0;
            double d3 = 0.0;
            int n2 = 0;
            while (n2 < 50) {
                double d4 = GISTrainer.this.modelExpects[GISTrainer.this.PID].get(n) * Math.exp((double)GISTrainer.this.constant * d3);
                double d5 = d4 + (d + d3) / GISTrainer.this.sigma - GISTrainer.this.observedExpects[GISTrainer.this.PID].get(n);
                double d6 = d4 * (double)GISTrainer.this.constant + 1.0 / GISTrainer.this.sigma;
                if (d6 == 0.0) break;
                d2 = d3 - d5 / d6;
                if (Math.abs(d2 - d3) < 1.0E-6) {
                    d3 = d2;
                    break;
                }
                d3 = d2;
                ++n2;
            }
            GISTrainer.this.params[GISTrainer.this.PID].put(n, d + d3);
            return true;
        }
    };

    GISTrainer() {
    }

    GISTrainer(boolean bl) {
        this();
        this.printMessages = bl;
    }

    public void setSmoothing(boolean bl) {
        this._simpleSmoothing = bl;
    }

    public void setSmoothingObservation(double d) {
        this._smoothingObservation = d;
    }

    public GISModel trainModel(EventStream eventStream, int n, int n2) {
        return this.trainModel(n, new OnePassDataIndexer(eventStream, n2));
    }

    public GISModel trainModel(int n, DataIndexer dataIndexer) {
        int n2;
        this.display("Incorporating indexed data for training...  \n");
        this.contexts = dataIndexer.getContexts();
        this.outcomes = dataIndexer.getOutcomeList();
        this.numTimesEventsSeen = dataIndexer.getNumTimesEventsSeen();
        this.numTokens = this.contexts.length;
        this.constant = this.contexts[0].length;
        this.TID = 1;
        while (this.TID < this.contexts.length) {
            if (this.contexts[this.TID].length > this.constant) {
                this.constant = this.contexts[this.TID].length;
            }
            ++this.TID;
        }
        this.constantInverse = 1.0 / (double)this.constant;
        this.display("done.\n");
        this.outcomeLabels = dataIndexer.getOutcomeLabels();
        this.outcomeList = dataIndexer.getOutcomeList();
        this.numOutcomes = this.outcomeLabels.length;
        this.iprob = Math.log(1.0 / (double)this.numOutcomes);
        this.predLabels = dataIndexer.getPredLabels();
        this.numPreds = this.predLabels.length;
        this.display("\tNumber of Event Tokens: " + this.numTokens + "\n");
        this.display("\t    Number of Outcomes: " + this.numOutcomes + "\n");
        this.display("\t  Number of Predicates: " + this.numPreds + "\n");
        int[][] nArray = new int[this.numPreds][this.numOutcomes];
        this.TID = 0;
        while (this.TID < this.numTokens) {
            int n3 = 0;
            while (n3 < this.contexts[this.TID].length) {
                int[] nArray2 = nArray[this.contexts[this.TID][n3]];
                int n4 = this.outcomeList[this.TID];
                nArray2[n4] = nArray2[n4] + this.numTimesEventsSeen[this.TID];
                ++n3;
            }
            ++this.TID;
        }
        dataIndexer = null;
        double d = this._smoothingObservation;
        this.params = new TIntParamHashMap[this.numPreds];
        this.modelExpects = new TIntParamHashMap[this.numPreds];
        this.observedExpects = new TIntParamHashMap[this.numPreds];
        float f = 0.9f;
        if (this.numOutcomes < 3) {
            n2 = 2;
            f = 1.0f;
        } else {
            n2 = this.numOutcomes < 5 ? 2 : this.numOutcomes / 2;
        }
        this.PID = 0;
        while (this.PID < this.numPreds) {
            this.params[this.PID] = new TIntParamHashMap(n2, f);
            this.modelExpects[this.PID] = new TIntParamHashMap(n2, f);
            this.observedExpects[this.PID] = new TIntParamHashMap(n2, f);
            this.OID = 0;
            while (this.OID < this.numOutcomes) {
                if (nArray[this.PID][this.OID] > 0) {
                    this.params[this.PID].put(this.OID, 0.0);
                    this.modelExpects[this.PID].put(this.OID, 0.0);
                    this.observedExpects[this.PID].put(this.OID, nArray[this.PID][this.OID]);
                } else if (this._simpleSmoothing) {
                    this.params[this.PID].put(this.OID, 0.0);
                    this.modelExpects[this.PID].put(this.OID, 0.0);
                    this.observedExpects[this.PID].put(this.OID, d);
                }
                ++this.OID;
            }
            this.params[this.PID].compact();
            this.modelExpects[this.PID].compact();
            this.observedExpects[this.PID].compact();
            ++this.PID;
        }
        if (this._useSlackParameter) {
            int n5 = 0;
            this.TID = 0;
            while (this.TID < this.numTokens) {
                int n6 = 0;
                while (n6 < this.contexts[this.TID].length) {
                    this.PID = this.contexts[this.TID][n6];
                    if (!this.modelExpects[this.PID].containsKey(this.outcomes[this.TID])) {
                        n5 += this.numTimesEventsSeen[this.TID];
                    }
                    ++n6;
                }
                n5 += (this.constant - this.contexts[this.TID].length) * this.numTimesEventsSeen[this.TID];
                ++this.TID;
            }
            this.cfObservedExpect = n5 == 0 ? Math.log(0.01) : Math.log(n5);
            this.correctionParam = 0.0;
        }
        nArray = null;
        this.display("...done.\n");
        this.modelDistribution = new double[this.numOutcomes];
        this.numfeats = new int[this.numOutcomes];
        this.display("Computing model parameters...\n");
        this.findParameters(n);
        return new GISModel(this.params, this.predLabels, this.outcomeLabels, this.constant, this.correctionParam);
    }

    private void findParameters(int n) {
        double d = 0.0;
        double d2 = 0.0;
        this.display("Performing " + n + " iterations.\n");
        int n2 = 1;
        while (n2 <= n) {
            if (n2 < 10) {
                this.display("  " + n2 + ":  ");
            } else if (n2 < 100) {
                this.display(" " + n2 + ":  ");
            } else {
                this.display(n2 + ":  ");
            }
            d2 = this.nextIteration();
            if (n2 > 1) {
                if (d > d2) {
                    System.err.println("Model Diverging: loglikelihood decreased");
                    break;
                }
                if (d2 - d < 1.0E-4) break;
            }
            d = d2;
            ++n2;
        }
        this.observedExpects = null;
        this.modelExpects = null;
        this.numTimesEventsSeen = null;
        this.contexts = null;
    }

    public void eval(int[] nArray, double[] dArray) {
        int n;
        int n2 = 0;
        while (n2 < this.numOutcomes) {
            dArray[n2] = this.iprob;
            this.numfeats[n2] = 0;
            ++n2;
        }
        int n3 = 0;
        while (n3 < nArray.length) {
            TIntParamHashMap tIntParamHashMap = this.params[nArray[n3]];
            int[] nArray2 = tIntParamHashMap.keys();
            int n4 = 0;
            while (n4 < nArray2.length) {
                int n5 = n = nArray2[n4];
                this.numfeats[n5] = this.numfeats[n5] + 1;
                int n6 = n;
                dArray[n6] = dArray[n6] + this.constantInverse * tIntParamHashMap.get(n);
                ++n4;
            }
            ++n3;
        }
        double d = 0.0;
        n = 0;
        while (n < this.numOutcomes) {
            dArray[n] = Math.exp(dArray[n]);
            if (this._useSlackParameter) {
                int n7 = n;
                dArray[n7] = dArray[n7] + (1.0 - (double)this.numfeats[n] / (double)this.constant) * this.correctionParam;
            }
            d += dArray[n];
            ++n;
        }
        int n8 = 0;
        while (n8 < this.numOutcomes) {
            int n9 = n8++;
            dArray[n9] = dArray[n9] / d;
        }
    }

    private double nextIteration() {
        double d = 0.0;
        this.CFMOD = 0.0;
        int n = 0;
        int n2 = 0;
        this.TID = 0;
        while (this.TID < this.numTokens) {
            this.eval(this.contexts[this.TID], this.modelDistribution);
            int n3 = 0;
            while (n3 < this.contexts[this.TID].length) {
                this.PID = this.contexts[this.TID][n3];
                this.modelExpects[this.PID].forEachEntry(this.updateModelExpect);
                if (this._useSlackParameter) {
                    this.OID = 0;
                    while (this.OID < this.numOutcomes) {
                        if (!this.modelExpects[this.PID].containsKey(this.OID)) {
                            this.CFMOD += this.modelDistribution[this.OID] * (double)this.numTimesEventsSeen[this.TID];
                        }
                        ++this.OID;
                    }
                }
                ++n3;
            }
            if (this._useSlackParameter) {
                this.CFMOD += (double)((this.constant - this.contexts[this.TID].length) * this.numTimesEventsSeen[this.TID]);
            }
            d += Math.log(this.modelDistribution[this.outcomes[this.TID]]) * (double)this.numTimesEventsSeen[this.TID];
            n += this.numTimesEventsSeen[this.TID];
            if (this.printMessages) {
                int n4 = 0;
                this.OID = 1;
                while (this.OID < this.numOutcomes) {
                    if (this.modelDistribution[this.OID] > this.modelDistribution[n4]) {
                        n4 = this.OID;
                    }
                    ++this.OID;
                }
                if (n4 == this.outcomes[this.TID]) {
                    n2 += this.numTimesEventsSeen[this.TID];
                }
            }
            ++this.TID;
        }
        this.display(".");
        this.PID = 0;
        while (this.PID < this.numPreds) {
            this.params[this.PID].forEachEntry(this.updateParams);
            this.modelExpects[this.PID].transformValues(this.backToZeros);
            ++this.PID;
        }
        if (this.CFMOD > 0.0 && this._useSlackParameter) {
            this.correctionParam += this.cfObservedExpect - Math.log(this.CFMOD);
        }
        this.display(". loglikelihood=" + d + "\t" + (double)n2 / (double)n + "\n");
        return d;
    }

    private void display(String string) {
        if (this.printMessages) {
            System.out.print(string);
        }
    }
}

