#' Mixed Membership Visualization
#' 
#' 
#' Plots \eqn{\theta}, the parameters which govern the distributions of variables in a mixed membership model. \eqn{\theta_{j,k}} governs the
#' distribution of variable j for sub-population k. The model fit is shown in black, and the comparison (if available) is shown in red.
#' 
#' This is the function called by the plot generic function for mixedMemModel objects 
#'  
#' @param model the \code{mixedMemModel} object that will be plotted
#' @param compare estimates to compare against. Should be an array with same dimensions as model$theta
#' @param main title of plot
#' @param varNames vector specifying labels for each variable
#' @param groupNames vector specifying labels for each sub-population
#' @param nrow the number of rows in each plot. If the argument is not specified, all variables will appear in one plot
#' @param indices a vector which indicates specific variables to plot. If the argument is not specified, all variables will be plotted
#' @param fitNames the names of the models plotted
vizTheta = function(model, compare = NULL, main = "Estimated Theta",
                    varNames = NULL, groupNames = NULL,nrow = NULL, fitNames = NULL, indices = NULL) {
  
  # Internal Variables to set
  h.space <- .25
  v.space <- .1
  
  if(is.null(varNames)) {
    varNames <- paste("Var", c(1:model$J))
  }
  
  if(is.null(groupNames)) {
    groupNames <- paste("Group", c(1:model$K))
  }
  
  if (is.null(nrow)) {
    nrow <- model$J
  }
  
  if(is.null(indices)) {
    indices <- c(1:model$J)
  }
  
  if(is.null(fitNames)) {
    if(is.null(compare)) { 
      fitNames <- paste("Model", 1)
    } else {
      fitNames <- paste("Model", c(1:2))
    }
  }
  
  par(oma = c(3,5,3,1), mfrow = c(nrow, model$K), mar = rep(.1,4))
  count = 1
  for(j in indices)
  {
    if(model$dist[j]=="multinomial"|model$dist[j]=="rank")
    {
      for(k in 1:model$K)
      {
        plot(model$theta[j,k,], type = "p", lwd = 2, col = "black", ylim = c(-v.space, 1+v.space), xlim = c(h.space, model$Vj[j] + h.space),
             yaxt = "n", xaxt = "n", pch = 16)
        if(!is.null(compare)) {
          points(c(1:model$Vj[j]), compare[j,k,c(1:model$Vj[j])], col = "red", pch = 4, lwd = 1.5)
        }
        if (k == 1) {
          mtext(varNames[j], line = 3, side = 2, cex = .7)
          axis(side = 2, at = c(0,.5,1), labels = c(0,.5,1))
        }
        if(count == indices[length(indices)]| (count %% nrow) == 0) {
          mtext(paste(groupNames[k], sep = " "), line = .2, side = 1, cex = 1-min(model$J,10)*.4)
        }
      }
    } else if (model$dist[j] == "bernoulli") {
      for(k in 1:model$K) {
        plot(model$theta[j,k,], type = "p", lwd = 2, col = "black", ylim = c(-v.space,1+ v.space), xlim = c(h.space, 1 + h.space),
             yaxt = "n", xaxt = "n", pch = 16)
        if(!is.null(compare)) {
          points(c(1:model$Vj[j]), compare[j,k,], col = "red", pch = 4, lwd = 1.5)
        }
        if(k == 1){
          mtext(varNames[j], line = 3, side = 2, cex = .7)
          axis(side = 2, at = c(0,.5,1), labels = c(0,.5,1))
        }
        if(count == indices[length(indices)]| (count %% nrow) == 0) {
          mtext(paste(groupNames[k], sep = " "), line = .2, side = 1, cex = .8)
        }
      }
    } 
    if((count %% nrow) == 0) {
      title(main = main, outer = T, cex = 1.2)      
      par(fig = c(0, 1, 0, 1), oma = c(0,5,0,1), mar = rep(0, 4), new = T)
      
      plot(0, 0, type = "n", bty = "n", xaxt ="n", yaxt = "n")
      
      if(is.null(compare)){
        legend("bottom", legend = fitNames, pch = 19,
               col = "black", cex = .8)
        
      } else {
        legend("bottom", legend = fitNames,
               pch = c(19, 4), col = c("black", "red"), ncol = 2, cex = .8)
      }
      par(oma = c(3,5,3,1), mfrow = c(nrow, model$K), mar = rep(.1,4))
    }
    count = count + 1
  }
  title(main = main, outer = T, cex = 1.2)      
  par(fig = c(0, 1, 0, 1), oma = c(0,5,0,1), mar = rep(0, 4), new = T)
  
  plot(0, 0, type = "n", bty = "n", xaxt ="n", yaxt = "n")
  
  if(is.null(compare)){
    legend("bottom", legend = fitNames, pch = 19,
           col = "black", cex = .8)
    
  } else {
    legend("bottom", legend = fitNames,
           pch = c(19, 4), col = c("black", "red"), ncol = 2, cex = .8)
  }
  par(oma = c(3,5,3,1), mfrow = c(nrow, model$K), mar = rep(.1,4))
}


#' Mixed Membership Visualization
#' 
#' 
#' Plots estimates for individual group membership. The estimates used are the
#' normalized phi, which are the posterior means from the variational distribution. The estimated model is shown in black,
#' and the comparison (if available) is shown in red. 
#' 
#' This is the function called by the plot generic function for mixedMemModel objects
#' 
#' @param model the \code{mixedMemModel} object that will be plotted
#' @param compare estimates to compare against. This should be a matrix with same dimensions as model$phi
#' @param main title of plot
#' @param groupNames vector specifying labels for each sub-population
#' @param nrow the number of rows in each plot
#' @param ncol the number of columns in each plot
#' @param indices the specific individuals to plot. If the argument is left blank, all individuals will be plotted
#' @param fitNames the names of the models plotted
#' @export
vizMem <- function(model, compare = NULL, main = "Estimated Membership",
                   nrow = NULL, ncol = NULL,
                   indices = NULL, groupNames = NULL, fitNames = NULL) {
  # Internal Variables
  pch.list = c(0:14)
  h.space = .25
  v.space = .1
  
  if (is.null(indices)) {
    indices <- c(1:model$Total)   
  }
  
  if (is.null(nrow)) {
    nrow = floor(sqrt(length(indices)))
  }
  
  if(is.null(ncol)) {
    ncol = ceiling(length(indices)/nrow)
  }
  
  if (is.null(groupNames)) {
    groupNames = paste("Group", c(1:model$K))
  }
  
  if(is.null(fitNames)) {
    if(is.null(compare)) { 
      fitNames <- paste("Model", 1)
    } else {
      fitNames <- paste("Model", c(1:2))
    }
  }
  
  mem.est <- model$phi / rowSums(model$phi)
  
  par(oma = c(3,5,3,1), mfrow = c(nrow, ncol), mar = rep(.1,4))
  count <- 0
  for(i in indices)
  {
    count <- count + 1
    
    plot(mem.est[i,], type = "p", lwd = 2, col = "black", ylim = c(-v.space,1 + v.space), xlim = c(h.space, model$K + h.space),
         yaxt = "n", xaxt = "n", pch = pch.list)
    text(0, 1, labels = i , cex = .9, adj = c(0, 1), pos = 4)
    if((count %% ncol) == 1) {
      axis(2, at = c(0,.5,1), labels = c(0,.5,1), cex = 1)
    }
    
    if(!is.null(compare)){
      points(c(1:model$K), compare[i,], col = "red", pch = pch.list)
    }
    if((count %%(nrow*ncol)) == 0)
    {
      title(main = main, outer = T, cex = 1.2)
      mtext("Group Membership", side = 2, outer = T, line = 3)
      mtext("Groups", side = 1, outer = T, line = 0)
      
      par(fig = c(0, 1, 0, 1), oma = c(0,5,0,1), mar = rep(0, 4), new = T)
      
      plot(0, 0, type = "n", bty = "n", xaxt ="n", yaxt = "n")
      if(is.null(compare)){
        legend("bottom", legend = paste(fitNames[1], groupNames), pch = pch.list[1:model$K],
               col = rep("black", model$K), cex = .8, ncol = model$K)
        
      } else {
        legend("bottom", legend = c(paste(fitNames[1], groupNames),paste(fitNames[2], groupNames) ),
               pch = pch.list[1:model$K], col = rep(c("black", "red"), each = model$K), ncol = model$K, cex = .8)
      }
      par(oma = c(3,5,3,1), mfrow = c(nrow, ncol), mar = rep(.1,4))
    }
  }
  title(main = main, outer = T, cex = 1.2)
  mtext("Group Membership", side = 2, outer = T, line = 3)
  mtext("Groups", side = 1, outer = T, line = 0)
  
  par(fig = c(0, 1, 0, 1), oma = c(0,5,0,1), mar = rep(0, 4), new = T)
  
  plot(0, 0, type = "n", bty = "n", xaxt ="n", yaxt = "n")
  
  if(is.null(compare)){
    legend("bottom", legend = paste(fitNames[1], groupNames), pch = pch.list[1:model$K],
           col = rep("black", model$K), cex = .8, ncol = model$K)
    
  } else {
    legend("bottom", legend = c(paste(fitNames[1], groupNames),paste(fitNames[2], groupNames) ),
           pch = pch.list[1:model$K], col = rep(c("black", "red"), each = model$K), ncol = model$K, cex = .8)
  }
  par(oma = c(3,5,3,1), mfrow = c(nrow, ncol), mar = rep(.1,4))
}
