/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.types.FeatureVector;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.CommandOption;
import cc.mallet.util.FeatureCountTool;
import cc.mallet.util.Randoms;
import com.google.errorprone.annotations.Var;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Arrays;

public class NonNegativeMatrixFactorization {
    static CommandOption.String inputFile = new CommandOption.String(NonNegativeMatrixFactorization.class, "input", "FILENAME", true, null, "The filename from which to read the list of training instances.  Use - for stdin.  The instances must be FeatureVectors, not FeatureSequences", null);
    static CommandOption.String outputWordsFile = new CommandOption.String(NonNegativeMatrixFactorization.class, "output-words", "FILENAME", true, "word-weights.txt", "The filename to write weights for each word.", null);
    static CommandOption.String outputDocsFile = new CommandOption.String(NonNegativeMatrixFactorization.class, "output-docs", "FILENAME", true, "doc-weights.txt", "The filename to write weights for each document.", null);
    static CommandOption.Integer numDimensions = new CommandOption.Integer(NonNegativeMatrixFactorization.class, "num-dimensions", "INTEGER", true, 50, "The number of dimensions to fit.", null);
    static CommandOption.Integer clusterSize = new CommandOption.Integer(NonNegativeMatrixFactorization.class, "init-cluster-size", "INTEGER", true, 0, "Select this number of random instances to initialize each dimension. 0 = off.", null);
    static CommandOption.Boolean useIDFOption = new CommandOption.Boolean(NonNegativeMatrixFactorization.class, "use-idf", "TRUE/FALSE", true, true, "Whether to use IDF weighting.", null);
    static CommandOption.Integer numIterationsOption = new CommandOption.Integer(NonNegativeMatrixFactorization.class, "num-iters", "INTEGER", true, 1000, "The number of passes through the training data.", null);
    InstanceList instances;
    int numFactors;
    int numFeatures;
    int numInstances;
    int numIterations;
    boolean idfWeighting;
    double[] featureWeights = null;
    double[][] featureFactorWeights;
    double[][] instanceFactorWeights;
    double[] featureSums;
    double[] instanceSums;
    Randoms random;
    public static final String[] BARS = new String[]{" ", "\u2581", "\u2582", "\u2583", "\u2584", "\u2585", "\u2586", "\u2587", "\u2588"};

    public NonNegativeMatrixFactorization(InstanceList instances, int numFactors, boolean idfWeighting) {
        this(instances, numFactors, idfWeighting, new Randoms());
    }

    public NonNegativeMatrixFactorization(InstanceList instances, int numFactors, boolean idfWeighting, Randoms random) {
        int factor;
        this.instances = instances;
        this.numFactors = numFactors;
        this.idfWeighting = idfWeighting;
        this.random = random;
        this.numFeatures = instances.getDataAlphabet().size();
        this.numInstances = instances.size();
        this.featureFactorWeights = new double[this.numFeatures][numFactors];
        this.instanceFactorWeights = new double[this.numInstances][numFactors];
        this.featureSums = new double[numFactors];
        this.instanceSums = new double[numFactors];
        if (idfWeighting) {
            this.calculateIDFWeights();
        }
        for (int feature = 0; feature < this.numFeatures; ++feature) {
            for (factor = 0; factor < numFactors; ++factor) {
                this.featureFactorWeights[feature][factor] = 0.001 * random.nextUniform() / (double)this.numFeatures;
                int n = factor;
                this.featureSums[n] = this.featureSums[n] + this.featureFactorWeights[feature][factor];
            }
        }
        for (int instance = 0; instance < this.numInstances; ++instance) {
            for (factor = 0; factor < numFactors; ++factor) {
                this.instanceFactorWeights[instance][factor] = 1.0 / (double)numFactors;
                int n = factor;
                this.instanceSums[n] = this.instanceSums[n] + this.instanceFactorWeights[instance][factor];
            }
        }
    }

    public void calculateIDFWeights() {
        this.idfWeighting = true;
        System.out.println("Counting word features");
        FeatureCountTool counter = new FeatureCountTool(this.instances);
        counter.count();
        int[] instanceCounts = counter.getDocumentFrequencies();
        this.featureWeights = new double[this.numFeatures];
        for (int feature = 0; feature < this.numFeatures; ++feature) {
            if (instanceCounts[feature] <= 0) continue;
            this.featureWeights[feature] = Math.log((double)this.numInstances / (double)instanceCounts[feature]);
        }
    }

    public void initialize(int clusterSize) {
        for (int factor = 0; factor < this.numFactors; ++factor) {
            for (int sample = 0; sample < clusterSize; ++sample) {
                FeatureVector data = (FeatureVector)((Instance)this.instances.get(this.random.nextInt(this.numInstances))).getData();
                for (int location = 0; location < data.numLocations(); ++location) {
                    int feature = data.indexAtLocation(location);
                    double value = data.valueAtLocation(location);
                    if (this.idfWeighting) {
                        value *= this.featureWeights[feature];
                    }
                    double[] dArray = this.featureFactorWeights[feature];
                    int n = factor;
                    dArray[n] = dArray[n] + value / (double)clusterSize;
                    int n2 = factor;
                    this.featureSums[n2] = this.featureSums[n2] + value / (double)clusterSize;
                }
            }
        }
    }

    public static String getBar(@Var double x, double min, double max) {
        if (x > max) {
            x = max;
        }
        if (x < min) {
            x = min;
        }
        return BARS[(int)Math.round(8.0 * (x - min) / (max - min))];
    }

    public static String getBars(double[] sequence, double min, double max) {
        StringBuilder out = new StringBuilder();
        for (double x : sequence) {
            out.append(NonNegativeMatrixFactorization.getBar(x, min, max));
        }
        return out.toString();
    }

    public static String getBars(double[] sequence) {
        double max = Double.NEGATIVE_INFINITY;
        double min = Double.POSITIVE_INFINITY;
        for (double x : sequence) {
            if (x > max) {
                max = x;
            }
            if (!(x < min)) continue;
            min = x;
        }
        return NonNegativeMatrixFactorization.getBars(sequence, 0.0, max);
    }

    public double getDivergence() {
        double divergence = 0.0;
        for (int instance = 0; instance < this.numInstances; ++instance) {
            FeatureVector data = (FeatureVector)((Instance)this.instances.get(instance)).getData();
            double[] currentInstanceFactorWeights = this.instanceFactorWeights[instance];
            for (int location = 0; location < data.numLocations(); ++location) {
                int feature = data.indexAtLocation(location);
                double value = data.valueAtLocation(location);
                if (this.idfWeighting) {
                    value *= this.featureWeights[feature];
                }
                double[] currentFeatureFactorWeights = this.featureFactorWeights[feature];
                double innerProduct = 0.0;
                for (int factor = 0; factor < this.numFactors; ++factor) {
                    innerProduct += currentInstanceFactorWeights[factor] * currentFeatureFactorWeights[factor];
                }
                if (innerProduct == 0.0) continue;
                divergence += value * Math.log(value / innerProduct) - value + innerProduct;
            }
        }
        return divergence;
    }

    public void updateWeights() {
        int feature;
        int factor;
        int instance;
        for (instance = 0; instance < this.numInstances; ++instance) {
            int factor2;
            FeatureVector data = (FeatureVector)((Instance)this.instances.get(instance)).getData();
            double[] currentInstanceFactorWeights = this.instanceFactorWeights[instance];
            double[] updateRatios = new double[this.numFactors];
            double valueSum = 0.0;
            for (int location = 0; location < data.numLocations(); ++location) {
                int feature2 = data.indexAtLocation(location);
                double value = data.valueAtLocation(location);
                if (this.idfWeighting) {
                    value *= this.featureWeights[feature2];
                }
                valueSum += value;
                double[] currentFeatureFactorWeights = this.featureFactorWeights[feature2];
                double innerProduct = 0.0;
                for (factor = 0; factor < this.numFactors; ++factor) {
                    innerProduct += currentInstanceFactorWeights[factor] * currentFeatureFactorWeights[factor];
                }
                if (innerProduct == 0.0) continue;
                double ratio = value / innerProduct;
                for (int factor3 = 0; factor3 < this.numFactors; ++factor3) {
                    int n = factor3;
                    updateRatios[n] = updateRatios[n] + currentFeatureFactorWeights[factor3] * ratio;
                }
            }
            if (valueSum > 0.0) {
                for (factor2 = 0; factor2 < this.numFactors; ++factor2) {
                    int n = factor2;
                    currentInstanceFactorWeights[n] = currentInstanceFactorWeights[n] * (updateRatios[factor2] / this.featureSums[factor2]);
                    assert (!Double.isNaN(currentInstanceFactorWeights[factor2]));
                }
                continue;
            }
            for (factor2 = 0; factor2 < this.numFactors; ++factor2) {
                currentInstanceFactorWeights[factor2] = 0.0;
            }
        }
        Arrays.fill(this.instanceSums, 0.0);
        for (instance = 0; instance < this.numInstances; ++instance) {
            for (int factor4 = 0; factor4 < this.numFactors; ++factor4) {
                int n = factor4;
                this.instanceSums[n] = this.instanceSums[n] + this.instanceFactorWeights[instance][factor4];
            }
        }
        double[][] featureFactorUpdateRatios = new double[this.numFeatures][this.numFactors];
        for (int instance2 = 0; instance2 < this.numInstances; ++instance2) {
            FeatureVector data = (FeatureVector)((Instance)this.instances.get(instance2)).getData();
            double[] currentInstanceFactorWeights = this.instanceFactorWeights[instance2];
            for (int location = 0; location < data.numLocations(); ++location) {
                int feature3 = data.indexAtLocation(location);
                double value = data.valueAtLocation(location);
                if (this.idfWeighting) {
                    value *= this.featureWeights[feature3];
                }
                if (value == 0.0) continue;
                double[] currentFeatureFactorWeights = this.featureFactorWeights[feature3];
                double innerProduct = 0.0;
                for (int factor5 = 0; factor5 < this.numFactors; ++factor5) {
                    assert (currentInstanceFactorWeights[factor5] >= 0.0);
                    assert (currentFeatureFactorWeights[factor5] >= 0.0);
                    innerProduct += currentInstanceFactorWeights[factor5] * currentFeatureFactorWeights[factor5];
                }
                double ratio = value / innerProduct;
                for (factor = 0; factor < this.numFactors; ++factor) {
                    assert (!Double.isNaN(currentInstanceFactorWeights[factor]));
                    assert (!Double.isNaN(ratio)) : value + " / " + innerProduct;
                    double[] dArray = featureFactorUpdateRatios[feature3];
                    int n = factor;
                    dArray[n] = dArray[n] + currentInstanceFactorWeights[factor] * ratio;
                }
            }
        }
        for (feature = 0; feature < this.numFeatures; ++feature) {
            double[] currentFeatureFactorWeights = this.featureFactorWeights[feature];
            for (int factor6 = 0; factor6 < this.numFactors; ++factor6) {
                assert (!Double.isNaN(featureFactorUpdateRatios[feature][factor6]));
                assert (!Double.isNaN(this.instanceSums[factor6])) : this.instanceSums[factor6];
                int n = factor6;
                currentFeatureFactorWeights[n] = currentFeatureFactorWeights[n] * (featureFactorUpdateRatios[feature][factor6] / this.instanceSums[factor6]);
                assert (!Double.isNaN(currentFeatureFactorWeights[factor6]));
            }
        }
        Arrays.fill(this.featureSums, 0.0);
        for (feature = 0; feature < this.numFeatures; ++feature) {
            for (int factor7 = 0; factor7 < this.numFactors; ++factor7) {
                int n = factor7;
                this.featureSums[n] = this.featureSums[n] + this.featureFactorWeights[feature][factor7];
            }
        }
    }

    public void printFactorFeatures(int limit) {
        Object[] sortedIDs = new IDSorter[this.numFeatures];
        StringBuilder output = new StringBuilder();
        for (int factor = 0; factor < this.numFactors; ++factor) {
            for (int feature = 0; feature < this.numFeatures; ++feature) {
                sortedIDs[feature] = new IDSorter(feature, this.featureFactorWeights[feature][factor]);
            }
            Arrays.sort(sortedIDs);
            output.append(factor + "\t");
            for (int i = 0; i < limit; ++i) {
                output.append(this.instances.getDataAlphabet().lookupObject(((IDSorter)sortedIDs[i]).getID()) + " ");
            }
            output.append("\n");
        }
        System.out.println(output);
    }

    public void writeFeatureFactors(PrintWriter out) throws IOException {
        for (int feature = 0; feature < this.numFeatures; ++feature) {
            double[] currentFeatureFactorWeights = this.featureFactorWeights[feature];
            out.print(this.instances.getDataAlphabet().lookupObject(feature));
            for (int factor = 0; factor < this.numFactors; ++factor) {
                out.format("\t%f", currentFeatureFactorWeights[factor]);
            }
            out.println();
        }
    }

    public void writeInstanceFactors(PrintWriter out) throws IOException {
        for (int instance = 0; instance < this.numInstances; ++instance) {
            double[] currentInstanceFactorWeights = this.instanceFactorWeights[instance];
            out.print(((Instance)this.instances.get(instance)).getName());
            for (int factor = 0; factor < this.numFactors; ++factor) {
                out.format("\t%f", currentInstanceFactorWeights[factor]);
            }
            out.println();
        }
    }

    public static void main(String[] args) throws Exception {
        CommandOption.setSummary(NonNegativeMatrixFactorization.class, "Train non-negative matrix factorization.");
        CommandOption.process(NonNegativeMatrixFactorization.class, args);
        InstanceList instances = InstanceList.load(new File(NonNegativeMatrixFactorization.inputFile.value));
        NonNegativeMatrixFactorization nmf = new NonNegativeMatrixFactorization(instances, NonNegativeMatrixFactorization.numDimensions.value, NonNegativeMatrixFactorization.useIDFOption.value);
        if (NonNegativeMatrixFactorization.clusterSize.value > 0) {
            nmf.initialize(NonNegativeMatrixFactorization.clusterSize.value);
        }
        System.out.println("Finding " + NonNegativeMatrixFactorization.numDimensions.value + " factors.");
        System.out.println("Histograms show relative factor sizes, the number measures factorization error (smaller is better).");
        double previousDivergence = Double.POSITIVE_INFINITY;
        for (int iteration = 1; iteration <= NonNegativeMatrixFactorization.numIterationsOption.value; ++iteration) {
            nmf.updateWeights();
            if (iteration % 100 == 0) {
                nmf.printFactorFeatures(15);
            }
            if (iteration % 10 != 0) continue;
            double divergence = nmf.getDivergence();
            System.out.println(NonNegativeMatrixFactorization.getBars(nmf.featureSums) + "\t" + NonNegativeMatrixFactorization.getBars(nmf.instanceSums) + "\t" + divergence);
            if (divergence / previousDivergence > 0.9999) break;
            previousDivergence = divergence;
        }
        if (NonNegativeMatrixFactorization.outputWordsFile.value != null) {
            System.out.println("Writing to " + NonNegativeMatrixFactorization.outputWordsFile.value);
            PrintWriter out = new PrintWriter(new File(NonNegativeMatrixFactorization.outputWordsFile.value));
            nmf.writeFeatureFactors(out);
            out.close();
        }
        if (NonNegativeMatrixFactorization.outputDocsFile.value != null) {
            System.out.println("Writing to " + NonNegativeMatrixFactorization.outputDocsFile.value);
            PrintWriter out = new PrintWriter(new File(NonNegativeMatrixFactorization.outputDocsFile.value));
            nmf.writeInstanceFactors(out);
            out.close();
        }
    }
}

