% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/hardhat.R
\name{tabnet_fit}
\alias{tabnet_fit}
\alias{tabnet_fit.default}
\alias{tabnet_fit.data.frame}
\alias{tabnet_fit.formula}
\alias{tabnet_fit.recipe}
\alias{tabnet_fit.Node}
\title{Tabnet model}
\usage{
tabnet_fit(x, ...)

\method{tabnet_fit}{default}(x, ...)

\method{tabnet_fit}{data.frame}(
  x,
  y,
  tabnet_model = NULL,
  config = tabnet_config(),
  ...,
  from_epoch = NULL,
  weights = NULL
)

\method{tabnet_fit}{formula}(
  formula,
  data,
  tabnet_model = NULL,
  config = tabnet_config(),
  ...,
  from_epoch = NULL,
  weights = NULL
)

\method{tabnet_fit}{recipe}(
  x,
  data,
  tabnet_model = NULL,
  config = tabnet_config(),
  ...,
  from_epoch = NULL,
  weights = NULL
)

\method{tabnet_fit}{Node}(
  x,
  tabnet_model = NULL,
  config = tabnet_config(),
  ...,
  from_epoch = NULL
)
}
\arguments{
\item{x}{Depending on the context:
\itemize{
\item A \strong{data frame} of predictors.
\item A \strong{matrix} of predictors.
\item A \strong{recipe} specifying a set of preprocessing steps
created from \code{\link[recipes:recipe]{recipes::recipe()}}.
\item A \strong{Node} where tree will be used as hierarchical outcome,
and attributes will be used as predictors.
}

The predictor data should be standardized (e.g. centered or scaled).
The model treats categorical predictors internally thus, you don't need to
make any treatment.
The model treats missing values internally thus, you don't need to make any
treatment.}

\item{...}{Model hyperparameters.
Any hyperparameters set here will update those set by the config argument.
See \code{\link[=tabnet_config]{tabnet_config()}} for a list of all possible hyperparameters.}

\item{y}{When \code{x} is a \strong{data frame} or \strong{matrix}, \code{y} is the outcome
specified as:
\itemize{
\item A \strong{data frame} with 1 or many numeric column (regression) or 1 or many categorical columns (classification) .
\item A \strong{matrix} with 1 column.
\item A \strong{vector}, either numeric or categorical.
}}

\item{tabnet_model}{A previously fitted \code{tabnet_model} object to continue the fitting on.
if \code{NULL} (the default) a brand new model is initialized.}

\item{config}{A set of hyperparameters created using the \code{tabnet_config} function.
If no argument is supplied, this will use the default values in \code{\link[=tabnet_config]{tabnet_config()}}.}

\item{from_epoch}{When a \code{tabnet_model} is provided, restore the network weights from a specific epoch.
Default is last available checkpoint for restored model, or last epoch for in-memory model.}

\item{weights}{Unused. Placeholder for hardhat::importance_weight() variables.}

\item{formula}{A formula specifying the outcome terms on the left-hand side,
and the predictor terms on the right-hand side.}

\item{data}{When a \strong{recipe} or \strong{formula} is used, \code{data} is specified as:
\itemize{
\item A \strong{data frame} containing both the predictors and the outcome.
}}
}
\value{
A TabNet model object. It can be used for serialization, predictions, or further fitting.
}
\description{
Fits the \href{https://arxiv.org/abs/1908.07442}{TabNet: Attentive Interpretable Tabular Learning} model
}
\section{Fitting a pre-trained model}{


When providing a parent \code{tabnet_model} parameter, the model fitting resumes from that model weights
at the following epoch:
\itemize{
\item last fitted epoch for a model already in torch context
\item Last model checkpoint epoch for a model loaded from file
\item the epoch related to a checkpoint matching or preceding the \code{from_epoch} value if provided
The model fitting metrics append on top of the parent metrics in the returned TabNet model.
}
}

\section{Multi-outcome}{


TabNet allows multi-outcome prediction, which is usually named \href{https://en.wikipedia.org/wiki/Multi-label_classification}{multi-label classification}
or multi-output regression when outcomes are numerical.
Multi-outcome currently expect outcomes to be either all numeric or all categorical.
}

\section{Threading}{


TabNet uses \code{torch} as its backend for computation and \code{torch} uses all
available threads by default.

You can control the number of threads used by \code{torch} with:

\if{html}{\out{<div class="sourceCode">}}\preformatted{torch::torch_set_num_threads(1)
torch::torch_set_num_interop_threads(1)
}\if{html}{\out{</div>}}
}

\examples{
\dontshow{if ((torch::torch_is_installed() && require("modeldata"))) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf}
\dontrun{
data("ames", package = "modeldata")
data("attrition", package = "modeldata")

## Single-outcome regression using formula specification
fit <- tabnet_fit(Sale_Price ~ ., data = ames, epochs = 4)

## Single-outcome classification using data-frame specification
attrition_x <- attrition[ids,-which(names(attrition) == "Attrition")]
fit <- tabnet_fit(attrition_x, attrition$Attrition, epochs = 4, verbose = TRUE)

## Multi-outcome regression on `Sale_Price` and `Pool_Area` in `ames` dataset using formula,
ames_fit <- tabnet_fit(Sale_Price + Pool_Area ~ ., data = ames, epochs = 4, valid_split = 0.2)

## Multi-label classification on `Attrition` and `JobSatisfaction` in
## `attrition` dataset using recipe
library(recipes)
rec <- recipe(Attrition + JobSatisfaction ~ ., data = attrition) \%>\%
  step_normalize(all_numeric(), -all_outcomes())

attrition_fit <- tabnet_fit(rec, data = attrition, epochs = 4, valid_split = 0.2)

## Hierarchical classification on  `acme`
data(acme, package = "data.tree")

acme_fit <- tabnet_fit(acme, epochs = 4, verbose = TRUE)

# Note: Model's number of epochs should be increased for publication-level results.
}
\dontshow{\}) # examplesIf}
}
