#' Evaluation Metrics for Changepoint Detection
#'
#' @description
#' Functions for evaluating changepoint detection results against ground truth.
#' Includes metrics for localization accuracy, segmentation quality, and
#' detection performance.
#'
#' @name evaluation
NULL

#' Evaluate Changepoint Detection Results
#'
#' @description
#' Comprehensive evaluation of detection results against known ground truth.
#' Computes multiple metrics for localization, segmentation, and detection.
#'
#' @param result A regime_result object or vector of changepoint locations
#' @param true_changepoints Vector of true changepoint locations
#' @param n Total number of observations (required if result is a vector)
#' @param tolerance Tolerance window for matching changepoints (for F1 score)
#'
#' @return A list of class "regime_evaluation" containing all metrics
#'
#' @examples
#' true_cp <- c(50, 100)
#' data <- c(rnorm(50), rnorm(50, mean = 2), rnorm(50))
#'
#' result <- detect_regimes(data)
#' evaluation <- evaluate(result, true_cp)
#' print(evaluation)
#'
#' @export
evaluate <- function(result, true_changepoints, n = NULL, tolerance = 5) {
  if (inherits(result, "regime_result")) {
    est_cp <- result$changepoints
    n <- result$n
  } else {
    est_cp <- as.integer(result)
    if (is.null(n)) {
      cli::cli_abort("n must be provided when result is a vector")
    }
  }
  
  true_changepoints <- sort(as.integer(true_changepoints))
  est_cp <- sort(as.integer(est_cp))
  
  metrics <- list(
    hausdorff = hausdorff_distance(est_cp, true_changepoints),
    mean_abs_error = mean_absolute_error(est_cp, true_changepoints),
    rmse = rmse_changepoints(est_cp, true_changepoints),
    precision = precision_score(est_cp, true_changepoints, tolerance),
    recall = recall_score(est_cp, true_changepoints, tolerance),
    f1 = f1_score(est_cp, true_changepoints, tolerance),
    rand_index = rand_index(est_cp, true_changepoints, n),
    adjusted_rand = adjusted_rand_index(est_cp, true_changepoints, n),
    covering = covering_metric(est_cp, true_changepoints, n),
    n_true = length(true_changepoints),
    n_detected = length(est_cp),
    n_matched = count_matched(est_cp, true_changepoints, tolerance)
  )
  
  class(metrics) <- c("regime_evaluation", "list")
  metrics
}

#' @export
print.regime_evaluation <- function(x, ...) {
  
  cat("\nChangepoint Detection Evaluation\n")
  cat("=================================\n\n")
  
  cat("Detection Performance:\n")
  cat(sprintf("  True changepoints: %d\n", x$n_true))
  cat(sprintf("  Detected: %d\n", x$n_detected))
  cat(sprintf("  Matched: %d\n", x$n_matched))
  cat(sprintf("  Precision: %.3f\n", x$precision))
  cat(sprintf("  Recall: %.3f\n", x$recall))
  cat(sprintf("  F1 Score: %.3f\n", x$f1))
  
  cat("\nLocalization Accuracy:\n")
  cat(sprintf("  Hausdorff Distance: %.2f\n", x$hausdorff))
  cat(sprintf("  Mean Absolute Error: %.2f\n", x$mean_abs_error))
  cat(sprintf("  RMSE: %.2f\n", x$rmse))
  
  cat("\nSegmentation Quality:\n")
  cat(sprintf("  Rand Index: %.3f\n", x$rand_index))
  cat(sprintf("  Adjusted Rand Index: %.3f\n", x$adjusted_rand))
  cat(sprintf("  Covering Metric: %.3f\n", x$covering))
  
  cat("\n")
  invisible(x)
}

#' Hausdorff Distance Between Changepoint Sets
#'
#' @description
#' Computes the Hausdorff distance between detected and true changepoints.
#' The Hausdorff distance is the maximum distance from a point in one set
#' to the nearest point in the other set.
#'
#' @param detected Vector of detected changepoint locations
#' @param true_cp Vector of true changepoint locations
#'
#' @return Hausdorff distance (non-negative number)
#'
#' @examples
#' hausdorff_distance(c(48, 102), c(50, 100))
#'
#' @export
hausdorff_distance <- function(detected, true_cp) {
  if (length(detected) == 0 && length(true_cp) == 0) return(0)
  if (length(detected) == 0) return(Inf)
  if (length(true_cp) == 0) return(Inf)
  
  d_to_true <- sapply(detected, function(d) min(abs(d - true_cp)))
  d_to_detected <- sapply(true_cp, function(t) min(abs(t - detected)))
  
  max(max(d_to_true), max(d_to_detected))
}

#' Mean Absolute Error for Changepoints
#'
#' @description
#' Computes mean absolute error between matched changepoints.
#' Unmatched changepoints are ignored.
#'
#' @param detected Vector of detected changepoint locations
#' @param true_cp Vector of true changepoint locations
#'
#' @return Mean absolute error
#'
#' @export
mean_absolute_error <- function(detected, true_cp) {
  if (length(detected) == 0 || length(true_cp) == 0) return(NA)
  
  matched <- match_to_nearest(detected, true_cp)
  
  mean(abs(detected - matched))
}

#' RMSE for Changepoints
#'
#' @param detected Vector of detected changepoint locations
#' @param true_cp Vector of true changepoint locations
#'
#' @return Root mean squared error
#'
#' @export
rmse_changepoints <- function(detected, true_cp) {
  if (length(detected) == 0 || length(true_cp) == 0) return(NA)
  
  matched <- match_to_nearest(detected, true_cp)
  sqrt(mean((detected - matched)^2))
}

#' @noRd
match_to_nearest <- function(detected, true_cp) {
  sapply(detected, function(d) {
    true_cp[which.min(abs(d - true_cp))]
  })
}

#' Precision Score with Tolerance
#'
#' @description
#' Proportion of detected changepoints that are within tolerance of a true one.
#'
#' @param detected Vector of detected changepoint locations
#' @param true_cp Vector of true changepoint locations
#' @param tolerance Maximum distance for a match
#'
#' @return Precision (0 to 1)
#'
#' @export
precision_score <- function(detected, true_cp, tolerance = 5) {
  if (length(detected) == 0) return(if (length(true_cp) == 0) 1 else 0)
  
  tp <- count_matched(detected, true_cp, tolerance)
  
  tp / length(detected)
}

#' Recall Score with Tolerance
#'
#' @description
#' Proportion of true changepoints that are matched by a detection.
#'
#' @inheritParams precision_score
#'
#' @return Recall (0 to 1)
#'
#' @export
recall_score <- function(detected, true_cp, tolerance = 5) {
  if (length(true_cp) == 0) return(1)
  if (length(detected) == 0) return(0)
  
  matched <- sapply(true_cp, function(t) {
    any(abs(detected - t) <= tolerance)
  })
  
  sum(matched) / length(true_cp)
}

#' F1 Score with Tolerance
#'
#' @description
#' Harmonic mean of precision and recall.
#'
#' @inheritParams precision_score
#'
#' @return F1 score (0 to 1)
#'
#' @examples
#' f1_score(c(48, 102, 150), c(50, 100), tolerance = 5)
#'
#' @export
f1_score <- function(detected, true_cp, tolerance = 5) {
  p <- precision_score(detected, true_cp, tolerance)
  r <- recall_score(detected, true_cp, tolerance)
  
  if (p + r == 0) return(0)
  
  2 * p * r / (p + r)
}

#' @noRd
count_matched <- function(detected, true_cp, tolerance) {
  if (length(detected) == 0 || length(true_cp) == 0) return(0)
  
  true_matched <- rep(FALSE, length(true_cp))
  detected_matched <- 0
  
  for (d in detected) {
    dists <- abs(d - true_cp)
    dists[true_matched] <- Inf
    
    min_idx <- which.min(dists)
    if (dists[min_idx] <= tolerance) {
      true_matched[min_idx] <- TRUE
      detected_matched <- detected_matched + 1
    }
  }
  
  detected_matched
}

#' Rand Index for Segmentation
#'
#' @description
#' Measures agreement between detected and true segmentations.
#' Based on pairwise comparisons of whether points are in the same segment.
#'
#' @param detected Vector of detected changepoint locations
#' @param true_cp Vector of true changepoint locations
#' @param n Total number of observations
#'
#' @return Rand Index (0 to 1)
#'
#' @export
rand_index <- function(detected, true_cp, n) {
  if (n < 2) return(1)
  
  labels_detected <- create_segment_labels(detected, n)
  labels_true <- create_segment_labels(true_cp, n)
  
  tp <- 0
  tn <- 0
  
  for (i in 1:(n-1)) {
    for (j in (i+1):n) {
      same_detected <- labels_detected[i] == labels_detected[j]
      same_true <- labels_true[i] == labels_true[j]
      
      if (same_detected && same_true) tp <- tp + 1
      if (!same_detected && !same_true) tn <- tn + 1
    }
  }
  
  total_pairs <- n * (n - 1) / 2
  (tp + tn) / total_pairs
}

#' Adjusted Rand Index
#'
#' @description
#' Rand Index corrected for chance agreement.
#'
#' @inheritParams rand_index
#'
#' @return Adjusted Rand Index (can be negative, 1 is perfect)
#'
#' @export
adjusted_rand_index <- function(detected, true_cp, n) {
  if (n < 2) return(1)
  
  labels_detected <- create_segment_labels(detected, n)
  labels_true <- create_segment_labels(true_cp, n)
  
  n_detected <- max(labels_detected)
  n_true <- max(labels_true)
  
  contingency <- table(labels_detected, labels_true)
  
  a <- sum(choose(contingency, 2))
  b <- sum(choose(rowSums(contingency), 2))
  c <- sum(choose(colSums(contingency), 2))
  d <- choose(n, 2)
  
  
  expected <- b * c / d
  max_val <- (b + c) / 2
  
  if (max_val - expected == 0) return(1)
  
  (a - expected) / (max_val - expected)
}

#' Covering Metric
#'
#' @description
#' Measures how well detected segments cover true segments using IoU.
#'
#' @inheritParams rand_index
#'
#' @return Covering metric (0 to 1)
#'
#' @export
covering_metric <- function(detected, true_cp, n) {
  segs_detected <- create_segments(detected, n)
  segs_true <- create_segments(true_cp, n)
  
  total_coverage <- 0
  
  for (seg_t in segs_true) {
    best_iou <- 0
    
    for (seg_d in segs_detected) {
      iou <- segment_iou(seg_t, seg_d)
      best_iou <- max(best_iou, iou)
    }
    
    seg_length <- seg_t[2] - seg_t[1] + 1
    total_coverage <- total_coverage + seg_length * best_iou
  }
  
  total_coverage / n
}

#' @noRd
create_segment_labels <- function(changepoints, n) {
  labels <- rep(1, n)
  
  if (length(changepoints) > 0) {
    changepoints <- sort(changepoints)
    current_label <- 1
    
    for (cp in changepoints) {
      if (cp < n) {
        current_label <- current_label + 1
        labels[(cp + 1):n] <- current_label
      }
    }
  }
  
  labels
}

#' @noRd
create_segments <- function(changepoints, n) {
  boundaries <- c(0, sort(changepoints), n)
  
  lapply(seq_len(length(boundaries) - 1), function(i) {
    c(boundaries[i] + 1, boundaries[i + 1])
  })
}

#' @noRd
segment_iou <- function(seg1, seg2) {
  int_start <- max(seg1[1], seg2[1])
  int_end <- min(seg1[2], seg2[2])
  intersection <- max(0, int_end - int_start + 1)
  
  union <- (seg1[2] - seg1[1] + 1) + (seg2[2] - seg2[1] + 1) - intersection
  
  if (union == 0) return(0)
  
  intersection / union
}

#' Compare Multiple Detection Methods
#'
#' @description
#' Runs multiple detection methods on the same data and compares results.
#'
#' @param data Numeric vector or matrix
#' @param methods Character vector of method names
#' @param true_changepoints Vector of true changepoints (for evaluation)
#' @param tolerance Tolerance for evaluation metrics
#' @param ... Additional arguments passed to detect_regimes
#'
#' @return Object of class "regime_comparison"
#'
#' @examples
#' true_cp <- c(50, 100)
#' data <- c(rnorm(50), rnorm(50, mean = 2), rnorm(50))
#'
#' comparison <- compare_methods(
#'   data,
#'   methods = c("pelt", "binseg"),
#'   true_changepoints = true_cp
#' )
#'
#' @export
compare_methods <- function(data, 
                            methods = c("pelt", "binseg", "wbs", "bocpd"),
                            true_changepoints = NULL,
                            tolerance = 5,
                            ...) {
  results <- lapply(methods, function(m) {
    tryCatch({
      t_start <- Sys.time()
      result <- detect_regimes(data, method = m, uncertainty = FALSE, ...)
      t_end <- Sys.time()
      result$time <- as.numeric(difftime(t_end, t_start, units = "secs"))
      result
    }, error = function(e) {
      cli::cli_warn("Method {m} failed: {e$message}")
      list(
        changepoints = integer(0),
        n_changepoints = 0,
        error = e$message,
        time = NA
      )
    })
  })
  names(results) <- methods
  
  evaluations <- NULL
  if (!is.null(true_changepoints)) {
    n <- if (is.matrix(data)) nrow(data) else length(data)
    
    evaluations <- lapply(results, function(r) {
      if (!is.null(r$error)) {
        return(list(f1 = NA, hausdorff = NA, adjusted_rand = NA))
      }
      evaluate(r$changepoints, true_changepoints, n = n, tolerance = tolerance)
    })
    names(evaluations) <- methods
  }
  
  summary_df <- data.frame(
    method = methods,
    n_changepoints = sapply(results, function(r) r$n_changepoints),
    time_sec = sapply(results, function(r) r$time),
    stringsAsFactors = FALSE
  )
  
  if (!is.null(evaluations)) {
    summary_df$f1 <- sapply(evaluations, function(e) {
      if (is.null(e$f1)) NA else e$f1
    })
    summary_df$hausdorff <- sapply(evaluations, function(e) {
      if (is.null(e$hausdorff)) NA else e$hausdorff
    })
    summary_df$adj_rand <- sapply(evaluations, function(e) {
      if (is.null(e$adjusted_rand)) NA else e$adjusted_rand
    })
  }
  
  structure(
    list(
      results = results,
      evaluations = evaluations,
      summary = summary_df,
      true_changepoints = true_changepoints,
      methods = methods
    ),
    class = c("regime_comparison", "list")
  )
}

#' @export
print.regime_comparison <- function(x, ...) {
  cat("\nChangepoint Detection Method Comparison\n")
  cat("========================================\n\n")
  
  print(x$summary, row.names = FALSE)
  
  cat("\n")
  
  if (!is.null(x$true_changepoints)) {
    cat("True changepoints:", paste(x$true_changepoints, collapse = ", "), "\n")
  }
  
  cat("\nDetected changepoints by method:\n")
  for (m in x$methods) {
    cps <- x$results[[m]]$changepoints
    cat(sprintf("  %s: %s\n", m, 
                if (length(cps) > 0) paste(cps, collapse = ", ") else "(none)"))
  }
  
  invisible(x)
}

#' @export
plot.regime_comparison <- function(x, ...) {
  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    cli::cli_abort("ggplot2 required for plotting")
  }
  
  if (!is.null(x$evaluations)) {
    df <- data.frame(
      method = x$methods,
      f1 = sapply(x$evaluations, `[[`, "f1")
    )
    
    p <- ggplot2::ggplot(df, ggplot2::aes(x = stats::reorder(method, -f1), y = f1)) +
      ggplot2::geom_col(fill = "steelblue") +
      ggplot2::labs(
        title = "Method Comparison: F1 Score",
        x = "Method",
        y = "F1 Score"
      ) +
      ggplot2::theme_minimal() +
      ggplot2::ylim(0, 1)
    
    print(p)
    invisible(p)
  }
}