/*
 * Decompiled with CFR 0.152.
 */
package weka.clusterers;

import java.io.Serializable;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.clusterers.ClusterEvaluation;
import weka.clusterers.DensityBasedClusterer;
import weka.clusterers.NumberOfClustersRequestable;
import weka.clusterers.SimpleKMeans;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.estimators.DiscreteEstimator;
import weka.estimators.Estimator;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

public class EM
extends DensityBasedClusterer
implements NumberOfClustersRequestable,
OptionHandler,
WeightedInstancesHandler {
    static final long serialVersionUID = 8348181483812829475L;
    private Estimator[][] m_model;
    private double[][][] m_modelNormal;
    private double m_minStdDev = 1.0E-6;
    private double[] m_minStdDevPerAtt;
    private double[][] m_weights;
    private double[] m_priors;
    private double m_loglikely;
    private Instances m_theInstances = null;
    private int m_num_clusters;
    private int m_initialNumClusters;
    private int m_num_attribs;
    private int m_num_instances;
    private int m_max_iterations;
    private double[] m_minValues;
    private double[] m_maxValues;
    private Random m_rr;
    private int m_rseed;
    private boolean m_verbose;
    private ReplaceMissingValues m_replaceMissing;
    private static double m_normConst = Math.log(Math.sqrt(Math.PI * 2));

    public String globalInfo() {
        return "Simple EM (expectation maximisation) class.\n\nEM assigns a probability distribution to each instance which indicates the probability of it belonging to each of the clusters. EM can decide how many clusters to create by cross validation, or you may specify apriori how many clusters to generate.\n\nThe cross validation performed to determine the number of clusters is done in the following steps:\n1. the number of clusters is set to 1\n2. the training set is split randomly into 10 folds.\n3. EM is performed 10 times using the 10 folds the usual CV way.\n4. the loglikelihood is averaged over all 10 results.\n5. if loglikelihood has increased the number of clusters is increased by 1 and the program continues at step 2. \n\nThe number of folds is fixed to 10, as long as the number of instances in the training set is not smaller 10. If this is the case the number of folds is set equal to the number of instances.";
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>(6);
        vector.addElement(new Option("\tnumber of clusters. If omitted or\n\t-1 specified, then cross validation is used to\n\tselect the number of clusters.", "N", 1, "-N <num>"));
        vector.addElement(new Option("\tmax iterations.\n(default 100)", "I", 1, "-I <num>"));
        vector.addElement(new Option("\trandom number seed.\n(default 1)", "S", 1, "-S <num>"));
        vector.addElement(new Option("\tverbose.", "V", 0, "-V"));
        vector.addElement(new Option("\tminimum allowable standard deviation for normal density computation \n\t(default 1e-6)", "M", 1, "-M <num>"));
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        this.resetOptions();
        this.setDebug(Utils.getFlag('V', stringArray));
        String string = Utils.getOption('I', stringArray);
        if (string.length() != 0) {
            this.setMaxIterations(Integer.parseInt(string));
        }
        if ((string = Utils.getOption('N', stringArray)).length() != 0) {
            this.setNumClusters(Integer.parseInt(string));
        }
        if ((string = Utils.getOption('S', stringArray)).length() != 0) {
            this.setSeed(Integer.parseInt(string));
        }
        if ((string = Utils.getOption('M', stringArray)).length() != 0) {
            this.setMinStdDev(new Double(string));
        }
    }

    public String minStdDevTipText() {
        return "set minimum allowable standard deviation";
    }

    public void setMinStdDev(double d) {
        this.m_minStdDev = d;
    }

    public void setMinStdDevPerAtt(double[] dArray) {
        this.m_minStdDevPerAtt = dArray;
    }

    public double getMinStdDev() {
        return this.m_minStdDev;
    }

    public String seedTipText() {
        return "random number seed";
    }

    public void setSeed(int n) {
        this.m_rseed = n;
    }

    public int getSeed() {
        return this.m_rseed;
    }

    public String numClustersTipText() {
        return "set number of clusters. -1 to select number of clusters automatically by cross validation.";
    }

    public void setNumClusters(int n) throws Exception {
        if (n == 0) {
            throw new Exception("Number of clusters must be > 0. (or -1 to select by cross validation).");
        }
        if (n < 0) {
            this.m_num_clusters = -1;
            this.m_initialNumClusters = -1;
        } else {
            this.m_num_clusters = n;
            this.m_initialNumClusters = n;
        }
    }

    public int getNumClusters() {
        return this.m_initialNumClusters;
    }

    public String maxIterationsTipText() {
        return "maximum number of iterations";
    }

    public void setMaxIterations(int n) throws Exception {
        if (n < 1) {
            throw new Exception("Maximum number of iterations must be > 0!");
        }
        this.m_max_iterations = n;
    }

    public int getMaxIterations() {
        return this.m_max_iterations;
    }

    public void setDebug(boolean bl) {
        this.m_verbose = bl;
    }

    public boolean getDebug() {
        return this.m_verbose;
    }

    public String[] getOptions() {
        String[] stringArray = new String[9];
        int n = 0;
        if (this.m_verbose) {
            stringArray[n++] = "-V";
        }
        stringArray[n++] = "-I";
        stringArray[n++] = "" + this.m_max_iterations;
        stringArray[n++] = "-N";
        stringArray[n++] = "" + this.getNumClusters();
        stringArray[n++] = "-S";
        stringArray[n++] = "" + this.m_rseed;
        stringArray[n++] = "-M";
        stringArray[n++] = "" + this.getMinStdDev();
        while (n < stringArray.length) {
            stringArray[n++] = "";
        }
        return stringArray;
    }

    private void EM_Init(Instances instances) throws Exception {
        int n;
        Serializable serializable;
        int n2;
        SimpleKMeans simpleKMeans = null;
        double d = Double.MAX_VALUE;
        for (n2 = 0; n2 < 10; ++n2) {
            serializable = new SimpleKMeans();
            ((SimpleKMeans)serializable).setSeed(this.m_rr.nextInt());
            ((SimpleKMeans)serializable).setNumClusters(this.m_num_clusters);
            ((SimpleKMeans)serializable).buildClusterer(instances);
            if (!(((SimpleKMeans)serializable).getSquaredError() < d)) continue;
            d = ((SimpleKMeans)serializable).getSquaredError();
            simpleKMeans = serializable;
        }
        this.m_num_clusters = simpleKMeans.numberOfClusters();
        this.m_weights = new double[instances.numInstances()][this.m_num_clusters];
        this.m_model = new DiscreteEstimator[this.m_num_clusters][this.m_num_attribs];
        this.m_modelNormal = new double[this.m_num_clusters][this.m_num_attribs][3];
        this.m_priors = new double[this.m_num_clusters];
        serializable = simpleKMeans.getClusterCentroids();
        Instances instances2 = simpleKMeans.getClusterStandardDevs();
        int[][][] nArray = simpleKMeans.getClusterNominalCounts();
        int[] nArray2 = simpleKMeans.getClusterSizes();
        for (n2 = 0; n2 < this.m_num_clusters; ++n2) {
            Instance instance = ((Instances)serializable).instance(n2);
            for (n = 0; n < this.m_num_attribs; ++n) {
                double d2;
                double d3;
                if (instances.attribute(n).isNominal()) {
                    this.m_model[n2][n] = new DiscreteEstimator(this.m_theInstances.attribute(n).numValues(), true);
                    for (int i = 0; i < instances.attribute(n).numValues(); ++i) {
                        this.m_model[n2][n].addValue(i, nArray[n2][n][i]);
                    }
                    continue;
                }
                double d4 = this.m_minStdDevPerAtt != null ? this.m_minStdDevPerAtt[n] : this.m_minStdDev;
                this.m_modelNormal[n2][n][0] = d3 = instance.isMissing(n) ? instances.meanOrMode(n) : instance.value(n);
                double d5 = d2 = instances2.instance(n2).isMissing(n) ? (this.m_maxValues[n] - this.m_minValues[n]) / (double)(2 * this.m_num_clusters) : instances2.instance(n2).value(n);
                if (d2 < d4) {
                    d2 = instances.attributeStats((int)n).numericStats.stdDev;
                    if (Double.isInfinite(d2)) {
                        d2 = d4;
                    }
                    if (d2 < d4) {
                        d2 = d4;
                    }
                }
                if (d2 <= 0.0) {
                    d2 = this.m_minStdDev;
                }
                this.m_modelNormal[n2][n][1] = d2;
                this.m_modelNormal[n2][n][2] = 1.0;
            }
        }
        for (n = 0; n < this.m_num_clusters; ++n) {
            this.m_priors[n] = nArray2[n];
        }
        Utils.normalize(this.m_priors);
    }

    private void estimate_priors(Instances instances) throws Exception {
        int n;
        for (n = 0; n < this.m_num_clusters; ++n) {
            this.m_priors[n] = 0.0;
        }
        for (n = 0; n < instances.numInstances(); ++n) {
            for (int i = 0; i < this.m_num_clusters; ++i) {
                int n2 = i;
                this.m_priors[n2] = this.m_priors[n2] + instances.instance(n).weight() * this.m_weights[n][i];
            }
        }
        Utils.normalize(this.m_priors);
    }

    private double logNormalDens(double d, double d2, double d3) {
        double d4 = d - d2;
        return -(d4 * d4 / (2.0 * d3 * d3)) - m_normConst - Math.log(d3);
    }

    private void new_estimators() {
        for (int i = 0; i < this.m_num_clusters; ++i) {
            for (int j = 0; j < this.m_num_attribs; ++j) {
                if (this.m_theInstances.attribute(j).isNominal()) {
                    this.m_model[i][j] = new DiscreteEstimator(this.m_theInstances.attribute(j).numValues(), true);
                    continue;
                }
                this.m_modelNormal[i][j][2] = 0.0;
                this.m_modelNormal[i][j][1] = 0.0;
                this.m_modelNormal[i][j][0] = 0.0;
            }
        }
    }

    private void M(Instances instances) throws Exception {
        int n;
        int n2;
        this.new_estimators();
        for (n2 = 0; n2 < this.m_num_clusters; ++n2) {
            for (n = 0; n < this.m_num_attribs; ++n) {
                for (int i = 0; i < instances.numInstances(); ++i) {
                    Instance instance = instances.instance(i);
                    if (instance.isMissing(n)) continue;
                    if (instances.attribute(n).isNominal()) {
                        this.m_model[n2][n].addValue(instance.value(n), instance.weight() * this.m_weights[i][n2]);
                        continue;
                    }
                    double[] dArray = this.m_modelNormal[n2][n];
                    dArray[0] = dArray[0] + instance.value(n) * instance.weight() * this.m_weights[i][n2];
                    double[] dArray2 = this.m_modelNormal[n2][n];
                    dArray2[2] = dArray2[2] + instance.weight() * this.m_weights[i][n2];
                    double[] dArray3 = this.m_modelNormal[n2][n];
                    dArray3[1] = dArray3[1] + instance.value(n) * instance.value(n) * instance.weight() * this.m_weights[i][n2];
                }
            }
        }
        for (n = 0; n < this.m_num_attribs; ++n) {
            if (instances.attribute(n).isNominal()) continue;
            for (n2 = 0; n2 < this.m_num_clusters; ++n2) {
                if (this.m_modelNormal[n2][n][2] <= 0.0) {
                    this.m_modelNormal[n2][n][1] = Double.MAX_VALUE;
                    this.m_modelNormal[n2][n][0] = this.m_minStdDev;
                    continue;
                }
                this.m_modelNormal[n2][n][1] = (this.m_modelNormal[n2][n][1] - this.m_modelNormal[n2][n][0] * this.m_modelNormal[n2][n][0] / this.m_modelNormal[n2][n][2]) / this.m_modelNormal[n2][n][2];
                if (this.m_modelNormal[n2][n][1] < 0.0) {
                    this.m_modelNormal[n2][n][1] = 0.0;
                }
                double d = this.m_minStdDevPerAtt != null ? this.m_minStdDevPerAtt[n] : this.m_minStdDev;
                this.m_modelNormal[n2][n][1] = Math.sqrt(this.m_modelNormal[n2][n][1]);
                if (this.m_modelNormal[n2][n][1] <= d) {
                    this.m_modelNormal[n2][n][1] = instances.attributeStats((int)n).numericStats.stdDev;
                    if (this.m_modelNormal[n2][n][1] <= d) {
                        this.m_modelNormal[n2][n][1] = d;
                    }
                }
                if (this.m_modelNormal[n2][n][1] <= 0.0) {
                    this.m_modelNormal[n2][n][1] = this.m_minStdDev;
                }
                if (Double.isInfinite(this.m_modelNormal[n2][n][1])) {
                    this.m_modelNormal[n2][n][1] = this.m_minStdDev;
                }
                double[] dArray = this.m_modelNormal[n2][n];
                dArray[0] = dArray[0] / this.m_modelNormal[n2][n][2];
            }
        }
    }

    private double E(Instances instances, boolean bl) throws Exception {
        double d = 0.0;
        double d2 = 0.0;
        for (int i = 0; i < instances.numInstances(); ++i) {
            Instance instance = instances.instance(i);
            d += instance.weight() * this.logDensityForInstance(instance);
            d2 += instance.weight();
            if (!bl) continue;
            this.m_weights[i] = this.distributionForInstance(instance);
        }
        if (bl) {
            this.estimate_priors(instances);
        }
        return d / d2;
    }

    public EM() {
        this.resetOptions();
    }

    protected void resetOptions() {
        this.m_minStdDev = 1.0E-6;
        this.m_max_iterations = 100;
        this.m_rseed = 100;
        this.m_num_clusters = -1;
        this.m_initialNumClusters = -1;
        this.m_verbose = false;
    }

    public double[][][] getClusterModelsNumericAtts() {
        return this.m_modelNormal;
    }

    public double[] getClusterPriors() {
        return this.m_priors;
    }

    public String toString() {
        if (this.m_priors == null) {
            return "No clusterer built yet!";
        }
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("\nEM\n==\n");
        if (this.m_initialNumClusters == -1) {
            stringBuffer.append("\nNumber of clusters selected by cross validation: " + this.m_num_clusters + "\n");
        } else {
            stringBuffer.append("\nNumber of clusters: " + this.m_num_clusters + "\n");
        }
        for (int i = 0; i < this.m_num_clusters; ++i) {
            stringBuffer.append("\nCluster: " + i + " Prior probability: " + Utils.doubleToString(this.m_priors[i], 4) + "\n\n");
            for (int j = 0; j < this.m_num_attribs; ++j) {
                stringBuffer.append("Attribute: " + this.m_theInstances.attribute(j).name() + "\n");
                if (this.m_theInstances.attribute(j).isNominal()) {
                    if (this.m_model[i][j] == null) continue;
                    stringBuffer.append(this.m_model[i][j].toString());
                    continue;
                }
                stringBuffer.append("Normal Distribution. Mean = " + Utils.doubleToString(this.m_modelNormal[i][j][0], 4) + " StdDev = " + Utils.doubleToString(this.m_modelNormal[i][j][1], 4) + "\n");
            }
        }
        return stringBuffer.toString();
    }

    private void EM_Report(Instances instances) {
        int n;
        System.out.println("======================================");
        for (n = 0; n < this.m_num_clusters; ++n) {
            for (int i = 0; i < this.m_num_attribs; ++i) {
                System.out.println("Clust: " + n + " att: " + i + "\n");
                if (this.m_theInstances.attribute(i).isNominal()) {
                    if (this.m_model[n][i] == null) continue;
                    System.out.println(this.m_model[n][i].toString());
                    continue;
                }
                System.out.println("Normal Distribution. Mean = " + Utils.doubleToString(this.m_modelNormal[n][i][0], 8, 4) + " StandardDev = " + Utils.doubleToString(this.m_modelNormal[n][i][1], 8, 4) + " WeightSum = " + Utils.doubleToString(this.m_modelNormal[n][i][2], 8, 4));
            }
        }
        for (int i = 0; i < instances.numInstances(); ++i) {
            int n2 = Utils.maxIndex(this.m_weights[i]);
            System.out.print("Inst " + Utils.doubleToString(i, 5, 0) + " Class " + n2 + "\t");
            for (n = 0; n < this.m_num_clusters; ++n) {
                System.out.print(Utils.doubleToString(this.m_weights[i][n], 7, 5) + "  ");
            }
            System.out.println();
        }
    }

    private void CVClusters() throws Exception {
        double d = -1.7976931348623157E308;
        boolean bl = true;
        int n = this.m_num_clusters = 1;
        int n2 = this.m_theInstances.numInstances() < 10 ? this.m_theInstances.numInstances() : 10;
        boolean bl2 = true;
        int n3 = this.m_rseed;
        int n4 = 0;
        block4: while (bl) {
            bl = false;
            Random random = new Random(this.m_rseed);
            Instances instances = new Instances(this.m_theInstances);
            instances.randomize(random);
            double d2 = 0.0;
            for (int i = 0; i < n2; ++i) {
                double d3;
                Instances instances2 = instances.trainCV(n2, i, random);
                if (n > instances2.numInstances()) break block4;
                Instances instances3 = instances.testCV(n2, i);
                this.m_rr = new Random(n3);
                for (int j = 0; j < 10; ++j) {
                    this.m_rr.nextDouble();
                }
                this.m_num_clusters = n;
                this.EM_Init(instances2);
                try {
                    this.iterate(instances2, false);
                }
                catch (Exception exception) {
                    exception.printStackTrace();
                    ++n3;
                    bl2 = false;
                    if (++n4 <= 5) break;
                    break block4;
                }
                try {
                    d3 = this.E(instances3, false);
                }
                catch (Exception exception) {
                    exception.printStackTrace();
                    ++n3;
                    bl2 = false;
                    if (++n4 <= 5) break;
                    break block4;
                }
                if (this.m_verbose) {
                    System.out.println("# clust: " + n + " Fold: " + i + " Loglikely: " + d3);
                }
                d2 += d3;
            }
            if (!bl2) continue;
            n4 = 0;
            n3 = this.m_rseed;
            d2 /= (double)n2;
            if (this.m_verbose) {
                System.out.println("=================================================\n# clust: " + n + " Mean Loglikely: " + d2 + "\n================================" + "=================");
            }
            if (!(d2 > d)) continue;
            d = d2;
            bl = true;
            ++n;
        }
        if (this.m_verbose) {
            System.out.println("Number of clusters: " + (n - 1));
        }
        this.m_num_clusters = n - 1;
    }

    public int numberOfClusters() throws Exception {
        if (this.m_num_clusters == -1) {
            throw new Exception("Haven't generated any clusters!");
        }
        return this.m_num_clusters;
    }

    private void updateMinMax(Instance instance) {
        for (int i = 0; i < this.m_theInstances.numAttributes(); ++i) {
            if (instance.isMissing(i)) continue;
            if (Double.isNaN(this.m_minValues[i])) {
                this.m_minValues[i] = instance.value(i);
                this.m_maxValues[i] = instance.value(i);
                continue;
            }
            if (instance.value(i) < this.m_minValues[i]) {
                this.m_minValues[i] = instance.value(i);
                continue;
            }
            if (!(instance.value(i) > this.m_maxValues[i])) continue;
            this.m_maxValues[i] = instance.value(i);
        }
    }

    public Capabilities getCapabilities() {
        return new SimpleKMeans().getCapabilities();
    }

    public void buildClusterer(Instances instances) throws Exception {
        int n;
        this.getCapabilities().testWithFail(instances);
        this.m_replaceMissing = new ReplaceMissingValues();
        Instances instances2 = new Instances(instances);
        instances2.setClassIndex(-1);
        this.m_replaceMissing.setInputFormat(instances2);
        instances = Filter.useFilter(instances2, this.m_replaceMissing);
        instances2 = null;
        this.m_theInstances = instances;
        this.m_minValues = new double[this.m_theInstances.numAttributes()];
        this.m_maxValues = new double[this.m_theInstances.numAttributes()];
        for (n = 0; n < this.m_theInstances.numAttributes(); ++n) {
            this.m_maxValues[n] = Double.NaN;
            this.m_minValues[n] = Double.NaN;
        }
        for (n = 0; n < this.m_theInstances.numInstances(); ++n) {
            this.updateMinMax(this.m_theInstances.instance(n));
        }
        this.doEM();
        this.m_theInstances = new Instances(this.m_theInstances, 0);
    }

    public double[] clusterPriors() {
        double[] dArray = new double[this.m_priors.length];
        System.arraycopy(this.m_priors, 0, dArray, 0, dArray.length);
        return dArray;
    }

    public double[] logDensityPerClusterForInstance(Instance instance) throws Exception {
        double[] dArray = new double[this.m_num_clusters];
        this.m_replaceMissing.input(instance);
        instance = this.m_replaceMissing.output();
        for (int i = 0; i < this.m_num_clusters; ++i) {
            double d = 0.0;
            for (int j = 0; j < this.m_num_attribs; ++j) {
                if (instance.isMissing(j)) continue;
                if (instance.attribute(j).isNominal()) {
                    d += Math.log(this.m_model[i][j].getProbability(instance.value(j)));
                    continue;
                }
                d += this.logNormalDens(instance.value(j), this.m_modelNormal[i][j][0], this.m_modelNormal[i][j][1]);
            }
            dArray[i] = d;
        }
        return dArray;
    }

    private void doEM() throws Exception {
        int n;
        if (this.m_verbose) {
            System.out.println("Seed: " + this.m_rseed);
        }
        this.m_rr = new Random(this.m_rseed);
        for (n = 0; n < 10; ++n) {
            this.m_rr.nextDouble();
        }
        this.m_num_instances = this.m_theInstances.numInstances();
        this.m_num_attribs = this.m_theInstances.numAttributes();
        if (this.m_verbose) {
            System.out.println("Number of instances: " + this.m_num_instances + "\nNumber of atts: " + this.m_num_attribs + "\n");
        }
        if (this.m_initialNumClusters == -1) {
            if (this.m_theInstances.numInstances() > 9) {
                this.CVClusters();
                this.m_rr = new Random(this.m_rseed);
                for (n = 0; n < 10; ++n) {
                    this.m_rr.nextDouble();
                }
            } else {
                this.m_num_clusters = 1;
            }
        }
        this.EM_Init(this.m_theInstances);
        this.m_loglikely = this.iterate(this.m_theInstances, this.m_verbose);
    }

    private double iterate(Instances instances, boolean bl) throws Exception {
        double d = 0.0;
        double d2 = 0.0;
        if (bl) {
            this.EM_Report(instances);
        }
        boolean bl2 = false;
        int n = this.m_rseed;
        int n2 = 0;
        while (!bl2) {
            try {
                for (int i = 0; i < this.m_max_iterations; ++i) {
                    d = d2;
                    d2 = this.E(instances, true);
                    if (bl) {
                        System.out.println("Loglikely: " + d2);
                    }
                    if (i > 0 && d2 - d < 1.0E-6) break;
                    this.M(instances);
                }
                bl2 = true;
            }
            catch (Exception exception) {
                exception.printStackTrace();
                ++n2;
                this.m_rr = new Random(++n);
                for (int i = 0; i < 10; ++i) {
                    this.m_rr.nextDouble();
                    this.m_rr.nextInt();
                }
                if (n2 > 5) {
                    --this.m_num_clusters;
                    n2 = 0;
                }
                this.EM_Init(this.m_theInstances);
            }
        }
        if (bl) {
            this.EM_Report(instances);
        }
        return d2;
    }

    public static void main(String[] stringArray) {
        try {
            System.out.println(ClusterEvaluation.evaluateClusterer(new EM(), stringArray));
        }
        catch (Exception exception) {
            System.out.println(exception.getMessage());
            exception.printStackTrace();
        }
    }
}

