/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.core.learning.parametric.bayesian;

import eu.amidst.core.datastream.Attribute;
import eu.amidst.core.datastream.DataInstance;
import eu.amidst.core.datastream.DataOnMemory;
import eu.amidst.core.datastream.DataStream;
import eu.amidst.core.distribution.UnivariateDistribution;
import eu.amidst.core.exponentialfamily.EF_LearningBayesianNetwork;
import eu.amidst.core.exponentialfamily.EF_UnivariateDistribution;
import eu.amidst.core.learning.parametric.bayesian.BayesianParameterLearningAlgorithm;
import eu.amidst.core.learning.parametric.bayesian.utils.DataPosterior;
import eu.amidst.core.learning.parametric.bayesian.utils.DataPosteriorAssignment;
import eu.amidst.core.learning.parametric.bayesian.utils.PlateuIIDReplication;
import eu.amidst.core.learning.parametric.bayesian.utils.PlateuStructure;
import eu.amidst.core.learning.parametric.bayesian.utils.TransitionMethod;
import eu.amidst.core.models.BayesianNetwork;
import eu.amidst.core.models.DAG;
import eu.amidst.core.utils.CompoundVector;
import eu.amidst.core.utils.Serialization;
import eu.amidst.core.variables.HashMapAssignment;
import eu.amidst.core.variables.Variable;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

public class SVB
implements BayesianParameterLearningAlgorithm,
Serializable {
    private static final long serialVersionUID = 4107783324901370839L;
    TransitionMethod transitionMethod = null;
    protected EF_LearningBayesianNetwork ef_extendedBN;
    protected PlateuStructure plateuStructure = new PlateuIIDReplication();
    protected DAG dag;
    transient DataStream<DataInstance> dataStream;
    double elbo;
    boolean nonSequentialModel = false;
    boolean randomRestart = false;
    int windowsSize = 100;
    int seed = 0;
    int nBatches = 0;
    int nIterTotal = 0;
    CompoundVector naturalVectorPrior = null;
    BatchOutput naturalVectorPosterior = null;
    private boolean activateOutput = false;

    @Override
    public int getWindowsSize() {
        return this.windowsSize;
    }

    public void setRandomRestart(boolean randomRestart) {
        this.randomRestart = randomRestart;
    }

    public PlateuStructure getPlateuStructure() {
        return this.plateuStructure;
    }

    @Override
    public void setPlateuStructure(PlateuStructure plateuStructure) {
        this.plateuStructure = plateuStructure;
    }

    public int getSeed() {
        return this.seed;
    }

    @Override
    public void setSeed(int seed) {
        this.seed = seed;
    }

    @Override
    public double getLogMarginalProbability() {
        return this.elbo;
    }

    @Override
    public void setWindowsSize(int windowsSize) {
        this.windowsSize = windowsSize;
    }

    public void setTransitionMethod(TransitionMethod transitionMethod) {
        this.transitionMethod = transitionMethod;
    }

    public <E extends TransitionMethod> E getTransitionMethod() {
        return (E)this.transitionMethod;
    }

    public CompoundVector getNaturalParameterPrior() {
        return this.computeNaturalParameterVectorPrior();
    }

    protected BatchOutput getNaturalParameterPosterior() {
        if (this.naturalVectorPosterior == null) {
            this.naturalVectorPosterior = new BatchOutput(this.computeNaturalParameterVectorPrior(), 0.0);
        }
        return this.naturalVectorPosterior;
    }

    protected CompoundVector computeNaturalParameterVectorPrior() {
        return this.getPlateuStructure().getPlateauNaturalParameterPrior();
    }

    @Override
    public void runLearning() {
        this.initLearning();
        this.elbo = !this.nonSequentialModel ? this.dataStream.streamOfBatches(this.windowsSize).mapToDouble(this::updateModel).sum() : this.dataStream.streamOfBatches(this.windowsSize).mapToDouble(this::updateModelParallel).sum();
    }

    public void setNonSequentialModel(boolean nonSequentialModel_) {
        this.nonSequentialModel = nonSequentialModel_;
    }

    @Override
    public void setOutput(boolean activateOutput) {
        this.activateOutput = activateOutput;
        this.getPlateuStructure().getVMP().setOutput(activateOutput);
        this.getPlateuStructure().getVMP().setTestELBO(activateOutput);
    }

    @Override
    public double predictedLogLikelihood(DataOnMemory<DataInstance> batch) {
        this.getPlateuStructure().getNonReplictedNodes().forEach(node -> node.setActive(false));
        double elbo = 0.0;
        this.getPlateuStructure().getNonReplictedNodes().forEach(node -> node.setActive(true));
        return elbo += this.updateModelOnBatchParallel(batch).getElbo();
    }

    @Override
    public double updateModel(DataOnMemory<DataInstance> batch) {
        double elboBatch = 0.0;
        if (!this.nonSequentialModel) {
            if (this.randomRestart) {
                this.getPlateuStructure().resetQs();
            }
            elboBatch = this.updateModelSequential(batch);
        } else {
            if (this.randomRestart) {
                this.getPlateuStructure().resetQs();
            }
            elboBatch = this.updateModelParallel(batch);
        }
        this.applyTransition();
        return elboBatch;
    }

    public void applyTransition() {
        if (this.transitionMethod != null) {
            this.ef_extendedBN = this.transitionMethod.transitionModel(this.ef_extendedBN, this.plateuStructure);
        }
    }

    private double updateModelSequential(DataOnMemory<DataInstance> batch) {
        ++this.nBatches;
        this.plateuStructure.setEvidence(batch.getList());
        this.plateuStructure.runInference();
        this.nIterTotal += this.plateuStructure.getVMP().getNumberOfIterations();
        this.updateNaturalParameterPrior(this.plateuStructure.getPlateauNaturalParameterPosterior());
        return this.plateuStructure.getLogProbabilityOfEvidence();
    }

    public BatchOutput updateModelOnBatchParallel(DataOnMemory<DataInstance> batch) {
        ++this.nBatches;
        this.plateuStructure.setEvidence(batch.getList());
        this.plateuStructure.runInference();
        this.nIterTotal += this.plateuStructure.getVMP().getNumberOfIterations();
        CompoundVector compoundVectorEnd = this.plateuStructure.getPlateauNaturalParameterPosterior();
        compoundVectorEnd.substract(this.getNaturalParameterPrior());
        return new BatchOutput(compoundVectorEnd, this.plateuStructure.getLogProbabilityOfEvidence());
    }

    @Override
    public List<DataPosterior> computePosterior(DataOnMemory<DataInstance> batch) {
        List<Variable> latentVariables = this.dag.getVariables().getListOfVariables().stream().filter(var -> var.getAttribute() != null).collect(Collectors.toList());
        return this.computePosterior(batch, latentVariables);
    }

    @Override
    public List<DataPosterior> computePosterior(DataOnMemory<DataInstance> batch, List<Variable> latentVariables) {
        Attribute seq_id = batch.getAttributes().getSeq_id();
        if (seq_id == null) {
            throw new IllegalArgumentException("Functionality only available for data sets with a seq_id attribute");
        }
        this.plateuStructure.desactiveParametersNodes();
        this.plateuStructure.setEvidence(batch.getList());
        this.plateuStructure.runInference();
        this.plateuStructure.activeParametersNodes();
        ArrayList<DataPosterior> posteriors = new ArrayList<DataPosterior>();
        for (int i = 0; i < batch.getNumberOfDataInstances(); ++i) {
            ArrayList<UnivariateDistribution> posteriorsQ = new ArrayList<UnivariateDistribution>();
            for (Variable latentVariable : latentVariables) {
                posteriorsQ.add((UnivariateDistribution)((EF_UnivariateDistribution)this.plateuStructure.getEFVariablePosterior(latentVariable, i)).deepCopy().toUnivariateDistribution());
            }
            posteriors.add(new DataPosterior((int)batch.getDataInstance(i).getValue(seq_id), posteriorsQ));
        }
        return posteriors;
    }

    public List<DataPosteriorAssignment> computePosteriorAssignment(DataOnMemory<DataInstance> batch, List<Variable> variables) {
        Attribute seq_id = batch.getAttributes().getSeq_id();
        if (seq_id == null) {
            throw new IllegalArgumentException("Functionality only available for data sets with a seq_id attribute");
        }
        this.plateuStructure.desactiveParametersNodes();
        this.plateuStructure.setEvidence(batch.getList());
        this.plateuStructure.runInference();
        this.plateuStructure.activeParametersNodes();
        ArrayList<DataPosteriorAssignment> posteriors = new ArrayList<DataPosteriorAssignment>();
        for (int i = 0; i < batch.getNumberOfDataInstances(); ++i) {
            ArrayList<UnivariateDistribution> posteriorsQ = new ArrayList<UnivariateDistribution>();
            HashMapAssignment assignment = new HashMapAssignment();
            for (Variable variable : variables) {
                Object dist = this.plateuStructure.getEFVariablePosterior(variable, i);
                if (dist != null) {
                    posteriorsQ.add((UnivariateDistribution)((EF_UnivariateDistribution)dist).deepCopy().toUnivariateDistribution());
                    continue;
                }
                assignment.setValue(variable, batch.getDataInstance(i).getValue(variable));
            }
            DataPosterior dataPosterior = new DataPosterior((int)batch.getDataInstance(i).getValue(seq_id), posteriorsQ);
            posteriors.add(new DataPosteriorAssignment(dataPosterior, assignment));
        }
        return posteriors;
    }

    private double updateModelParallel(DataOnMemory<DataInstance> batch) {
        ++this.nBatches;
        this.plateuStructure.setEvidence(batch.getList());
        this.plateuStructure.runInference();
        this.nIterTotal += this.plateuStructure.getVMP().getNumberOfIterations();
        CompoundVector compoundVectorEnd = this.plateuStructure.getPlateauNaturalParameterPosterior();
        compoundVectorEnd.substract(this.getNaturalParameterPrior());
        BatchOutput out = new BatchOutput(compoundVectorEnd, this.plateuStructure.getLogProbabilityOfEvidence());
        this.naturalVectorPosterior = BatchOutput.sumNonStateless(out, this.getNaturalParameterPosterior());
        return out.getElbo();
    }

    public int getNumberOfBatches() {
        return this.nBatches;
    }

    public double getAverageNumOfIterations() {
        return (double)this.nIterTotal / (double)this.nBatches;
    }

    public DAG getDAG() {
        return this.dag;
    }

    @Override
    public void setDAG(DAG dag) {
        this.dag = dag;
    }

    @Override
    public void initLearning() {
        this.plateuStructure.initTransientDataStructure();
        this.getPlateuStructure().getVMP().setOutput(this.activateOutput);
        this.getPlateuStructure().getVMP().setTestELBO(this.activateOutput);
        this.plateuStructure.setNRepetitions(this.windowsSize);
        this.nBatches = 0;
        this.nIterTotal = 0;
        this.plateuStructure.setSeed(this.seed);
        this.plateuStructure.setDAG(this.dag);
        this.plateuStructure.replicateModel();
        this.plateuStructure.resetQs();
        this.ef_extendedBN = this.plateuStructure.getEFLearningBN();
        if (this.transitionMethod != null) {
            this.ef_extendedBN = this.transitionMethod.initModel(this.ef_extendedBN, this.plateuStructure);
        }
    }

    @Override
    public void setDataStream(DataStream<DataInstance> data) {
        this.dataStream = data;
    }

    public void updateNaturalParameterPrior(CompoundVector parameterVector) {
        this.plateuStructure.updateNaturalParameterPrior(parameterVector);
        this.ef_extendedBN = this.plateuStructure.getEFLearningBN();
        this.naturalVectorPrior = this.computeNaturalParameterVectorPrior();
    }

    public void updateNaturalParameterPosteriors(CompoundVector parameterVector) {
        this.plateuStructure.updateNaturalParameterPosteriors(parameterVector);
    }

    @Override
    public BayesianNetwork getLearntBayesianNetwork() {
        if (!this.nonSequentialModel) {
            return new BayesianNetwork(this.dag, this.ef_extendedBN.toConditionalDistribution());
        }
        CompoundVector prior = this.plateuStructure.getPlateauNaturalParameterPrior();
        this.updateNaturalParameterPrior(this.plateuStructure.getPlateauNaturalParameterPosterior());
        BayesianNetwork learntBN = new BayesianNetwork(this.dag, this.ef_extendedBN.toConditionalDistribution());
        this.updateNaturalParameterPrior(prior);
        return learntBN;
    }

    @Override
    public void setParallelMode(boolean parallelMode) {
        throw new UnsupportedOperationException("Non Parallel Mode Supported. Use class ParallelSVB");
    }

    @Override
    public <E extends UnivariateDistribution> E getParameterPosterior(Variable parameter) {
        return ((EF_UnivariateDistribution)this.getPlateuStructure().getEFParameterPosterior(parameter)).toUnivariateDistribution();
    }

    public static class BatchOutput
    implements Serializable {
        private static final long serialVersionUID = 4107783324901370839L;
        CompoundVector vector;
        double elbo;

        public BatchOutput(CompoundVector vector_, double elbo_) {
            this.vector = vector_;
            this.elbo = elbo_;
        }

        public CompoundVector getVector() {
            return this.vector;
        }

        public double getElbo() {
            return this.elbo;
        }

        public void setElbo(double elbo) {
            this.elbo = elbo;
        }

        public static BatchOutput sumNonStateless(BatchOutput batchOutput1, BatchOutput batchOutput2) {
            batchOutput2.getVector().sum(batchOutput1.getVector());
            batchOutput2.setElbo(batchOutput2.getElbo() + batchOutput1.getElbo());
            return batchOutput2;
        }

        public static BatchOutput sumStateless(BatchOutput batchOutput1, BatchOutput batchOutput2) {
            BatchOutput sum = Serialization.deepCopy(batchOutput2);
            sum.getVector().sum(batchOutput1.getVector());
            sum.setElbo(batchOutput2.getElbo() + batchOutput1.getElbo());
            return sum;
        }
    }
}

