/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.LinearRegression;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.UnsupportedAttributeTypeException;
import weka.core.UnsupportedClassTypeException;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.supervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.instance.RemoveRange;

public class LeastMedSq
extends Classifier
implements OptionHandler {
    private double[] m_Residuals;
    private double[] m_weight;
    private double m_SSR;
    private double m_scalefactor;
    private double m_bestMedian = Double.POSITIVE_INFINITY;
    private LinearRegression m_currentRegression;
    private LinearRegression m_bestRegression;
    private LinearRegression m_ls;
    private Instances m_Data;
    private Instances m_RLSData;
    private Instances m_SubSample;
    private ReplaceMissingValues m_MissingFilter;
    private NominalToBinary m_TransformFilter;
    private RemoveRange m_SplitFilter;
    private int m_samplesize = 4;
    private int m_samples;
    private boolean m_israndom = false;
    private boolean m_debug = false;
    private Random m_random;
    private long m_randomseed = 0L;

    public String globalInfo() {
        return "Implements a least median sqaured linear regression utilising the existing weka LinearRegression class to form predictions. Least squared regression functions are generated from random subsamples of the data. The least squared regression with the lowest meadian squared error is chosen as the final model.\n\nThe basis of the algorithm is \n\nRobust regression and outlier detection Peter J. Rousseeuw, Annick M. Leroy. c1987";
    }

    public void buildClassifier(Instances instances) throws Exception {
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        if (!instances.classAttribute().isNumeric()) {
            throw new UnsupportedClassTypeException("Class attribute has to be numeric for regression!");
        }
        if (instances.numInstances() == 0) {
            throw new Exception("No instances in training file!");
        }
        if (instances.checkForStringAttributes()) {
            throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
        }
        this.cleanUpData(instances);
        this.getSamples();
        this.findBestRegression();
        this.buildRLSRegression();
    }

    public double classifyInstance(Instance instance) throws Exception {
        Instance instance2 = instance;
        this.m_TransformFilter.input(instance2);
        instance2 = this.m_TransformFilter.output();
        this.m_MissingFilter.input(instance2);
        instance2 = this.m_MissingFilter.output();
        return this.m_ls.classifyInstance(instance2);
    }

    private void cleanUpData(Instances instances) throws Exception {
        this.m_Data = instances;
        this.m_TransformFilter = new NominalToBinary();
        this.m_TransformFilter.setInputFormat(this.m_Data);
        this.m_Data = Filter.useFilter(this.m_Data, this.m_TransformFilter);
        this.m_MissingFilter = new ReplaceMissingValues();
        this.m_MissingFilter.setInputFormat(this.m_Data);
        this.m_Data = Filter.useFilter(this.m_Data, this.m_MissingFilter);
        this.m_Data.deleteWithMissingClass();
    }

    private void getSamples() throws Exception {
        int[] nArray = new int[]{500, 50, 22, 17, 15, 14};
        int n = this.m_samplesize * 500;
        this.m_samples = this.m_samplesize < 7 ? (this.m_Data.numInstances() < nArray[this.m_samplesize - 1] ? LeastMedSq.combinations(this.m_Data.numInstances(), this.m_samplesize) : this.m_samplesize * 500) : 3000;
        if (this.m_debug) {
            System.out.println("m_samplesize: " + this.m_samplesize);
            System.out.println("m_samples: " + this.m_samples);
            System.out.println("m_randomseed: " + this.m_randomseed);
        }
    }

    private void setRandom() {
        this.m_random = new Random(this.getRandomSeed());
    }

    private void findBestRegression() throws Exception {
        this.setRandom();
        this.m_bestMedian = Double.POSITIVE_INFINITY;
        if (this.m_debug) {
            System.out.println("Starting:");
        }
        int n = 0;
        int n2 = 0;
        while (n < this.m_samples) {
            if (this.m_debug && n % (this.m_samples / 100) == 0) {
                System.out.print("*");
            }
            this.genRegression();
            this.getMedian();
            ++n;
            ++n2;
        }
        if (this.m_debug) {
            System.out.println("");
        }
        this.m_currentRegression = this.m_bestRegression;
    }

    private void genRegression() throws Exception {
        this.m_currentRegression = new LinearRegression();
        this.m_currentRegression.setOptions(new String[]{"-S", "1"});
        this.selectSubSample(this.m_Data);
        this.m_currentRegression.buildClassifier(this.m_SubSample);
    }

    private void findResiduals() throws Exception {
        this.m_SSR = 0.0;
        this.m_Residuals = new double[this.m_Data.numInstances()];
        for (int i = 0; i < this.m_Data.numInstances(); ++i) {
            this.m_Residuals[i] = this.m_currentRegression.classifyInstance(this.m_Data.instance(i));
            int n = i;
            this.m_Residuals[n] = this.m_Residuals[n] - this.m_Data.instance(i).value(this.m_Data.classAttribute());
            int n2 = i;
            this.m_Residuals[n2] = this.m_Residuals[n2] * this.m_Residuals[i];
            this.m_SSR += this.m_Residuals[i];
        }
    }

    private void getMedian() throws Exception {
        this.findResiduals();
        int n = this.m_Residuals.length;
        LeastMedSq.select(this.m_Residuals, 0, n - 1, n / 2);
        if (this.m_Residuals[n / 2] < this.m_bestMedian) {
            this.m_bestMedian = this.m_Residuals[n / 2];
            this.m_bestRegression = this.m_currentRegression;
        }
    }

    public String toString() {
        if (this.m_ls == null) {
            return "model has not been built";
        }
        return this.m_ls.toString();
    }

    private void buildWeight() throws Exception {
        this.findResiduals();
        this.m_scalefactor = 1.4826 * (double)(1 + 5 / (this.m_Data.numInstances() - this.m_Data.numAttributes())) * Math.sqrt(this.m_bestMedian);
        this.m_weight = new double[this.m_Residuals.length];
        for (int i = 0; i < this.m_Residuals.length; ++i) {
            this.m_weight[i] = Math.sqrt(this.m_Residuals[i]) / this.m_scalefactor < 2.5 ? 1.0 : 0.0;
        }
    }

    private void buildRLSRegression() throws Exception {
        this.buildWeight();
        this.m_RLSData = new Instances(this.m_Data);
        int n = 0;
        int n2 = this.m_RLSData.numInstances();
        for (int i = 0; i < n2; ++i) {
            if (this.m_weight[n] == 0.0) {
                this.m_RLSData.delete(i);
                n2 = this.m_RLSData.numInstances();
                --i;
            }
            ++n;
        }
        if (this.m_RLSData.numInstances() == 0) {
            System.err.println("rls regression unbuilt");
            this.m_ls = this.m_currentRegression;
        } else {
            this.m_ls = new LinearRegression();
            this.m_ls.setOptions(new String[]{"-S", "1"});
            this.m_ls.buildClassifier(this.m_RLSData);
            this.m_currentRegression = this.m_ls;
        }
    }

    private static void select(double[] dArray, int n, int n2, int n3) {
        if (n2 <= n) {
            return;
        }
        int n4 = LeastMedSq.partition(dArray, n, n2);
        if (n4 > n3) {
            LeastMedSq.select(dArray, n, n4 - 1, n3);
        }
        if (n4 < n3) {
            LeastMedSq.select(dArray, n4 + 1, n2, n3);
        }
    }

    private static int partition(double[] dArray, int n, int n2) {
        double d;
        int n3 = n - 1;
        int n4 = n2;
        double d2 = dArray[n2];
        while (true) {
            if (dArray[++n3] < d2) {
                continue;
            }
            while (d2 < dArray[--n4] && n4 != n) {
            }
            if (n3 >= n4) break;
            d = dArray[n3];
            dArray[n3] = dArray[n4];
            dArray[n4] = d;
        }
        d = dArray[n3];
        dArray[n3] = dArray[n2];
        dArray[n2] = d;
        return n3;
    }

    private void selectSubSample(Instances instances) throws Exception {
        this.m_SplitFilter = new RemoveRange();
        this.m_SplitFilter.setInvertSelection(true);
        this.m_SubSample = instances;
        this.m_SplitFilter.setInputFormat(this.m_SubSample);
        this.m_SplitFilter.setInstancesIndices(this.selectIndices(this.m_SubSample));
        this.m_SubSample = Filter.useFilter(this.m_SubSample, this.m_SplitFilter);
    }

    private String selectIndices(Instances instances) {
        StringBuffer stringBuffer = new StringBuffer();
        int n = 0;
        for (int i = 0; i < this.m_samplesize; ++i) {
            while ((n = (int)(this.m_random.nextDouble() * (double)instances.numInstances())) == 0) {
            }
            stringBuffer.append(Integer.toString(n));
            if (i < this.m_samplesize - 1) {
                stringBuffer.append(",");
                continue;
            }
            stringBuffer.append("\n");
        }
        return stringBuffer.toString();
    }

    public String sampleSizeTipText() {
        return "Set the size of the random samples used to generate the least sqaured regression functions.";
    }

    public void setSampleSize(int n) {
        this.m_samplesize = n;
    }

    public int getSampleSize() {
        return this.m_samplesize;
    }

    public String randomSeedTipText() {
        return "Set the seed for selecting random subsamples of the training data.";
    }

    public void setRandomSeed(long l) {
        this.m_randomseed = l;
    }

    public long getRandomSeed() {
        return this.m_randomseed;
    }

    public void setDebug(boolean bl) {
        this.m_debug = bl;
    }

    public boolean getDebug() {
        return this.m_debug;
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>(1);
        vector.addElement(new Option("\tSet sample size\n\t(default: 4)\n", "S", 4, "-S <sample size>"));
        vector.addElement(new Option("\tSet the seed used to generate samples\n\t(default: 0)\n", "G", 0, "-G <seed>"));
        vector.addElement(new Option("\tProduce debugging output\n\t(default no debugging output)\n", "D", 0, "-D"));
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        String string = Utils.getOption('S', stringArray);
        if (string.length() != 0) {
            this.setSampleSize(Integer.parseInt(string));
        } else {
            this.setSampleSize(4);
        }
        string = Utils.getOption('G', stringArray);
        if (string.length() != 0) {
            this.setRandomSeed(Long.parseLong(string));
        } else {
            this.setRandomSeed(0L);
        }
        this.setDebug(Utils.getFlag('D', stringArray));
    }

    public String[] getOptions() {
        String[] stringArray = new String[9];
        int n = 0;
        stringArray[n++] = "-S";
        stringArray[n++] = "" + this.getSampleSize();
        stringArray[n++] = "-G";
        stringArray[n++] = "" + this.getRandomSeed();
        if (this.getDebug()) {
            stringArray[n++] = "-D";
        }
        while (n < stringArray.length) {
            stringArray[n++] = "";
        }
        return stringArray;
    }

    public static int combinations(int n, int n2) throws Exception {
        int n3 = 1;
        int n4 = 1;
        int n5 = 1;
        int n6 = n2;
        if (n2 > n) {
            throw new Exception("r must be less that or equal to n.");
        }
        n2 = Math.min(n2, n - n2);
        for (int i = 1; i <= n2; ++i) {
            n5 *= n - i + 1;
            n4 *= i;
        }
        n3 = n5 / n4;
        return n3;
    }

    public static void main(String[] stringArray) {
        try {
            System.out.println(Evaluation.evaluateModel(new LeastMedSq(), stringArray));
        }
        catch (Exception exception) {
            System.out.println(exception.getMessage());
            exception.printStackTrace();
        }
    }
}

