#' @title Locate the next design point for a (D)GP emulator or a bundle of (D)GP emulators using PEI
#'
#' @description This function searches from a candidate set to locate the next design point(s) to be added to a (D)GP emulator
#'     or a bundle of (D)GP emulators using the Pseudo Expected Improvement (PEI), see the reference below.
#'
#' @param object can be one of the following:
#' * the S3 class `gp`.
#' * the S3 class `dgp`.
#' * the S3 class `bundle`.
#' @param x_cand a matrix (with each row being a design point and column being an input dimension) that gives a candidate set
#'     from which the next design point(s) are determined.
#' @param pseudo_points an optional matrix (with columns being input dimensions) that gives the pseudo input points for PEI calculations. See the reference below
#'    for further details about the pseudo points. When `object` is an instance of the `bundle` class, `pseudo_points` can also be a list with the length
#'    equal to the number of emulators in the bundle. Each element in the list is a matrix that gives the the pseudo input points for the corresponding
#'    emulator in the bundle. Defaults to `NULL`. When [pei()] is used in [design()], `pseudo_points` will be automatically generated by [design()].
#' @param batch_size an integer that gives the number of design points to be chosen.
#'     Defaults to `1`.
#' @param workers the number of workers/cores to be used for the criterion calculation. If set to `NULL`,
#'     the number of workers is set to `(max physical cores available - 1)`. Defaults to `1`.
#' @param threading a bool indicating whether to use the multi-threading to accelerate the criterion calculation for a DGP emulator.
#'     Turning this option on could improve the speed of criterion calculations when the DGP emulator is built with a moderately large number of
#'     training data points and the Matérn-2.5 kernel.
#' @param aggregate an R function that aggregates scores of the PEI across different output dimensions (if `object` is an instance
#'     of the `dgp` class) or across different emulators (if `object` is an instance of the `bundle` class). The function should be specified in the
#'     following basic form:
#' * the first argument is a matrix representing scores. The rows of the matrix correspond to different design points. The number of columns
#'   of the matrix equals to:
#'   - the emulator output dimension if `object` is an instance of the `dgp` class; or
#'   - the number of emulators contained in `object` if `object` is an instance of the `bundle` class.
#' * the output should be a vector that gives aggregations of scores at different design points.
#'
#' Set to `NULL` to disable the aggregation. Defaults to `NULL`.
#' @param ... any arguments (with names different from those of arguments used in [pei()]) that are used by `aggregate` or [gp()] (for emulating the
#' ES-LOO errors) can be passed here.
#'
#' @return
#' * If `object` is an instance of the `gp` class, a vector is returned with the length equal to `batch_size`, giving the positions (i.e., row numbers)
#'   of next design points from `x_cand`.
#' * If `object` is an instance of the `dgp` class, a matrix is returned with row number equal to `batch_size` and column number equal to one (if `aggregate`
#'   is not `NULL`) or the output dimension (if `aggregate` is `NULL`), giving positions (i.e., row numbers) of next design points from `x_cand` to be added
#'   to the DGP emulator across different outputs.
#' * If `object` is an instance of the `bundle` class, a matrix is returned with row number equal to `batch_size` and column number equal to the number of
#'   emulators in the bundle, giving positions (i.e., row numbers) of next design points from `x_cand` to be added to individual emulators.
#'
#' @note
#' * The column order of the first argument of `aggregate` must be consistent with the order of emulator output dimensions (if `object` is an instance of the
#'     `dgp` class), or the order of emulators placed in `object` if `object` is an instance of the `bundle` class;
#' * The function is only applicable to DGP emulators without likelihood layers.
#' * Any R vector detected in `x_cand` and `pseudo_points` will be treated as a column vector and automatically converted into a single-column
#'   R matrix.
#' @references
#' Mohammadi, H., Challenor, P., Williamson, D., & Goodfellow, M. (2022). Cross-validation-based adaptive sampling for Gaussian process models. *SIAM/ASA Journal on Uncertainty Quantification*, **10(1)**, 294-316.
#'
#' @details See further examples and tutorials at <https://mingdeyu.github.io/dgpsi-R/>.
#' @examples
#' \dontrun{
#'
#' # load packages and the Python env
#' library(lhs)
#' library(dgpsi)
#' init_py()
#'
#' # construct a 1D non-stationary function
#' f <- function(x) {
#'  sin(30*((2*x-1)/2-0.4)^5)*cos(20*((2*x-1)/2-0.4))
#' }
#'
#' # generate the initial design
#' X <- maximinLHS(10,1)
#' Y <- f(X)
#'
#' # training a 2-layered DGP emulator with the global connection off
#' m <- dgp(X, Y, connect = F)
#'
#' # generate a candidate set
#' x_cand <- maximinLHS(200,1)
#'
#' # locate the next design point using PEI
#' next_point <- pei(m, x_cand = x_cand)
#' X_new <- x_cand[next_point,,drop = F]
#'
#' # obtain the corresponding output at the located design point
#' Y_new <- f(X_new)
#'
#' # combine the new input-output pair to the existing data
#' X <- rbind(X, X_new)
#' Y <- rbind(Y, Y_new)
#'
#' # update the DGP emulator with the new input and output data and refit with 500 training iterations
#' m <- update(m, X, Y, refit = TRUE, N = 500)
#'
#' # plot the LOO validation
#' plot(m)
#' }
#' @md
#' @name pei
#' @export
pei <- function(object, x_cand, ...){
  UseMethod("pei")
}

#' @rdname pei
#' @method pei gp
#' @export
pei.gp <- function(object, x_cand, pseudo_points = NULL, batch_size = 1, ...) {
  #check class
  if ( !inherits(object,"gp") ) stop("'object' must be an instance of the 'gp' class.", call. = FALSE)
  training_input <- object$data$X
  n_dim_X <- ncol(training_input)
  #check x_cand
  if ( !is.matrix(x_cand)&!is.vector(x_cand) ) stop("'x_cand' must be a vector or a matrix.", call. = FALSE)
  if ( is.vector(x_cand) ) x_cand <- as.matrix(x_cand)
  if ( ncol(x_cand)!=n_dim_X ) stop("'x_cand' and the training input have different number of dimensions.", call. = FALSE)
  #check pseudo_points
  if ( !is.null(pseudo_points) ){
    if ( !is.matrix(pseudo_points)&!is.vector(pseudo_points) ) stop("'pseudo_points' must be a vector or a matrix.", call. = FALSE)
    if ( is.vector(pseudo_points) ) pseudo_points <- as.matrix(pseudo_points)
    if ( ncol(pseudo_points)!=n_dim_X ) stop("'pseudo_points' and the training input have different number of dimensions.", call. = FALSE)
    }
  #check batch size
  batch_size <- as.integer(batch_size)
  if ( batch_size < 1 ) stop("'batch_size' must be >= 1.", call. = FALSE)
  #set gp params
  default_param <- list(name = 'matern2.5', nugget = 1e-12, verb = FALSE)
  add_arg <- list(...)
  gpnames <- methods::formalArgs(gp)
  gpidx <- gpnames %in% names(add_arg)
  gpparam <- add_arg[gpnames[gpidx]]
  final_gp_arg <- utils::modifyList(default_param, gpparam)

  #locate
  ei <- rep(0, nrow(x_cand))
  score <- log(object$emulator_obj$esloo())
  score_gp <- do.call(gp, c(list(training_input, score), final_gp_arg))
  score_gp <- predict.gp(score_gp, x_cand)
  score_mean <- score_gp$results$mean
  score_var <- score_gp$results$var
  score_std <- sqrt(score_gp$results$var)
  score_scale <- as.numeric(score_gp$emulator_obj$kernel$scale)
  mask <- (score_var/score_scale)>1e-6
  error <- score_mean[mask]-max(score)
  st_error <- error/score_std[mask]
  ei[mask] <- error*stats::pnorm(st_error) + score_std[mask]*stats::dnorm(st_error)
  if ( !is.null(pseudo_points) ){
    pseudo_training_points <- rbind(training_input, pseudo_points)
  } else {
    pseudo_training_points <- training_input
  }
  idx <- c()
  for ( i in 1:batch_size){
    rf <- pkg.env$np$prod(1 - pkg.env$dgpsi$functions$k_one_vec(pseudo_training_points,
                    x_cand, score_gp$emulator_obj$kernel$length, score_gp$emulator_obj$kernel$name), axis=0L)
    pei <- rf*ei
    idx_i <- which.max(pei)
    idx <- c(idx, idx_i)
    pseudo_training_points <- rbind(pseudo_training_points, x_cand[idx_i,,drop = F])
  }
  pkg.env$py_gc$collect()
  gc(full=T)
  return(idx)
}


#' @rdname pei
#' @method pei dgp
#' @export
pei.dgp <- function(object, x_cand, pseudo_points = NULL, batch_size = 1, workers = 1, threading = FALSE, aggregate = NULL, ...) {
  #check class
  if ( !inherits(object,"dgp") ) stop("'object' must be an instance of the 'dgp' class.", call. = FALSE)
  if ( object$constructor_obj$all_layer[[object$constructor_obj$n_layer]][[1]]$type == 'likelihood' ){
    stop("The function is only applicable to DGP emulators without likelihood layers.", call. = FALSE)
  }
  object$emulator_obj$set_nb_parallel(threading)
  training_input <- object$data$X
  training_output <- object$data$Y
  n_dim_X <- ncol(training_input)
  n_dim_Y <- ncol(training_output)
  #check x_cand
  if ( !is.matrix(x_cand)&!is.vector(x_cand) ) stop("'x_cand' must be a vector or a matrix.", call. = FALSE)
  if ( is.vector(x_cand) ) x_cand <- as.matrix(x_cand)
  if ( ncol(x_cand)!=n_dim_X ) stop("'x_cand' and the training input have different number of dimensions.", call. = FALSE)
  #check pseudo_points
  if ( !is.null(pseudo_points) ){
    if ( !is.matrix(pseudo_points)&!is.vector(pseudo_points) ) stop("'pseudo_points' must be a vector or a matrix.", call. = FALSE)
    if ( is.vector(pseudo_points) ) pseudo_points <- as.matrix(pseudo_points)
    if ( ncol(pseudo_points)!=n_dim_X ) stop("'pseudo_points' and the training input have different number of dimensions.", call. = FALSE)
  }
  #check core number
  if( !is.null(workers) ) {
    workers <- as.integer(workers)
    if ( workers < 1 ) stop("The worker number must be >= 1.", call. = FALSE)
  }
  add_arg <- list(...)
  #check aggregate
  if ( !is.null(aggregate) ){
    gnames <- methods::formalArgs(aggregate)
  }
  #set gp params
  default_param <- list(name = 'matern2.5', nugget = 1e-12, verb = FALSE)
  gpnames <- methods::formalArgs(gp)
  #default_param <- list(name = 'matern2.5',verb = FALSE, N=100, B=10)
  #gpnames <- methods::formalArgs(dgp)
  gpidx <- gpnames %in% names(add_arg)
  gpparam <- add_arg[gpnames[gpidx]]
  final_gp_arg <- utils::modifyList(default_param, gpparam)
  #check batch size
  batch_size <- as.integer(batch_size)
  if ( batch_size < 1 ) stop("'batch_size' must be >= 1.", call. = FALSE)
  #locate
  ei <- matrix(0, nrow(x_cand), n_dim_Y)
  length_list <- list()
  name_list <- list()
  if ( identical(workers,as.integer(1)) ){
    score <- log(object$emulator_obj$esloo(training_input, training_output))
  } else {
    score <- log(object$emulator_obj$pesloo(training_input, training_output, core_num = workers))
  }
  for ( i in 1:n_dim_Y ){
    score_gp <- do.call(gp, c(list(training_input, score[,i]), final_gp_arg))
    #score_gp_cp <- gp(training_input, score[,i], name='matern2.5',verb=F)
    #score_gp <- do.call(dgp, c(list(training_input, score[,i]), final_gp_arg))
    length_list[[i]] <- score_gp$emulator_obj$kernel$length
    name_list[[i]] <- score_gp$emulator_obj$kernel$name
    score_gp <- predict.gp(score_gp, x_cand)
    score_mean <- score_gp$results$mean
    score_var <- score_gp$results$var
    score_std <- sqrt(score_gp$results$var)
    score_scale <- as.numeric(score_gp$emulator_obj$kernel$scale)
    #score_scale <- as.numeric(score_gp$emulator_obj$all_layer[[score_gp$constructor_obj$n_layer]][[i]]$scale)
    mask <- (score_var/score_scale)>1e-6
    error <- score_mean[mask]-max(score[,i])
    st_error <- error/score_std[mask]
    ei[mask,i] <- error*stats::pnorm(st_error) + score_std[mask]*stats::dnorm(st_error)
  }
  if ( !is.null(pseudo_points) ){
    pseudo_training_points <- rbind(training_input, pseudo_points)
  } else {
    pseudo_training_points <- training_input
  }

  idx <- c()
  #res = object$emulator_obj$metric(x_cand = x_cand, method = 'ALM', score_only = TRUE)
  for ( i in 1:batch_size){
    pei <- c()

    for ( j in 1:n_dim_Y ){
      rf <- pkg.env$np$prod(1 - pkg.env$dgpsi$functions$k_one_vec(pseudo_training_points,
                                                                  x_cand, length_list[[j]], name_list[[j]]), axis=0L)
      #pei <- cbind(pei, rf*ei[,j]*(res[,j]-min(res[,j]))/(max(res[,j])-min(res[,j])))
      pei <- cbind(pei, rf*ei[,j])
    }

    if ( is.null(aggregate) ){
      idx_i <- pkg.env$np$argmax(pei, axis=0L) + 1
    } else {
      if ( ncol(pei)==1 ){
        idx_i <- pkg.env$np$argmax(pei, axis=0L) + 1
      } else {
        if ("..." %in% gnames){
          agg_pei <- do.call(aggregate, c(list(pei), add_arg))
        } else {
          gidx <- gnames %in% names(add_arg)
          gparam <- add_arg[gnames[gidx]]
          agg_pei <- do.call(aggregate, c(list(pei), gparam))
        }
        idx_i <- which.max(agg_pei)
      }
    }

    idx <- rbind(idx, idx_i)
    pseudo_training_points <- rbind(pseudo_training_points, x_cand[idx_i,,drop = F])
  }
  pkg.env$py_gc$collect()
  gc(full=T)
  return(idx)
}


#' @rdname pei
#' @method pei bundle
#' @export
pei.bundle <- function(object, x_cand, pseudo_points = NULL, batch_size = 1, workers = 1, threading = FALSE, aggregate = NULL, ...) {
  #check class
  if ( !inherits(object,"bundle") ) stop("'object' must be an instance of the 'bundle' class.", call. = FALSE)
  #check no of emulators
  n_emulators <- length(object)
  if ( "data" %in% names(object) ) n_emulators <- n_emulators - 1
  if ( "design" %in% names(object) ) n_emulators <- n_emulators - 1
  training_input <- list()
  training_output <- list()
  for ( k in 1:n_emulators ){
    training_input[[k]] <- object[[paste('emulator',k,sep='')]]$data$X
    training_output[[k]] <- object[[paste('emulator',k,sep='')]]$data$Y
  }
  n_dim_X <- ncol(training_input[[1]])
  #check x_cand
  if ( !is.matrix(x_cand)&!is.vector(x_cand) ) stop("'x_cand' must be a vector or a matrix.", call. = FALSE)
  if ( is.vector(x_cand) ) x_cand <- as.matrix(x_cand)
  if ( ncol(x_cand)!=n_dim_X ) stop("'x_cand' and the training input have different number of dimensions.", call. = FALSE)
  #check pseudo_points
  if ( !is.null(pseudo_points) ){
    if ( is.list(pseudo_points) ){
      if ( length(pseudo_points)!=n_emulators ) stop("'pseudo_points' should be a list of length equal to the number of emulators in the bundle.", call. = FALSE)
      for ( i in 1:n_emulators ){
        if ( !is.matrix(pseudo_points[[i]])&!is.vector(pseudo_points[[i]]) ) stop(sprintf("Element %i in 'pseudo_points' must be a vector or a matrix.", i), call. = FALSE)
        if ( is.vector(pseudo_points[[i]]) ) pseudo_points[[i]] <- as.matrix(pseudo_points[[i]])
        if ( ncol(pseudo_points[[i]])!=n_dim_X ) stop(sprintf("Element %i in 'pseudo_points' and the training input have different number of dimensions.", i), call. = FALSE)
      }
    } else {
      if ( !is.matrix(pseudo_points)&!is.vector(pseudo_points) ) stop("'pseudo_points' must be a vector or a matrix.", call. = FALSE)
      if ( is.vector(pseudo_points) ) pseudo_points <- as.matrix(pseudo_points)
      if ( ncol(pseudo_points)!=n_dim_X ) stop("'pseudo_points' and the training input have different number of dimensions.", call. = FALSE)
      pseudo_points <- replicate(n_emulators, pseudo_points, simplify = FALSE)
    }
  }
  #check core number
  if( !is.null(workers) ) {
    workers <- as.integer(workers)
    if ( workers < 1 ) stop("The worker number must be >= 1.", call. = FALSE)
  }
  add_arg <- list(...)
  #check aggregate
  if ( !is.null(aggregate) ){
    gnames <- methods::formalArgs(aggregate)
  }
  #set gp params
  default_param <- list(name = 'matern2.5', nugget = 1e-12, verb = FALSE)
  gpnames <- methods::formalArgs(gp)
  gpidx <- gpnames %in% names(add_arg)
  gpparam <- add_arg[gpnames[gpidx]]
  final_gp_arg <- utils::modifyList(default_param, gpparam)
  #check batch size
  batch_size <- as.integer(batch_size)
  if ( batch_size < 1 ) stop("'batch_size' must be >= 1.", call. = FALSE)
  #locate
  ei <- matrix(0, nrow(x_cand), n_emulators)
  length_list <- list()
  name_list <- list()
  for ( i in 1:n_emulators ){
    obj_i <- object[[paste('emulator',i,sep='')]]
    if ( inherits(obj_i,"gp") ){
      score <- log(obj_i$emulator_obj$esloo())
    } else {
      obj_i$emulator_obj$set_nb_parallel(threading)
      if ( identical(workers,as.integer(1)) ){
        score <- log(obj_i$emulator_obj$esloo(training_input[[i]], training_output[[i]]))
      } else {
        score <- log(obj_i$emulator_obj$pesloo(training_input[[i]], training_output[[i]], core_num = workers))
      }
    }
    score_gp <- do.call(gp, c(list(training_input[[i]], score), final_gp_arg))
    length_list[[i]] <- score_gp$emulator_obj$kernel$length
    name_list[[i]] <- score_gp$emulator_obj$kernel$name
    score_gp <- predict.gp(score_gp, x_cand)
    score_mean <- score_gp$results$mean
    score_var <- score_gp$results$var
    score_std <- sqrt(score_gp$results$var)
    score_scale <- as.numeric(score_gp$emulator_obj$kernel$scale)
    mask <- (score_var/score_scale)>1e-6
    error <- score_mean[mask]-max(score)
    st_error <- error/score_std[mask]
    ei[mask,i] <- error*stats::pnorm(st_error) + score_std[mask]*stats::dnorm(st_error)
  }
  pseudo_training_points <- training_input
  if ( !is.null(pseudo_points) ){
    for ( i in 1:n_emulators ){
      pseudo_training_points[[i]] <- rbind(pseudo_training_points[[i]], pseudo_points[[i]])
    }
  }

  idx <- c()
  for ( i in 1:batch_size){
    pei <- c()

    for ( j in 1:n_emulators ){
      rf <- pkg.env$np$prod(1 - pkg.env$dgpsi$functions$k_one_vec(pseudo_training_points[[j]],
                                                                  x_cand, length_list[[j]], name_list[[j]]), axis=0L)
      pei <- cbind(pei, rf*ei[,j])
    }

    if ( is.null(aggregate) ){
      idx_i <- pkg.env$np$argmax(pei, axis=0L) + 1
    } else {
      if ("..." %in% gnames){
        agg_pei <- do.call(aggregate, c(list(pei), add_arg))
      } else {
        gidx <- gnames %in% names(add_arg)
        gparam <- add_arg[gnames[gidx]]
        agg_pei <- do.call(aggregate, c(list(pei), gparam))
      }
      idx_i <- which.max(agg_pei)
      idx_i <- rep(idx_i, n_emulators)
    }

    idx <- rbind(idx, idx_i)
    for ( j in 1:n_emulators ){
      pseudo_training_points[[j]] <- rbind(pseudo_training_points[[j]], x_cand[idx_i[j],,drop = F])
    }
  }
  pkg.env$py_gc$collect()
  gc(full=T)
  return(idx)
}


pp <- function(x, limits){
  N <- nrow(limits)
  if (N!=1){
    n_vertex <- as.matrix( 0:(2^N-1), ncol = 1)
    bt <- bitops::bitShiftL(1, 0:(N-1))
    bv_fct <- function(x){bitwAnd(x, bt)}
    vertex_coord_cube <- (apply(n_vertex, 1, bv_fct)>0)*1
    vertex_coord <- t((limits[,2] - limits[,1])*vertex_coord_cube + limits[,1])
  } else {
    vertex_coord <- c()
  }

  idx_max <- pkg.env$np$argmax(x, axis=0L) + 1
  pp_max <- x[idx_max,,drop=F]
  diag(pp_max) <- limits[,2]

  idx_min <- pkg.env$np$argmin(x, axis=0L) + 1
  pp_min <- x[idx_min,,drop=F]
  diag(pp_min) <- limits[,1]

  pp_all <- rbind(vertex_coord, pp_max, pp_min)
  return(pp_all)
}


