package org.nltk.mallet;

import edu.umass.cs.mallet.base.fst.CRF4;
import edu.umass.cs.mallet.base.fst.SimpleTagger;
import edu.umass.cs.mallet.base.fst.TransducerEvaluator;
import edu.umass.cs.mallet.base.pipe.Pipe;
import edu.umass.cs.mallet.base.pipe.iterator.LineGroupIterator;
import edu.umass.cs.mallet.base.types.Alphabet;
import edu.umass.cs.mallet.base.types.FeatureSelection;
import edu.umass.cs.mallet.base.types.FeatureSequence;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.util.CommandOption;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.ObjectOutputStream;
import java.util.logging.ConsoleHandler;
import java.util.logging.LogRecord;
import java.util.regex.Pattern;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;
import org.nltk.mallet.CRFInfo;

/* loaded from: input_file:org/nltk/mallet/TrainCRF.class */
public class TrainCRF {
    private static final CommandOption.File trainFileOption;
    private static final CommandOption.File modelFileOption;
    private static final CommandOption.List commandOptions;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/nltk/mallet/TrainCRF$MyHandler.class */
    private static class MyHandler extends ConsoleHandler {
        private MyHandler() {
        }

        @Override // java.util.logging.ConsoleHandler, java.util.logging.StreamHandler, java.util.logging.Handler
        public void publish(LogRecord logRecord) {
            System.err.println("DEBUG: OO");
            System.err.flush();
        }
    }

    public static CRF4 createCRF(File file, CRFInfo cRFInfo) throws FileNotFoundException {
        FileReader fileReader = new FileReader(file);
        SimpleTagger.SimpleTaggerSentence2FeatureVectorSequence simpleTaggerSentence2FeatureVectorSequence = new SimpleTagger.SimpleTaggerSentence2FeatureVectorSequence();
        simpleTaggerSentence2FeatureVectorSequence.setTargetProcessing(true);
        simpleTaggerSentence2FeatureVectorSequence.getTargetAlphabet().lookupIndex(cRFInfo.defaultLabel);
        InstanceList instanceList = new InstanceList(simpleTaggerSentence2FeatureVectorSequence);
        instanceList.add(new LineGroupIterator(fileReader, Pattern.compile("^\\s*$"), true));
        CRF4 crf4 = new CRF4(simpleTaggerSentence2FeatureVectorSequence, (Pipe) null);
        crf4.setGaussianPriorVariance(cRFInfo.gaussianVariance);
        crf4.setTransductionType(cRFInfo.transductionType);
        if (cRFInfo.stateInfoList != null) {
            for (CRFInfo.StateInfo stateInfo : cRFInfo.stateInfoList) {
                crf4.addState(stateInfo.name, stateInfo.initialCost, stateInfo.finalCost, stateInfo.destinationNames, stateInfo.labelNames, stateInfo.weightNames);
            }
        } else if (cRFInfo.stateStructure == CRFInfo.FULLY_CONNECTED_STRUCTURE) {
            crf4.addStatesForLabelsConnectedAsIn(instanceList);
        } else if (cRFInfo.stateStructure == CRFInfo.HALF_CONNECTED_STRUCTURE) {
            crf4.addStatesForHalfLabelsConnectedAsIn(instanceList);
        } else if (cRFInfo.stateStructure == CRFInfo.THREE_QUARTERS_CONNECTED_STRUCTURE) {
            crf4.addStatesForThreeQuarterLabelsConnectedAsIn(instanceList);
        } else {
            if (cRFInfo.stateStructure != CRFInfo.BILABELS_STRUCTURE) {
                throw new RuntimeException("Unexpected state structure " + cRFInfo.stateStructure);
            }
            crf4.addStatesForBiLabelsConnectedAsIn(instanceList);
        }
        if (cRFInfo.weightGroupInfoList != null) {
            for (CRFInfo.WeightGroupInfo weightGroupInfo : cRFInfo.weightGroupInfoList) {
                crf4.setFeatureSelection(crf4.getWeightsIndex(weightGroupInfo.name), FeatureSelection.createFromRegex(crf4.getInputAlphabet(), Pattern.compile(weightGroupInfo.featureSelectionRegex)));
            }
        }
        crf4.train(instanceList, (InstanceList) null, (InstanceList) null, (TransducerEvaluator) null, cRFInfo.maxIterations);
        return crf4;
    }

    public boolean[][] labelConnectionsIn(Alphabet alphabet, InstanceList instanceList, String str) {
        int size = alphabet.size();
        boolean[][] zArr = new boolean[size][size];
        for (int i = 0; i < instanceList.size(); i++) {
            FeatureSequence featureSequence = (FeatureSequence) instanceList.getInstance(i).getTarget();
            for (int i2 = 1; i2 < featureSequence.size(); i2++) {
                int lookupIndex = alphabet.lookupIndex(featureSequence.get(i2 - 1));
                int lookupIndex2 = alphabet.lookupIndex(featureSequence.get(i2));
                if (!$assertionsDisabled && (lookupIndex < 0 || lookupIndex2 < 0)) {
                    throw new AssertionError();
                }
                zArr[lookupIndex][lookupIndex2] = true;
            }
        }
        if (str != null) {
            int lookupIndex3 = alphabet.lookupIndex(str);
            for (int i3 = 0; i3 < alphabet.size(); i3++) {
                zArr[lookupIndex3][i3] = true;
            }
        }
        return zArr;
    }

    public static void main(String[] strArr) throws Exception {
        int processOptions = commandOptions.processOptions(strArr);
        if (processOptions != strArr.length) {
            commandOptions.printUsage(true);
            throw new IllegalArgumentException("Unexpected arg " + strArr[processOptions]);
        }
        if (trainFileOption.value == null) {
            commandOptions.printUsage(true);
            throw new IllegalArgumentException("Expected --train-file FILE");
        }
        if (modelFileOption.value == null) {
            commandOptions.printUsage(true);
            throw new IllegalArgumentException("Expected --model-file FILE");
        }
        ZipFile zipFile = new ZipFile(modelFileOption.value);
        ZipEntry entry = zipFile.getEntry("crf-info.xml");
        CRFInfo cRFInfo = new CRFInfo(zipFile.getInputStream(entry));
        byte[] bArr = new byte[(int) entry.getSize()];
        zipFile.getInputStream(entry).read(bArr);
        CRF4 createCRF = createCRF(trainFileOption.value, cRFInfo);
        ZipOutputStream zipOutputStream = new ZipOutputStream(new FileOutputStream(modelFileOption.value));
        zipOutputStream.putNextEntry(new ZipEntry("crf-info.xml"));
        zipOutputStream.write(bArr);
        zipOutputStream.closeEntry();
        zipOutputStream.putNextEntry(new ZipEntry("crf-model.ser"));
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(zipOutputStream);
        objectOutputStream.writeObject(createCRF);
        objectOutputStream.flush();
        zipOutputStream.closeEntry();
        zipOutputStream.close();
    }

    static {
        $assertionsDisabled = !TrainCRF.class.desiredAssertionStatus();
        trainFileOption = new CommandOption.File(TrainCRF.class, "train-file", "FILENAME", true, (File) null, "The filename for the training data.", (String) null);
        modelFileOption = new CommandOption.File(TrainCRF.class, "model-file", "FILENAME", true, (File) null, "The CRF model file, a zip file containing crf-info.xml.TrainCRF will add crf-model.ser to this file.", (String) null);
        commandOptions = new CommandOption.List("Train a CRF tagger.", new CommandOption[]{trainFileOption, modelFileOption});
    }
}
