#' Super Learner: Cross-Validation Based Ensemble Learning
#'
#' Super learning with functional programming!
#'
#' The goal of any super learner is to use cross-validation and a
#' set of candidate learners to 1) evaluate how the learners perform
#' on held out data and 2) to use that evaluation to produce a weighted
#' average (for continuous super learner) or to pick a best learner (for
#' discrete super learner) of the specified candidate learners.
#'
#' Super learner and its statistically desirable properties have been written
#' about at length, including at least the following references:
#'
#'   * <https://biostats.bepress.com/ucbbiostat/paper222/>
#'   * <https://www.stat.berkeley.edu/users/laan/Class/Class_subpages/BASS_sec1_3.1.pdf>
#'
#' `nadir::super_learner` adopts several user-interface design-perspectives
#' that will be useful to know in understanding what it does and how it works:
#'
#'   * The specification of learners should be _very flexible_, really only
#'   constrained by the fact that candidate learners should be designed
#'   for the same prediction problem but their details can wildly vary
#'   from learner to learner.
#'   * It should be easy to specify a customized or new learner.
#'
#' `nadir::super_learner` at its core accepts `data`,
#' a `formula` (a single one passed to `formulas` is fine),
#' and a list of `learners`.
#'
#' `learners` are taken to be lists of functions of the following specification:
#'
#'   * a learner must accept a `data` and `formula` argument,
#'   * a learner may accept more arguments, and
#'   * a learner must return a prediction function that accepts `newdata` and
#' produces a vector of prediction values given `newdata`.
#'
#' In essence, a learner is specified to be a function taking (`data`, `formula`, ...)
#' and returning a _closure_ (see <http://adv-r.had.co.nz/Functional-programming.html#closures> for an introduction to closures)
#' which is a function accepting `newdata` returning predictions.
#'
#' Since many candidate learners will have hyperparameters that should be tuned,
#' like depth of trees in random forests, or the `lambda` parameter for `glmnet`,
#' extra arguments can be passed to each learner via the `extra_learner_args`
#' argument. `extra_learner_args` should be a list of lists, one list of
#' extra arguments for each learner. If no additional arguments are needed
#' for some learners, but some learners you're using do require additional
#' arguments, you can just put a `NULL` value into the `extra_learner_args`.
#' See the examples.
#'
#' In order to seamlessly support using features implemented by extensions
#' to the formula syntax (like random effects formatted like random intercepts or slopes that use the
#' `(age | strata)` syntax in
#' `lme4` or splines like `s(age | strata)` in `mgcv`), we allow for the
#' `formulas` argument to either be one fixed formula that
#' `super_learner` will use for all the models, or a vector of formulas,
#' one for each learner specified.
#'
#' Note that in the examples a mean-squared-error (mse) is calculated on
#' the same training/test set, and this is only useful as a crude diagnostic to
#' see that super_learner is working. A more rigorous performance metric to
#' evaluate `super_learner` on is the cv-rmse produced by cv_super_learner.
#'
#' @param data Data to use in training a `super_learner`.
#' @param learners A list of predictor/closure-returning-functions. See Details.
#' @param formulas Either a single regression formula or a vector of regression formulas.
#' @param y_variable Typically `y_variable` can be inferred automatically from the `formulas`, but if needed, the y_variable can be specified explicitly.
#' @param n_folds The number of cross-validation folds to use in constructing the `super_learner`.
#' @param determine_super_learner_weights A function/method to determine the weights for each of the candidate `learners`. The default is to use `determine_super_learner_weights_nnls`.
#' @param ensemble_or_discrete Defaults to `'ensemble'`, but can be set to `'discrete'`. Discrete \code{super_learner()} chooses only one of the candidate learners to have weight 1 in the resulting prediction algorithm,
#'   while \code{ensemble} \code{super_learner()} combines predictions from 1 or more candidate learners, with respective weights adding up to 1.
#' @param cv_schema A function that takes `data`, `n_folds` and returns a list containing `training_data` and `validation_data`, each of which are lists of `n_folds` data frames.
#' @param outcome_type One of 'continuous', 'binary', 'multiclass', or 'density'. \code{outcome_type} is used to infer the correct \code{determine_super_learner_weights} function if it is not explicitly passed.
#' @param extra_learner_args A list of equal length to the `learners` with additional arguments to pass to each of the specified learners.
#' @param cluster_ids (default: null) If specified, clusters will either be entirely assigned to training or validation (not both) in each cross-validation split.
#' @param strata_ids (default: null) If specified, strata are balanced across training and validation splits so that strata appear in both the training and validation splits.
#' @param weights If specified, (per observation) weights are used to
#'   indicate that risk minimization across models (i.e., the meta-learning
#'   step) should be targeted to higher weight observations.
#' @param use_complete_cases (default: FALSE) If the \code{data} passed have any NA or NaN missing data, restrict the \code{data} to
#'   \code{data[complete.cases(data),]}.
#' @returns An object of class inheriting from \code{nadir_sl_model}. This is an S3 object,
#' with elements including a \code{$predict(newdata)} method, and some information
#' about the fit model including \code{y_variable}, \code{outcome_type}, \code{learner_weights},
#' \code{holdout_predictions} and optionally information about any errors thrown by the
#' learner fitting process.
#'
#' @seealso predict.nadir_sl_model compare_learners
#'
#' @examples
#'
#' learners <- list(
#'      glm = lnr_glm,
#'      rf = lnr_rf,
#'      glmnet = lnr_glmnet,
#'      lmer = lnr_lmer
#'   )
#'
#' # mtcars example ---
#' formulas <- c(
#'   .default = mpg ~ cyl + hp, # first three models use same formula
#'   lmer = mpg ~ (1 | cyl) + hp # lme4 uses different language features
#'   )
#'
#' # fit a super_learner
#' sl_model <- super_learner(
#'   data = mtcars,
#'   formula = formulas,
#'   learners = learners)
#'
#' # We recommend taking a look at this object to see what's contained inside it:
#' sl_model
#'
#' compare_learners(sl_model)
#'
#' # iris example ---
#' sl_model <- super_learner(
#'   data = iris,
#'   formula = list(
#'     .default = Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width,
#'     lmer = Sepal.Length ~ (Sepal.Width | Species) + Petal.Length),
#'   learners = learners)
#'
#' # produce super_learner predictions and compare against the individual learners
#' compare_learners(sl_model)
#'
#' @importFrom future.apply future_lapply
#' @importFrom future plan
#' @importFrom tibble tibble
#' @importFrom tidyr pivot_wider
#' @importFrom tidyr unnest
#' @importFrom stats complete.cases
#'
#' @seealso cv_super_learner
#'
#' @export
super_learner <- function(
    data,
    learners,
    formulas,
    y_variable = NULL,
    n_folds = 5,
    determine_super_learner_weights,
    ensemble_or_discrete = 'ensemble',
    cv_schema,
    outcome_type = 'continuous',
    extra_learner_args = NULL,
    cluster_ids = NULL,
    strata_ids = NULL,
    weights = NULL,
    use_complete_cases = FALSE) {

  # error if NA or NaN appears in the data
  if (! all(complete.cases(data)) & ! use_complete_cases) {
    stop(
"nadir::super_learner() does not have any missing data imputation methods builtin.
Users may pass use_complete_cases = TRUE in order to train super_learner()
on the complete cases in the data passed.")
  }

  # if use_complete_cases and there are incomplete cases, filter to only
  # complete cases.
  if (use_complete_cases & any(! complete.cases(data))) {
    message(
"Note that use_complete_cases = TRUE will filter out any rows from data where
missing data appears, regardless of whether or not the missing data appears in a
column referenced by the formula(s) passed. Users are advised to restrict their
data to only the columns relevant to their formula(s) if passing
use_complete_cases = TRUE.")
    data <- data[complete.cases(data),]
  }

  if (! is.list(learners)) {
    stop("the learners passed must be a list of learner functions. see ?learners")
  }

  if (! outcome_type %in% c('continuous', 'density', 'binary', 'multiclass')) {
    stop("The outcome_type passed to nadir::super_learner() needs to be one 'continuous', 'density', 'binary', or 'multiclass'.")
  }

  # make the learners have unique names
  learners <- make_learner_names_unique(learners)

  # throw a warning if the sl_lnr_type of the learners do not match the outcome_type given
  validate_learner_types(learners, outcome_type)

  # if the cv_schema is not specified and cluster_ids nor strata_ids are not being used
  # then just use the cv_random_schema function.
  #
  # if the cluster_ids or strata_ids are passed and cv_schema was not specified,
  # call cv_origami_schema with folds_vfold and pass along the cluster / strata_ids.
  if (missing(cv_schema) && (is.null(cluster_ids) || missing(cluster_ids)) && (is.null(strata_ids) || missing(strata_ids))) {
    cv_schema <- cv_random_schema
  } else if (missing(cv_schema) & (!missing(cluster_ids) | !missing(strata_ids))) {
    use_cluster_ids <- ! missing(cluster_ids)
    use_strata_ids <- ! missing(strata_ids)
    cv_schema <- function(data, n_folds) {
      cv_origami_schema_args <- list(data = data,
                                     n_folds = n_folds,
                                     fold_fun = folds_vfold)
      if (use_cluster_ids) {
        cv_origami_schema_args$cluster_ids <- cluster_ids
      }
      if (use_strata_ids) {
        cv_origami_schema_args$strata_ids <- strata_ids
      }
      return(do.call(what = cv_origami_schema,
                     args = cv_origami_schema_args))
    }
  }

  use_weights <- FALSE
  if (! missing(weights) & is.numeric(weights) & length(weights) == nrow(data)) {
    if (any(is.na(weights))) {
      warning("There cannot be any NA weights passed to super_learner. Weights will not be used.")
    } else {
      data[['.sl_weights']] <- weights
      use_weights <- TRUE
    }
  }

  # set up training and validation data
  #
  # the training and validation data are lists of datasets,
  # where the training data are distinct (n-1)/n subsets of the data and the
  # validation data are the corresponding other 1/n of the data.
  training_and_validation_data <- cv_schema(data, n_folds)
  training_data <- training_and_validation_data$training_data
  validation_data <- training_and_validation_data$validation_data

  # make a tibble/dataframe to hold the trained learners:
  # one for each combination of a specific fold and a specific model
  trained_learners <- tibble::tibble(
    .sl_fold = rep(1:n_folds, length(learners)),
    learner_name = rep(names(learners), each = n_folds))

  # Extract the Y-variable (its character name)
  #
  # This only supports simple Y variables, nothing like a survival right-hand-side or
  # a transformed right-hand-side.
  #
  y_variable <- extract_y_variable(
    formulas = formulas,
    learner_names = names(learners),
    data_colnames = colnames(data),
    y_variable = y_variable
  )

  # handle vectorized formulas argument
  #
  # if the formulas is just a single formula, then we repeat it
  # in a vector length(learners) times to make it simple to just pass the ith
  # learner formula[[i]].
  formulas <- parse_formulas(formulas = formulas,
                                        learner_names = names(learners))

  # handle named extra arguments:
  #   * extra arguments can be passed with a .default option and otherwise named
  #      entries for each learner
  #   * they can be passed as a 1:length(learners) list of extra arguments in order
  #   * they can be passed as a 1:length(learners) list of extra arguments where the names
  #      match 1-1 with the names(learners).
  extra_learner_args <- parse_extra_learner_arguments(
    extra_learner_args = extra_learner_args,
    learner_names = names(learners))

  # add outcome_type dependent extra arguments to the extra_learner_args
  for (learner_i in 1:length(learners)) {
    # get the outcome type dependent extra arguments according to that learner
    outcome_type_dependent_args <- attr(learners[[learner_i]], 'outcome_type_dependent_args')

    # if they are not null, proceed
    if (! is.null(outcome_type_dependent_args)) {
      # get the ones that match the outcome_type argument given to super_learner()
      outcome_type_dependent_arg_matched <- outcome_type_dependent_args[[outcome_type]]

      # if the outcome_type_dependent_args matching the outcome_type are not
      # NULL, proceed
      if (! is.null(outcome_type_dependent_arg_matched)) {

        # go one-by-one through the new arguments
        for (new_arg_i in 1:length(outcome_type_dependent_arg_matched)) {
          # check if the new argument already appears in the extra_learner_args
          #
          # if it is not already in the extra_learner_args, add it
          if (! names(outcome_type_dependent_arg_matched)[new_arg_i] %in%
                names(extra_learner_args[[learner_i]]))

            # append the outcome_type dependent new argument to the
            # extra learner arguments
            extra_learner_args[[learner_i]] <- c(
              extra_learner_args[[learner_i]],
              outcome_type_dependent_arg_matched[new_arg_i])
        }
      }
    }
  }


  # A list to store errors from training the learners on training_data
  learner_training_errors <- list()

  # for each i in 1:n_folds and each model, train the model
  #
  # following along with the structure of the trained_learners data frame,
  # for each learner (i) we train on each training fold of the data (j)
  #
  trained_learners[['learned_predictor']] <- unlist(future_lapply(
    1:length(learners), function(learner_i) {
      future_lapply(1:n_folds, function(fold_j) {
        # this tryCatch serves to catch errors from training learners, improve them,
        # and then append them to the learner_training_errors list
        #
        # the improvement mentioned comes in terms of rewriting the call associated
        # with the error. instead of showing the user that do.call(learners[[learner_i]],
        # ... ) was what errored, we want to show them something useful, like
        # lnr_lmer(data, formula = mpg ~ cyl) failed.  In order to make that appear
        # as the call, we use substitute to replace elements of the call, which is
        # a language object.
        learner_args <- c(list(data = training_data[[fold_j]],
                               formula = formulas[[learner_i]]),
                          extra_learner_args[[learner_i]])
        if (use_weights) {
          learner_args$weights <- training_data[[fold_j]][['.sl_weights']]
        }
        tryCatch(
          expr = {
            do.call(what = learners[[learner_i]],
                    args = learner_args)
          },
          error = function(e) {
            e$call <- substitute(
              learner(training_data[[fold_j]],
                      formula = formula_i,
                      extra_learner_args_i),
              list(
                fold_j = fold_j,
                formula_i = formulas[[learner_i]],
                extra_learner_args_i = extra_learner_args[[learner_i]],
                learner = as.name(paste0('lnr_', names(learners)[learner_i]))
              )
            )
            learner_training_errors <<-
              c(learner_training_errors, e)
            return(e)
          }
        )
      }, future.seed = TRUE)
    }, future.seed = TRUE), recursive = FALSE)

  learner_prediction_errors <- list()

  # predict from each fold+model combination on the held-out data
  trained_learners$predictions_for_testset <- future_lapply(
    1:nrow(trained_learners), function(i) {
      # for some reason, it seems like future.apply::future_lapply and
      # regular lapply slightly differ in their syntax here.  We just have to be
      # careful that if trained_learners[[i, 'learned_predictor']] isn't a function,
      # then it's a list containing a function.
      tryCatch(expr = {
      if (is.list(trained_learners[[i,'learned_predictor']])) {
      trained_learners[[i,'learned_predictor']][[1]](validation_data[[trained_learners[[i, '.sl_fold']]]])
      } else {
      trained_learners[[i,'learned_predictor']](validation_data[[trained_learners[[i, '.sl_fold']]]])
      }
      },
      # again we use substitute to improve how the erroring call appears to the user.
      # here we want to show users things like trained_learners[['lmer']][[1]](validation_data[[1]])
      # was what errored, not just stuff like trained_learners[[i]]
      error = function(e) {
        e$call <- substitute(trained_learners[[lnr_name]][[fold_j]](validation_data[[fold_j]]),
                             list(
                               lnr_name = trained_learners[['learner_name']][i],
                               fold_j = trained_learners[['.sl_fold']][i]
                             ))
        learner_prediction_errors <<- c(learner_prediction_errors, e)
        return(e)
      })
    }, future.seed = TRUE
  )

  # from here forward, we just need to use the split + model name + predictions on the test-set
  # to regress against the held-out (validation) data to determine the ensemble weights
  second_stage_SL_dataset <- trained_learners[,c('.sl_fold', 'learner_name', 'predictions_for_testset')]

  # pivot it into a wider format, with one column per model, with columnname model_name
  second_stage_SL_dataset <- tidyr::pivot_wider(
    second_stage_SL_dataset,
    names_from = 'learner_name',
    values_from = 'predictions_for_testset')


  # insert the validation Y data in another column next to the predictions
  second_stage_SL_dataset[[y_variable]] <- lapply(1:nrow(second_stage_SL_dataset), function(i) {
    validation_data[[second_stage_SL_dataset[[i, '.sl_fold']]]][[y_variable]]
  })
  if (use_weights) {
    second_stage_SL_weights <- unlist(lapply(1:nrow(second_stage_SL_dataset), function(i) {
      validation_data[[second_stage_SL_dataset[[i, '.sl_fold']]]][['.sl_weights']]
    }))
  }

  # determine which learners erred in the process
  erring_learners <- second_stage_SL_dataset |>
    dplyr::select(-.sl_fold) |>
    summarize(across(everything(), function(x) {
      any(sapply(x, function(y) { inherits(y, 'error') }))
    }))

  # get the names of the erring learners
  erring_learners <- colnames(erring_learners)[which(erring_learners[1,] == TRUE)]
  erring_learner_locations <- which(colnames(second_stage_SL_dataset) %in% erring_learners)

  # drop the erring learners from the meta-learning stage
  if (length(erring_learner_locations) > 0) {
    second_stage_SL_dataset <- second_stage_SL_dataset[, -erring_learner_locations]
  }

  # unnest all of the data (each cell prior to this contained a vector of either
  # predictions or the validation data)
  second_stage_SL_dataset <- tidyr::unnest(second_stage_SL_dataset, cols = colnames(second_stage_SL_dataset))

  # drop the split column so we can simplify the following regression formula
  split_col_index <- which(colnames(second_stage_SL_dataset) == '.sl_fold')

  # if determine_super_learner_weights is left unspecified, we set it based on
  # the outcome_type
  if (missing(determine_super_learner_weights)) {
    switch(outcome_type,
           'continuous' = {
             determine_super_learner_weights <-
               determine_super_learner_weights_nnls
           },
           'binary' = {
             determine_super_learner_weights <-
               determine_weights_for_binary_outcomes
           },
           'density' = {
             determine_super_learner_weights <-
               determine_weights_using_neg_log_loss
           },
           'multiclass' = {
             determine_super_learner_weights <-
               determine_weights_using_neg_log_loss
           }
           )
  }

  # perform the meta-learning step:
  #
  # use determine_super_learner_weights on the second_stage_SL_dataset
  args_for_determining_weights <- list(
    data = second_stage_SL_dataset[,-split_col_index],
    y_variable = y_variable)
  if (use_weights) {
    args_for_determining_weights$obs_weights <- second_stage_SL_weights
  }
  learner_weights <- do.call(what = determine_super_learner_weights, args = args_for_determining_weights)
  names(learner_weights) <- setdiff(names(learners), erring_learners)


  # adjust weights according to if using ensemble or discrete super-learner
  if (ensemble_or_discrete == 'ensemble') {
    # nothing needs to be done; leave the learner_weights as-is
  } else if (ensemble_or_discrete == 'discrete') {
    max_learner_weight <- which(learner_weights == max(learner_weights))
    if (length(max_learner_weight) > 1) {
      warning("Multiple learners were tied for the maximum weight. Since discrete super-learner was specified, the first learner with the maximum weight will be used.")
    }
    learner_weights <- rep(0, length(learner_weights))
    learner_weights[max_learner_weight[1]] <- 1
  } else {
    stop("Argument ensemble_or_discrete must be one of 'ensemble' or 'discrete'")
  }

  final_fit_errors <- list()

  # we want to drop any erring learners from the super_learner(). if the learner
  # couldn't train on the training dataset, why would they be able to train on
  # the full dataset? also we couldn't assign them weight in the meta-regression
  # step, so they shouldn't be included for that reason.
  erring_learners_indicator <- names(learners) %in% erring_learners

  if (any(erring_learners_indicator)) {
  learners[erring_learners_indicator] <- NULL
  formulas[erring_learners_indicator] <- NULL
  extra_learner_args[erring_learners_indicator] <- NULL
  }

  # fit all of the learners on the entire dataset
  fit_learners <- future_lapply(
    1:length(learners), function(i) {
      learner_args <- c(list(
        data = data,
        formula = formulas[[i]]),
        extra_learner_args[[i]]
      )

      if (use_weights) {
        learner_args$weights <- weights
      }
      tryCatch(expr = {
      do.call(
        what = learners[[i]],
        args = learner_args
      )
      }, error = function(e) {
        e$call <- substitute(learner(data, formula = formula_i, extra_learner_args[[i]]),
                             list(learner = as.name(paste0('lnr_', names(learners)[[i]])),
                             formula_i = formulas[[i]],
                             i = i,
                             extra_learner_args = extra_learner_args))
        final_fit_errors <<- c(final_fit_errors, e)
        return(e)
      })
    }, future.seed = TRUE)


  # construct a function that predicts using all of the learners combined using
  # SuperLearned weights
  #
  # this is a closure that will be returned from this function
  predict_from_super_learned_model <- function(newdata) {
    # for each model, predict on the newdata and apply the model weights
    future_lapply(1:length(fit_learners), function(i) {
      fit_learners[[i]](newdata) * learner_weights[[i]]
    }, future.seed = TRUE) |>
      Reduce(`+`, x = _) # aggregate across the weighted model predictions
  }

  # construct output
  output <- list(
    predict = predict_from_super_learned_model,
    y_variable = y_variable,
    outcome_type = outcome_type,
    learner_weights = learner_weights,
    holdout_predictions = second_stage_SL_dataset
    )
  # tag the verbose output as such for use in compare_learners() and similar
  class(output) <- "nadir_sl_model"

  # if there were errors, report them to the user inside the verbose output
  if (length(learner_training_errors) > 0) {
    output$errors_from_training_cv_stage1 <- learner_training_errors
  }
  if (length(learner_prediction_errors) > 0) {
    output$errors_from_predicting_cv_stage2 <- learner_prediction_errors
  }
  if (length(final_fit_errors) > 0) {
    output$errors_from_training_on_entire_data <- final_fit_errors
  }
  if (any(erring_learners_indicator)) {
    output$erring_learners <- erring_learners
  }

  return(output)
}


#' Predict from a \code{nadir::super_learner()} model
#'
#' @param object An object of class inheriting from \code{nadir_sl_model}.
#' @param newdata A tabular data structure (data.frame or matrix) of
#' predictor variables.
#' @param ... Ellipses, solely provided so that the \code{predict.nadir_sl_model} method
#' is compatible with the generic \code{predict}, which takes ellipses as an argument.
#'
#' @export
#' @returns a numeric vector of predicted values
#'
#' @examples
#' sl_fit <- super_learner(mtcars, mpg ~ hp,
#'   learners = list(lnr_lm, lnr_rf, lnr_earth))
#' predict(sl_fit, newdata = mtcars)
predict.nadir_sl_model <- function(object, newdata, ...) {
  object$predict(newdata)
}

