#include <RcppArmadillo.h>
#include "ieTest.h"

using namespace Rcpp;


// Function that tests a null hypothesis for the mediation effect using
// the maxp test. Function also requires additional values, either in the form
// of estimates and covariances, or direct cumulative probabilities. 
// If estimates and variances are passed, the user must specify what 
// distribution is to be used (t- and normal available currently). 
// The function performs the test using either
// given or calculated U values, returning the p-value.

//' MaxP-Test for the indirect effect - Ordered Mediators
//'
//' This function takes 
//' estimates and covariances, or 3 U values, for the maxp-test 
//' for an ordered mediation scenario.
//' If estimates are passed to the function, the user must specify 
//' what distribution is to be used to find the cumulative probabilities.
//' The p-value of the maxp-test is returned. 
//'
//' @param u1,u2,u3 The U values to be used in the test. Given priority over estimates, but all must be supplied.
//' @param V1Dist String value specifying the distribution of the estimate of the independent variable on the first mediator. Ignored if u1, u2, and u3 are supplied.
//' @param V1 Value of the estimate of the independent variable on the first mediator. Ignored if u1, u2, and u3 are supplied.
//' @param V1_VAR Value of the variance of the estimate of the independent variable on the first mediator. Ignored if u1, u2, and u3 are supplied.
//' @param V1_DF Degrees of freedom for V1. Only needed if t-distribution is used.
//' @param V2Dist String value specifying the distribution of the estimate of the first mediator (and interaction term) on the second mediator.
//' @param V2 Value of the estimate of the first mediator on the second mediator. Ignored if u1, u2, and u3 are supplied.
//' @param V2_VAR Value of the variance of the estimate of the first mediator on the second mediator. Ignored if u1, u2, and u3 are supplied.
//' @param V2_DF Degrees of freedom for V2. 
//' @param V2b Value of the estimate of the effect of the interaction of the independent and first mediator variable on the second mediator. Ignored if u1, u2, and u3 are supplied.
//' @param V2b_VAR Value of the variance of the estimate of the effect of the interaction of the independent and first mediator variable on the second mediator. Ignored if u1, u2, and u3 are supplied.
//' @param V2bmult Value indicating the value of the independent variable used for the interactions. Typically 1.
//' @param V3Dist String value specifying the distribution of the estimate of the second mediator (and interaction term) on the response.
//' @param V3 Value of the estimate of the second mediator on the response. Ignored if u1, u2, and u3 are supplied.
//' @param V3_VAR Value of the variance of the estimate of the second mediator on the response. Ignored if u1, u2, and u3 are supplied.
//' @param V3_DF Degrees of freedom for V3. 
//' @param V3b Value of the estimate of the effect of the interaction of the independent and second mediator variable on the response. Ignored if u1, u2, and u3 are supplied.
//' @param V3b_VAR Value of the variance of the estimate of the effect of the interaction of the independent and second mediator variable on the response. Ignored if u1, u2, and u3 are supplied.
//' @param V1_V2_cov Value of the covariance between V1 and V2. Typically 0 for fully observed data. 
//' @param V1_V2b_cov Value of the covariance between V1 and V2b. Typically 0 for fully observed data.
//' @param V1_V3_cov Value of the covariance between V1 and V3. Typically 0 for fully observed data. 
//' @param V1_V3b_cov Value of the covariance between V1 and V3b. Typically 0 for fully observed data.
//' @param V2_V2b_cov Value of the covariance between V2 and V2b 
//' @param V2_V3_cov Value of the covariance between V2 and V3 Typically 0 for fully observed data. 
//' @param V2_V3b_cov Value of the covariance between V2 and V3b Typically 0 for fully observed data.
//' @param V2b_V3_cov Value of the covariance between V2b and V3. Typically 0 for fully observed data. 
//' @param V2b_V3b_cov Value of the covariance between V2b and V3b Typically 0 for fully observed data.
//' @param V3_V3b_cov Value of the covariance between V3 and V3b.
//' @param V1_0 Null value for V1.
//' @param V2_0 Null value for V2.
//' @param V2b_0 Null value for V2b.
//' @param V3_0 Null value for V3.
//' @param V3b_0 Null value for V3b.
//' @returns The p-value of the test in the form of the larger of the p-values for the individual parameters.
//' @export
//' @examples
//' maxp_ord( u1 = .02, u2= .015, u3 = .995)
// [[Rcpp::export]]
double maxp_ord(Rcpp::Nullable<Rcpp::NumericVector> u1 = R_NilValue,
               Rcpp::Nullable<Rcpp::NumericVector> u2 = R_NilValue,
                Rcpp::Nullable<Rcpp::NumericVector> u3 = R_NilValue,
               Nullable<CharacterVector> V1Dist = R_NilValue,
               Rcpp::Nullable<Rcpp::NumericVector> V1 = R_NilValue, 
               Rcpp::Nullable<Rcpp::NumericVector> V1_VAR = R_NilValue,
               Nullable<int> V1_DF = R_NilValue,
               Nullable<CharacterVector> V2Dist = R_NilValue,
               Rcpp::Nullable<Rcpp::NumericVector> V2 = R_NilValue, 
               Rcpp::Nullable<Rcpp::NumericVector> V2_VAR = R_NilValue,
               Rcpp::Nullable<Rcpp::NumericVector> V2_DF = R_NilValue,
               double V2b = 0, double V2b_VAR = 0,
               const int V2bmult = 1,
               Nullable<CharacterVector> V3Dist = R_NilValue,
               Rcpp::Nullable<Rcpp::NumericVector> V3 = R_NilValue, 
               Rcpp::Nullable<Rcpp::NumericVector> V3_VAR = R_NilValue,
               Rcpp::Nullable<Rcpp::NumericVector> V3_DF = R_NilValue,
               double V3b = 0, double V3b_VAR = 0,
               double V1_V2_cov = 0, double V1_V2b_cov = 0,
                double V1_V3_cov = 0, double V1_V3b_cov = 0,
               double V2_V2b_cov = 0,
                double V2_V3_cov = 0, double V2_V3b_cov = 0,
               double V2b_V3_cov = 0, double V2b_V3b_cov = 0,
               double V3_V3b_cov = 0,
               double V1_0 = 0, double V2_0 = 0, double V2b_0 = 0,
               double V3_0 = 0, double V3b_0 = 0){
  
  const int V3bmult = V2bmult;
  
  // Figure out which distribution is being used
  // Use given u1, u2, and u3 values if passed
  
  // Check if calculations need to be done (or if U values are passed in)
  // Then check if correlation between variables and do transformation
  // if necessary.
  
  /* Check if calculations */
  // If one or two u values are specified, return with error;
  if( ( u1.isNull() + u2.isNull() + u3.isNull() == 1) || 
        ( u1.isNull() + u2.isNull() + u3.isNull() == 2) ){
    Rcout << "All U values must be specified." << "\n";
    return -1;
  }
  
  // Declare u1, u2, and u3 to either be filled in with inputs or calculated;
  double u1_ = .5;
  double u2_ = .5;
  double u3_ = .5;
  
  // Calculate u values if neither specified. Otherwise just use what is given down below.
  if(u1.isNull() && u2.isNull() && u3.isNull()){
    
    double V1_ = 0;
    double V2_ = 0;
    double V3_ = 0;
    double V1_VAR_ = 0;
    double V2_VAR_ = 0;
    double V3_VAR_ = 0;
    CharacterVector V1Dist_;
    CharacterVector V2Dist_;
    CharacterVector V3Dist_;

    // If only one distributrion is specified, assign to all three. If none are specified, or 
    // two are specified, return with error due to ambiguity. 
    if( ( V1Dist.isNull() + V2Dist.isNull() + V3Dist.isNull() == 3 ) ||
          ( V1Dist.isNull() + V2Dist.isNull() + V3Dist.isNull() == 1 )){
      Rcout << "Only one distribution, or three distributions must be specified." << "\n";
      return -1;
    }

    // If only one distribution is specified, assign distribution to all effects
    if( V1Dist.isNull() + V2Dist.isNull() + V3Dist.isNull() == 2 ){
      if(V1Dist.isNotNull()){
        V1Dist_ = CharacterVector(V1Dist);
        V2Dist_ = CharacterVector(V1Dist);
        V3Dist_ = CharacterVector(V1Dist);
      }else if(V2Dist.isNotNull()){
        V1Dist_ = CharacterVector(V2Dist);
        V2Dist_ = CharacterVector(V2Dist);
        V3Dist_ = CharacterVector(V2Dist);
      }else if(V3Dist.isNotNull()){
        V1Dist_ = CharacterVector(V3Dist);
        V2Dist_ = CharacterVector(V3Dist);
        V3Dist_ = CharacterVector(V3Dist);
      }
    }
    
    if( V1Dist.isNotNull() && V2Dist.isNotNull() && V3Dist.isNotNull() ){
      V1Dist_ = CharacterVector(V1Dist);
      V2Dist_ = CharacterVector(V2Dist);
      V3Dist_ = CharacterVector(V3Dist);
    }

    // If V1 or V2 (including variances) are missing, return with error;
    if(V1.isNull() || V2.isNull() || V3.isNull() || 
        V1_VAR.isNull() || V2_VAR.isNull() || V3_VAR.isNull() ){
      Rcout << "Error: Effects and their variances must be specified" << "\n";
      return -1;
    }
  
    // If correlation between v1 and either v2 or V2b, then transform;
    if( V1_V2_cov != 0 || V1_V2b_cov != 0 || V1_V3_cov != 0 || V1_V3b_cov != 0 || 
        V2_V3_cov != 0 || V1_V3b_cov != 0 ){
      arma::vec tParam(5);
      tParam(0) = NumericVector(V1)[0]; tParam(1) = NumericVector(V2)[0]; 
      tParam(2) = V2b ; tParam(3) = NumericVector(V3)[0]; 
      tParam(4) = V3b ;
      arma::vec tNull(5);
      tNull(0) = V1_0; tNull(1) = V2_0; tNull(2) = V2b_0;
      tNull(3) = V3_0; tNull(4) = V3b_0;
      
      arma::mat tSig(5, 5);
      
      tSig(0, 0) = NumericVector(V1_VAR)[0]; 
      tSig(1, 1) = NumericVector(V2_VAR)[0]; 
      tSig(2, 2) = V2b_VAR;
      tSig(3, 3) = NumericVector(V3_VAR)[0]; 
      tSig(4, 4) = V3b_VAR;
      tSig(0, 1) = V1_V2_cov; tSig(1, 0) = V1_V2_cov;
      tSig(0, 2) = V1_V2b_cov; tSig(2, 0) = V1_V2b_cov;
      tSig(0, 3) = V1_V3_cov; tSig(3, 0) = V1_V3_cov;
      tSig(0, 4) = V1_V3b_cov; tSig(4, 0) = V1_V3b_cov;
      tSig(1, 2) = V2_V2b_cov; tSig(2, 1) = V2_V2b_cov;
      tSig(1, 3) = V2_V3_cov; tSig(3, 1) = V2_V3_cov;
      tSig(1, 4) = V2_V3b_cov; tSig(4, 1) = V2_V3b_cov;
      tSig(2, 3) = V2b_V3_cov; tSig(3, 2) = V2b_V3_cov;
      tSig(2, 4) = V2b_V3b_cov; tSig(4, 2) = V2b_V3b_cov;
      tSig(3, 4) = V3_V3b_cov; tSig(4, 3) = V3_V3b_cov;
      
      if(tSig(2, 2) == 0){tSig(2, 2) = 1;}
      if(tSig(4, 4) == 0){tSig(4, 4) = 1;}
      
      arma::mat transMeans = corrTrans(tParam, tNull, tSig);
      
      // Update values with transformation;
      V1_ = transMeans(0, 0); 
      V2_ = transMeans(1, 0); 
      V2b = transMeans(2, 0);
      V3_ = transMeans(3, 0); 
      V3b = transMeans(4, 0);
      
      V1_0 = transMeans(0, 1); 
      V2_0 = transMeans(1, 1); 
      V2b_0 = transMeans(2, 1);
      V3_0 = transMeans(3, 1); 
      V3b_0 = transMeans(4, 1);
      
      V1_VAR_ = 1; V2_VAR_ = 1; V3_VAR_ = 1; 
      if(V2b_VAR != 0){V2b_VAR = 1;}
      if(V3b_VAR != 0){V3b_VAR = 1;}
      V1_V2_cov = 0; V1_V2b_cov = 0; V1_V3_cov = 0; V1_V3b_cov = 0;
      V2_V2b_cov = 0; V2_V3_cov = 0; V2_V3b_cov = 0;
      V2b_V3_cov = 0; V2b_V3b_cov = 0;
      V3_V3b_cov = 0;
    }else{
      V1_ = NumericVector(V1)[0];
      V2_ = NumericVector(V2)[0];
      V3_ = NumericVector(V3)[0];
      V1_VAR_ = NumericVector(V1_VAR)[0];
      V2_VAR_ = NumericVector(V2_VAR)[0];
      V3_VAR_ = NumericVector(V3_VAR)[0];
    }

    // Now calculate U values if possible;
    // u1;
    if( (V1Dist_[0] == "Normal") ||
        (V1Dist_[0] == "normal") ||
        (V1Dist_[0] == "N") ||
        (V1Dist_[0] == "n")){
      u1_ = normU(V1_, V1_VAR_, 0, 0, 1, 0, V1_0, 0);
    }else if((V1Dist_[0] == "T") ||
      (V1Dist_[0] == "t")){
      if(V1_DF.isNull()){
        Rcout << "Degrees of Freedom must be specified for t-distribution" << "\n";
        return 1;
      }
      u1_ = tU(V1_, V1_VAR_, NumericVector(V1_DF)[0], 0, 0, 1, 0, V1_0, 0, 1);
      
      // Rcout << "u1 = " << u1_ << ", V1 = " << V1_<< ", V1_VAR = " << V1_VAR_ << 
      //           ", DF = " << NumericVector(V1_DF)[0] << ", Null = " << V1_0 << "\n";
    }
    
    // u2;
    if( (V2Dist_[0] == "Normal") ||
        (V2Dist_[0] == "normal") ||
        (V2Dist_[0] == "N") ||
        (V2Dist_[0] == "n")){
      u2_ = normU(V2_, V2_VAR_, V2b, V2b_VAR, V2bmult, V2_V2b_cov, V2_0, V2b_0);
    }else if((V2Dist_[0] == "T") ||
      (V2Dist_[0] == "t")){
      if(V2_DF.isNull()){
        Rcout << "Degrees of Freedom must be specified for t-distribution" << "\n";
        return 1;
      }
      u2_ = tU(V2_, V2_VAR_, NumericVector(V2_DF)[0], V2b, V2b_VAR, V2bmult, V2_V2b_cov, V2_0, V2b_0, 1e5);
      // Rcout << "u2 = " << u2_ << ", V2 = " << V2_<< ", V2_VAR = " << V2_VAR_ << 
      //   ", DF = " << NumericVector(V2_DF)[0] << ", Null = " << V2_0 << 
      //     ", Interaction = " << V2b << ", Inter Var = " << V2b_VAR << 
      //       ", Cov: " << V2_V2b_cov << ", Int Null = " << V2b_0 << "\n";
    }
    
    // u3;
    if( (V3Dist_[0] == "Normal") ||
        (V3Dist_[0] == "normal") ||
        (V3Dist_[0] == "N") ||
        (V3Dist_[0] == "n")){
      u3_ = normU(V3_, V3_VAR_, V3b, V3b_VAR, V3bmult, V3_V3b_cov, V3_0, V3b_0);
    }else if((V3Dist_[0] == "T") ||
      (V3Dist_[0] == "t")){
      if(V3_DF.isNull()){
        Rcout << "Degrees of Freedom must be specified for t-distribution" << "\n";
        return 1;
      }
      u3_ = tU(V3_, V3_VAR_, NumericVector(V3_DF)[0], V3b, V3b_VAR, V3bmult, V3_V3b_cov, V3_0, V3b_0, 1e5);
      // Rcout << "u3 = " << u3_ << ", V3 = " << V3_<< ", V3_VAR = " << V3_VAR_ << 
      //   ", DF = " << NumericVector(V3_DF)[0] << ", Null = " << V3_0 << 
      //     ", Interaction = " << V3b << ", Inter Var = " << V3b_VAR << 
      //       ", Cov: " << V3_V3b_cov << ", Int Null = " << V3b_0 << "\n";
    }
    
  }else{
    u1_ = NumericVector(u1)[0];
    u2_ = NumericVector(u2)[0];
    u3_ = NumericVector(u3)[0];
  }

  // Put u values in lower left corner of region.
  u1_ = std::min(u1_, 1.0 - u1_);
  u2_ = std::min(u2_, 1.0 - u2_);
  u3_ = std::min(u3_, 1.0 - u3_);
  
//Rcout << "maxp: u1: " << u1_ << ", u2: " << u2_ << ", u3: " << u3_ << "\n";
  
  return std::max(u1_*2.0, std::max(u2_*2.0, u3_*2.0));

}
