/*
 * Decompiled with CFR 0.152.
 */
package net.myrrix.online.factorizer.als;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import net.myrrix.common.LangUtils;
import net.myrrix.common.collection.FastByIDFloatMap;
import net.myrrix.common.collection.FastByIDMap;
import net.myrrix.common.math.MatrixUtils;
import net.myrrix.common.math.SimpleVectorMath;
import net.myrrix.common.parallel.ExecutorUtils;
import net.myrrix.common.random.RandomManager;
import net.myrrix.common.random.RandomUtils;
import net.myrrix.common.stats.DoubleWeightedMean;
import net.myrrix.common.stats.JVMEnvironment;
import net.myrrix.online.factorizer.MatrixFactorizer;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Pair;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class AlternatingLeastSquares
implements MatrixFactorizer {
    private static final Logger log = LoggerFactory.getLogger(AlternatingLeastSquares.class);
    public static final double DEFAULT_ALPHA = 1.0;
    public static final double DEFAULT_LAMBDA = 0.1;
    public static final double DEFAULT_CONVERGENCE_THRESHOLD = 0.001;
    public static final int DEFAULT_MAX_ITERATIONS = 30;
    private static final int WORK_UNIT_SIZE = 100;
    private static final int NUM_USER_ITEMS_TO_TEST_CONVERGENCE = 100;
    private static final long LOG_INTERVAL = 100000L;
    private static final int MAX_FAR_FROM_VECTORS = 100000;
    private static final boolean RECONSTRUCT_R_MATRIX = Boolean.parseBoolean(System.getProperty("model.reconstructRMatrix", "false"));
    private static final boolean LOSS_IGNORES_UNSPECIFIED = Boolean.parseBoolean(System.getProperty("model.lossIgnoresUnspecified", "false"));
    private final FastByIDMap<FastByIDFloatMap> RbyRow;
    private final FastByIDMap<FastByIDFloatMap> RbyColumn;
    private final int features;
    private final double estimateErrorConvergenceThreshold;
    private final int maxIterations;
    private FastByIDMap<float[]> X;
    private FastByIDMap<float[]> Y;
    private FastByIDMap<float[]> previousY;

    public AlternatingLeastSquares(FastByIDMap<FastByIDFloatMap> RbyRow, FastByIDMap<FastByIDFloatMap> RbyColumn) {
        this(RbyRow, RbyColumn, 30, 0.001, 30);
    }

    public AlternatingLeastSquares(FastByIDMap<FastByIDFloatMap> RbyRow, FastByIDMap<FastByIDFloatMap> RbyColumn, int features) {
        this(RbyRow, RbyColumn, features, 0.001, 30);
    }

    public AlternatingLeastSquares(FastByIDMap<FastByIDFloatMap> RbyRow, FastByIDMap<FastByIDFloatMap> RbyColumn, int features, double estimateErrorConvergenceThreshold, int maxIterations) {
        Preconditions.checkNotNull(RbyRow);
        Preconditions.checkNotNull(RbyColumn);
        Preconditions.checkArgument(features > 0, "features must be positive: %s", features);
        Preconditions.checkArgument(estimateErrorConvergenceThreshold > 0.0 && estimateErrorConvergenceThreshold < 1.0, "threshold must be in (0,1): %s", estimateErrorConvergenceThreshold);
        this.RbyRow = RbyRow;
        this.RbyColumn = RbyColumn;
        this.features = features;
        this.estimateErrorConvergenceThreshold = estimateErrorConvergenceThreshold;
        this.maxIterations = maxIterations;
    }

    @Override
    public FastByIDMap<float[]> getX() {
        return this.X;
    }

    @Override
    public FastByIDMap<float[]> getY() {
        return this.Y;
    }

    @Override
    public void setPreviousX(FastByIDMap<float[]> previousX) {
    }

    @Override
    public void setPreviousY(FastByIDMap<float[]> previousY) {
        this.previousY = previousY;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public Void call() throws ExecutionException, InterruptedException {
        block16: {
            this.X = new FastByIDMap(this.RbyRow.size(), 1.25f);
            boolean randomY = this.previousY == null || this.previousY.isEmpty();
            this.Y = this.constructInitialY(this.previousY);
            String threadsString = System.getProperty("model.threads");
            int numThreads = threadsString == null ? Runtime.getRuntime().availableProcessors() : Integer.parseInt(threadsString);
            ExecutorService executor = Executors.newFixedThreadPool(numThreads, new ThreadFactoryBuilder().setNameFormat("ALS-%d").setDaemon(true).build());
            log.info("Iterating using {} threads", (Object)numThreads);
            if (!Boolean.parseBoolean(System.getProperty("model.als.iterate", "true"))) {
                try {
                    this.iterateXFromY(executor);
                }
                finally {
                    ExecutorUtils.shutdownNowAndAwait(executor);
                }
                return null;
            }
            RandomGenerator random = RandomManager.getRandom();
            long[] testUserIDs = RandomUtils.chooseAboutNFromStream(100, this.RbyRow.keySetIterator(), this.RbyRow.size(), random);
            long[] testItemIDs = RandomUtils.chooseAboutNFromStream(100, this.RbyColumn.keySetIterator(), this.RbyColumn.size(), random);
            double[][] estimates = new double[testUserIDs.length][testItemIDs.length];
            if (!this.X.isEmpty()) {
                for (int i = 0; i < testUserIDs.length; ++i) {
                    for (int j = 0; j < testItemIDs.length; ++j) {
                        estimates[i][j] = SimpleVectorMath.dot(this.X.get(testUserIDs[i]), this.Y.get(testItemIDs[j]));
                    }
                }
            }
            try {
                double convergenceValue;
                int iterationNumber = 0;
                do {
                    this.iterateXFromY(executor);
                    this.iterateYFromX(executor);
                    DoubleWeightedMean averageAbsoluteEstimateDiff = new DoubleWeightedMean();
                    for (int i = 0; i < testUserIDs.length; ++i) {
                        for (int j = 0; j < testItemIDs.length; ++j) {
                            double newValue = SimpleVectorMath.dot(this.X.get(testUserIDs[i]), this.Y.get(testItemIDs[j]));
                            double oldValue = estimates[i][j];
                            estimates[i][j] = newValue;
                            averageAbsoluteEstimateDiff.increment(FastMath.abs(newValue - oldValue), FastMath.max(0.0, newValue));
                        }
                    }
                    log.info("Finished iteration {}", (Object)(++iterationNumber));
                    if (this.maxIterations > 0 && iterationNumber >= this.maxIterations) {
                        log.info("Reached iteration limit");
                    } else {
                        log.info("Avg absolute difference in estimate vs prior iteration: {}", (Object)averageAbsoluteEstimateDiff);
                        convergenceValue = averageAbsoluteEstimateDiff.getResult();
                        if (LangUtils.isFinite(convergenceValue)) continue;
                        log.warn("Invalid convergence value, aborting iteration! {}", (Object)convergenceValue);
                    }
                    break block16;
                } while (randomY && iterationNumber == 1 || !(convergenceValue < this.estimateErrorConvergenceThreshold));
                log.info("Converged");
            }
            finally {
                ExecutorUtils.shutdownNowAndAwait(executor);
            }
        }
        return null;
    }

    private FastByIDMap<float[]> constructInitialY(FastByIDMap<float[]> previousY) {
        FastByIDMap<float[]> randomY;
        RandomGenerator random = RandomManager.getRandom();
        if (previousY == null || previousY.isEmpty()) {
            log.info("Starting from new, random Y matrix");
            randomY = new FastByIDMap(this.RbyColumn.size(), 1.25f);
        } else {
            int oldFeatureCount = previousY.entrySet().iterator().next().getValue().length;
            if (oldFeatureCount > this.features) {
                log.info("Feature count has decreased to {}, projecting down previous generation's Y matrix", (Object)this.features);
                randomY = new FastByIDMap(previousY.size(), 1.25f);
                for (FastByIDMap.MapEntry<Object> mapEntry : previousY.entrySet()) {
                    float[] oldLargerVector = (float[])mapEntry.getValue();
                    float[] newSmallerVector = new float[this.features];
                    System.arraycopy(oldLargerVector, 0, newSmallerVector, 0, newSmallerVector.length);
                    SimpleVectorMath.normalize(newSmallerVector);
                    randomY.put(mapEntry.getKey(), newSmallerVector);
                }
            } else if (oldFeatureCount < this.features) {
                log.info("Feature count has increased to {}, using previous generation's Y matrix as subspace", (Object)this.features);
                randomY = new FastByIDMap(previousY.size(), 1.25f);
                for (FastByIDMap.MapEntry<Object> mapEntry : previousY.entrySet()) {
                    float[] oldSmallerVector = (float[])mapEntry.getValue();
                    float[] newLargerVector = new float[this.features];
                    System.arraycopy(oldSmallerVector, 0, newLargerVector, 0, oldSmallerVector.length);
                    for (int i = oldSmallerVector.length; i < newLargerVector.length; ++i) {
                        newLargerVector[i] = (float)random.nextGaussian();
                    }
                    SimpleVectorMath.normalize(newLargerVector);
                    randomY.put(mapEntry.getKey(), newLargerVector);
                }
            } else {
                log.info("Starting from previous generation's Y matrix");
                randomY = previousY;
            }
        }
        ArrayList<float[]> recentVectors = Lists.newArrayList();
        for (FastByIDMap.MapEntry mapEntry : randomY.entrySet()) {
            if (recentVectors.size() >= 100000) break;
            recentVectors.add((float[])mapEntry.getValue());
        }
        LongPrimitiveIterator it = this.RbyColumn.keySetIterator();
        long l = 0L;
        while (it.hasNext()) {
            long id = it.nextLong();
            if (!randomY.containsKey(id)) {
                float[] vector = RandomUtils.randomUnitVectorFarFrom(this.features, recentVectors, random);
                randomY.put(id, vector);
                if (recentVectors.size() < 100000) {
                    recentVectors.add(vector);
                }
            }
            if (++l % 100000L != 0L) continue;
            log.info("Computed {} initial Y rows", (Object)l);
        }
        log.info("Constructed initial Y");
        return randomY;
    }

    private void iterateXFromY(ExecutorService executor) throws ExecutionException, InterruptedException {
        RealMatrix YTY = MatrixUtils.transposeTimesSelf(this.Y);
        ArrayList<Future<?>> futures = Lists.newArrayList();
        this.addWorkers(this.RbyRow, this.Y, YTY, this.X, executor, futures);
        int count = 0;
        long total = 0L;
        for (Future future : futures) {
            future.get();
            if ((long)(count += 100) < 100000L) continue;
            JVMEnvironment env = new JVMEnvironment();
            log.info("{} X/tag rows computed ({}MB heap)", (Object)(total += (long)count), (Object)env.getUsedMemoryMB());
            if (env.getPercentUsedMemory() > 95) {
                log.warn("Memory is low. Increase heap size with -Xmx, decrease new generation size with larger -XX:NewRatio value, and/or use -XX:+UseCompressedOops");
            }
            count = 0;
        }
    }

    private void iterateYFromX(ExecutorService executor) throws ExecutionException, InterruptedException {
        RealMatrix XTX = MatrixUtils.transposeTimesSelf(this.X);
        ArrayList<Future<?>> futures = Lists.newArrayList();
        this.addWorkers(this.RbyColumn, this.X, XTX, this.Y, executor, futures);
        int count = 0;
        long total = 0L;
        for (Future future : futures) {
            future.get();
            if ((long)(count += 100) < 100000L) continue;
            JVMEnvironment env = new JVMEnvironment();
            log.info("{} Y/tag rows computed ({}MB heap)", (Object)(total += (long)count), (Object)env.getUsedMemoryMB());
            if (env.getPercentUsedMemory() > 95) {
                log.warn("Memory is low. Increase heap size with -Xmx, decrease new generation size with larger -XX:NewRatio value, and/or use -XX:+UseCompressedOops");
            }
            count = 0;
        }
    }

    private void addWorkers(FastByIDMap<FastByIDFloatMap> R, FastByIDMap<float[]> M, RealMatrix MTM, FastByIDMap<float[]> MTags, ExecutorService executor, Collection<Future<?>> futures) {
        if (R != null) {
            ArrayList<Pair<Long, FastByIDFloatMap>> workUnit = Lists.newArrayListWithCapacity(100);
            for (FastByIDMap.MapEntry<FastByIDFloatMap> entry : R.entrySet()) {
                workUnit.add(new Pair<Long, FastByIDFloatMap>(entry.getKey(), entry.getValue()));
                if (workUnit.size() != 100) continue;
                futures.add(executor.submit(new Worker(this.features, M, MTM, MTags, workUnit)));
                workUnit = Lists.newArrayListWithCapacity(100);
            }
            if (!workUnit.isEmpty()) {
                futures.add(executor.submit(new Worker(this.features, M, MTM, MTags, workUnit)));
            }
        }
    }

    private static final class Worker
    implements Callable<Void> {
        private final int features;
        private final FastByIDMap<float[]> Y;
        private final RealMatrix YTY;
        private final FastByIDMap<float[]> X;
        private final List<Pair<Long, FastByIDFloatMap>> workUnit;

        private Worker(int features, FastByIDMap<float[]> Y, RealMatrix YTY, FastByIDMap<float[]> X, List<Pair<Long, FastByIDFloatMap>> workUnit) {
            this.features = features;
            this.Y = Y;
            this.YTY = YTY;
            this.X = X;
            this.workUnit = workUnit;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public Void call() {
            double alpha = Worker.getAlpha();
            double lambda = Worker.getLambda() * alpha;
            int features = this.features;
            for (Pair<Long, FastByIDFloatMap> work : this.workUnit) {
                FastByIDFloatMap ru = work.getSecond();
                RealMatrix Wu = LOSS_IGNORES_UNSPECIFIED ? Worker.partialTransposeTimesSelf(this.Y, this.YTY.getRowDimension(), ru.keySetIterator()) : this.YTY.copy();
                double[][] WuData = MatrixUtils.accessMatrixDataDirectly(Wu);
                double[] YTCupu = new double[features];
                for (FastByIDFloatMap.MapEntry entry : ru.entrySet()) {
                    double xu = entry.getValue();
                    float[] vector = this.Y.get(entry.getKey());
                    if (vector == null) {
                        log.warn("No vector for {}. This should not happen. Continuing...", (Object)entry.getKey());
                        continue;
                    }
                    if (RECONSTRUCT_R_MATRIX) {
                        for (int row = 0; row < features; ++row) {
                            int n = row;
                            YTCupu[n] = YTCupu[n] + xu * (double)vector[row];
                        }
                        continue;
                    }
                    double cu = 1.0 + alpha * FastMath.abs(xu);
                    for (int row = 0; row < features; ++row) {
                        float vectorAtRow = vector[row];
                        double rowValue = (double)vectorAtRow * (cu - 1.0);
                        double[] WuDataRow = WuData[row];
                        for (int col = 0; col < features; ++col) {
                            int n = col;
                            WuDataRow[n] = WuDataRow[n] + rowValue * (double)vector[col];
                        }
                        if (!(xu > 0.0)) continue;
                        int n = row;
                        YTCupu[n] = YTCupu[n] + (double)vectorAtRow * cu;
                    }
                }
                double lambdaTimesCount = lambda * (double)ru.size();
                int x = 0;
                while (x < features) {
                    double[] dArray = WuData[x];
                    int n = x++;
                    dArray[n] = dArray[n] + lambdaTimesCount;
                }
                float[] xu = MatrixUtils.getSolver(Wu).solveDToF(YTCupu);
                FastByIDMap<float[]> fastByIDMap = this.X;
                synchronized (fastByIDMap) {
                    this.X.put(work.getFirst(), xu);
                }
            }
            return null;
        }

        private static double getAlpha() {
            String alphaProperty = System.getProperty("model.als.alpha");
            return alphaProperty == null ? 1.0 : LangUtils.parseDouble(alphaProperty);
        }

        private static double getLambda() {
            String lambdaProperty = System.getProperty("model.als.lambda");
            return lambdaProperty == null ? 0.1 : LangUtils.parseDouble(lambdaProperty);
        }

        private static RealMatrix partialTransposeTimesSelf(FastByIDMap<float[]> M, int dimension, LongPrimitiveIterator keys) {
            Array2DRowRealMatrix result = new Array2DRowRealMatrix(dimension, dimension);
            while (keys.hasNext()) {
                long key = (Long)keys.next();
                float[] vector = M.get(key);
                for (int row = 0; row < dimension; ++row) {
                    float rowValue = vector[row];
                    for (int col = 0; col < dimension; ++col) {
                        result.addToEntry(row, col, rowValue * vector[col]);
                    }
                }
            }
            return result;
        }
    }
}

