/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.rexp.DecorationUtil;
import org.jpmml.rexp.RBooleanVector;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RNumberVector;
import org.jpmml.rexp.RStringVector;
import org.jpmml.rexp.TreeModelConverter;

public class RangerConverter
extends TreeModelConverter<RGenericVector> {
    boolean hasDependentVar = false;

    public RangerConverter(RGenericVector ranger) {
        super(ranger);
    }

    @Override
    public void encodeSchema(RExpEncoder encoder) {
        DataField dataField;
        RGenericVector ranger = (RGenericVector)this.getObject();
        RGenericVector forest = ranger.getGenericElement("forest", false);
        if (forest == null) {
            throw new IllegalArgumentException("Missing 'forest' element. Please re-train the model object with 'write.forest' argument set to TRUE");
        }
        RStringVector treeType = ranger.getStringElement("treetype");
        RGenericVector variableLevels = DecorationUtil.getGenericElement(ranger, "variable.levels");
        FieldName name = FieldName.create((String)"_target");
        switch ((String)treeType.asScalar()) {
            case "Regression": {
                dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
                break;
            }
            case "Classification": 
            case "Probability estimation": {
                RStringVector levels = forest.getStringElement("levels");
                dataField = encoder.createDataField(name, OpType.CATEGORICAL, null, levels.getValues());
                break;
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
        encoder.setLabel(dataField);
        RBooleanVector isOrdered = forest.getBooleanElement("is.ordered");
        RStringVector independentVariableNames = forest.getStringElement("independent.variable.names");
        this.hasDependentVar = isOrdered.size() == independentVariableNames.size() + 1;
        for (int i = 0; i < independentVariableNames.size(); ++i) {
            DataField dataField2;
            if (!isOrdered.getValue(this.hasDependentVar ? i + 1 : i).booleanValue()) {
                throw new IllegalArgumentException();
            }
            String independentVariableName = independentVariableNames.getValue(i);
            FieldName name2 = FieldName.create((String)independentVariableName);
            if (variableLevels.hasElement(independentVariableName)) {
                RStringVector levels = variableLevels.getStringElement(independentVariableName);
                dataField2 = encoder.createDataField(name2, OpType.CATEGORICAL, DataType.STRING, levels.getValues());
            } else {
                dataField2 = encoder.createDataField(name2, OpType.CONTINUOUS, DataType.DOUBLE);
            }
            encoder.addFeature((Field<?>)dataField2);
        }
    }

    public MiningModel encodeModel(Schema schema) {
        RGenericVector ranger = (RGenericVector)this.getObject();
        RStringVector treetype = ranger.getStringElement("treetype");
        switch ((String)treetype.asScalar()) {
            case "Regression": {
                return this.encodeRegression(ranger, schema);
            }
            case "Classification": {
                return this.encodeClassification(ranger, schema);
            }
            case "Probability estimation": {
                return this.encodeProbabilityForest(ranger, schema);
            }
        }
        throw new IllegalArgumentException();
    }

    private MiningModel encodeRegression(RGenericVector ranger, Schema schema) {
        RGenericVector forest = ranger.getGenericElement("forest");
        ScoreEncoder scoreEncoder = new ScoreEncoder(){

            @Override
            public Node encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
                node.setScore((Object)splitValue);
                return node;
            }
        };
        List<TreeModel> treeModels = this.encodeForest(forest, MiningFunction.REGRESSION, scoreEncoder, schema);
        MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.AVERAGE, treeModels));
        return miningModel;
    }

    private MiningModel encodeClassification(RGenericVector ranger, Schema schema) {
        RGenericVector forest = ranger.getGenericElement("forest");
        final RStringVector levels = forest.getStringElement("levels");
        ScoreEncoder scoreEncoder = new ScoreEncoder(){

            @Override
            public Node encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
                int index = ValueUtil.asInt((Number)splitValue);
                if (terminalClassCount != null) {
                    throw new IllegalArgumentException();
                }
                node.setScore((Object)levels.getValue(index - 1));
                return node;
            }
        };
        List<TreeModel> treeModels = this.encodeForest(forest, MiningFunction.CLASSIFICATION, scoreEncoder, schema);
        MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.MAJORITY_VOTE, treeModels));
        return miningModel;
    }

    private MiningModel encodeProbabilityForest(RGenericVector ranger, Schema schema) {
        RGenericVector forest = ranger.getGenericElement("forest");
        final RStringVector levels = forest.getStringElement("levels");
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        ScoreEncoder scoreEncoder = new ScoreEncoder(){

            @Override
            public Node encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
                if (splitValue.doubleValue() != 0.0 || terminalClassCount == null || terminalClassCount.size() != levels.size()) {
                    throw new IllegalArgumentException();
                }
                node = new ClassifierNode(node);
                List scoreDistributions = node.getScoreDistributions();
                Number maxProbability = null;
                for (int i = 0; i < terminalClassCount.size(); ++i) {
                    String value = levels.getValue(i);
                    Number probability = (Number)terminalClassCount.getValue(i);
                    if (maxProbability == null || ((Comparable)((Object)maxProbability)).compareTo(probability) < 0) {
                        node.setScore((Object)value);
                        maxProbability = probability;
                    }
                    ScoreDistribution scoreDistribution = new ScoreDistribution((Object)value, probability);
                    scoreDistributions.add(scoreDistribution);
                }
                return node;
            }
        };
        List<TreeModel> treeModels = this.encodeForest(forest, MiningFunction.CLASSIFICATION, scoreEncoder, schema);
        MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel)).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.AVERAGE, treeModels)).setOutput(ModelUtil.createProbabilityOutput((DataType)DataType.DOUBLE, (CategoricalLabel)categoricalLabel));
        return miningModel;
    }

    private List<TreeModel> encodeForest(RGenericVector forest, MiningFunction miningFunction, ScoreEncoder scoreEncoder, Schema schema) {
        RNumberVector<?> numTrees = forest.getNumericElement("num.trees");
        RGenericVector childNodeIDs = forest.getGenericElement("child.nodeIDs");
        RGenericVector splitVarIDs = forest.getGenericElement("split.varIDs");
        RGenericVector splitValues = forest.getGenericElement("split.values");
        RGenericVector terminalClassCounts = forest.getGenericElement("terminal.class.counts", false);
        Schema segmentSchema = schema.toAnonymousSchema();
        ArrayList<TreeModel> treeModels = new ArrayList<TreeModel>();
        for (int i = 0; i < ValueUtil.asInt((Number)((Number)numTrees.asScalar())); ++i) {
            TreeModel treeModel = this.encodeTreeModel(miningFunction, scoreEncoder, (RGenericVector)childNodeIDs.getValue(i), (RNumberVector)splitVarIDs.getValue(i), (RNumberVector)splitValues.getValue(i), terminalClassCounts != null ? (RGenericVector)terminalClassCounts.getValue(i) : null, segmentSchema);
            treeModels.add(treeModel);
        }
        return treeModels;
    }

    private TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder scoreEncoder, RGenericVector childNodeIDs, RNumberVector<?> splitVarIDs, RNumberVector<?> splitValues, RGenericVector terminalClassCounts, Schema schema) {
        RNumberVector leftChildIDs = (RNumberVector)childNodeIDs.getValue(0);
        RNumberVector rightChildIDs = (RNumberVector)childNodeIDs.getValue(1);
        Node root = this.encodeNode((org.dmg.pmml.Predicate)True.INSTANCE, 0, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, new CategoryManager(), schema);
        TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema((Label)schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
        return treeModel;
    }

    private Node encodeNode(org.dmg.pmml.Predicate predicate, int index, ScoreEncoder scoreEncoder, RNumberVector<?> leftChildIDs, RNumberVector<?> rightChildIDs, RNumberVector<?> splitVarIDs, RNumberVector<?> splitValues, RGenericVector terminalClassCounts, CategoryManager categoryManager, Schema schema) {
        org.dmg.pmml.Predicate rightPredicate;
        org.dmg.pmml.Predicate leftPredicate;
        RNumberVector terminalClassCount;
        int leftIndex = ValueUtil.asInt((Number)((Number)leftChildIDs.getValue(index)));
        int rightIndex = ValueUtil.asInt((Number)((Number)rightChildIDs.getValue(index)));
        Number splitValue = (Number)splitValues.getValue(index);
        RNumberVector rNumberVector = terminalClassCount = terminalClassCounts != null ? (RNumberVector)terminalClassCounts.getValue(index) : null;
        if (leftIndex == 0 && rightIndex == 0) {
            LeafNode result = new LeafNode(null, predicate);
            return scoreEncoder.encode((Node)result, splitValue, terminalClassCount);
        }
        CategoryManager leftCategoryManager = categoryManager;
        CategoryManager rightCategoryManager = categoryManager;
        int splitVarIndex = ValueUtil.asInt((Number)((Number)splitVarIDs.getValue(index)));
        Feature feature = schema.getFeature(this.hasDependentVar ? splitVarIndex - 1 : splitVarIndex);
        if (feature instanceof CategoricalFeature) {
            CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
            int splitLevelIndex = ValueUtil.asInt((Number)Math.floor(splitValue.doubleValue()));
            FieldName name = categoricalFeature.getName();
            List values = categoricalFeature.getValues();
            Predicate valueFilter = categoryManager.getValueFilter(name);
            List<Object> leftValues = RangerConverter.filterValues(values.subList(0, splitLevelIndex), valueFilter);
            List<Object> rightValues = RangerConverter.filterValues(values.subList(splitLevelIndex, values.size()), valueFilter);
            leftCategoryManager = leftCategoryManager.fork(name, leftValues);
            rightCategoryManager = rightCategoryManager.fork(name, rightValues);
            leftPredicate = this.createSimpleSetPredicate((Feature)categoricalFeature, leftValues);
            rightPredicate = this.createSimpleSetPredicate((Feature)categoricalFeature, rightValues);
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            leftPredicate = this.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, splitValue);
            rightPredicate = this.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_THAN, splitValue);
        }
        Node leftChild = this.encodeNode(leftPredicate, leftIndex, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, leftCategoryManager, schema);
        Node rightChild = this.encodeNode(rightPredicate, rightIndex, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, rightCategoryManager, schema);
        Node result = new BranchNode(null, predicate).addNodes(leftChild, rightChild);
        return result;
    }

    private static List<Object> filterValues(List<?> values, Predicate<Object> valueFilter) {
        return values.stream().filter(valueFilter).collect(Collectors.toList());
    }

    private static interface ScoreEncoder {
        public Node encode(Node var1, Number var2, RNumberVector<?> var3);
    }
}

