/*
 * Decompiled with CFR 0.152.
 */
package opennlp.tools.ml.naivebayes;

import java.io.IOException;
import opennlp.tools.ml.AbstractEventTrainer;
import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Context;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.EvalParameters;
import opennlp.tools.ml.model.MutableContext;
import opennlp.tools.ml.naivebayes.NaiveBayesEvalParameters;
import opennlp.tools.ml.naivebayes.NaiveBayesModel;
import opennlp.tools.util.Parameters;
import opennlp.tools.util.TrainingParameters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NaiveBayesTrainer
extends AbstractEventTrainer<TrainingParameters> {
    private static final Logger logger = LoggerFactory.getLogger(NaiveBayesTrainer.class);
    public static final String NAIVE_BAYES_VALUE = "NAIVEBAYES";
    private int numUniqueEvents;
    private int numEvents;
    private int numPreds;
    private int numOutcomes;
    private int[][] contexts;
    private float[][] values;
    private int[] outcomeList;
    private int[] numTimesEventsSeen;
    private String[] outcomeLabels;
    private String[] predLabels;

    public NaiveBayesTrainer() {
    }

    public NaiveBayesTrainer(TrainingParameters parameters) {
        super((Parameters)parameters);
    }

    public boolean isSortAndMerge() {
        return false;
    }

    public AbstractModel doTrain(DataIndexer<TrainingParameters> indexer) throws IOException {
        return this.trainModel(indexer);
    }

    public AbstractModel trainModel(DataIndexer<TrainingParameters> di) {
        logger.info("Incorporating indexed data for training...");
        this.contexts = di.getContexts();
        this.values = di.getValues();
        this.numTimesEventsSeen = di.getNumTimesEventsSeen();
        this.numEvents = di.getNumEvents();
        this.numUniqueEvents = this.contexts.length;
        this.outcomeLabels = di.getOutcomeLabels();
        this.outcomeList = di.getOutcomeList();
        this.predLabels = di.getPredLabels();
        this.numPreds = this.predLabels.length;
        this.numOutcomes = this.outcomeLabels.length;
        logger.info("done.");
        logger.info("\tNumber of Event Tokens: {} \n\t Number of Outcomes: {} \n\t Number of Predicates: {}", new Object[]{this.numUniqueEvents, this.numOutcomes, this.numPreds});
        logger.info("Computing model parameters...");
        MutableContext[] finalParameters = this.findParameters();
        logger.info("...done.");
        return new NaiveBayesModel((Context[])finalParameters, this.predLabels, this.outcomeLabels);
    }

    private MutableContext[] findParameters() {
        int[] allOutcomesPattern = new int[this.numOutcomes];
        for (int oi = 0; oi < this.numOutcomes; ++oi) {
            allOutcomesPattern[oi] = oi;
        }
        MutableContext[] params = new MutableContext[this.numPreds];
        for (int pi = 0; pi < this.numPreds; ++pi) {
            params[pi] = new MutableContext(allOutcomesPattern, new double[this.numOutcomes]);
            for (int aoi = 0; aoi < this.numOutcomes; ++aoi) {
                params[pi].setParameter(aoi, 0.0);
            }
        }
        double[] outcomeTotals = new double[this.outcomeLabels.length];
        for (MutableContext context : params) {
            for (int j = 0; j < context.getOutcomes().length; ++j) {
                int outcome = context.getOutcomes()[j];
                double count = context.getParameters()[j];
                int n = outcome;
                outcomeTotals[n] = outcomeTotals[n] + count;
            }
        }
        NaiveBayesEvalParameters evalParams = new NaiveBayesEvalParameters((Context[])params, this.outcomeLabels.length, outcomeTotals, this.predLabels.length);
        double stepSize = 1.0;
        for (int ei = 0; ei < this.numUniqueEvents; ++ei) {
            int targetOutcome = this.outcomeList[ei];
            for (int ni = 0; ni < this.numTimesEventsSeen[ei]; ++ni) {
                for (int ci = 0; ci < this.contexts[ei].length; ++ci) {
                    int pi = this.contexts[ei][ci];
                    if (this.values == null) {
                        params[pi].updateParameter(targetOutcome, stepSize);
                        continue;
                    }
                    params[pi].updateParameter(targetOutcome, stepSize * (double)this.values[ei][ci]);
                }
            }
        }
        this.trainingStats(evalParams);
        return params;
    }

    private double trainingStats(EvalParameters evalParams) {
        int numCorrect = 0;
        for (int ei = 0; ei < this.numUniqueEvents; ++ei) {
            for (int ni = 0; ni < this.numTimesEventsSeen[ei]; ++ni) {
                double[] modelDistribution = new double[this.numOutcomes];
                if (this.values != null) {
                    NaiveBayesModel.eval(this.contexts[ei], this.values[ei], modelDistribution, evalParams, false);
                } else {
                    NaiveBayesModel.eval(this.contexts[ei], null, modelDistribution, evalParams, false);
                }
                int max = ArrayMath.argmax((double[])modelDistribution);
                if (max != this.outcomeList[ei]) continue;
                ++numCorrect;
            }
        }
        double trainingAccuracy = (double)numCorrect / (double)this.numEvents;
        logger.info("Stats: ({}/{}) {}", new Object[]{numCorrect, this.numEvents, trainingAccuracy});
        return trainingAccuracy;
    }
}

