#' Block-wise Rank In Similarity graph Edge-count (BRISE) Test
#'
#' @description
#' \code{BRISE} implements the Two-Sample Test that handles block-wise missingness.
#' It identifies missing-data patterns, constructs a (blockwise) dissimilarity matrix,
#' induces ranks via a k-nearest neighbor style graph, and computes a quadratic statistic under two versions:
#' the congregated form (‘con’) and vectorized form (‘vec’). Permutation p-values are optionally available.
#'
#' @param X Numeric matrix (m × p) of observations for X (Sample 1). Optional if \code{D} and \code{ptn_list} are provided.
#' @param Y Numeric matrix (n × p) of observations for Y (Sample 2). Optional if \code{D} and \code{ptn_list} are provided.
#' @param D Numeric square dissimilarity matrix (N × N), where N = m + n. Required when \code{X} and \code{Y} are not given.
#' @param ptn_list List of integer vectors. Each element contains indices (1…N) of observations that share the same missing-data pattern.
#' @param k Positive integer. Neighborhood size offset for rank truncation in nearest-neighbor ranking. Default is 10.
#' @param perm Integer. Number of permutations for computing permutation p-value. Default is 0 (no permutation).
#' @param skip Integer (0 or 1). When set to 1 (default), skip rank-based dissimilarity for modality pairs with no shared observed variables; setting to 0 computes them (slower).
#' @param ver Character. Version of the test statistic: \code{"con"} (congregated form, default) or \code{"vec"} (vectorized form).
#'
#' @details
#' If both \code{X} and \code{Y} are supplied, \code{Identify_mods} is used to detect missing patterns and reorganize variables by modality. The dissimilarity matrix \code{D} is then constructed via \code{Blockdist}. Patterns with too few observations in either sample (e.g. fewer than 2) or patterns that are very small relative to the largest pattern are filtered out for robustness. A symmetric rank matrix is built based on truncated nearest-neighbor ranks. Under \code{ver="con"} the contrast statistic (two degrees of freedom) is used; under \code{ver="vec"} a higher-dimensional vector statistic is used. Asymptotic p-values use chi-square approximations; if \code{perm > 0}, empirical permutation p-values are also computed.
#'
#' @return A list with elements:
#' \describe{
#'   \item{test.statistic}{Numeric. The computed test statistic.}
#'   \item{pval.approx}{Numeric. Asymptotic p-value (chi-square based).}
#'   \item{Cov}{Covariance matrix used in computing the test statistic.}
#'   \item{pval.perm}{(Optional) Permutation p-value if \code{perm > 0}.}
#' }
#'
#' @references
#' Zhang, K., Liang, M., Maile, R. & Zhou, D. (2025).
#' \emph{Two-Sample Testing with Block-wise Missingness in Multi-source Data.}
#' \emph{arXiv preprint arXiv:2508.17411.}
#'
#' @examples
#' set.seed(1)
#' X <- matrix(rnorm(50*200, mean = 0), nrow=50)
#' Y <- matrix(rnorm(50*200, mean = 0.3), nrow=50)
#' X[1:20, 1:100] <- 0
#' X[30:50, 101:200] <- 0
#' Y[1:10, 1:100] <- 0
#' Y[30:40, 101:200] <- 0
#' out <- BRISE(X = X, Y = Y, k = 5, perm = 1000, ver = "con")
#' print(out$test.statistic)
#' print(out$pval.approx)
#'
#'
#' @seealso \code{\link{BRISE_Rank}}, \code{\link{Cov_mu.c}}, \code{\link{Cov_mu.v}}
#' @importFrom stats dist pchisq setNames
#' @export
BRISE <- function(X=NULL, Y=NULL, D=NULL, ptn_list=NULL, k=10,
                    perm = 0, skip=1, ver='con') {
  if (!is.null(X) & !is.null(Y)) {
    m <- nrow(X)
    n <- nrow(Y)
    d <- ncol(X)
    N <- m + n

    ## Identify modality and missing patterns
    outs = Identify_mods(list(X=X, Y=Y), m, n, d)
    data <- outs[[1]]
    modality <- outs[[2]]
    mod_bound <- outs[[3]]
    mod_id <- outs[[4]]

    ptn <- numeric(N)
    for (i in 1:N) {
      i_ptn <- 0
      for (j in 1:modality) {
        if (mod_id[i, j] == 1) {
          i_ptn <- i_ptn + 2^(j - 1)
        }
      }
      ptn[i] <- i_ptn
    }
    ptn_list <- split(seq_along(ptn), ptn)
    D <- Blockdist(data, m, n, d, ptn_list, mod_id, modality, mod_bound, skip =
                     skip)
  }

  if (!is.matrix(D) || nrow(D) != ncol(D)) {
    stop("`D` must be a square dissimilarity matrix.")
  }

  if ((is.null(X) || is.null(Y)) && (is.null(D) || is.null(ptn_list))) {
    stop("Either `X`,`Y` must be provided together, or `D`,`ptn_list` must be provided together.")
  }

  n_ptn = length(ptn_list)
  m_ <- numeric(n_ptn)
  n_ <- numeric(n_ptn)
  N_ <- numeric(n_ptn)
  for (i in 1:n_ptn) {
    m_[i] <- sum(ptn_list[[i]] <= m)
    n_[i] <- sum(ptn_list[[i]] > m)
    N_[i] <- m_[i] + n_[i]
  }

  ## Filter out rare patterns for robustness
  keep_idx <- rep(TRUE, n_ptn)
  for (i in 1:n_ptn) {
    if (m_[i] < 2 || n_[i] < 2 || N_[i] < max(N_)/10) {
      keep_idx[i] <- FALSE
    }
  }
  if (all(keep_idx)==FALSE) {
    m_ <- m_[keep_idx]
    n_ <- n_[keep_idx]
    N_ <- N_[keep_idx]
    ptn_list <- ptn_list[keep_idx]
    idx_keep <- sort(unique(unlist(ptn_list)))
    D <- D[idx_keep, idx_keep]
    new_idx <- seq_along(idx_keep)
    idx_map <- setNames(new_idx, idx_keep)  # named vector: old -> new
    ptn_list <- lapply(ptn_list, function(idx) idx_map[as.character(idx)])
    n_ptn = length(ptn_list)
  }

  m = sum(m_)
  n = sum(n_)
  N = sum(N_)

  ## Rank induced by k nearest neighbor graph
  R <- matrix(0, nrow = N, ncol = N)
  for (i in 1:n_ptn) {
    for (j in (i + 1):n_ptn) {
      if (i == n_ptn) {
        next
      }
      Dij <- D[ptn_list[[i]], ptn_list[[j]], drop = FALSE]
      Sij <- max(Dij) - Dij
      R[ptn_list[[i]], ptn_list[[j]]] <- BRISE_Rank(Sij, method = 'rowij') -
        N_[j] + 1 + k
      Sji <- t(Sij)
      R[ptn_list[[j]], ptn_list[[i]]] <- BRISE_Rank(Sji, method = 'rowij') -
        N_[i] + 1 + k
    }
    Dii <- D[ptn_list[[i]], ptn_list[[i]], drop = FALSE]
    Sii <- max(Dii) - Dii
    R[ptn_list[[i]], ptn_list[[i]]] <- BRISE_Rank(Sii, method = 'row') -
      N_[i] + 1 + k
  }
  R[R < 0] <- 0
  R <- (R + t(R))/2
  diag(R) <- 0

  if (ver=='vec') {
    ourR = BRISE_v.stat(R, 1:m, (1 + m):N, ptn_list)
    outs = Cov_mu.v(R, m_, n_, ptn_list)
    Cov = outs[[1]]
    mu = unlist(outs[[2]])
    nonzero_indices <- which(rowSums(abs(Cov)) >= 1e-6) #avoid collinearity
    if (length(nonzero_indices) < nrow(Cov)) {
      Cov <- Cov[nonzero_indices, nonzero_indices]
      ourR = ourR[nonzero_indices]
      mu <- mu[nonzero_indices]
    }
    TR <- t(ourR - mu) %*% solve(Cov) %*% (ourR - mu)[1, 1]
    p = 1 - pchisq(TR, length(mu))
    test.asy = list(
      'test.statistic' = TR,
      'pval.approx' = p,
      'Cov' = Cov
    )
    if (perm > 0) {
      BR = t(sapply(1:perm, function(i) {
        ID1 = c()
        ID2 = c()
        for (b in seq_along(ptn_list)) {
          block = ptn_list[[b]]
          m_b = m_[b]
          n_b = n_[b]
          stopifnot(length(block) == m_b + n_b)  # safety check
          permuted_block = sample(block, length(block), replace = FALSE)
          ID1 = c(ID1, permuted_block[1:m_b])
          ID2 = c(ID2, permuted_block[(m_b + 1):(m_b + n_b)])
        }
        rvec = BRISE_v.stat(R, ID1, ID2, ptn_list)
        if (exists("nonzero_indices")) {
          rvec = rvec[nonzero_indices]
        }
        return(rvec)
      }))
      BR0 = scale(BR, center = mu , scale = F)
      T2B.all = rowSums((BR0 %*% solve(Cov)) * BR0)
      pval.perm = mean(test.asy$test.statistic < T2B.all)
      test.asy = c(test.asy, list('pval.perm' = pval.perm))
    }
  }

  if (ver=='con') {
    ourR = BRISE_c.stat(R, 1:m, (1 + m):N)
    outs = Cov_mu.c(R, m_, n_, ptn_list)
    Cov = outs[[1]]
    mu = outs[[2]]
    TR = (t(ourR - mu) %*% solve(Cov) %*% (ourR - mu))[1, 1]
    p = 1 - pchisq(TR, 2)
    test.asy = list(
      'test.statistic' = TR,
      'pval.approx' = p,
      'Cov' = Cov
    )
    if (perm > 0) {
      BR = t(sapply(1:perm, function(i) {
        ID1 = c()
        ID2 = c()
        for (b in seq_along(ptn_list)) {
          block = ptn_list[[b]]
          m_b = m_[b]
          n_b = n_[b]
          stopifnot(length(block) == m_b + n_b)  # safety check
          permuted_block = sample(block, length(block), replace = FALSE)
          ID1 = c(ID1, permuted_block[1:m_b])
          ID2 = c(ID2, permuted_block[(m_b + 1):(m_b + n_b)])
        }
        BRISE_c.stat(R, ID1, ID2)
      }))
      BR0 = scale(BR, center = mu , scale = F)
      T2B.all = rowSums((BR0 %*% solve(Cov)) * BR0)
      pval.perm = mean(test.asy$test.statistic < T2B.all)
      test.asy = c(test.asy, list('pval.perm' = pval.perm))
    }
  }
  return(test.asy)
}


#' Identify Data Modalities
#'
#' @description
#' Detects modalities across the combined data (samples X and Y), rearranges variables/columns by modality, and produces identification structures used downstream for blockwise operations.
#'
#' @param data List with components \code{X} and \code{Y} (numeric matrices).
#' @param m Integer. Number of rows in \code{X}.
#' @param n Integer. Number of rows in \code{Y}.
#' @param d Integer. Number of features (columns) in \code{X} (and \code{Y}).
#'
#' @return List with components:
#' \describe{
#'   \item{rearr_data}{List with rearranged \code{X}, \code{Y} after grouping features by modality.}
#'   \item{modality}{Integer. Number of distinct missing-data modalities.}
#'   \item{mod_bound}{Integer vector. Cumulative boundaries of modalities among the features.}
#'   \item{mod_id}{Binary matrix (N × modality) indicating, for each observation, whether each modality is observed (1) or missing (0).}
#' }
#' @keywords internal
Identify_mods = function(data, m, n, d) {
  N = m + n
  X = data$X
  Y = data$Y
  Z = rbind(X,Y)
  patterns = unique(is.na(Z))
  t_patterns <- t(patterns)

  # Convert rows to strings to identify unique rows
  row_strings <- apply(t_patterns, 1, paste, collapse = ",")
  unique_rows <- unique(row_strings)
  modality_list <- lapply(unique_rows, function(x) which(row_strings == x))
  modality = length(modality_list)

  # Rearrange the dimensions
  rearr_Z = Z
  new_order = c()
  mod_bound = c(0)
  bound = 0
  for (i in 1:modality) {
    idx = modality_list[[i]]
    k = length(idx)
    rearr_Z[, bound + (1:k)] = Z[, idx]
    new_order = c(new_order, idx)   # track column order
    bound = bound + k
    mod_bound = append(mod_bound, bound)
  }
  colnames(rearr_Z) = colnames(Z)[new_order]

  mod_id = matrix(0, nrow = N, ncol = modality)
  for (i in 1:N) {
    for (j in 1:modality) {
      if (is.na(Z[i,modality_list[[j]]][1])==FALSE) {
        mod_id[i,j] = 1
      }
    }
  }
  rearr_data = list()
  rearr_data$X = as.matrix(rearr_Z[1:m,])
  rearr_data$Y = as.matrix(rearr_Z[(m+1):N,])
  return(list(rearr_data,modality,mod_bound,mod_id))
}


#' Block-wise Distance Matrix Construction
#'
#' @description
#' Constructs a symmetric dissimilarity matrix that accounts for missing-data patterns. Within blocks where both observations share a modality, standard Euclidean distances are used. Optionally, for observations without shared observed features (based on modality), a rank-based dissimilarity is computed (if \code{skip = 0}).
#'
#' @param data List with \code{X} and \code{Y} matrices.
#' @param m Integer. Number of rows (observations) in \code{X}.
#' @param n Integer. Number of rows in \code{Y}.
#' @param d Integer. Number of features (columns).
#' @param ptn_list List of integer vectors: each element indexes observations sharing the same missing pattern.
#' @param mod_id Binary matrix (N × modality) indicating modality membership per observation.
#' @param modality Integer. Number of modalities.
#' @param mod_bound Integer vector. Feature indices boundaries per modality block.
#' @param skip Integer (0 or 1). If set to 1, dissimilarity for modality-disjoint pairs is skipped. If 0, computed rank-based distances are used.
#'
#' @return Numeric symmetric matrix (N × N) of pairwise dissimilarities.
#' @keywords internal
Blockdist = function(data, m, n, d, ptn_list, mod_id, modality, mod_bound, skip=1) {
  N = m + n
  D = matrix(0, nrow = N, ncol = N)
  Z = rbind(data$X, data$Y)
  for (i in 1:N) {
    for (j in 1:i) {
      for (k in 1:modality) {
        if (mod_id[i, k] + mod_id[j, k] == 2) {
          # both have same pattern k
          Diff = Z[i, (mod_bound[k] + 1):mod_bound[k + 1]] - Z[j, (mod_bound[k] + 1):mod_bound[k +1]]
          D[i, j] = D[i, j] + sum(Diff ^ 2)
        }
      }
      D[i, j] = sqrt(D[i, j])
      D[j, i] = D[i, j]
    }
  }
  ## the loop below employing a rank-based dissimilarity for pattern pairs with no shared modality
  ## skip the loop for faster executions
  if (skip == 0) {
    for (Pi in 1:length(ptn_list)) {
      for (Pj in 1:Pi) {
        mods = mod_id[ptn_list[[Pi]][1],]+mod_id[ptn_list[[Pj]][1],]
        if (max(mods)<2) {
          obs_list = 1:N
          obs_list <- obs_list[apply(mod_id[obs_list, ], 1, function(row) all(row >= mods))]
          N_obs = length(obs_list)
          l_list = matrix(0, nrow = N, ncol = N_obs)
          for (i in c(ptn_list[[Pi]], ptn_list[[Pj]])) {
            for (ii in 1:N_obs) {
              s = obs_list[ii]
              for (k in 1:modality) {
                if (mod_id[i, k] == 1) {
                  # i have pattern k
                  Diff = Z[i, (mod_bound[k] + 1):mod_bound[k + 1]] - Z[s, (mod_bound[k] + 1):mod_bound[k +1]]
                  l_list[i,ii] = l_list[i,ii] + sum(Diff ^ 2)
                }
              }
            }
            l_list[i,] = rank(l_list[i,]) # let l_list store the ranks
          }
          for (i in ptn_list[[Pi]]) {
            for (j in ptn_list[[Pj]]) {
              l1 = l_list[i,]
              l2 = l_list[j,]
              # Calculate Spearman's rank correlation
              spearman_corr <- 1 - 6/N_obs/(N_obs^2-1)*sum((l1-l2)^2)
              D[i, j] = 1-spearman_corr
              D[j, i] = D[i, j]
            }
          }
        }
      }
    }
  }
  return(D)
}


#' Rank Induction within- and cross-pattern similarity blocks
#'
#' @description
#' Compute row-wise ranks of a similarity matrix for two cases:
#' \itemize{
#'   \item \code{method = "row"}: within-pattern block (Sii) (square). Because self-pairs exist, the diagonal
#'         (self-similarity) is first forced below the minimum entry of (S) so that self-neighbors are always ranked last
#'         and thus excluded when top-(k) truncation is applied downstream.
#'   \item \code{method = "rowij"}: cross-pattern block (Sij) (rectangular, i!=j). There are no self-pairs,
#'         so no diagonal adjustment is needed.
#' }
#' Ranks are computed row-wise with \code{rank()} and then shifted by 1 (i.e., the function returns \code{rank - 1}).
#'
#' @param S Numeric similarity matrix: (Sii) (square) when \code{method = "row"}; (Sij) (rectangular) when
#'   \code{method = "rowij"}. Larger values indicate greater similarity.
#' @param method Character, either \code{"row"} (within-pattern (Sii), diagonal suppressed) or \code{"rowij"}
#'   (cross-pattern (Sij), no diagonal to suppress).
#' @return A numeric matrix with the same dimensions as \code{S} containing row-wise ranks minus one.
#' @keywords internal
BRISE_Rank<-function(S,method='row'){
  #S: a Na X Nb similarity matrix
  #R: the output of Na X Nb rank matrix
  if(method=='row'){
    diag(S) = min(S) - 100
    R = t(apply(S, 1, rank)) - 1
  }
  if(method=='rowij'){
    R = t(apply(S, 1, rank)) - 1
  }
  return(R)
}


#' Block-wise Statistic (Vectorized Form)
#'
#' @description
#' For the vectorized version of BRISE, computes the within-pattern rank sums for both samples across all pattern pairs. Returns a concatenated vector of (Ux_ab, Uy_ab) for all blocks (a, b) with a>b.
#'
#' @param R Numeric symmetric rank matrix (N × N).
#' @param sample1ID Integer vector. Indices of observations in X.
#' @param sample2ID Integer vector. Indices of observations in Y.
#' @param ptn_list List of integer vectors that indexes observations sharing the same missing pattern.
#'
#' @return Numeric vector containing the sums of R entries within X and Y, for each pattern pair.
#' @keywords internal
BRISE_v.stat <- function(R, sample1ID, sample2ID, ptn_list) {
  n_ptn = length(ptn_list)
  Ux_ab = c()
  Uy_ab = c()
  for (a in 1:n_ptn) {
    for (b in 1:a) {
      id_ax = sample1ID[sample1ID %in% ptn_list[[a]]]
      id_bx = sample1ID[sample1ID %in% ptn_list[[b]]]
      id_ay = sample2ID[sample2ID %in% ptn_list[[a]]]
      id_by = sample2ID[sample2ID %in% ptn_list[[b]]]
      Ux_ab = append(Ux_ab, sum(R[id_ax, id_bx]))
      Uy_ab = append(Uy_ab, sum(R[id_ay, id_by]))
    }
  }
  return(c(Ux_ab, Uy_ab))
}


#' Block-wise Statistic (Congregated Form)
#'
#' @description
#' For the contrast version of BRISE (“con”), computes within-sample sums of the rank matrix R (i.e. Ux, Uy) over all observations in X and Y, for congregated BRISE test.
#'
#' @param R Numeric symmetric rank matrix with zero diagonal.
#' @param sample1ID Integer vector of indices for X.
#' @param sample2ID Integer vector of indices for Y.
#'
#' @return Numeric vector \code{c(Ux, Uy)}, the within-sample rank sums for the two samples.
#' @keywords internal
BRISE_c.stat <- function(R, sample1ID, sample2ID) {
  #R is the rank matrix, the diagonal of R should be zero
  #sample1ID: the IDs of the population X; sample1ID: the IDs of the population Y
  Ux = sum(R[sample1ID, sample1ID])
  Uy = sum(R[sample2ID, sample2ID])
  return(c(Ux, Uy))
}


#' Covariance and Expectation (Vectorized Form)
#'
#' @description
#' Computes the asymptotic covariance matrix and expectation (mu) vector for the vectorized BRISE statistic under the pattern-wise permutation null distribution, based on rank matrix R and the list of pattern indicator. Used to form the quadratic statistic and its chi-square approximation.
#'
#' @param R Numeric symmetric rank matrix (N × N).
#' @param m_ Integer vector. X's sample sizes in each pattern.
#' @param n_ Integer vector. Y's sample sizes in each pattern.
#' @param ptn_list List of integer vectors that indexes observations sharing the same missing pattern.
#'
#' @return List with two elements:
#' \describe{
#'   \item{Cov}{Covariance matrix corresponding to the vector of pair-wise statistics.}
#'   \item{mu}{Expectation vector for those pair-wise statistics under the null.}
#' }
#' @keywords internal
Cov_mu.v <- function(R, m_, n_, ptn_list) {
  n_ptn = length(m_)
  N_ = m_ + n_
  r_0 = matrix(0, nrow = n_ptn, ncol = n_ptn)
  Ri. = array(0, dim = c(n_ptn, n_ptn, max(N_)))
  r_1 = array(0, dim = c(n_ptn, n_ptn, n_ptn))
  r_2 = matrix(0, nrow = n_ptn, ncol = n_ptn)

  V_1 = array(0, dim = c(n_ptn, n_ptn, n_ptn))
  V_2 = matrix(0, nrow = n_ptn, ncol = n_ptn)

  for (i in 1:n_ptn) {
    Rii = R[ptn_list[[i]], ptn_list[[i]], drop = FALSE]
    r_0[i, i] = sum(Rii) / (N_[i] * (N_[i] - 1))
    r_2[i, i] = sum(Rii^2) / (N_[i] * (N_[i] - 1))
    Ri.[i, i, 1:N_[i]] = rowSums(Rii) / ((N_[i] - 1))

    for (j in (i + 1):n_ptn) {
      if (i == n_ptn) {
        next
      }
      Rij = R[ptn_list[[i]], ptn_list[[j]], drop = FALSE]
      r_0[i, j] = sum(Rij) / (N_[i] * N_[j])
      r_0[j, i] = r_0[i, j]
      r_2[i, j] = sum(Rij^2) / (N_[i] * N_[j])
      r_2[j, i] = r_2[i, j]
      Ri.[i, j, 1:N_[i]] = rowSums(Rij) / (N_[j])
      Rji = R[ptn_list[[j]], ptn_list[[i]], drop = FALSE]
      Ri.[j, i, 1:N_[j]] = rowSums(Rji) / (N_[i])
    }
  }

  for (i in 1:n_ptn) {
    for (j in 1:n_ptn) {
      for (k in 1:n_ptn) {
        r_1[i, j, k] = Ri.[i, j, ] %*% Ri.[i, k, ] / N_[i]
        V_1[i, j, k] = r_1[i, j, k] - r_0[i, j] * r_0[i, k]
      }
      V_2[i, j] = r_2[i, j] - r_0[i, j]^2
    }
  }

  ## E(Ux)
  mx = matrix(0, nrow = n_ptn, ncol = n_ptn)
  for (a in 1:n_ptn) {
    mx[a, a] = m_[a] * (m_[a] - 1) * r_0[a, a]
  }

  for (a in 1:n_ptn) {
    for (b in 1:n_ptn) {
      if (a == b) {
        next
      }
      mx[a, b] = m_[a] * m_[b] * r_0[a, b]
    }
  }

  ## E(Uy)
  my = matrix(0, nrow = n_ptn, ncol = n_ptn)
  for (a in 1:n_ptn) {
    my[a, a] = n_[a] * (n_[a] - 1) * r_0[a, a]
  }

  for (a in 1:n_ptn) {
    for (b in 1:n_ptn) {
      if (a == b) {
        next
      }
      my[a, b] = n_[a] * n_[b] * r_0[a, b]
    }
  }

  ## Covariances
  Cov = matrix(0,
               ncol = 2 * n_ptn * n_ptn,
               nrow = 2 * n_ptn * n_ptn)

  ## type 1: {a,a,a,a}
  for (a in 1:n_ptn) {
    id_1x = (a - 1) * n_ptn + a
    id_1y = id_1x + n_ptn * n_ptn
    Cov[id_1x, id_1x] = 2 * m_[a] * n_[a] * (m_[a] - 1) / ((N_[a] - 2) * (N_[a] - 3)) *
      ((n_[a] - 1) * V_2[a, a] + 2 * (m_[a] - 2) * (N_[a] - 1) * V_1[a, a, a])
    Cov[id_1y, id_1y] = 2 * m_[a] * n_[a] * (n_[a] - 1) / ((N_[a] - 2) * (N_[a] - 3)) *
      ((m_[a] - 1) * V_2[a, a] + 2 * (n_[a] - 2) * (N_[a] - 1) * V_1[a, a, a])
    Cov[id_1x, id_1y] = 2 * m_[a] * n_[a] * (m_[a] - 1) * (n_[a] - 1) / ((N_[a] - 2) *
                                                                           (N_[a] - 3)) * (V_2[a, a] - 2 * (N_[a] - 1) * V_1[a, a, a])
  }

  ## type 2: {a,b,a,b}
  for (a in 1:n_ptn) {
    for (b in 1:a) {
      if (a == b) {
        next
      }
      id_1x = (a - 1) * n_ptn + b
      id_1y = id_1x + n_ptn * n_ptn
      Cov[id_1x, id_1x] = m_[a] * m_[b] / (N_[a] - 1) / (N_[b] - 1) *
        (n_[a] * n_[b] * V_2[a, b] + n_[a] * (m_[b] - 1) * N_[b] * V_1[a, b, b] +
           n_[b] * (m_[a] - 1) * N_[a] * V_1[b, a, a])
      Cov[id_1y, id_1y] = n_[a] * n_[b] / (N_[a] - 1) / (N_[b] - 1) *
        (m_[a] * m_[b] * V_2[a, b] + m_[a] * (n_[b] - 1) * N_[b] * V_1[a, b, b] +
           m_[b] * (n_[a] - 1) * N_[a] * V_1[b, a, a])
      Cov[id_1x, id_1y] = m_[a] * m_[b] * n_[a] * n_[b] / (N_[a] - 1) / (N_[b] - 1) *
        (V_2[a, b] - N_[b] * V_1[a, b, b] - N_[a] * V_1[b, a, a])
    }
  }

  ## type 3: {a,a,a,b}
  for (a in 1:n_ptn) {
    for (b in 1:a) {
      if (a == b) {
        next
      }
      # {a,b,a,a}
      id_1x = (a - 1) * n_ptn + b
      id_1y = id_1x + n_ptn * n_ptn
      id_2x = (a - 1) * n_ptn + a
      id_2y = id_2x + n_ptn * n_ptn
      Cov[id_1x, id_2x] = 2 * m_[a] * n_[a] * m_[b] * (m_[a] - 1) / (N_[a] - 2) * V_1[a, a, b]
      Cov[id_1y, id_2y] = 2 * m_[a] * n_[a] * n_[b] * (n_[a] - 1) / (N_[a] - 2) * V_1[a, a, b]
      Cov[id_1x, id_2y] = -2 * m_[a] * n_[a] * (n_[a] - 1) * m_[b] / (N_[a] - 2) * V_1[a, a, b]
      Cov[id_2x, id_1y] = -2 * m_[a] * n_[a] * (m_[a] - 1) * n_[b] / (N_[a] - 2) * V_1[a, a, b]
      # {b,b,a,b}
      id_1x = (b - 1) * n_ptn + b
      id_1y = id_1x + n_ptn * n_ptn
      id_2x = (a - 1) * n_ptn + b
      id_2y = id_2x + n_ptn * n_ptn
      Cov[id_1x, id_2x] = 2 * m_[b] * n_[b] * m_[a] * (m_[b] - 1) / (N_[b] - 2) * V_1[b, b, a]
      Cov[id_1y, id_2y] = 2 * m_[b] * n_[b] * n_[a] * (n_[b] - 1) / (N_[b] - 2) * V_1[b, b, a]
      Cov[id_2x, id_1y] = -2 * m_[b] * n_[b] * (n_[b] - 1) * m_[a] / (N_[b] - 2) * V_1[b, b, a]
      Cov[id_1x, id_2y] = -2 * m_[b] * n_[b] * (m_[b] - 1) * n_[a] / (N_[b] - 2) * V_1[b, b, a]
    }
  }

  ## type 4: {a,b,a,c}
  for (a in 1:n_ptn) {
    for (b in 1:a) {
      for (c in 1:b) {
        if (any(a == b, a == c, b == c)) {
          next
        }
        # {a,c,a,b}
        id_1x = (a - 1) * n_ptn + c
        id_1y = id_1x + n_ptn * n_ptn
        id_2x = (a - 1) * n_ptn + b
        id_2y = id_2x + n_ptn * n_ptn
        Cov[id_1x, id_2x] = m_[a] * n_[a] * m_[b] * m_[c] / (N_[a] - 1) * V_1[a, b, c]
        Cov[id_1y, id_2y] = m_[a] * n_[a] * n_[b] * n_[c] / (N_[a] - 1) * V_1[a, b, c]
        Cov[id_1x, id_2y] = -m_[a] * n_[a] * n_[b] * m_[c] / (N_[a] - 1) * V_1[a, b, c]
        Cov[id_2x, id_1y] = -m_[a] * n_[a] * m_[b] * n_[c] / (N_[a] - 1) * V_1[a, b, c]
        # {b,c,a,b}
        id_1x = (b - 1) * n_ptn + c
        id_1y = id_1x + n_ptn * n_ptn
        id_2x = (a - 1) * n_ptn + b
        id_2y = id_2x + n_ptn * n_ptn
        Cov[id_1x, id_2x] = m_[b] * n_[b] * m_[a] * m_[c] / (N_[b] - 1) * V_1[b, a, c]
        Cov[id_1y, id_2y] = m_[b] * n_[b] * n_[a] * n_[c] / (N_[b] - 1) * V_1[b, a, c]
        Cov[id_1x, id_2y] = -m_[b] * n_[b] * n_[a] * m_[c] / (N_[b] - 1) * V_1[b, a, c]
        Cov[id_2x, id_1y] = -m_[b] * n_[b] * m_[a] * n_[c] / (N_[b] - 1) * V_1[b, a, c]
        # {b,c,a,c}
        id_1x = (b - 1) * n_ptn + c
        id_1y = id_1x + n_ptn * n_ptn
        id_2x = (a - 1) * n_ptn + c
        id_2y = id_2x + n_ptn * n_ptn
        Cov[id_1x, id_2x] = m_[c] * n_[c] * m_[a] * m_[b] / (N_[c] - 1) * V_1[c, a, b]
        Cov[id_1y, id_2y] = m_[c] * n_[c] * n_[a] * n_[b] / (N_[c] - 1) * V_1[c, a, b]
        Cov[id_1x, id_2y] = -m_[c] * n_[c] * n_[a] * m_[b] / (N_[c] - 1) * V_1[c, a, b]
        Cov[id_2x, id_1y] = -m_[c] * n_[c] * m_[a] * n_[b] / (N_[c] - 1) * V_1[c, a, b]
      }
    }
  }

  Lines = c()
  for (a in 1:n_ptn) {
    for (b in 1:a) {
      Lines = append(Lines, (a - 1) * n_ptn + b)
    }
  }
  mu = list(mx[Lines], my[Lines])
  Lines = c(Lines, Lines + n_ptn * n_ptn)
  Cov = Cov[Lines, Lines]
  Cov = Cov + t(Cov) - diag(diag(Cov))
  return(list(Cov, mu))
}


#' Covariance and Expectation (Congregated Form)
#'
#' @description
#' Computes the 2×2 covariance matrix and expectation vector (mu) for the congregated BRISE statistic (Ux, Uy), under the pattern-wise permutation null distribution.
#'
#' @param R Numeric symmetric rank matrix (N × N).
#' @param m_ Integer vector. X's sample sizes in each pattern.
#' @param n_ Integer vector. Y's sample sizes in each pattern.
#' @param ptn_list List of integer vectors that indexes observations sharing the same missing pattern.
#'
#' @return List with two elements:
#' \describe{
#'   \item{Cov}{2×2 covariance matrix for (Ux, Uy).}
#'   \item{mu}{Numeric vector length-2 giving expected values of (Ux, Uy) under null.}
#' }
#' @keywords internal
Cov_mu.c <- function(R, m_, n_, ptn_list) {
  N_ = m_ + n_
  n_ptn = length(m_)
  r_0 = matrix(0, nrow = n_ptn, ncol = n_ptn)
  Ri. = array(0, dim = c(n_ptn, n_ptn, max(N_)))
  r_1 = array(0, dim = c(n_ptn, n_ptn, n_ptn))
  r_2 = matrix(0, nrow = n_ptn, ncol = n_ptn)

  V_1 = array(0, dim = c(n_ptn, n_ptn, n_ptn))
  V_2 = matrix(0, nrow = n_ptn, ncol = n_ptn)

  for (i in 1:n_ptn) {
    Rii = R[ptn_list[[i]], ptn_list[[i]], drop = FALSE]
    r_0[i, i] = sum(Rii) / (N_[i] * (N_[i] - 1))
    r_2[i, i] = sum(Rii^2) / (N_[i] * (N_[i] - 1))
    Ri.[i, i, 1:N_[i]] = rowSums(Rii) / ((N_[i] - 1))

    for (j in (i + 1):n_ptn) {
      if (i == n_ptn) {
        next
      }
      Rij = R[ptn_list[[i]], ptn_list[[j]], drop = FALSE]
      r_0[i, j] = sum(Rij) / (N_[i] * N_[j])
      r_0[j, i] = r_0[i, j]
      r_2[i, j] = sum(Rij^2) / (N_[i] * N_[j])
      r_2[j, i] = r_2[i, j]
      Ri.[i, j, 1:N_[i]] = rowSums(Rij) / (N_[j])
      Rji = R[ptn_list[[j]], ptn_list[[i]], drop = FALSE]
      Ri.[j, i, 1:N_[j]] = rowSums(Rji) / (N_[i])
    }
  }

  for (i in 1:n_ptn) {
    for (j in 1:n_ptn) {
      for (k in 1:n_ptn) {
        r_1[i, j, k] = Ri.[i, j, ] %*% Ri.[i, k, ] / N_[i]
        V_1[i, j, k] = r_1[i, j, k] - r_0[i, j] * r_0[i, k]
      }
      V_2[i, j] = r_2[i, j] - r_0[i, j]^2
    }
  }

  ## Var(Ux)
  vx = 0
  for (a in 1:n_ptn) {
    vx = vx + 2 * m_[a] * n_[a] * (m_[a] - 1) / ((N_[a] - 2) * (N_[a] - 3)) *
      ((n_[a] - 1) * V_2[a, a] + 2 * (m_[a] - 2) * (N_[a] - 1) * V_1[a, a, a])
  }

  for (a in 1:n_ptn) {
    for (b in 1:n_ptn) {
      if (a == b) {
        next
      }
      vx = vx + 2 * m_[a] * n_[a] * m_[b] / ((N_[a] - 1) * (N_[b] - 1)) *
        (n_[b] * V_2[a, b] + 4 * (m_[a] - 1) / (N_[a] - 2) * (N_[a] - 1) *
           (N_[b] - 1) * V_1[a, a, b] +
           2 * (m_[b] - 1) * N_[b] * V_1[a, b, b])
    }
  }

  for (a in 1:n_ptn) {
    for (b in 1:n_ptn) {
      for (c in 1:n_ptn) {
        if (any(a == b, a == c, b == c)) {
          next
        }
        vx = vx + 4 * m_[a] * n_[a] * m_[b] * m_[c] / (N_[a] - 1) * V_1[a, b, c]
      }
    }
  }

  ## E(Ux)
  mx = 0
  for (a in 1:n_ptn) {
    mx = mx + m_[a] * (m_[a] - 1) * r_0[a, a]
  }

  for (a in 1:n_ptn) {
    for (b in 1:n_ptn) {
      if (a == b) {
        next
      }
      mx = mx + m_[a] * m_[b] * r_0[a, b]
    }
  }

  ## Var(Uy)
  vy = 0
  for (a in 1:n_ptn) {
    vy = vy + 2 * m_[a] * n_[a] * (n_[a] - 1) / ((N_[a] - 2) * (N_[a] - 3)) *
      ((m_[a] - 1) * V_2[a, a] + 2 * (n_[a] - 2) * (N_[a] - 1) * V_1[a, a, a])
  }

  for (a in 1:n_ptn) {
    for (b in 1:n_ptn) {
      if (a == b) {
        next
      }
      vy = vy + 2 * m_[a] * n_[a] * n_[b] / ((N_[a] - 1) * (N_[b] - 1)) *
        (m_[b] * V_2[a, b] + 4 * (n_[a] - 1) / (N_[a] - 2) * (N_[a] - 1) *
           (N_[b] - 1) * V_1[a, a, b] +
           2 * (n_[b] - 1) * N_[b] * V_1[a, b, b])
    }
  }

  for (a in 1:n_ptn) {
    for (b in 1:n_ptn) {
      for (c in 1:n_ptn) {
        if (any(a == b, a == c, b == c)) {
          next
        }
        vy = vy + 4 * m_[a] * n_[a] * n_[b] * n_[c] / (N_[a] - 1) * V_1[a, b, c]
      }
    }
  }

  ## E(Uy)
  my = 0
  for (a in 1:n_ptn) {
    my = my + n_[a] * (n_[a] - 1) * r_0[a, a]
  }

  for (a in 1:n_ptn) {
    for (b in 1:n_ptn) {
      if (a == b) {
        next
      }
      my = my + n_[a] * n_[b] * r_0[a, b]
    }
  }

  ## Cov(Ux,Uy)
  vxy = 0
  for (a in 1:n_ptn) {
    vxy = vxy + 2 * m_[a] * n_[a] * (m_[a] - 1) * (n_[a] - 1) / ((N_[a] - 2) * (N_[a] - 3)) *
      (V_2[a, a] - 2 * (N_[a] - 1) * V_1[a, a, a])
  }

  for (a in 1:n_ptn) {
    for (b in 1:n_ptn) {
      if (a == b) {
        next
      }
      vxy = vxy + 2 * m_[a] * n_[a] * m_[b] * n_[b] / ((N_[a] - 1) * (N_[b] - 1)) *
        (V_2[a, b] + 2 * ((m_[a] - 1) / (N_[a] - 2) * (1 - n_[b] / m_[b]) -
                            1) *
           (N_[a] - 1) * (N_[b] - 1) / n_[b] * V_1[a, a, b] -
           2 * N_[b] * V_1[a, b, b])
    }
  }

  for (a in 1:n_ptn) {
    for (b in 1:n_ptn) {
      for (c in 1:n_ptn) {
        if (any(a == b, a == c, b == c)) {
          next
        }
        vxy = vxy - 4 * m_[a] * n_[a] * m_[b] * n_[c] / (N_[a] - 1) * V_1[a, b, c]
      }
    }
  }
  Cov = matrix(c(vx, vxy, vxy, vy), nrow = 2)
  mu = c(mx, my)
  return(list(Cov, mu))
}
