#include "scalablebayesm.h"
#include <RcppArmadillo.h>

using namespace Rcpp;
using namespace arma;




//[[Rcpp::export]]
List rhierMnlDPParallel_rcpp_loop(int R, int keep, int nprint, List const& lgtdata, 
                                  arma::mat const& Z, arma::vec const& deltabar, 
                                  arma::mat const& Ad, List const& PrioralphaList,
                                  List const& lambda_hyper, bool drawdelta, int nvar,
                                  arma::mat oldbetas, double s, int maxuniq, int gridsize,
                                  double BayesmConstantA, int BayesmConstantnuInc,
                                  double BayesmConstantDPalpha, bool verbose) {
  //                                 , int Ns, int N, bool drawcomp) {
  
  // Wayne Taylor 2/21/2015
  // Boyang Yu 07/2023
  //Initialize variable placeholders
  int mkeep, Istar;
  vec betabar, q0v;
  mat rootpi, ucholinv, incroot, V;
  List compdraw(floor(R/keep)), nmix;
  DPOut mgout_struct;
  mnlMetropOnceOut metropout_struct;
  murooti thetaStarLgt_struct;
  
  int nz = Z.n_cols;
  int nlgt = lgtdata.size();
  
  // convert List to std::vector of struct
  List lgtdatai;
  std::vector<moments> lgtdata_vector;
  moments lgtdatai_struct;
  for (int lgt = 0; lgt<nlgt; lgt++){
    lgtdatai = lgtdata[lgt];
    
    lgtdatai_struct.y = as<vec>(lgtdatai["y"]);
    lgtdatai_struct.X = as<mat>(lgtdatai["X"]);
    lgtdatai_struct.hess = as<mat>(lgtdatai["hess"]);
    lgtdata_vector.push_back(lgtdatai_struct);    
  }
  
  //initialize indicator vector, delta, thetaStar, thetaNp10, alpha, oldprob
  ivec indic = ones<ivec>(nlgt);
  
  mat olddelta; if (drawdelta) olddelta = zeros<vec>(nz*nvar);
  //Edited
  mat Deltadraw(1,1); if(drawdelta) Deltadraw.zeros(floor((R)/keep), nz*nvar);
  
  //std::vector<murooti> thetaStar_vector(1); //declares a std::vector named thetaStar_vector that stores objects of type murooti.
  //murooti thetaNp10_struct, thetaStar0_struct;//Two additional objects of type murooti are declared: thetaNp10_struct and thetaStar0_struct.
  //thetaStar0_struct.mu = zeros<vec>(nvar); //set mu as a vector
  //sets the member variable rooti to an identity matrix of size nvar by nvar.
  //thetaStar_vector[0] = thetaStar0_struct; //it assigns the thetaStar0_struct object to the first element of the thetaStar_vector vector
  
  std::vector<murooti> thetaStar_vector(1);
  murooti thetaNp10_struct, thetaStar0_struct;
  thetaStar0_struct.mu = zeros<vec>(nvar);
  thetaStar0_struct.rooti = eye(nvar,nvar);
  thetaStar_vector[0] = thetaStar0_struct;
  
  double alpha = BayesmConstantDPalpha;
  
  //fix oldprob (only one comp)
  double oldprob = 1.0;
  
  //convert Prioralpha from List to struct
  //alpha ∼ (1-(alpha-alphamin) / (alphamax-alphamin))^{power}(1−(alpha−alphamin)/(alphamax−alphamin))^power
  priorAlpha priorAlpha_struct; 
  priorAlpha_struct.power = PrioralphaList["power"]; 
  //alpha = alphamin then expected number of components = Istarmin
  priorAlpha_struct.alphamin = PrioralphaList["alphamin"];
  //alpha = alphamax then expected number of components = Istarmax
  priorAlpha_struct.alphamax = PrioralphaList["alphamax"];
  priorAlpha_struct.n = PrioralphaList["n"];
  
  //initialize lambda
  lambda lambda_struct;
  lambda_struct.mubar = zeros<vec>(nvar);
  //Define Amu (should in Prior ?a?)
  lambda_struct.Amu = BayesmConstantA;
  //Define nu (should in Prior)
  lambda_struct.nu = nvar+BayesmConstantnuInc;
  //Define v (should in Prior)
  lambda_struct.V = lambda_struct.nu*eye(nvar,nvar);
  
  //allocate space for draws
  //  mat Deltadraw(1,1); if(drawdelta) Deltadraw.zeros(floor(R/keep), nz*nvar);//enlarge Deltadraw only if the space is required
  cube betadraw(nlgt, nvar, floor(R/keep));
  vec probdraw = zeros<vec>(floor(R/keep));
  vec oldll = zeros<vec>(nlgt);
  ////////////Results don't need to return?/////////////
  //Can be used for MNL Logit
  vec loglike = zeros<vec>(floor((R)/keep));
  //R/keep draws of number of unique components
  vec Istardraw = zeros<vec>(floor((R)/keep));
  //keepR/keep draws of number of DP tightness parameter
  vec alphadraw = zeros<vec>(floor((R)/keep));
  //Above is for theta_i=(\mu_i,\sigma_i) ~ DP(G0(\lambda), alpha)? 
  //Since we got comp draw, therefore not necessary?
  //R/keep draws of hyperparm nu (for lambda, rooti)
  vec nudraw = zeros<vec>(floor((R)/keep));
  //R/keep draws of hyperparm v (for lambda,rooti)
  vec vdraw = zeros<vec>(floor((R)/keep));
  //R/keep draws of hyperparm a (for lambda,rooti)
  vec adraw = zeros<vec>(floor((R)/keep));
  //////////////////////////////////////////////////
  if ((nprint>0) && verbose) startMcmcTimer();
  
  //start main iteration loop
  for(int rep = 0; rep<R; rep++) {
    
    //first draw comps,indic,p | {beta_i}, delta
    //  indic,p need initialization comps is drawn first in sub-Gibbs
    if(drawdelta){
      olddelta.reshape(nvar,nz);
      
      mgout_struct = rDPGibbs1(oldbetas-Z*trans(olddelta),lambda_struct,thetaStar_vector,maxuniq,indic,q0v,alpha,priorAlpha_struct,gridsize,lambda_hyper);
    } else {
      mgout_struct = rDPGibbs1(oldbetas,lambda_struct,thetaStar_vector,maxuniq,indic,q0v,alpha,priorAlpha_struct,gridsize,lambda_hyper);
    }
    
    indic = mgout_struct.indic;//For each iteration, randomly pick 1 mu and 1 rooti from 3 components of compdraw
    lambda_struct = mgout_struct.lambda_struct; //get lambda
    alpha = mgout_struct.alpha;//get alphadraw
    thetaStar_vector = mgout_struct.thetaStar_vector; //thetaStar includes indic, lambda, and alpha
    Istar = thetaStar_vector.size();
    
    if(drawdelta) {olddelta = drawDeltaDP(Z,oldbetas,indic,thetaStar_vector,deltabar,Ad);}
    
    //loop over all lgt equations drawing beta_i | ind[i],z[i,],mu[ind[i]],rooti[ind[i]]
    for (int lgt=0; lgt<nlgt; lgt++){
      thetaStarLgt_struct = thetaStar_vector[indic[lgt]-1];
      rootpi = thetaStarLgt_struct.rooti;
      //note: beta_i = Delta*z_i + u_i  Delta is nvar x nz
      if(drawdelta){
        olddelta.reshape(nvar,nz);
        betabar = thetaStarLgt_struct.mu + olddelta * trans(Z(lgt,span::all));
      } else {
        betabar = thetaStarLgt_struct.mu;
      }
      
      if (rep == 0) oldll[lgt] = llmnl(vectorise(oldbetas(lgt,span::all)),lgtdata_vector[lgt].y,lgtdata_vector[lgt].X);
      
      //compute inc.root
      ucholinv = solve(trimatu(chol(lgtdata_vector[lgt].hess+rootpi*trans(rootpi))), eye(nvar,nvar)); //trimatu interprets the matrix as upper triangular and makes solve more efficient
      incroot = chol(ucholinv*trans(ucholinv));
      
      metropout_struct = mnlMetropOnce(lgtdata_vector[lgt].y,lgtdata_vector[lgt].X,vectorise(oldbetas(lgt,span::all)),
                                       oldll[lgt],s,incroot,betabar,rootpi);
      
      oldbetas(lgt,span::all) = trans(metropout_struct.betadraw);
      oldll[lgt] = metropout_struct.oldll;   
    }
    
    //print time to completion and draw # every nprint'th draw
    if ((nprint>0) && verbose) if((rep+1)%nprint==0) infoMcmcTimer(rep, R);
    
    if(((rep+1)>0) & ((rep+1)%keep==0)){
      
      mkeep = (rep+1)/keep;
      betadraw.slice(mkeep-1) = oldbetas;
      probdraw[mkeep-1] = oldprob;
      alphadraw[mkeep-1] = alpha;
      Istardraw[mkeep-1] = Istar;
      adraw[mkeep-1] = lambda_struct.Amu;
      nudraw[mkeep-1] = lambda_struct.nu;
      V = lambda_struct.V;
      vdraw[mkeep-1] = V(0,0)/(lambda_struct.nu+0.0);
      loglike[mkeep-1] = sum(oldll);
      //Edited
      if(drawdelta) Deltadraw(mkeep-1, span::all) = trans(vectorise(olddelta));
      //
      thetaNp10_struct = mgout_struct.thetaNp1_vector[0];
      //we have to convert to a NumericVector for the plotting functions to work
      compdraw[mkeep-1] = List::create(List::create(Named("mu") = NumericVector(thetaNp10_struct.mu.begin(),thetaNp10_struct.mu.end()),Named("rooti") = thetaNp10_struct.rooti));
    }
  }
  
  if ((nprint>0) && verbose) endMcmcTimer();
///////////////////////////////////////////////Use for drawposterior  
  //draw from posterior predictive density
//  int ndraws = ceil((float)R / shards);
//  mat betadraw(ndraws, nvar);
//  int ncompdraw = floor((R)/keep);
//  ivec r = randi(ndraws, distr_param(0,ncompdraw-1));
  
//  List compdrawi, compdrawi0;
//  mat root;
//  vec mu;
  
//  for(int i = 0; i<ndraws; i++){
//    compdrawi = compdraw[r[i]];
//    compdrawi0 = compdrawi[0];
//    mat rooti = compdrawi0["rooti"];
//    vec mu = compdrawi0["mu"]; 
//    root = solve(trimatu(rooti), eye(nvar,nvar));
//    betadraw.row(i) = trans(mu + trans(root) * as<vec>(rnorm(nvar)));
//  }
///////////////////////////////////////////////////////  
//  return(List::create(
//      Named("betadraw") = betadraw,
//      Named("R") = R,
//      Named("alphadraw") = alphadraw,
//      Named("Istardraw") = Istardraw,
//      Named("adraw") = adraw,
//      Named("nudraw") = nudraw,
//      Named("vdraw") = vdraw,
//      Named("loglike") = loglike));
//}
if (drawdelta){
  return(Rcpp::List::create(Rcpp::Named("compdraw") = compdraw,
                            Rcpp::Named("probdraw")= probdraw,
                            Rcpp::Named("Deltadraw")= Deltadraw));
}else{
  return(Rcpp::List::create(Rcpp::Named("compdraw") = compdraw,
                            Rcpp::Named("probdraw")= probdraw));
} 
} 
