/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.discrete;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.tree.TreeParameterModel;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.MachineAccuracy;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;

public abstract class HyperParameterGradient
implements GradientWrtParameterProvider,
HessianWrtParameterProvider,
Reportable,
Loggable {
    private final TreeDataLikelihood treeDataLikelihood;
    private final GradientWrtParameterProvider gradientWrtParameterProvider;
    private final Parameter parameter;
    private final Tree tree;
    private final boolean useHessian;
    protected final TreeParameterModel branchParameter;
    protected MultivariateFunction numeric1 = new MultivariateFunction(){

        @Override
        public double evaluate(double[] dArray) {
            for (int i = 0; i < dArray.length; ++i) {
                HyperParameterGradient.this.parameter.setParameterValue(i, dArray[i]);
            }
            return HyperParameterGradient.this.treeDataLikelihood.getLogLikelihood();
        }

        @Override
        public int getNumArguments() {
            return HyperParameterGradient.this.parameter.getDimension();
        }

        @Override
        public double getLowerBound(int n) {
            return 0.0;
        }

        @Override
        public double getUpperBound(int n) {
            return Double.POSITIVE_INFINITY;
        }
    };
    private static final boolean DEBUG = true;
    protected static final boolean COUNT_TOTAL_OPERATIONS = true;
    protected long getGradientLogDensityCount = 0L;

    public HyperParameterGradient(TreeDataLikelihood treeDataLikelihood, GradientWrtParameterProvider gradientWrtParameterProvider, Parameter parameter, TreeParameterModel treeParameterModel, boolean bl) {
        this.treeDataLikelihood = treeDataLikelihood;
        this.gradientWrtParameterProvider = gradientWrtParameterProvider;
        this.parameter = parameter;
        this.useHessian = bl;
        this.tree = treeDataLikelihood.getTree();
        this.branchParameter = treeParameterModel;
    }

    @Override
    public Likelihood getLikelihood() {
        return this.treeDataLikelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override
    public int getDimension() {
        return this.parameter.getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        double[] dArray = this.gradientWrtParameterProvider.getGradientLogDensity();
        if (dArray.length != this.tree.getNodeCount() - 1) {
            throw new RuntimeException("Dimension mismatch!");
        }
        double[] dArray2 = new double[this.getDimension()];
        for (int i = 0; i < this.branchParameter.getParameterSize(); ++i) {
            NodeRef nodeRef = this.tree.getNode(this.branchParameter.getNodeNumberFromParameterIndex(i));
            double[] dArray3 = this.getDifferential(this.tree, nodeRef);
            for (int j = 0; j < dArray2.length; ++j) {
                int n = j;
                dArray2[n] = dArray2[n] + dArray[i] * dArray3[j];
            }
        }
        return dArray2;
    }

    @Override
    public double[] getDiagonalHessianLogDensity() {
        return NumericalDerivative.diagonalHessian(this.numeric1, this.parameter.getParameterValues());
    }

    abstract double[] getDifferential(Tree var1, NodeRef var2);

    protected boolean valuesAreSufficientlyLarge(double[] dArray) {
        for (double d : dArray) {
            if (!(Math.abs(d) < MachineAccuracy.SQRT_EPSILON * 1.2)) continue;
            return false;
        }
        return true;
    }

    @Override
    public String getReport() {
        double[] dArray = this.parameter.getParameterValues();
        double[] dArray2 = null;
        double[] dArray3 = null;
        boolean bl = this.valuesAreSufficientlyLarge(this.parameter.getParameterValues());
        if (bl) {
            dArray2 = NumericalDerivative.gradient(this.numeric1, this.parameter.getParameterValues());
        }
        if (this.useHessian && bl) {
            dArray3 = NumericalDerivative.diagonalHessian(this.numeric1, this.parameter.getParameterValues());
        }
        for (int i = 0; i < dArray.length; ++i) {
            this.parameter.setParameterValue(i, dArray[i]);
        }
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("Gradient Peeling: ").append(new Vector(this.getGradientLogDensity()));
        stringBuilder.append("\n");
        if (dArray2 != null && bl) {
            stringBuilder.append("Gradient numeric: ").append(new Vector(dArray2));
        } else {
            stringBuilder.append("Gradient mumeric: too close to 0");
        }
        stringBuilder.append("\n");
        if (this.useHessian) {
            if (bl) {
                stringBuilder.append("Hessian Peeling: ").append(new Vector(this.getDiagonalHessianLogDensity()));
                stringBuilder.append("\n");
            }
            if (dArray3 != null && bl) {
                stringBuilder.append("Hessian numeric: ").append(new Vector(dArray3));
            } else {
                stringBuilder.append("Hessian mumeric: too close to 0");
            }
            stringBuilder.append("\n");
        }
        stringBuilder.append("\n\tgetGradientLogDensityCount = ").append(this.getGradientLogDensityCount).append("\n");
        stringBuilder.append(this.treeDataLikelihood.getReport());
        return stringBuilder.toString();
    }

    @Override
    public LogColumn[] getColumns() {
        LogColumn[] logColumnArray = new LogColumn[]{new LogColumn.Default("gradient report", new Object(){

            public String toString() {
                return "\n" + HyperParameterGradient.this.getReport();
            }
        })};
        return logColumnArray;
    }
}

