#' Prints "shapviz" Object
#'
#' @param x A on object of class "shapviz".
#' @param n Maximum number of rows of SHAP values and feature values to show.
#' @param ... Further arguments passed from other methods.
#' @return Invisibly, the input is returned.
#' @export
#' @examples
#' S <- matrix(c(1, -1, -1, 1), ncol = 2, dimnames = list(NULL, c("x", "y")))
#' X <- data.frame(x = c("a", "b"), y = c(100, 10))
#' shapviz(S, X, baseline = 4)
#' @seealso \code{\link{shapviz}}.
print.shapviz <- function(x, n = 2L, ...) {
  S <- get_shap_values(x)
  n <- min(n, nrow(S))
  cat(
    "'shapviz' object representing \n  - SHAP matrix of dimension",
    nrow(S), "x", ncol(S),
    "\n  - feature data.frame of dimension",  nrow(S), "x", ncol(S),
    "\n  - baseline value of", get_baseline(x)
  )
  cat("\n\n")
  cat("SHAP values of first", n, "observations:\n")
  print(utils::head(S, n))
  cat("\n Corresponding feature values:\n")
  print(utils::head(get_feature_values(x), n))
  cat("\n")
  invisible(x)
}

#' Dimensions of "shapviz" Object
#'
#' @param x A on object of class "shapviz".
#' @return A numeric vector of length two providing the number of rows and columns
#' of the SHAP matrix (or the feature dataset) stored in \code{x}.
#' @export
#' @examples
#' S <- matrix(c(1, -1, -1, 1), ncol = 2, dimnames = list(NULL, c("x", "y")))
#' X <- data.frame(x = c("a", "b"), y = c(100, 10))
#' dim(shapviz(S, X))
#' @seealso \code{\link{shapviz}}.
dim.shapviz <- function(x) {
  dim(get_shap_values(x))
}

#' Check for shapviz
#'
#' Is object of class "shapviz"?
#'
#' @param object An R object.
#' @return Returns \code{TRUE} if \code{object} has "\code{shapviz}" among its classes, and \code{FALSE} otherwise.
#' @export
#' @examples
#' S <- matrix(c(1, -1, -1, 1), ncol = 2, dimnames = list(NULL, c("x", "y")))
#' X <- data.frame(x = c("a", "b"), y = c(100, 10))
#' shp <- shapviz(S, X)
#' is.shapviz(shp)
#' is.shapviz("a")
is.shapviz <- function(object){
  inherits(object, "shapviz")
}


#' Extractor Functions
#'
#' Functions to extract SHAP values, feature values, or the baseline from a "shapviz" object.
#'
#' @name extractors
#' @param object Object to extract something.
#' @param ... Currently unused.
#' @return `get_shap_values()` returns the matrix of SHAP values, `get_feature_values()` the \code{data.frame} of feature values, and `get_baseline()` the numeric baseline value of the input.
NULL

#' @rdname extractors
#' @export
get_shap_values <- function(object, ...){
  UseMethod("get_shap_values")
}

#' @rdname extractors
#' @export
#' @examples
#' S <- matrix(c(1, -1, -1, 1), ncol = 2, dimnames = list(NULL, c("x", "y")))
#' X <- data.frame(x = c("a", "b"), y = c(100, 10))
#' shp <- shapviz(S, X, baseline = 4)
#' get_shap_values(shp)
get_shap_values.shapviz = function(object, ...) {
  object[["S"]]
}

#' @rdname extractors
#' @export
get_shap_values.default = function(object, ...) {
  stop("No default method available.")
}

#' @rdname extractors
#' @export
get_feature_values <- function(object, ...){
  UseMethod("get_feature_values")
}

#' @rdname extractors
#' @export
get_feature_values.shapviz = function(object, ...) {
  object[["X"]]
}

#' @rdname extractors
#' @export
get_feature_values.default = function(object, ...) {
  stop("No default method available.")
}

#' @rdname extractors
#' @export
get_baseline <- function(object, ...){
  UseMethod("get_baseline")
}

#' @rdname extractors
#' @export
get_baseline.shapviz = function(object, ...) {
  object[["baseline"]]
}

#' @rdname extractors
#' @export
get_baseline.default = function(object, ...) {
  stop("No default method available.")
}

#' Collapse SHAP values
#'
#' Function used to collapse groups of columns in the SHAP matrix by rowwise summation.
#' A typical application is when the matrix of SHAP values is generated by a model with
#' one or multiple one-hot encoded variables and the explanations should be done using
#' the original variables.
#'
#' @param object An object containing SHAP values.
#' @param collapse A named list of character vectors. Each vector specifies a group of
#' column names in the SHAP matrix that should be collapsed to a single column by summation.
#' The name of the new column equals the name of the vector in \code{collapse}.
#' @param ... Currently unused.
#' @return A matrix with collapsed columns.
#' @export
#' @seealso \code{\link{shapviz}}.
#' @examples
#' S <- cbind(
#'   x = c(0.1, 0.1, 0.1),
#'   `age low` = c(0.2, -0.1, 0.1),
#'   `age mid` = c(0, 0.2, -0.2),
#'   `age high` = c(1, -1, 0)
#' )
#' collapse <- list(age = c("age low", "age mid", "age high"))
#' collapse_shap(S, collapse)
collapse_shap <- function(object, ...) {
  UseMethod("collapse_shap")
}

#' @describeIn collapse_shap Default method.
#' @export
collapse_shap.default <- function(object, ...) {
  stop("No default method available. collapse_shap() is available for objects
       of class 'matrix' only.")
}

#' @describeIn collapse_shap Collapse method for object of class "matrix".
#' @export
collapse_shap.matrix <- function(object, collapse = NULL, ...) {
  if (is.null(collapse) || length(collapse) == 0L || ncol(object) == 0L) {
    return(object)
  }
  stopifnot(
    "'object' must have column names" = !is.null(colnames(object)),
    "'collapse' must be a named list" = is.list(collapse) && !is.null(names(collapse)),
    "'collapse' can't have duplicated names" = !anyDuplicated(names(collapse))
  )
  u <- unlist(collapse, use.names = FALSE, recursive = FALSE)
  keep <- setdiff(colnames(object), u)
  stopifnot(
    "'collapse' cannot have overlapping vectors." = !anyDuplicated(u),
    "Values of 'collapse' should be in colnames(object)" = all(u %in% colnames(object)),
    "Names of 'collapse' must be different from untouched column names" =
      !any(names(collapse) %in% keep)
  )
  add <- do.call(
    cbind,
    lapply(collapse, function(z) rowSums(object[, z, drop = FALSE], na.rm = TRUE))
  )
  cbind(object[, keep, drop = FALSE], add)
}

