#' @title Generate binary classification predictions via ROCR ROC curves.
#'
#' @description
#' Data is generated by calling \code{\link{asROCRPrediction}},
#' and then ROCR's \code{\link[ROCR]{performance}}.
#'
#' See these methods in ROCR for further info.
#'
#' @family roc
#' @family predict
#' @family generate_plot_data
#'
#' @template arg_plotroc_obj
#' @param meas1 [\code{character(1)}]\cr
#'   Measure on x-axis. Note that this is a measure name from *ROCR* and not from mlr!
#'   Default is \dQuote{tpr}.
#' @param meas2 [\code{character(1)}]\cr
#'   Measure on y-axis. Note that this is a measure name from *ROCR* and not from mlr!
#'   Default is \dQuote{fpr}.
#' @param avg [\code{chracter(1)}]\cr
#'   If \code{obj} is of class \code{ResampleResult} or \code{BenchmarkResult}, how are
#'   the predictions to be combined (by learner)?
#'   If \code{obj} is not one of these classes, this argument is ignored.
#'   Possibilities are \dQuote{threshold}, \dQuote{horizontal}, \dQuote{vertical}, and \dQuote{none}.
#'   Default is \dQuote{threshold}.
#' @param perf.args [named \code{list}]\cr
#'   Further arguments passed to ROCR's \code{\link[ROCR]{performance}}.
#'   Usually not needed and \code{meas1} and \code{meas2} are set internally.
#'   Default is an empty list.
#' @param task.id [\code{character(1)}]\cr
#'   Selected task in \code{\link{BenchmarkResult}} to do plots for, ignored otherwise.
#'   Default is first task.
#'
#' @return A \code{ROCRCurvesData} object, a \code{list} with elements giving the data output from
#'   \code{\link[ROCR]{performance}} and the input arguments.
#' @export
generateROCRCurvesData = function(obj, meas1 = "tpr", meas2 = "fpr", avg = "threshold",
                              perf.args = list(), task.id = NULL) {

  ## lets not check the value-names from ROCR here. they might be changed behind our back later...
  assertString(meas1)
  assertString(meas2)
  assertString(avg)
  stopifnot(avg %in% c("none", "threshold", "horizontal", "vertical"))
  assertList(perf.args, names = "unique")
  UseMethod("generateROCRCurvesData")
}

#' @export
generateROCRCurvesData.Prediction = function(obj, meas1 = "tpr", meas2 = "fpr", avg = "none",
                                         perf.args = list(), task.id = NULL) {

  l = namedList(names = "prediction", init = obj)
  generateROCRCurvesData.list(l, meas1, meas2, avg, perf.args, task.id)
}

#' @export
generateROCRCurvesData.ResampleResult = function(obj, meas1 = "tpr", meas2 = "fpr", avg = "threshold",
                                             perf.args = list(), task.id = NULL) {

  l = namedList(names = obj$learner.id, init = obj)
  generateROCRCurvesData.list(l, meas1, meas2, avg, perf.args, task.id)
}

#' @export
generateROCRCurvesData.BenchmarkResult = function(obj, meas1 = "tpr", meas2 = "fpr", avg = "threshold",
                                              perf.args = list(), task.id = NULL) {

  tids = getBMRTaskIds(obj)
  if (is.null(task.id))
    task.id = tids[1L]
  else
    assertChoice(task.id, tids)
  ps = getBMRPredictions(obj, task.ids = task.id, as.df = FALSE)[[1L]]
  generateROCRCurvesData.list(ps, meas1, meas2, avg, perf.args, task.id)
}

#' @export
generateROCRCurvesData.list = function(obj, meas1 = "tpr", meas2 = "fpr", avg = "threshold",
                                   perf.args = list(), task.id = NULL) {

  assertList(obj, c("Prediction", "ResampleResult"), min.len = 1L)
  ## unwrap ResampleResult to Prediction and set default names
  if (inherits(obj[[1L]], "ResampleResult")) {
    if (is.null(names(obj)))
      names(obj) = extractSubList(obj, c("pred", "learner.id"))
    obj = extractSubList(obj, "pred", simplify = FALSE)
  }
  assertList(obj, names = "unique")

  perf.args = insert(perf.args, list(measure = meas1, x.measure = meas2))
  plt_data = lapply(obj, function(obj_i) {
    perf.args$prediction.obj = asROCRPrediction(obj_i)
    perf = do.call(ROCR::performance, perf.args)

    max_length = max(sapply(perf@x.values, length))
    x = vector("list", length(perf@x.values))
    y = vector("list", length(perf@y.values))

    idx_x = lapply(perf@x.values, function(z) which(is.finite(z)))
    idx_y = lapply(perf@y.values, function(z) which(is.finite(z)))
    idx = lapply(1:length(perf@x.values), function(i) intersect(idx_x[[i]], idx_y[[i]]))

    is_resample = length(perf@x.values) > 1L
    if (!is_resample)
      avg = "none"

    is_alpha = length(perf@alpha.values) > 0L

    if (is_alpha) {
      perf@alpha.values = lapply(perf@alpha.values,
                                 function(alpha) {
        inf_idx = which(is.infinite(alpha))
        if (length(inf_idx) > 0L) {
          alpha[inf_idx] = max(alpha[-inf_idx]) +
            mean(abs(alpha[-inf_idx][-1] - alpha[-inf_idx][-length(alpha[-inf_idx])]))
        }
        alpha
      })

      if (avg == "threshold" & is_resample) {
        alpha = unlist(perf@alpha.values)
        alpha = rev(seq(min(alpha), max(alpha), length.out = max_length))
      }
    }

    for (i in 1:length(perf@x.values)) {
      perf@x.values[[i]] = perf@x.values[[i]][idx[[i]]]
      perf@y.values[[i]] = perf@y.values[[i]][idx[[i]]]
      if (is_alpha)
        perf@alpha.values[[i]] = perf@alpha.values[[i]][idx[[i]]]

      if (avg == "threshold" & is_alpha & is_resample) {
        x[[i]] = approxfun(perf@alpha.values[[i]],
                           perf@x.values[[i]], rule = 2)(alpha)
        y[[i]] = approxfun(perf@alpha.values[[i]],
                           perf@y.values[[i]], rule = 2)(alpha)
      } else if (avg == "vertical" & is_resample) {
        x = unlist(perf@x.values)
        x = seq(min(x), max(x), length.out = max_length)
        y[[i]] = approxfun(perf@x.values[[i]],
                           perf@y.values[[i]], rule = 2)(x)
      } else if (avg == "horizontal" & is_resample) {
        y = unlist(perf@y.values)
        y = seq(min(y), max(y), length.out = max_length)
        x[[i]] = approxfun(perf@y.values[[i]],
                           perf@x.values[[i]], rule = 2)(y)
      } else { ## no averaging needed
        break
      }
    }

    check_avg_objects = exists("x") & exists("y")

    if (avg == "threshold" & check_avg_objects) {
      out = data.frame(x = rowMeans(do.call(cbind, x)), y = rowMeans(do.call(cbind, y)))
    } else if (avg == "horizontal" & check_avg_objects) {
      out = data.frame(x = rowMeans(do.call(cbind, x)), y)
    } else if (avg == "vertical" & check_avg_objects) {
      out = data.frame(x, y = rowMeans(do.call(cbind, y)))
    } else {
      out = data.frame(x = unlist(perf@x.values), y = unlist(perf@y.values))
    }
    colnames(out)[1:2] = c(perf@x.name, perf@y.name)

    if (is_alpha & avg == "none")
      out[, perf@alpha.name] = unlist(perf@alpha.values)

    if (is_alpha & avg == "threshold")
      out[, perf@alpha.name] = alpha

    if (is_resample & avg == "none")
      out$iter = rep(1:length(perf@x.values), times = sapply(perf@x.values, length))

    out
  })

  plt_data = plyr::ldply(plt_data, .id = "learner")

  out = list("data" = plt_data,
             "meas1" = meas1,
             "meas2" = meas2,
             "avg" = avg,
             "perf.args" = perf.args)
  class(out) = append(class(out), "ROCRCurvesData")
  return(out)
}
#' @export
print.ROCRCurvesData = function(x, ...) {
  print(x$data[1:5, ], ...)
}
#' @title Plots results from generateROCRCurvesData using ggplot2.
#'
#' @description
#' Visualize how binary classification performs across different measures.
#'
#' @family roc
#' @family predict
#' @family plot
#'
#' @param obj [\code{ROCRCurvesData}]\cr
#'   Result of \code{\link{generateROCRCurvesData}}.
#' @param diagonal [\code{logical(1)}]\cr
#'   Whether to plot a dashed diagonal line.
#'   Default is false.
#' @param xlab [\code{character(1)}]\cr
#'   Label for x-axis.
#'   Default is \code{meas1}.
#' @param ylab [\code{character(1)}]\cr
#'   Label for y-axis.
#'   Default is \code{meas2}.
#' @param title [\code{character(1)}]\cr
#'   Title for plot.
#'   Default is an empty string.
#' @template ret_gg2
#' @export
plotROCRCurves = function(obj, diagonal = FALSE, xlab, ylab, title = "") {

  assertClass(obj, c("ROCRCurvesData", "list"))
  obj$data = obj$data
  x_name = colnames(obj$data)[2]
  y_name = colnames(obj$data)[3]
  colnames(obj$data)[2:3] = c(obj$meas1, obj$meas2)

  if (length(unique(obj$data$learner)) == 1L) {
    plt = ggplot2::ggplot(obj$data, ggplot2::aes_string(obj$meas1, obj$meas2))
  } else if (obj$avg == "none" & "iter" %in% colnames(obj$data)) {
    obj$data$int = interaction(obj$data$iter, obj$data$learner)
    plt = ggplot2::ggplot(obj$data, ggplot2::aes_string(obj$meas1, obj$meas2,
                                                   group = "int", color = "learner"))
  } else {
    plt = ggplot2::ggplot(obj$data, ggplot2::aes_string(obj$meas1, obj$meas2,
                                                   group = "learner", color = "learner"))
  }
  plt = plt + ggplot2::geom_line()

  if (all(sapply(obj$data[, c(obj$meas1, obj$meas2)], max) <= 1) & diagonal)
    plt = plt + ggplot2::geom_abline(ggplot2::aes_string(intercept = 0, slope = 1),
                                     linetype = "dashed", alpha = .5)

  if (missing(xlab))
    xlab = x_name
  if (missing(ylab))
    ylab = y_name

  plt = plt + ggplot2::labs(x = xlab, y = ylab, title = title)
  return(plt)
}
#' @title Plots results from generateROCRCurvesData using ggvis.
#'
#' @description
#' Visualize how binary classification performs across different measures.
#'
#' @family roc
#' @family predict
#' @family plot
#'
#' @param obj [\code{ROCRCurvesData}]\cr
#'   Result of \code{\link{generateROCRCurvesData}}.
#' @param diagonal [\code{logical(1)}]\cr
#'   Whether to plot a dashed diagonal line.
#'   Default is \code{FALSE}.
#' @param cutoffs [\code{logical(1)}]\cr
#'   Whether to plot tooltips displaying threshold values using Shiny.
#'   The plot will be opened in the default browser.
#' @template ret_ggv
#' @export
plotROCRCurvesGGVIS = function(obj, diagonal = FALSE, cutoffs = FALSE) {

  assertClass(obj, c("ROCRCurvesData", "list"))
  obj$data = obj$data
  x_name = colnames(obj$data)[2]
  y_name = colnames(obj$data)[3]
  colnames(obj$data)[2:3] = c(obj$meas1, obj$meas2)

  obj$data$id = 1:nrow(obj$data)
  if (length(unique(obj$data$learner)) == 1L) {
    plt = ggvis::ggvis(obj$data, ggvis::prop("x", as.name(obj$meas1)),
                       ggvis::prop("y", as.name(obj$meas2)))
    plt = ggvis::layer_lines(plt)
  } else if (obj$avg == "none" & "iter" %in% colnames(obj$data)) {
    plt = ggvis::ggvis(obj$data, ggvis::prop("x", as.name(obj$meas1)),
                       ggvis::prop("y", as.name(obj$meas2)),
                       ggvis::prop("stroke", as.name("learner")))
    plt = ggvis::group_by_.ggvis(plt, as.name("iter"), as.name("learner"))
    plt = ggvis::layer_paths(plt)
  } else {
    plt = ggvis::ggvis(obj$data, ggvis::prop("x", as.name(obj$meas1)),
                       ggvis::prop("y", as.name(obj$meas2)),
                       ggvis::prop("stroke", as.name("learner")),
                       ggvis::prop("key", as.name("id"), scale = FALSE))
    plt = ggvis::layer_lines(plt)
  }

  if (all(sapply(obj$data[, c(obj$meas1, obj$meas2)], max) <= 1) & diagonal) {
    abline_data = data.frame(dx = seq(0, 1, length.out = nrow(obj$data)),
                             dy = seq(0, 1, length.out = nrow(obj$data)),
                             learner = obj$data$learner)
    plt = ggvis::layer_paths(plt, ggvis::prop("x", as.name("dx")),
                             ggvis::prop("y", as.name("dy")),
                             ggvis::prop("stroke", "gray", scale = FALSE),
                             data = abline_data)
  }

  if ("alpha" %in% colnames(obj$data) & obj$avg %in% c("none", "threshold") & cutoffs) {
    gen_tooltip <- function(x) {
      if (is.null(x)) {
        NULL
      } else {
        row = obj$data[obj$data$id == x$id, ]
        paste0(prettyNum(row$alpha, scientific = FALSE))
      }
    }
    plt = ggvis::layer_points(plt, ggvis::prop("opacity", .5, scale = FALSE),
                              ggvis::prop("fill", as.name("learner")))
    plt = ggvis::add_tooltip(plt, gen_tooltip, on = "hover")
  }

  plt = ggvis::add_axis(plt, "x", title = x_name)
  plt = ggvis::add_axis(plt, "y", title = y_name)
  plt
}
