
### Restricted mean survival time based on a hazard (or cumulative hazard) function
#rmstFit = function(tau, h0 = NULL, H0 = function(x){x}) {
#  rms = integrate(psurv, 0, tau, h0=h0, H0 = H0, low.tail = FALSE)$value
#  return(rms)
#}

rmst.deepSurv = function(object, newdata=NULL, risk=NULL, tau=NULL, ...) {
  sfit = survfit(object)
  time = sfit$time
  chaz = sfit$cumhaz
  mdl  = object$model
  p    = mdl$input_shape
  if(is.null(tau)) tau = max(time)

  rmsfunRisk = function(risk, tm, chz, tau) {
    Ht = approxfun(tm, chz*risk, rule = 2)
    return(rmstFit(tau, H0 = Ht))
  }
  ### when there was not new data
  if(is.null(newdata)) {
    if(is.null(risk)) return(rmst(object$y, tau = tau)) ### rmst for response variable y
    else return(sapply(risk, rmsfunRisk, tm = time, chz = chaz, tau=tau))
  }
  
  ### when new data were provided
  if(!is.null(risk)) stop("Only one of the newdata or the risk can be non-NULL.")
  if(is(newdata, "matrix") == FALSE) 
    newdata =as.matrix(newdata, ncol = p)
  else if (ncol(newdata) != p) stop("newdata must be a matrix with ", p, "columns.")
  
  risk = exp(predict(mdl, newdata))
  return(sapply(risk, rmsfunRisk, tm = time, chz = chaz, tau=tau))
}

### example
#h = function(t) ht0 = exp(-5 + 2.5*cos(pi*t)+0.5*t^0.3)
#print(qsurv(0.5))
# H1 = function(x) x^3
# qsurv(seq(0.1, 0.9, 0.2), H0 = H1) ### shall be the same as
# qweibull(seq(0.1, 0.9, 0.2), 3)
