/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.ling.HasTag;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.parser.lexparser.AbstractDependencyGrammar;
import edu.stanford.nlp.parser.lexparser.BasicCategoryTagProjection;
import edu.stanford.nlp.parser.lexparser.IntDependency;
import edu.stanford.nlp.parser.lexparser.IntTaggedWord;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.parser.lexparser.TagProjection;
import edu.stanford.nlp.parser.lexparser.Test;
import edu.stanford.nlp.parser.lexparser.TestTagProjection;
import edu.stanford.nlp.parser.lexparser.TreebankLangParserParams;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Numberer;
import edu.stanford.nlp.util.StringUtils;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class MLEDependencyGrammar
extends AbstractDependencyGrammar {
    private static final boolean useSmoothTagProjection = false;
    private static final boolean useUnigramWordSmoothing = false;
    protected int numWordTokens;
    protected ClassicCounter<IntDependency> argCounter = new ClassicCounter();
    protected ClassicCounter<IntDependency> stopCounter = new ClassicCounter();
    public double smooth_aT_hTWd = 32.0;
    public double smooth_aTW_hTWd = 16.0;
    public double smooth_stop = 4.0;
    public double interp = 0.6;
    public double smooth_aTW_aT = 96.0;
    public double smooth_aTW_hTd = 32.0;
    public double smooth_aT_hTd = 32.0;
    public double smooth_aPTW_aPT = 16.0;
    static transient EndHead tempEndHead = new EndHead();
    protected transient List<IntTaggedWord> tagITWList = null;
    private TagProjection smoothTP;
    private Numberer smoothTPNumberer;
    private static final String TP_PREFIX = ".*TP*.";
    private static final boolean verbose = false;
    protected static final double MIN_PROBABILITY = 1.0E-40;
    private static final long serialVersionUID = 1L;

    public MLEDependencyGrammar(TreebankLangParserParams tlpParams, boolean directional, boolean distance, boolean coarseDistance) {
        this(LexicalizedParser.basicCategoryTagsInDependencyGrammar ? new BasicCategoryTagProjection(tlpParams.treebankLanguagePack()) : new TestTagProjection(), tlpParams, directional, distance, coarseDistance);
    }

    public MLEDependencyGrammar(TagProjection tagProjection, TreebankLangParserParams tlpParams, boolean directional, boolean useDistance, boolean useCoarseDistance) {
        super(tlpParams.treebankLanguagePack(), tagProjection, directional, useDistance, useCoarseDistance);
        double[] smoothParams = tlpParams.MLEDependencyGrammarSmoothingParams();
        this.smooth_aT_hTWd = smoothParams[0];
        this.smooth_aTW_hTWd = smoothParams[1];
        this.smooth_stop = smoothParams[2];
        this.interp = smoothParams[3];
        this.smoothTP = new BasicCategoryTagProjection(tlpParams.treebankLanguagePack());
    }

    public String toString() {
        NumberFormat nf = NumberFormat.getNumberInstance();
        nf.setMaximumFractionDigits(2);
        StringBuilder sb = new StringBuilder(2000);
        String cl = this.getClass().getName();
        sb.append(cl.substring(cl.lastIndexOf(46) + 1)).append("[tagbins=");
        sb.append(this.numTagBins).append(",wordTokens=").append(this.numWordTokens).append("; head -> arg\n");
        sb.append("]");
        return sb.toString();
    }

    public boolean pruneTW(IntTaggedWord argTW) {
        String[] punctTags;
        for (String punctTag : punctTags = this.tlp.punctuationTags()) {
            if (argTW.tag != MLEDependencyGrammar.tagNumberer().number(punctTag)) continue;
            return true;
        }
        return false;
    }

    protected static EndHead treeToDependencyHelper(Tree tree, List<IntDependency> depList, int loc) {
        if (tree.isLeaf() || tree.isPreTerminal()) {
            MLEDependencyGrammar.tempEndHead.head = loc;
            MLEDependencyGrammar.tempEndHead.end = loc + 1;
            return tempEndHead;
        }
        Tree[] kids = tree.children();
        if (kids.length == 1) {
            return MLEDependencyGrammar.treeToDependencyHelper(kids[0], depList, loc);
        }
        tempEndHead = MLEDependencyGrammar.treeToDependencyHelper(kids[0], depList, loc);
        int lHead = MLEDependencyGrammar.tempEndHead.head;
        int split = MLEDependencyGrammar.tempEndHead.end;
        tempEndHead = MLEDependencyGrammar.treeToDependencyHelper(kids[1], depList, MLEDependencyGrammar.tempEndHead.end);
        int end = MLEDependencyGrammar.tempEndHead.end;
        int rHead = MLEDependencyGrammar.tempEndHead.head;
        String hTag = ((HasTag)((Object)tree.label())).tag();
        String lTag = ((HasTag)((Object)kids[0].label())).tag();
        String rTag = ((HasTag)((Object)kids[1].label())).tag();
        String hWord = ((HasWord)((Object)tree.label())).word();
        String lWord = ((HasWord)((Object)kids[0].label())).word();
        String rWord = ((HasWord)((Object)kids[1].label())).word();
        boolean leftHeaded = hWord.equals(lWord);
        String aTag = leftHeaded ? rTag : lTag;
        String aWord = leftHeaded ? rWord : lWord;
        int hT = MLEDependencyGrammar.tagNumberer().number(hTag);
        int aT = MLEDependencyGrammar.tagNumberer().number(aTag);
        int hW = MLEDependencyGrammar.wordNumberer().hasSeen(hWord) ? MLEDependencyGrammar.wordNumberer().number(hWord) : MLEDependencyGrammar.wordNumberer().number("UNK");
        int aW = MLEDependencyGrammar.wordNumberer().hasSeen(aWord) ? MLEDependencyGrammar.wordNumberer().number(aWord) : MLEDependencyGrammar.wordNumberer().number("UNK");
        int head = leftHeaded ? lHead : rHead;
        int arg = leftHeaded ? rHead : lHead;
        IntDependency dependency = new IntDependency(hW, hT, aW, aT, leftHeaded, leftHeaded ? split - head - 1 : head - split);
        depList.add(dependency);
        IntDependency stopL = new IntDependency(aW, aT, -2, -2, false, leftHeaded ? arg - split : arg - loc);
        depList.add(stopL);
        IntDependency stopR = new IntDependency(aW, aT, -2, -2, true, leftHeaded ? end - arg - 1 : split - arg - 1);
        depList.add(stopR);
        MLEDependencyGrammar.tempEndHead.head = head;
        return tempEndHead;
    }

    public void dumpSizes() {
        System.out.println("arg counter " + this.argCounter.size());
        System.out.println("stop counter " + this.stopCounter.size());
    }

    public static List<IntDependency> treeToDependencyList(Tree tree) {
        ArrayList<IntDependency> depList = new ArrayList<IntDependency>();
        MLEDependencyGrammar.treeToDependencyHelper(tree, depList, 0);
        return depList;
    }

    public double scoreAll(Collection<IntDependency> deps) {
        double totalScore = 0.0;
        for (IntDependency d : deps) {
            double score = this.score(d);
            if (!(score > Double.NEGATIVE_INFINITY)) continue;
            totalScore += score;
        }
        return totalScore;
    }

    @Override
    public void tune(Collection<Tree> trees) {
        ArrayList<IntDependency> deps = new ArrayList<IntDependency>();
        for (Tree tree : trees) {
            deps.addAll(MLEDependencyGrammar.treeToDependencyList(tree));
        }
        double bestScore = Double.NEGATIVE_INFINITY;
        double bestSmooth_stop = 0.0;
        double bestSmooth_aTW_hTWd = 0.0;
        double bestSmooth_aT_hTWd = 0.0;
        double bestInterp = 0.0;
        System.err.println("Tuning smooth_stop...");
        this.smooth_stop = 0.01;
        while (this.smooth_stop < 100.0) {
            double totalScore = 0.0;
            for (IntDependency dep : deps) {
                if (MLEDependencyGrammar.rootTW(dep.head)) continue;
                double stopProb = this.getStopProb(dep);
                if (!dep.arg.equals(stopTW)) {
                    stopProb = 1.0 - stopProb;
                }
                if (!(stopProb > 0.0)) continue;
                totalScore += Math.log(stopProb);
            }
            if (totalScore > bestScore) {
                bestScore = totalScore;
                bestSmooth_stop = this.smooth_stop;
            }
            this.smooth_stop *= 1.25;
        }
        this.smooth_stop = bestSmooth_stop;
        System.err.println("Tuning selected smooth_stop: " + this.smooth_stop);
        Iterator iter = deps.iterator();
        while (iter.hasNext()) {
            IntDependency dep = (IntDependency)iter.next();
            if (!dep.arg.equals(stopTW)) continue;
            iter.remove();
        }
        System.err.println("Tuning other parameters...");
        bestScore = Double.NEGATIVE_INFINITY;
        this.smooth_aTW_hTWd = 0.5;
        while (this.smooth_aTW_hTWd < 100.0) {
            System.err.print(".");
            this.smooth_aT_hTWd = 0.5;
            while (this.smooth_aT_hTWd < 100.0) {
                this.interp = 0.02;
                while (this.interp < 1.0) {
                    double totalScore = 0.0;
                    for (IntDependency dep : deps) {
                        double score = this.score(dep);
                        if (!(score > Double.NEGATIVE_INFINITY)) continue;
                        totalScore += score;
                    }
                    if (totalScore > bestScore) {
                        bestScore = totalScore;
                        bestInterp = this.interp;
                        bestSmooth_aTW_hTWd = this.smooth_aTW_hTWd;
                        bestSmooth_aT_hTWd = this.smooth_aT_hTWd;
                        System.err.println("Current best interp: " + this.interp + " with score " + totalScore);
                    }
                    this.interp += 0.02;
                }
                this.smooth_aT_hTWd *= 1.25;
            }
            this.smooth_aTW_hTWd *= 1.25;
        }
        this.smooth_aTW_hTWd = bestSmooth_aTW_hTWd;
        this.smooth_aT_hTWd = bestSmooth_aT_hTWd;
        this.interp = bestInterp;
        System.err.println("\nTuning selected smooth_aTW_hTWd: " + this.smooth_aTW_hTWd + " smooth_aT_hTWd: " + this.smooth_aT_hTWd + " interp: " + this.interp + " smooth_aTW_aT: " + this.smooth_aTW_aT + " smooth_aTW_hTd: " + this.smooth_aTW_hTd + " smooth_aT_hTd: " + this.smooth_aT_hTd);
    }

    public void addRule(IntDependency dependency, double count) {
        if (!this.directional) {
            dependency.leftHeaded = false;
        }
        this.expandDependency(dependency, count);
    }

    private IntTaggedWord getCachedITW(short tag) {
        IntTaggedWord headT;
        if (this.tagITWList == null) {
            this.tagITWList = new ArrayList<IntTaggedWord>(this.numTagBins + 2);
            for (int i = 0; i < this.numTagBins + 2; ++i) {
                this.tagITWList.add(i, null);
            }
        }
        if ((headT = this.tagITWList.get(this.tagBin(tag) + 2)) == null) {
            headT = new IntTaggedWord(-1, this.tagBin(tag));
            this.tagITWList.set(this.tagBin(tag) + 2, headT);
        }
        return headT;
    }

    protected void expandDependency(IntDependency dependency, double count) {
        if (dependency.head == null || dependency.arg == null) {
            return;
        }
        if (dependency.arg.word != -2) {
            this.expandArg(dependency, this.valenceBin(dependency.distance), count);
        }
        this.expandStop(dependency, this.distanceBin(dependency.distance), count, true);
    }

    private short tagProject(short tag) {
        if (this.smoothTPNumberer == null) {
            this.smoothTPNumberer = new Numberer(MLEDependencyGrammar.tagNumberer());
        }
        if (tag < 0) {
            return tag;
        }
        String tagStr = (String)this.smoothTPNumberer.object(tag);
        String binStr = TP_PREFIX + this.smoothTP.project(tagStr);
        return (short)this.smoothTPNumberer.number(binStr);
    }

    private void expandArg(IntDependency dependency, short valBinDist, double count) {
        IntTaggedWord headT = this.getCachedITW(dependency.head.tag);
        IntTaggedWord argT = this.getCachedITW(dependency.arg.tag);
        IntTaggedWord head = new IntTaggedWord(dependency.head.word, this.tagBin(dependency.head.tag));
        IntTaggedWord arg = new IntTaggedWord(dependency.arg.word, this.tagBin(dependency.arg.tag));
        boolean leftHeaded = dependency.leftHeaded;
        this.argCounter.incrementCount(this.intern(head, arg, leftHeaded, valBinDist), count);
        this.argCounter.incrementCount(this.intern(headT, arg, leftHeaded, valBinDist), count);
        this.argCounter.incrementCount(this.intern(head, argT, leftHeaded, valBinDist), count);
        this.argCounter.incrementCount(this.intern(headT, argT, leftHeaded, valBinDist), count);
        this.argCounter.incrementCount(this.intern(head, wildTW, leftHeaded, valBinDist), count);
        this.argCounter.incrementCount(this.intern(headT, wildTW, leftHeaded, valBinDist), count);
        this.argCounter.incrementCount(this.intern(wildTW, arg, false, (short)-1), count);
        this.argCounter.incrementCount(this.intern(wildTW, argT, false, (short)-1), count);
        ++this.numWordTokens;
    }

    private void expandStop(IntDependency dependency, short distBinDist, double count, boolean wildForStop) {
        IntTaggedWord headT = this.getCachedITW(dependency.head.tag);
        IntTaggedWord head = new IntTaggedWord(dependency.head.word, this.tagBin(dependency.head.tag));
        IntTaggedWord arg = new IntTaggedWord(dependency.arg.word, this.tagBin(dependency.arg.tag));
        boolean leftHeaded = dependency.leftHeaded;
        if (arg.word == -2) {
            this.stopCounter.incrementCount(this.intern(head, arg, leftHeaded, distBinDist), count);
            this.stopCounter.incrementCount(this.intern(headT, arg, leftHeaded, distBinDist), count);
        }
        if (wildForStop || arg.word != -2) {
            this.stopCounter.incrementCount(this.intern(head, wildTW, leftHeaded, distBinDist), count);
            this.stopCounter.incrementCount(this.intern(headT, wildTW, leftHeaded, distBinDist), count);
        }
    }

    public double countHistory(IntDependency dependency) {
        short hTBackup = dependency.head.tag;
        IntTaggedWord aTWBackup = dependency.arg;
        short dist = dependency.distance;
        dependency.head.tag = (short)this.tagBin(dependency.head.tag);
        dependency.distance = this.valenceBin(dist);
        dependency.arg = wildTW;
        double s = this.argCounter.getCount(dependency);
        dependency.head.tag = hTBackup;
        dependency.arg = aTWBackup;
        dependency.distance = dist;
        return s;
    }

    @Override
    public double scoreTB(IntDependency dependency) {
        return Test.depWeight * Math.log(this.probTB(dependency));
    }

    protected double probTB(IntDependency dependency) {
        if (!this.directional) {
            dependency.leftHeaded = false;
        }
        boolean leftHeaded = dependency.leftHeaded;
        short distance = dependency.distance;
        int hW = dependency.head.word;
        int aW = dependency.arg.word;
        short hT = dependency.head.tag;
        short aT = dependency.arg.tag;
        IntTaggedWord aTW = dependency.arg;
        IntTaggedWord hTW = dependency.head;
        boolean isRoot = MLEDependencyGrammar.rootTW(dependency.head);
        double pb_stop_hTWds = isRoot ? 0.0 : this.getStopProb(dependency);
        if (dependency.arg.word == -2) {
            return pb_stop_hTWds;
        }
        double pb_go_hTWds = 1.0 - pb_stop_hTWds;
        short binDistance = dependency.distance = this.valenceBin(distance);
        IntDependency copy = new IntDependency(dependency.head, dependency.arg, dependency.leftHeaded, dependency.distance);
        double c_aTW_hTWd = this.argCounter.getCount(dependency);
        dependency.arg.word = -1;
        double c_aT_hTWd = this.argCounter.getCount(dependency);
        dependency.arg.word = aW;
        dependency.arg = wildTW;
        double c_hTWd = this.argCounter.getCount(dependency);
        dependency.arg = aTW;
        if (!dependency.equals(copy)) {
            throw new RuntimeException("Dependencies not equal: " + dependency + " and " + copy);
        }
        dependency.head.word = -1;
        double c_aTW_hTd = this.argCounter.getCount(dependency);
        dependency.arg.word = -1;
        double c_aT_hTd = this.argCounter.getCount(dependency);
        dependency.arg.word = aW;
        dependency.arg = wildTW;
        double c_hTd = this.argCounter.getCount(dependency);
        dependency.arg = aTW;
        dependency.head.word = hW;
        if (!dependency.equals(copy)) {
            throw new RuntimeException("Dependencies not equal: " + dependency + " and " + copy);
        }
        int aPT = Short.MIN_VALUE;
        double c_aPTW_hPTd = Double.NaN;
        double c_aPT_hPTd = Double.NaN;
        double c_hPTd = Double.NaN;
        double c_aPTW_aPT = Double.NaN;
        double c_aPT = Double.NaN;
        dependency.head = wildTW;
        dependency.leftHeaded = false;
        dependency.distance = (short)-1;
        double c_aTW = this.argCounter.getCount(dependency);
        dependency.arg.word = -1;
        double c_aT = this.argCounter.getCount(dependency);
        dependency.arg.word = aW;
        dependency.arg.tag = (short)-1;
        double c_aW = this.argCounter.getCount(dependency);
        dependency.arg.tag = aT;
        dependency.head = hTW;
        dependency.leftHeaded = leftHeaded;
        dependency.distance = binDistance;
        if (!dependency.equals(copy)) {
            throw new RuntimeException("Dependencies not equal: " + dependency + " and " + copy);
        }
        dependency.distance = distance;
        double p_aTW_aT = Test.useLexiconToScoreDependencyPwGt ? (dependency.leftHeaded ? Math.exp(this.lex.score(dependency.arg, 1)) : Math.exp(this.lex.score(dependency.arg, -1))) : (c_aTW > 0.0 ? c_aTW / c_aT : 1.0);
        double p_aTW_hTd = c_hTd > 0.0 ? c_aTW_hTd / c_hTd : 0.0;
        double p_aT_hTd = c_hTd > 0.0 ? c_aT_hTd / c_hTd : 0.0;
        double pb_aTW_hTWd = (c_aTW_hTWd + this.smooth_aTW_hTWd * p_aTW_hTd) / (c_hTWd + this.smooth_aTW_hTWd);
        double pb_aT_hTWd = (c_aT_hTWd + this.smooth_aT_hTWd * p_aT_hTd) / (c_hTWd + this.smooth_aT_hTWd);
        double score = (this.interp * pb_aTW_hTWd + (1.0 - this.interp) * p_aTW_aT * pb_aT_hTWd) * pb_go_hTWds;
        if (Test.prunePunc && this.pruneTW(aTW)) {
            return 1.0;
        }
        if (Double.isNaN(score)) {
            score = 0.0;
        }
        if (score < 1.0E-40) {
            score = 0.0;
        }
        return score;
    }

    protected double getStopProb(IntDependency dependency) {
        int hW = dependency.head.word;
        IntTaggedWord aTW = dependency.arg;
        short distance = dependency.distance;
        dependency.distance = this.distanceBin(distance);
        dependency.arg = stopTW;
        double c_stop_hTWds = this.stopCounter.getCount(dependency);
        dependency.head.word = -1;
        double c_stop_hTds = this.stopCounter.getCount(dependency);
        dependency.head.word = hW;
        dependency.arg = wildTW;
        double c_hTWds = this.stopCounter.getCount(dependency);
        dependency.head.word = -1;
        double c_hTds = this.stopCounter.getCount(dependency);
        dependency.head.word = hW;
        dependency.arg = aTW;
        dependency.distance = distance;
        double p_stop_hTds = c_hTds > 0.0 ? c_stop_hTds / c_hTds : 1.0;
        double pb_stop_hTWds = (c_stop_hTWds + this.smooth_stop * p_stop_hTds) / (c_hTWds + this.smooth_stop);
        return pb_stop_hTWds;
    }

    private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
        double count;
        stream.defaultReadObject();
        ClassicCounter<IntDependency> compressedArgC = this.argCounter;
        this.argCounter = new ClassicCounter();
        ClassicCounter<IntDependency> compressedStopC = this.stopCounter;
        this.stopCounter = new ClassicCounter();
        for (IntDependency d : compressedArgC.keySet()) {
            count = compressedArgC.getCount(d);
            this.expandArg(d, d.distance, count);
        }
        for (IntDependency d : compressedStopC.keySet()) {
            count = compressedStopC.getCount(d);
            this.expandStop(d, d.distance, count, false);
        }
        this.expandDependencyMap = null;
    }

    private void writeObject(ObjectOutputStream stream) throws IOException {
        ClassicCounter<IntDependency> fullArgCounter = this.argCounter;
        this.argCounter = new ClassicCounter();
        for (IntDependency dependency : fullArgCounter.keySet()) {
            if (dependency.head == wildTW || dependency.arg == wildTW || dependency.head.word == -1 || dependency.arg.word == -1) continue;
            this.argCounter.incrementCount(dependency, fullArgCounter.getCount(dependency));
        }
        ClassicCounter<IntDependency> fullStopCounter = this.stopCounter;
        this.stopCounter = new ClassicCounter();
        for (IntDependency dependency : fullStopCounter.keySet()) {
            if (dependency.head.word == -1) continue;
            this.stopCounter.incrementCount(dependency, fullStopCounter.getCount(dependency));
        }
        stream.defaultWriteObject();
        this.argCounter = fullArgCounter;
        this.stopCounter = fullStopCounter;
    }

    @Override
    public void readData(BufferedReader in) throws IOException {
        String LEFT = "left";
        int lineNum = 1;
        boolean doingStop = false;
        IntDependency tempDependency = new IntDependency(-2, -2, -2, -2, false, 0);
        String line = in.readLine();
        while (line != null && line.length() > 0) {
            block6: {
                try {
                    if (line.equals("BEGIN_STOP")) {
                        doingStop = true;
                        break block6;
                    }
                    String[] fields = StringUtils.splitOnCharWithQuoting(line, ' ', '\"', '\\');
                    tempDependency.leftHeaded = fields[3].equals("left");
                    short distance = (short)Integer.parseInt(fields[4]);
                    tempDependency.head = new IntTaggedWord(fields[0], '/');
                    tempDependency.arg = new IntTaggedWord(fields[2], '/');
                    double count = Double.parseDouble(fields[5]);
                    if (doingStop) {
                        this.expandStop(tempDependency, distance, count, false);
                    } else {
                        this.expandArg(tempDependency, distance, count);
                    }
                }
                catch (Exception e) {
                    e.printStackTrace();
                    throw new IOException("Error on line " + lineNum + ": " + line);
                }
                ++lineNum;
            }
            line = in.readLine();
        }
    }

    @Override
    public void writeData(PrintWriter out2) throws IOException {
        double count;
        for (IntDependency dependency : this.argCounter.keySet()) {
            if (dependency.head == wildTW || dependency.arg == wildTW || dependency.head.word == -1 || dependency.arg.word == -1) continue;
            count = this.argCounter.getCount(dependency);
            out2.println(dependency + " " + count);
        }
        out2.println("BEGIN_STOP");
        for (IntDependency dependency : this.stopCounter.keySet()) {
            if (dependency.head.word == -1) continue;
            count = this.stopCounter.getCount(dependency);
            out2.println(dependency + " " + count);
        }
        out2.flush();
    }

    static class EndHead {
        public int end;
        public int head;

        EndHead() {
        }
    }
}

