/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.meta;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.trees.HoeffdingTree;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.options.ClassOption;
import moa.tasks.TaskMonitor;

public class OnlineAccuracyUpdatedEnsemble
extends AbstractClassifier
implements MultiClassClassifier {
    private static final long serialVersionUID = 1L;
    public ClassOption learnerOption = new ClassOption("learner", 'l', "Classifier to train.", Classifier.class, "trees.HoeffdingTree -e 2000000 -g 100 -c 0.01");
    public IntOption memberCountOption = new IntOption("memberCount", 'n', "The maximum number of classifiers in an ensemble.", 10, 1, Integer.MAX_VALUE);
    public FloatOption windowSizeOption = new FloatOption("windowSize", 'w', "The window size used for classifier creation and evaluation.", 500.0, 1.0, 2.147483647E9);
    public IntOption maxByteSizeOption = new IntOption("maxByteSize", 'm', "Maximum memory consumed by ensemble.", 0x2000000, 0, Integer.MAX_VALUE);
    public FlagOption verboseOption = new FlagOption("verbose", 'v', "When checked the algorithm outputs additional information about component classifier weights.");
    public FlagOption linearOption = new FlagOption("linearFunction", 'f', "When checked the algorithm uses a linear weighting function.");
    protected double[][] weights;
    protected long[] classDistributions;
    protected ClassifierWithMemory[] ensemble;
    protected int processedInstances;
    protected ClassifierWithMemory candidate;
    protected int[] currentWindow;
    protected double mse_r = 0.0;
    protected int windowSize = 0;

    @Override
    public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
        this.windowSize = (int)this.windowSizeOption.getValue();
        this.candidate = new ClassifierWithMemory(((Classifier)this.getPreparedClassOption(this.learnerOption)).copy(), this.windowSize);
        this.candidate.classifier.resetLearning();
        super.prepareForUseImpl(monitor, repository);
    }

    @Override
    public void resetLearningImpl() {
        this.currentWindow = null;
        this.windowSize = (int)this.windowSizeOption.getValue();
        this.classDistributions = null;
        this.processedInstances = 0;
        this.ensemble = new ClassifierWithMemory[0];
        this.candidate = new ClassifierWithMemory(((Classifier)this.getPreparedClassOption(this.learnerOption)).copy(), this.windowSize);
        this.candidate.classifier.resetLearning();
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
        int i;
        this.initVariables();
        if (this.processedInstances < this.windowSize) {
            int n = (int)inst.classValue();
            this.classDistributions[n] = this.classDistributions[n] + 1L;
        } else {
            int n = this.currentWindow[this.processedInstances % this.windowSize];
            this.classDistributions[n] = this.classDistributions[n] - 1L;
            int n2 = (int)inst.classValue();
            this.classDistributions[n2] = this.classDistributions[n2] + 1L;
        }
        this.currentWindow[this.processedInstances % this.windowSize] = (int)inst.classValue();
        ++this.processedInstances;
        this.computeMseR();
        if (this.processedInstances % this.windowSize == 0) {
            this.createNewClassifier(inst);
        } else {
            this.candidate.classifier.trainOnInstance(inst);
            for (i = 0; i < this.ensemble.length; ++i) {
                this.weights[i][0] = this.computeWeight(i, inst);
            }
        }
        for (i = 0; i < this.ensemble.length; ++i) {
            this.ensemble[i].classifier.trainOnInstance(inst);
        }
    }

    @Override
    public boolean isRandomizable() {
        return false;
    }

    @Override
    public double[] getVotesForInstance(Instance inst) {
        DoubleVector combinedVote = new DoubleVector();
        if (this.trainingWeightSeenByModel > 0.0) {
            for (int i = 0; i < this.ensemble.length; ++i) {
                DoubleVector vote;
                if (!(this.weights[i][0] > 0.0) || !((vote = new DoubleVector(this.ensemble[(int)this.weights[i][1]].classifier.getVotesForInstance(inst))).sumOfValues() > 0.0)) continue;
                vote.normalize();
                vote.scaleValues(this.weights[i][0] / (1.0 * (double)this.ensemble.length + 1.0));
                combinedVote.addValues(vote);
            }
        }
        return combinedVote.getArrayRef();
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
    }

    @Override
    public Classifier[] getSubClassifiers() {
        Classifier[] subClassifiers = new Classifier[this.ensemble.length];
        for (int i = 0; i < this.ensemble.length; ++i) {
            subClassifiers[i] = this.ensemble[i].classifier;
        }
        return subClassifiers;
    }

    protected void createNewClassifier(Instance inst) {
        double candidateClassifierWeight = 1.0 / (this.mse_r + Double.MIN_VALUE);
        if (this.linearOption.isSet()) {
            candidateClassifierWeight = Math.max(this.mse_r, Double.MIN_VALUE);
        }
        for (int i = 0; i < this.ensemble.length; ++i) {
            this.weights[i][0] = this.computeWeight(i, inst);
        }
        this.candidate.birthday = this.processedInstances;
        if (this.ensemble.length < this.memberCountOption.getValue()) {
            this.addToStored(this.candidate, candidateClassifierWeight);
        } else {
            int poorestClassifier = this.getPoorestClassifierIndex();
            if (this.weights[poorestClassifier][0] < candidateClassifierWeight) {
                this.weights[poorestClassifier][0] = candidateClassifierWeight;
                this.candidate.classifier = this.candidate.classifier;
                this.ensemble[(int)this.weights[poorestClassifier][1]] = this.candidate;
            }
        }
        this.candidate = new ClassifierWithMemory(((Classifier)this.getPreparedClassOption(this.learnerOption)).copy(), this.windowSize);
        this.candidate.classifier.resetLearning();
        this.enforceMemoryLimit();
    }

    protected void enforceMemoryLimit() {
        double memoryLimit = (double)this.maxByteSizeOption.getValue() / (double)(this.ensemble.length + 1);
        for (int i = 0; i < this.ensemble.length; ++i) {
            ((HoeffdingTree)((ClassifierWithMemory)this.ensemble[(int)this.weights[i][1]]).classifier).maxByteSizeOption.setValue((int)Math.round(memoryLimit));
            ((HoeffdingTree)this.ensemble[(int)this.weights[i][1]].classifier).enforceTrackerLimit();
        }
    }

    protected void computeMseR() {
        this.mse_r = 0.0;
        for (int i = 0; i < this.classDistributions.length; ++i) {
            double p_c = (double)this.classDistributions[i] / (double)this.windowSize;
            this.mse_r += p_c * ((1.0 - p_c) * (1.0 - p_c));
        }
    }

    protected double computeWeight(int i, Instance example) {
        int d = this.windowSize;
        int t = this.processedInstances - this.ensemble[i].birthday;
        double e_it = 0.0;
        double mse_it = 0.0;
        double voteSum = 0.0;
        try {
            double[] votes;
            for (double element : votes = this.ensemble[i].classifier.getVotesForInstance(example)) {
                voteSum += element;
            }
            if (voteSum > 0.0) {
                double f_it = 1.0 - votes[(int)example.classValue()] / voteSum;
                e_it = f_it * f_it;
            } else {
                e_it = 1.0;
            }
        }
        catch (Exception e) {
            e_it = 1.0;
        }
        mse_it = t > d ? this.ensemble[i].mse_it + e_it / (double)d - this.ensemble[i].squareErrors[t % d] / (double)d : this.ensemble[i].mse_it * (double)(t - 1) / (double)t + e_it / (double)t;
        ((ClassifierWithMemory)this.ensemble[i]).squareErrors[t % d] = e_it;
        this.ensemble[i].mse_it = mse_it;
        if (this.linearOption.isSet()) {
            return Math.max(this.mse_r - mse_it, Double.MIN_VALUE);
        }
        return 1.0 / (this.mse_r + mse_it + Double.MIN_VALUE);
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        if (this.verboseOption.isSet()) {
            Measurement[] measurements = new Measurement[this.memberCountOption.getValue()];
            for (int m = 0; m < this.memberCountOption.getValue(); ++m) {
                measurements[m] = new Measurement("Member weight " + (m + 1), -1.0);
            }
            if (this.weights != null) {
                for (int i = 0; i < this.weights.length; ++i) {
                    measurements[i] = new Measurement("Member weight " + (i + 1), this.weights[i][0]);
                }
            }
            return measurements;
        }
        return null;
    }

    protected void addToStored(ClassifierWithMemory newClassifier, double newClassifiersWeight) {
        ClassifierWithMemory[] newStored = new ClassifierWithMemory[this.ensemble.length + 1];
        double[][] newStoredWeights = new double[newStored.length][2];
        for (int i = 0; i < newStored.length; ++i) {
            if (i < this.ensemble.length) {
                newStored[i] = this.ensemble[i];
                newStoredWeights[i][0] = this.weights[i][0];
                newStoredWeights[i][1] = this.weights[i][1];
                continue;
            }
            newStored[i] = newClassifier;
            newStoredWeights[i][0] = newClassifiersWeight;
            newStoredWeights[i][1] = i;
        }
        this.ensemble = newStored;
        this.weights = newStoredWeights;
    }

    private int getPoorestClassifierIndex() {
        int minIndex = 0;
        for (int i = 1; i < this.weights.length; ++i) {
            if (!(this.weights[i][0] < this.weights[minIndex][0])) continue;
            minIndex = i;
        }
        return minIndex;
    }

    private void initVariables() {
        if (this.currentWindow == null) {
            this.currentWindow = new int[this.windowSize];
        }
        if (this.classDistributions == null) {
            this.classDistributions = new long[this.getModelContext().classAttribute().numValues()];
        }
    }

    protected class ClassifierWithMemory {
        private Classifier classifier;
        private int birthday;
        private double[] squareErrors;
        private double mse_it;

        protected ClassifierWithMemory(Classifier classifier, int windowSize) {
            this.classifier = classifier;
            this.squareErrors = new double[windowSize];
            this.mse_it = 0.0;
        }
    }
}

