/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Boostable;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.NaiveBayes;
import cc.mallet.pipe.Noop;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AlphabetCarrying;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;
import cc.mallet.types.Multinomial;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

public class NaiveBayesTrainer
extends ClassifierTrainer<NaiveBayes>
implements ClassifierTrainer.ByInstanceIncrements<NaiveBayes>,
Boostable,
AlphabetCarrying,
Serializable {
    Multinomial.Estimator featureEstimator = new Multinomial.LaplaceEstimator();
    Multinomial.Estimator priorEstimator = new Multinomial.LaplaceEstimator();
    Multinomial.Estimator[] me;
    Multinomial.Estimator pe;
    double docLengthNormalization = -1.0;
    NaiveBayes classifier;
    Pipe instancePipe;
    Alphabet dataAlphabet;
    Alphabet targetAlphabet;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 1;

    public NaiveBayesTrainer(NaiveBayes initialClassifier) {
        if (initialClassifier != null) {
            this.instancePipe = initialClassifier.getInstancePipe();
            this.dataAlphabet = initialClassifier.getAlphabet();
            this.targetAlphabet = initialClassifier.getLabelAlphabet();
            this.classifier = initialClassifier;
        }
    }

    public NaiveBayesTrainer(Pipe instancePipe) {
        this.instancePipe = instancePipe;
        this.dataAlphabet = instancePipe.getDataAlphabet();
        this.targetAlphabet = instancePipe.getTargetAlphabet();
    }

    public NaiveBayesTrainer() {
    }

    @Override
    public NaiveBayes getClassifier() {
        return this.classifier;
    }

    public NaiveBayesTrainer setDocLengthNormalization(double d) {
        this.docLengthNormalization = d;
        return this;
    }

    public double getDocLengthNormalization() {
        return this.docLengthNormalization;
    }

    public Multinomial.Estimator getFeatureMultinomialEstimator() {
        return this.featureEstimator;
    }

    public NaiveBayesTrainer setFeatureMultinomialEstimator(Multinomial.Estimator me) {
        if (this.instancePipe != null) {
            throw new IllegalStateException("Can't set after incrementalTrain() is called");
        }
        this.featureEstimator = me;
        return this;
    }

    public Multinomial.Estimator getPriorMultinomialEstimator() {
        return this.priorEstimator;
    }

    public NaiveBayesTrainer setPriorMultinomialEstimator(Multinomial.Estimator me) {
        if (this.instancePipe != null) {
            throw new IllegalStateException("Can't set after incrementalTrain() is called");
        }
        this.priorEstimator = me;
        return this;
    }

    @Override
    public NaiveBayes train(InstanceList trainingList) {
        this.me = null;
        this.pe = null;
        this.classifier = this.trainIncremental(trainingList);
        return this.classifier;
    }

    @Override
    public NaiveBayes trainIncremental(InstanceList trainingInstancesToAdd) {
        this.setup(trainingInstancesToAdd, null);
        for (Instance instance : trainingInstancesToAdd) {
            this.incorporateOneInstance(instance, trainingInstancesToAdd.getInstanceWeight(instance));
        }
        this.classifier = new NaiveBayes(this.instancePipe, this.pe.estimate(), this.estimateFeatureMultinomials());
        return this.classifier;
    }

    @Override
    public NaiveBayes trainIncremental(Instance instance) {
        this.setup(null, instance);
        this.incorporateOneInstance(instance, 1.0);
        if (this.instancePipe == null) {
            this.instancePipe = new Noop(this.dataAlphabet, this.targetAlphabet);
        }
        this.classifier = new NaiveBayes(this.instancePipe, this.pe.estimate(), this.estimateFeatureMultinomials());
        return this.classifier;
    }

    private void setup(InstanceList instances, Instance instance) {
        assert (instances != null || instance != null);
        if (instance == null && instances != null) {
            instance = (Instance)instances.get(0);
        }
        if (this.dataAlphabet == null) {
            this.dataAlphabet = instance.getDataAlphabet();
            this.targetAlphabet = instance.getTargetAlphabet();
        } else if (!Alphabet.alphabetsMatch(instance, this)) {
            throw new IllegalArgumentException("Training set alphabets do not match those of NaiveBayesTrainer.");
        }
        if (instances != null) {
            if (this.instancePipe == null) {
                this.instancePipe = instances.getPipe();
            } else if (this.instancePipe != instances.getPipe()) {
                throw new IllegalArgumentException("Training set pipe does not match that of NaiveBayesTrainer.");
            }
        }
        if (this.me == null) {
            int numLabels = this.targetAlphabet.size();
            this.me = new Multinomial.Estimator[numLabels];
            for (int i = 0; i < numLabels; ++i) {
                this.me[i] = (Multinomial.Estimator)this.featureEstimator.clone();
                this.me[i].setAlphabet(this.dataAlphabet);
            }
            this.pe = (Multinomial.Estimator)this.priorEstimator.clone();
        }
        if (this.targetAlphabet.size() > this.me.length) {
            int targetAlphabetSize = this.targetAlphabet.size();
            Multinomial.Estimator[] newMe = new Multinomial.Estimator[targetAlphabetSize];
            System.arraycopy(this.me, 0, newMe, 0, this.me.length);
            for (int i = this.me.length; i < targetAlphabetSize; ++i) {
                Multinomial.Estimator mest = (Multinomial.Estimator)this.featureEstimator.clone();
                mest.setAlphabet(this.dataAlphabet);
                newMe[i] = mest;
            }
            this.me = newMe;
        }
    }

    private void incorporateOneInstance(Instance instance, double instanceWeight) {
        Labeling labeling = instance.getLabeling();
        if (labeling == null) {
            return;
        }
        FeatureVector fv = (FeatureVector)instance.getData();
        double oneNorm = fv.oneNorm();
        if (oneNorm <= 0.0) {
            return;
        }
        if (this.docLengthNormalization > 0.0) {
            instanceWeight *= this.docLengthNormalization / oneNorm;
        }
        assert (instanceWeight > 0.0 && !Double.isInfinite(instanceWeight));
        for (int lpos = 0; lpos < labeling.numLocations(); ++lpos) {
            int li = labeling.indexAtLocation(lpos);
            double labelWeight = labeling.valueAtLocation(lpos);
            if (labelWeight == 0.0) continue;
            this.me[li].increment(fv, labelWeight * instanceWeight);
            this.pe.increment(li, labelWeight * instanceWeight);
        }
    }

    private Multinomial[] estimateFeatureMultinomials() {
        int numLabels = this.targetAlphabet.size();
        Multinomial[] m = new Multinomial[numLabels];
        for (int li = 0; li < numLabels; ++li) {
            m[li] = this.me[li].estimate();
        }
        return m;
    }

    public String toString() {
        return "NaiveBayesTrainer";
    }

    public boolean alphabetsMatch(AlphabetCarrying object) {
        return Alphabet.alphabetsMatch(this, object);
    }

    @Override
    public Alphabet getAlphabet() {
        return this.dataAlphabet;
    }

    @Override
    public Alphabet[] getAlphabets() {
        return new Alphabet[]{this.dataAlphabet, this.targetAlphabet};
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(1);
        out.writeObject(this.featureEstimator);
        out.writeObject(this.priorEstimator);
        out.writeObject(this.me);
        out.writeObject(this.pe);
        out.writeObject(this.instancePipe);
        out.writeObject(this.dataAlphabet);
        out.writeObject(this.targetAlphabet);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        if (version != 1) {
            throw new ClassNotFoundException("Mismatched NaiveBayesTrainer versions: wanted 1, got " + version);
        }
        this.featureEstimator = (Multinomial.Estimator)in.readObject();
        this.priorEstimator = (Multinomial.Estimator)in.readObject();
        this.me = (Multinomial.Estimator[])in.readObject();
        this.pe = (Multinomial.Estimator)in.readObject();
        this.instancePipe = (Pipe)in.readObject();
        this.dataAlphabet = (Alphabet)in.readObject();
        this.targetAlphabet = (Alphabet)in.readObject();
    }

    public static class Factory
    extends ClassifierTrainer.Factory<NaiveBayesTrainer> {
        Multinomial.Estimator featureEstimator = new Multinomial.LaplaceEstimator();
        Multinomial.Estimator priorEstimator = new Multinomial.LaplaceEstimator();
        double docLengthNormalization = -1.0;

        @Override
        public NaiveBayesTrainer newClassifierTrainer(Classifier initialClassifier) {
            return new NaiveBayesTrainer((NaiveBayes)initialClassifier);
        }

        public Factory setDocLengthNormalization(double docLengthNormalization) {
            this.docLengthNormalization = docLengthNormalization;
            return this;
        }

        public Factory setFeatureMultinomialEstimator(Multinomial.Estimator featureEstimator) {
            this.featureEstimator = featureEstimator;
            return this;
        }

        public Factory setPriorMultinomialEstimator(Multinomial.Estimator priorEstimator) {
            this.priorEstimator = priorEstimator;
            return this;
        }
    }
}

