Title: | Higher Level 'API' for 'torch' |
---|---|
Description: | A high level interface for 'torch' providing utilities to reduce the the amount of code needed for common tasks, abstract away torch details and make the same code work on both the 'CPU' and 'GPU'. It's flexible enough to support expressing a large range of models. It's heavily inspired by 'fastai' by Howard et al. (2020) <arXiv:2002.04688>, 'Keras' by Chollet et al. (2015) and 'PyTorch Lightning' by Falcon et al. (2019) <doi:10.5281/zenodo.3828935>. |
Authors: | Daniel Falbel [aut, cre, cph], Christophe Regouby [ctb], RStudio [cph] |
Maintainer: | Daniel Falbel <[email protected]> |
License: | MIT + file LICENSE |
Version: | 0.4.0.9002 |
Built: | 2025-01-15 04:03:06 UTC |
Source: | https://github.com/mlverse/luz |
Create an accelerator
accelerator( device_placement = TRUE, cpu = FALSE, cuda_index = torch::cuda_current_device() )
accelerator( device_placement = TRUE, cpu = FALSE, cuda_index = torch::cuda_current_device() )
device_placement |
(logical) whether the |
cpu |
(logical) whether the training procedure should run on the CPU. |
cuda_index |
(integer) index of the CUDA device to use if multiple GPUs are available. Default: the result of torch::cuda_current_device(). |
as_dataloader
is used internally by luz to convert input
data
and valid_data
as passed to fit.luz_module_generator()
to a
torch::dataloader
as_dataloader(x, ...) ## S3 method for class 'dataset' as_dataloader(x, ..., batch_size = 32) ## S3 method for class 'iterable_dataset' as_dataloader(x, ..., batch_size = 32) ## S3 method for class 'list' as_dataloader(x, ...) ## S3 method for class 'dataloader' as_dataloader(x, ...) ## S3 method for class 'matrix' as_dataloader(x, ...) ## S3 method for class 'numeric' as_dataloader(x, ...) ## S3 method for class 'array' as_dataloader(x, ...) ## S3 method for class 'torch_tensor' as_dataloader(x, ...)
as_dataloader(x, ...) ## S3 method for class 'dataset' as_dataloader(x, ..., batch_size = 32) ## S3 method for class 'iterable_dataset' as_dataloader(x, ..., batch_size = 32) ## S3 method for class 'list' as_dataloader(x, ...) ## S3 method for class 'dataloader' as_dataloader(x, ...) ## S3 method for class 'matrix' as_dataloader(x, ...) ## S3 method for class 'numeric' as_dataloader(x, ...) ## S3 method for class 'array' as_dataloader(x, ...) ## S3 method for class 'torch_tensor' as_dataloader(x, ...)
x |
the input object. |
... |
Passed to |
batch_size |
(int, optional): how many samples per batch to load
(default: |
as_dataloader
methods should have sensible defaults for batch_size,
parallel workers, etc.
It allows users to quickly experiment with fit.luz_module_generator()
by not requiring
to create a torch::dataset and a torch::dataloader in simple
experiments.
as_dataloader(dataset)
: Converts a torch::dataset()
to a torch::dataloader()
.
as_dataloader(iterable_dataset)
: Converts a torch::iterable_dataset()
into a torch::dataloader()
as_dataloader(list)
: Converts a list of tensors or arrays with the same
size in the first dimension to a torch::dataloader()
as_dataloader(dataloader)
: Returns the same dataloader
as_dataloader(matrix)
: Converts the matrix to a dataloader
as_dataloader(numeric)
: Converts the numeric vector to a dataloader
as_dataloader(array)
: Converts the array to a dataloader
as_dataloader(torch_tensor)
: Converts the tensor to a dataloader
You can implement your own as_dataloader
S3 method if you want your data
structure to be automatically supported by luz's fit.luz_module_generator()
.
The method must satisfy the following conditions:
The method should return a torch::dataloader()
.
The only required argument is x
. You have good default for all other
arguments.
It's better to avoid implementing as_dataloader
methods for common S3 classes
like data.frames
. In this case, its better to assign a different class to
the inputs and implement as_dataloader
for it.
Context object storing information about the model training context. See also ctx.
buffers
This is a list of buffers that callbacks can use to write temporary
information into ctx
.
records
stores information about values logged with self$log
.
device
allows querying the current accelerator device
callbacks
list of callbacks that will be called.
iter
current iteration
batch
the current batch data. a list with input data and targets.
input
a shortcut for ctx$batch[[1]]
target
a shortcut for ctx$batch[[2]]
min_epochs
the minimum number of epochs that the model will run on.
max_epochs
the maximum number of epochs that the model will run.
hparams
a list of hyperparameters that were used to initialize ctx$model
.
opt_hparams
a list of hyperparameters used to initialize the ctx$optimizers
.
train_data
a dataloader that is used for training the model
valid_data
a dataloader using during model validation
accelerator
an accelerator()
used to move data, model and etc the the correct
device.
optimizers
a named list of optimizers that will be used during model training.
verbose
bool wether the process is in verbose mode or not.
handlers
List of error handlers that can be used. See rlang::try_fetch()
for more info.
epoch_handlers
List of error handlers that can be used. See rlang::try_fetch()
for more info.
training
A bool indicating if the model is in training or validation mode.
model
The model being trained.
pred
Last predicted values.
opt
Current optimizer.
opt_name
Current optimizer name.
data
Current dataloader in use.
loss_fn
Loss function used to train the model
loss
Last computed loss values. Detached from the graph.
loss_grad
Last computed loss value, not detached, so you can do additional tranformation.
epoch
Current epoch.
metrics
List of metrics that are tracked by the process.
step_opt
Defines how step is called for the optimizer. It must be a function taking an optimizer as argument.
new()
Initializes the context object with minimal necessary information.
context$new(verbose, accelerator, callbacks, training)
verbose
Whether the context should be in verbose mode or not.
accelerator
A luz accelerator()
that configures device placement and
others.
callbacks
A list of callbacks used by the model. See luz_callback()
.
training
A boolean that indicates if the context is in training mode or not.
log()
Allows logging arbitrary information in the ctx
.
context$log(what, set, value, index = NULL, append = TRUE)
what
(string) What you are logging.
set
(string) Usually 'train' or 'valid' indicating the set you want to log to. But can be arbitrary info.
value
Arbitrary value to log.
index
Index that this value should be logged. If NULL
the value
is added to the end of list, otherwise the index is used.
append
If TRUE
and a value in the corresponding index already
exists, then value is appended to the current value. If FALSE
value
is overwritten in favor of the new value.
log_metric()
Log a metric by its name and value. Metric values are indexed by epoch.
context$log_metric(name, value)
name
name of the metric
value
Arbitrary value to log.
get_log()
Get a specific value from the log.
context$get_log(what, set, index = NULL)
what
(string) What you are logging.
set
(string) Usually 'train' or 'valid' indicating the set you want to log to. But can be arbitrary info.
index
Index that this value should be logged. If NULL
the value
is added to the end of list, otherwise the index is used.
get_metrics()
Get all metric given an epoch and set.
context$get_metrics(set, epoch = NULL)
set
(string) Usually 'train' or 'valid' indicating the set you want to log to. But can be arbitrary info.
epoch
The epoch you want to extract metrics from.
get_metric()
Get the value of a metric given its name, epoch and set.
context$get_metric(name, set, epoch = NULL)
name
name of the metric
set
(string) Usually 'train' or 'valid' indicating the set you want to log to. But can be arbitrary info.
epoch
The epoch you want to extract metrics from.
get_formatted_metrics()
Get formatted metrics values
context$get_formatted_metrics(set, epoch = NULL)
set
(string) Usually 'train' or 'valid' indicating the set you want to log to. But can be arbitrary info.
epoch
The epoch you want to extract metrics from.
get_metrics_df()
Get a data.frame containing all metrics.
context$get_metrics_df()
set_verbose()
Allows setting the verbose
attribute.
context$set_verbose(verbose = NULL)
verbose
boolean. If TRUE
verbose mode is used. If FALSE
non verbose.
if NULL
we use the result of interactive()
.
clean()
Removes unnecessary information from the context object.
context$clean()
call_callbacks()
Call the selected callbacks. Where name
is the callback types to call, eg
'on_epoch_begin'.
context$call_callbacks(name)
name
name of the metric
state_dict()
Returns a list containing minimal information from the context. Used to create the returned values.
context$state_dict()
unsafe_set_records()
Are you sure you know what you are doing?
context$unsafe_set_records(records)
records
New set of records to be set.
clone()
The objects of this class are cloneable with this method.
context$clone(deep = FALSE)
deep
Whether to make a deep clone.
Context objects used in luz to share information between model methods, metrics and callbacks.
The ctx
object is used in luz to share information between the
training loop and callbacks, model methods, and metrics. The table below
describes information available in the ctx
by default. Other callbacks
could potentially modify these attributes or add new ones.
Attribute | Description |
verbose |
The value (TRUE or FALSE ) attributed to the verbose argument in fit . |
accelerator |
Accelerator object used to query the correct device to place models, data, etc. It assumes the value passed to the accelerator parameter in fit . |
model |
Initialized nn_module object that will be trained during the fit procedure. |
optimizers |
A named list of optimizers used during training. |
data |
The currently in-use dataloader. When training it’s ctx$train_data , when doing validation its ctx$valid_data . It can also be the prediction dataset when in predict . |
train_data |
Dataloader passed to the data argument in fit . Modified to yield data in the selected device. |
valid_data |
Dataloader passed to the valid_data argument in fit . Modified to yield data in the selected device. |
min_epochs |
Minimum number of epochs the model will be trained for. |
max_epochs |
Maximum number of epochs the model will be trained for. |
epoch |
Current training epoch. |
iter |
Current training iteration. It’s reset every epoch and when going from training to validation. |
training |
Whether the model is in training or validation mode. See also help("luz_callback_train_valid") |
callbacks |
List of callbacks that will be called during the training procedure. It’s the union of the list passed to the callbacks parameter and the default callbacks . |
step |
Closure that will be used to do one step of the model. It’s used for both training and validation. Takes no argument, but can access the ctx object. |
call_callbacks |
Call callbacks by name. For example call_callbacks("on_train_begin") will call all callbacks that provide methods for this point. |
batch |
Last batch obtained by the dataloader. A batch is a list() with 2 elements, one that is used as input and the other as target . |
input |
First element of the last batch obtained by the current dataloader. |
target |
Second element of the last batch obtained by the current dataloader. |
pred |
Last predictions obtained by ctx$model$forward . Note: can be potentially modified by previously ran callbacks. Also note that this might not be available if you used a custom training step. |
loss_fn |
The active loss function that will be minimized during training. |
loss |
Last computed loss from the model. Note: this might not be available if you modified the training or validation step. |
opt |
Current optimizer, ie. the optimizer that will be used to do the next step to update parameters. |
opt_nm |
Current optimizer name. By default it’s opt , but can change if your model uses more than one optimizer depending on the set of parameters being optimized. |
metrics |
list() with current metric objects that are update d at every on_train_batch_end() or on_valid_batch_end() . See also help("luz_callback_metrics") |
records |
list() recording metric values for training and validation for each epoch. See also help("luz_callback_metrics") . Also records profiling metrics. See help("luz_callback_profile") for more information. |
handlers |
A named list() of handlers that is passed to rlang::with_handlers() during the training loop and can be used to handle errors or conditions that might be raised by other callbacks. |
epoch_handlers |
A named list of handlers that is used with rlang::with_handlers() . Those handlers are used inside the epochs loop, thus you can handle epoch specific conditions, that won’t necessarily end training. |
Context attributes
Context object: context
Evaluates a fitted model on a dataset
evaluate( object, data, ..., metrics = NULL, callbacks = list(), accelerator = NULL, verbose = NULL, dataloader_options = NULL )
evaluate( object, data, ..., metrics = NULL, callbacks = list(), accelerator = NULL, verbose = NULL, dataloader_options = NULL )
object |
A fitted model to evaluate. |
data |
(dataloader, dataset or list) A dataloader created with
|
... |
Currently unused. |
metrics |
A list of luz metrics to be tracked during evaluation. If |
callbacks |
(list, optional) A list of callbacks defined with
|
accelerator |
(accelerator, optional) An optional |
verbose |
(logical, optional) An optional boolean value indicating if
the fitting procedure should emit output to the console during training.
By default, it will produce output if |
dataloader_options |
Options used when creating a dataloader. See
|
Once a model has been trained you might want to evaluate its performance
on a different dataset. For that reason, luz provides the ?evaluate
function that takes a fitted model and a dataset and computes the
metrics attached to the model.
Evaluate returns a luz_module_evaluation
object that you can query for
metrics using the get_metrics
function or simply print
to see the
results.
For example:
evaluation <- fitted %>% evaluate(data = valid_dl) metrics <- get_metrics(evaluation) print(evaluation)
## A `luz_module_evaluation` ## -- Results --------------------------------------------------------------------- ## loss: 1.5146 ## mae: 1.0251 ## mse: 1.5159 ## rmse: 1.2312
Other training:
fit.luz_module_generator()
,
predict.luz_module_fitted()
,
setup()
nn_module
Fit a nn_module
## S3 method for class 'luz_module_generator' fit( object, data, epochs = 10, callbacks = NULL, valid_data = NULL, accelerator = NULL, verbose = NULL, ..., dataloader_options = NULL )
## S3 method for class 'luz_module_generator' fit( object, data, epochs = 10, callbacks = NULL, valid_data = NULL, accelerator = NULL, verbose = NULL, ..., dataloader_options = NULL )
object |
An |
data |
(dataloader, dataset or list) A dataloader created with
|
epochs |
(int) The maximum number of epochs for training the model. If a
single value is provided, this is taken to be the |
callbacks |
(list, optional) A list of callbacks defined with
|
valid_data |
(dataloader, dataset, list or scalar value; optional) A
dataloader created with |
accelerator |
(accelerator, optional) An optional |
verbose |
(logical, optional) An optional boolean value indicating if
the fitting procedure should emit output to the console during training.
By default, it will produce output if |
... |
Currently unused. |
dataloader_options |
Options used when creating a dataloader. See
|
A fitted object that can be saved with luz_save()
and can be
printed with print()
and plotted with plot()
.
predict.luz_module_fitted()
for how to create predictions.
setup()
to find out how to create modules that can be trained with fit
.
Other training:
evaluate()
,
predict.luz_module_fitted()
,
setup()
Get metrics from the object
get_metrics(object, ...) ## S3 method for class 'luz_module_fitted' get_metrics(object, ...)
get_metrics(object, ...) ## S3 method for class 'luz_module_fitted' get_metrics(object, ...)
object |
The object to query for metrics. |
... |
Currently unused. |
A data.frame containing the metric values.
get_metrics(luz_module_fitted)
: Extract metrics from a luz fitted model.
Learning Rate Finder
lr_finder( object, data, steps = 100, start_lr = 1e-07, end_lr = 0.1, log_spaced_intervals = TRUE, ..., verbose = NULL )
lr_finder( object, data, steps = 100, start_lr = 1e-07, end_lr = 0.1, log_spaced_intervals = TRUE, ..., verbose = NULL )
object |
An nn_module that has been setup(). |
data |
(dataloader) A dataloader created with torch::dataloader() used for learning rate finding. |
steps |
(integer) The number of steps to iterate over in the learning rate finder. Default: 100. |
start_lr |
(float) The smallest learning rate. Default: 1e-7. |
end_lr |
(float) The highest learning rate. Default: 1e-1. |
log_spaced_intervals |
(logical) Whether to divide the range between start_lr and end_lr into log-spaced intervals (alternative: uniform intervals). Default: TRUE |
... |
Other arguments passed to |
verbose |
Wether to show a progress bar during the process. |
A dataframe with two columns: learning rate and loss
if (torch::torch_is_installed()) { library(torch) ds <- torch::tensor_dataset(x = torch_randn(100, 10), y = torch_randn(100, 1)) dl <- torch::dataloader(ds, batch_size = 32) model <- torch::nn_linear model <- model %>% setup( loss = torch::nn_mse_loss(), optimizer = torch::optim_adam ) %>% set_hparams(in_features = 10, out_features = 1) records <- lr_finder(model, dl, verbose = FALSE) plot(records) }
if (torch::torch_is_installed()) { library(torch) ds <- torch::tensor_dataset(x = torch_randn(100, 10), y = torch_randn(100, 1)) dl <- torch::dataloader(ds, batch_size = 32) model <- torch::nn_linear model <- model %>% setup( loss = torch::nn_mse_loss(), optimizer = torch::optim_adam ) %>% set_hparams(in_features = 10, out_features = 1) records <- lr_finder(model, dl, verbose = FALSE) plot(records) }
Create a new callback
luz_callback( name = NULL, ..., private = NULL, active = NULL, parent_env = parent.frame(), inherit = NULL )
luz_callback( name = NULL, ..., private = NULL, active = NULL, parent_env = parent.frame(), inherit = NULL )
name |
name of the callback |
... |
Public methods of the callback. The name of the methods is used to know how they should be called. See the details section. |
private |
An optional list of private members, which can be functions and non-functions. |
active |
An optional list of active binding functions. |
parent_env |
An environment to use as the parent of newly-created objects. |
inherit |
A R6ClassGenerator object to inherit from; in other words, a
superclass. This is captured as an unevaluated expression which is
evaluated in |
Let’s implement a callback that prints ‘Iteration n
’ (where n
is the
iteration number) for every batch in the training set and ‘Done’ when an
epoch is finished. For that task we use the luz_callback
function:
print_callback <- luz_callback( name = "print_callback", initialize = function(message) { self$message <- message }, on_train_batch_end = function() { cat("Iteration ", ctx$iter, "\n") }, on_epoch_end = function() { cat(self$message, "\n") } )
luz_callback()
takes named functions as ...
arguments, where the
name indicates the moment at which the callback should be called. For
instance on_train_batch_end()
is called for every batch at the end of
the training procedure, and on_epoch_end()
is called at the end of
every epoch.
The returned value of luz_callback()
is a function that initializes an
instance of the callback. Callbacks can have initialization parameters,
like the name of a file where you want to log the results. In that case,
you can pass an initialize
method when creating the callback
definition, and save these parameters to the self
object. In the above
example, the callback has a message
parameter that is printed at the
end of each epoch.
Once a callback is defined it can be passed to the fit
function via
the callbacks
parameter:
fitted <- net %>% setup(...) %>% fit(..., callbacks = list( print_callback(message = "Done!") ))
Callbacks can be called in many different positions of the training loop, including combinations of them. Here’s an overview of possible callback breakpoints:
Start Fit - on_fit_begin Start Epoch Loop - on_epoch_begin Start Train - on_train_begin Start Batch Loop - on_train_batch_begin Start Default Training Step - on_train_batch_after_pred - on_train_batch_after_loss - on_train_batch_before_backward - on_train_batch_before_step - on_train_batch_after_step End Default Training Step: - on_train_batch_end End Batch Loop - on_train_end End Train Start Valid - on_valid_begin Start Batch Loop - on_valid_batch_begin Start Default Validation Step - on_valid_batch_after_pred - on_valid_batch_after_loss End Default Validation Step - on_valid_batch_end End Batch Loop - on_valid_end End Valid - on_epoch_end End Epoch Loop - on_fit_end End Fit
Every step marked with on_*
is a point in the training procedure that
is available for callbacks to be called.
The other important part of callbacks is the ctx
(context) object. See
help("ctx")
for details.
By default, callbacks are called in the same order as they were passed
to fit
(or predict
or evaluate
), but you can provide a weight
attribute that will control the order in which it will be called. For
example, if one callback has weight = 10
and another has weight = 1
,
then the first one is called after the second one. Callbacks that don’t
specify a weight
attribute are considered weight = 0
. A few built-in
callbacks in luz already provide a weight value. For example, the
?luz_callback_early_stopping
has a weight of Inf
, since in general
we want to run it as the last thing in the loop.
A luz_callback
that can be passed to fit.luz_module_generator()
.
You can also use callbacks when using predict()
. In this case the supported
callback methods are detailed below:
Start predict - on_predict_begin Start prediction loop - on_predict_batch_begin - on_predict_batch_end End prediction loop - on_predict_end End predict
Callbacks can also be used with evaluate()
, in this case, the callbacks that
are used are equivalent to those of the validation loop when using fit()
:
Start Valid - on_valid_begin Start Batch Loop - on_valid_batch_begin Start Default Validation Step - on_valid_batch_after_pred - on_valid_batch_after_loss End Default Validation Step - on_valid_batch_end End Batch Loop - on_valid_end End Valid
Other luz_callbacks:
luz_callback_auto_resume()
,
luz_callback_csv_logger()
,
luz_callback_early_stopping()
,
luz_callback_interrupt()
,
luz_callback_keep_best_model()
,
luz_callback_lr_scheduler()
,
luz_callback_metrics()
,
luz_callback_mixed_precision()
,
luz_callback_mixup()
,
luz_callback_model_checkpoint()
,
luz_callback_profile()
,
luz_callback_progress()
,
luz_callback_resume_from_checkpoint()
,
luz_callback_train_valid()
print_callback <- luz_callback( name = "print_callback", on_train_batch_end = function() { cat("Iteration ", ctx$iter, "\n") }, on_epoch_end = function() { cat("Done!\n") } )
print_callback <- luz_callback( name = "print_callback", on_train_batch_end = function() { cat("Iteration ", ctx$iter, "\n") }, on_epoch_end = function() { cat("Done!\n") } )
This callback allows you to resume training a model.
luz_callback_auto_resume(path = "./state.pt")
luz_callback_auto_resume(path = "./state.pt")
path |
Path to save state files for the model. |
When using it, model weights, optimizer state are serialized at the end of each epoch. If something fails during training simply re-running the same script will restart the model training from the epoch right after the last epoch that was serialized.
By default model, optimizer state and records are serialized. Callbacks can
be used to customize serialization by implementing the state_dict()
and
load_state_dict()
methods.
If those methods are implemented, then state_dict()
is called at the end of
each epoch and load_state_dict()
is called when the model is resumed.
In general you will want to add this callback as the last in the callbacks
list, this way, the serialized state is likely to contain all possible changes
that other callbacks could have made at 'on_epoch_end'
. The default weight
attribute of this callback is Inf
.
Read the checkpointing article in the pkgdown website for more information.
Other luz_callbacks:
luz_callback_csv_logger()
,
luz_callback_early_stopping()
,
luz_callback_interrupt()
,
luz_callback_keep_best_model()
,
luz_callback_lr_scheduler()
,
luz_callback_metrics()
,
luz_callback_mixed_precision()
,
luz_callback_mixup()
,
luz_callback_model_checkpoint()
,
luz_callback_profile()
,
luz_callback_progress()
,
luz_callback_resume_from_checkpoint()
,
luz_callback_train_valid()
,
luz_callback()
if (torch::torch_is_installed()) { library(torch) library(luz) x <- torch_randn(1000, 10) y <- torch_randn(1000, 1) model <- nn_linear %>% setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>% set_hparams(in_features = 10, out_features = 1) %>% set_opt_hparams(lr = 0.01) # simulate a failure in the middle of epoch 5 happening only once. callback_stop <- luz_callback( "interrupt", failed = FALSE, on_epoch_end = function() { if (ctx$epoch == 5 && !self$failed) { self$failed <- TRUE stop("Error on epoch 5") } } ) path <- tempfile() autoresume <- luz_callback_auto_resume(path = path) interrupt <- callback_stop() # try once and the model fails try({ results <- model %>% fit( list(x, y), callbacks = list(autoresume, interrupt), verbose = FALSE ) }) # model resumes and completes results <- model %>% fit( list(x, y), callbacks = list(autoresume, interrupt), verbose = FALSE ) get_metrics(results) }
if (torch::torch_is_installed()) { library(torch) library(luz) x <- torch_randn(1000, 10) y <- torch_randn(1000, 1) model <- nn_linear %>% setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>% set_hparams(in_features = 10, out_features = 1) %>% set_opt_hparams(lr = 0.01) # simulate a failure in the middle of epoch 5 happening only once. callback_stop <- luz_callback( "interrupt", failed = FALSE, on_epoch_end = function() { if (ctx$epoch == 5 && !self$failed) { self$failed <- TRUE stop("Error on epoch 5") } } ) path <- tempfile() autoresume <- luz_callback_auto_resume(path = path) interrupt <- callback_stop() # try once and the model fails try({ results <- model %>% fit( list(x, y), callbacks = list(autoresume, interrupt), verbose = FALSE ) }) # model resumes and completes results <- model %>% fit( list(x, y), callbacks = list(autoresume, interrupt), verbose = FALSE ) get_metrics(results) }
Logs metrics obtained during training a file on disk. The file will have 1 line for each epoch/validation.
luz_callback_csv_logger(path)
luz_callback_csv_logger(path)
path |
path to a file on disk. |
Other luz_callbacks:
luz_callback()
,
luz_callback_auto_resume()
,
luz_callback_early_stopping()
,
luz_callback_interrupt()
,
luz_callback_keep_best_model()
,
luz_callback_lr_scheduler()
,
luz_callback_metrics()
,
luz_callback_mixed_precision()
,
luz_callback_mixup()
,
luz_callback_model_checkpoint()
,
luz_callback_profile()
,
luz_callback_progress()
,
luz_callback_resume_from_checkpoint()
,
luz_callback_train_valid()
Stops training when a monitored metric stops improving
luz_callback_early_stopping( monitor = "valid_loss", min_delta = 0, patience = 0, mode = "min", baseline = NULL )
luz_callback_early_stopping( monitor = "valid_loss", min_delta = 0, patience = 0, mode = "min", baseline = NULL )
monitor |
A string in the format |
min_delta |
Minimum improvement to reset the patience counter. |
patience |
Number of epochs without improving until stoping training. |
mode |
Specifies the direction that is considered an improvement. By default 'min' is used. Can also be 'max' (higher is better) and 'zero' (closer to zero is better). |
baseline |
An initial value that will be used as the best seen value
in the begining. Model will stopm training if no better than baseline value
is found in the first |
A luz_callback
that does early stopping.
This callback adds a on_early_stopping
callback that can be used to
call callbacks as soon as the model stops training.
If verbose=TRUE
in fit.luz_module_generator()
a message is printed when
early stopping.
Other luz_callbacks:
luz_callback_auto_resume()
,
luz_callback_csv_logger()
,
luz_callback_interrupt()
,
luz_callback_keep_best_model()
,
luz_callback_lr_scheduler()
,
luz_callback_metrics()
,
luz_callback_mixed_precision()
,
luz_callback_mixup()
,
luz_callback_model_checkpoint()
,
luz_callback_profile()
,
luz_callback_progress()
,
luz_callback_resume_from_checkpoint()
,
luz_callback_train_valid()
,
luz_callback()
cb <- luz_callback_early_stopping()
cb <- luz_callback_early_stopping()
By adding the GradientClip callback, the gradient norm_type
(default:2) norm
is clipped to at most max_norm
(default:1) using torch::nn_utils_clip_grad_norm_()
,
which can avoid loss divergence.
luz_callback_gradient_clip(max_norm = 1, norm_type = 2)
luz_callback_gradient_clip(max_norm = 1, norm_type = 2)
max_norm |
(float or int): max norm of the gradients |
norm_type |
(float or int): type of the used p-norm. Can be |
See FastAI documentation for the GradientClip callback.
Adds a handler that allows interrupting the training loop using ctrl + C
.
Also registers a on_interrupt
breakpoint so users can register callbacks to
be run on training loop interruption.
luz_callback_interrupt()
luz_callback_interrupt()
A luz_callback
In general you don't need to use these callback by yourself because it's always
included by default in fit.luz_module_generator()
.
Other luz_callbacks:
luz_callback_auto_resume()
,
luz_callback_csv_logger()
,
luz_callback_early_stopping()
,
luz_callback_keep_best_model()
,
luz_callback_lr_scheduler()
,
luz_callback_metrics()
,
luz_callback_mixed_precision()
,
luz_callback_mixup()
,
luz_callback_model_checkpoint()
,
luz_callback_profile()
,
luz_callback_progress()
,
luz_callback_resume_from_checkpoint()
,
luz_callback_train_valid()
,
luz_callback()
interrupt_callback <- luz_callback_interrupt()
interrupt_callback <- luz_callback_interrupt()
Each epoch, if there's improvement in the monitored metric we serialize the model weights to a temp file. When training is done, we reload weights from the best model.
luz_callback_keep_best_model( monitor = "valid_loss", mode = "min", min_delta = 0 )
luz_callback_keep_best_model( monitor = "valid_loss", mode = "min", min_delta = 0 )
monitor |
A string in the format |
mode |
Specifies the direction that is considered an improvement. By default 'min' is used. Can also be 'max' (higher is better) and 'zero' (closer to zero is better). |
min_delta |
Minimum improvement to reset the patience counter. |
Other luz_callbacks:
luz_callback_auto_resume()
,
luz_callback_csv_logger()
,
luz_callback_early_stopping()
,
luz_callback_interrupt()
,
luz_callback_lr_scheduler()
,
luz_callback_metrics()
,
luz_callback_mixed_precision()
,
luz_callback_mixup()
,
luz_callback_model_checkpoint()
,
luz_callback_profile()
,
luz_callback_progress()
,
luz_callback_resume_from_checkpoint()
,
luz_callback_train_valid()
,
luz_callback()
cb <- luz_callback_keep_best_model()
cb <- luz_callback_keep_best_model()
Initializes and runs torch::lr_scheduler()
s.
luz_callback_lr_scheduler( lr_scheduler, ..., call_on = "on_epoch_end", opt_name = NULL )
luz_callback_lr_scheduler( lr_scheduler, ..., call_on = "on_epoch_end", opt_name = NULL )
lr_scheduler |
A |
... |
Additional arguments passed to |
call_on |
The callback breakpoint that |
opt_name |
name of the optimizer that will be affected by this callback.
Should match the name given in |
A luz_callback()
generator.
Other luz_callbacks:
luz_callback_auto_resume()
,
luz_callback_csv_logger()
,
luz_callback_early_stopping()
,
luz_callback_interrupt()
,
luz_callback_keep_best_model()
,
luz_callback_metrics()
,
luz_callback_mixed_precision()
,
luz_callback_mixup()
,
luz_callback_model_checkpoint()
,
luz_callback_profile()
,
luz_callback_progress()
,
luz_callback_resume_from_checkpoint()
,
luz_callback_train_valid()
,
luz_callback()
if (torch::torch_is_installed()) { cb <- luz_callback_lr_scheduler(torch::lr_step, step_size = 30) }
if (torch::torch_is_installed()) { cb <- luz_callback_lr_scheduler(torch::lr_step, step_size = 30) }
Tracks metrics passed to setup()
during training and validation.
luz_callback_metrics()
luz_callback_metrics()
This callback takes care of 2 ctx attributes:
ctx$metrics
: stores the current metrics objects that are initialized once for epoch,
and are further update()
d and compute()
d every batch. You will rarely need
to work with these metrics.
ctx$records$metrics
: Stores metrics per training/validation and epoch. The
structure is very similar to ctx$losses
.
A luz_callback
In general you won't need to explicitly use the metrics callback as it's
used by default in fit.luz_module_generator()
.
Other luz_callbacks:
luz_callback_auto_resume()
,
luz_callback_csv_logger()
,
luz_callback_early_stopping()
,
luz_callback_interrupt()
,
luz_callback_keep_best_model()
,
luz_callback_lr_scheduler()
,
luz_callback_mixed_precision()
,
luz_callback_mixup()
,
luz_callback_model_checkpoint()
,
luz_callback_profile()
,
luz_callback_progress()
,
luz_callback_resume_from_checkpoint()
,
luz_callback_train_valid()
,
luz_callback()
This callback will enable torch::local_autocast()
training model forward
and during loss computation. It will then disable autocast and scale the loss
before backward()
and opt$step()
. See here
for more information.
luz_callback_mixed_precision(...)
luz_callback_mixed_precision(...)
... |
Passed to |
A luz_callback
Other luz_callbacks:
luz_callback_auto_resume()
,
luz_callback_csv_logger()
,
luz_callback_early_stopping()
,
luz_callback_interrupt()
,
luz_callback_keep_best_model()
,
luz_callback_lr_scheduler()
,
luz_callback_metrics()
,
luz_callback_mixup()
,
luz_callback_model_checkpoint()
,
luz_callback_profile()
,
luz_callback_progress()
,
luz_callback_resume_from_checkpoint()
,
luz_callback_train_valid()
,
luz_callback()
Implementation of 'mixup: Beyond Empirical Risk Minimization'.
As of today, tested only for categorical data,
where targets are expected to be integers, not one-hot encoded vectors.
This callback is supposed to be used together with nn_mixup_loss()
.
luz_callback_mixup(alpha = 0.4, ..., run_valid = FALSE, auto_loss = FALSE)
luz_callback_mixup(alpha = 0.4, ..., run_valid = FALSE, auto_loss = FALSE)
alpha |
parameter for the beta distribution used to sample mixing coefficients |
... |
currently unused. Just to force named arguments. |
run_valid |
Should it run during validation |
auto_loss |
Should it automatically modify the loss function? This will wrap
the loss function to create the mixup loss. If |
Overall, we follow the fastai implementation described here. Namely,
We work with a single dataloader only, randomly mixing two observations from the same batch.
We linearly combine losses computed for both targets:
loss(output, new_target) = weight * loss(output, target1) + (1-weight) * loss(output, target2)
We draw different mixing coefficients for every pair.
We replace weight
with weight = max(weight, 1-weight)
to avoid duplicates.
A luz_callback
Other luz_callbacks:
luz_callback_auto_resume()
,
luz_callback_csv_logger()
,
luz_callback_early_stopping()
,
luz_callback_interrupt()
,
luz_callback_keep_best_model()
,
luz_callback_lr_scheduler()
,
luz_callback_metrics()
,
luz_callback_mixed_precision()
,
luz_callback_model_checkpoint()
,
luz_callback_profile()
,
luz_callback_progress()
,
luz_callback_resume_from_checkpoint()
,
luz_callback_train_valid()
,
luz_callback()
if (torch::torch_is_installed()) { mixup_callback <- luz_callback_mixup() }
if (torch::torch_is_installed()) { mixup_callback <- luz_callback_mixup() }
This saves checkpoints of the model according to the specified metric and behavior.
luz_callback_model_checkpoint( path, monitor = "valid_loss", save_best_only = FALSE, mode = "min", min_delta = 0 )
luz_callback_model_checkpoint( path, monitor = "valid_loss", save_best_only = FALSE, mode = "min", min_delta = 0 )
path |
Path to save the model on disk. The path is interpolated with |
monitor |
A string in the format |
save_best_only |
if |
mode |
Specifies the direction that is considered an improvement. By default 'min' is used. Can also be 'max' (higher is better) and 'zero' (closer to zero is better). |
min_delta |
Minimum difference to consider as improvement. Only used when
|
mode
and min_delta
are only used when save_best_only=TRUE
.
save_best_only
will overwrite the saved models if the path
parameter
don't differentiate by epochs.
Read the checkpointing article in the pkgdown website for more information.
Other luz_callbacks:
luz_callback_auto_resume()
,
luz_callback_csv_logger()
,
luz_callback_early_stopping()
,
luz_callback_interrupt()
,
luz_callback_keep_best_model()
,
luz_callback_lr_scheduler()
,
luz_callback_metrics()
,
luz_callback_mixed_precision()
,
luz_callback_mixup()
,
luz_callback_profile()
,
luz_callback_progress()
,
luz_callback_resume_from_checkpoint()
,
luz_callback_train_valid()
,
luz_callback()
luz_callback_model_checkpoint(path= "path/to/dir") luz_callback_model_checkpoint(path= "path/to/dir/epoch-{epoch:02d}/model.pt") luz_callback_model_checkpoint(path= "path/to/dir/epoch-{epoch:02d}/model-{monitor:.2f}.pt")
luz_callback_model_checkpoint(path= "path/to/dir") luz_callback_model_checkpoint(path= "path/to/dir/epoch-{epoch:02d}/model.pt") luz_callback_model_checkpoint(path= "path/to/dir/epoch-{epoch:02d}/model-{monitor:.2f}.pt")
Computes the times for high-level operations in the training loops.
luz_callback_profile()
luz_callback_profile()
Records are saved in ctx$records$profile
. Times are stored as seconds.
Data is stored in the following structure:
fit time for the entire fit procedure.
epoch times per epoch
A luz_callback
In general you don't need to use these callback by yourself because it's always
included by default in fit.luz_module_generator()
.
Other luz_callbacks:
luz_callback_auto_resume()
,
luz_callback_csv_logger()
,
luz_callback_early_stopping()
,
luz_callback_interrupt()
,
luz_callback_keep_best_model()
,
luz_callback_lr_scheduler()
,
luz_callback_metrics()
,
luz_callback_mixed_precision()
,
luz_callback_mixup()
,
luz_callback_model_checkpoint()
,
luz_callback_progress()
,
luz_callback_resume_from_checkpoint()
,
luz_callback_train_valid()
,
luz_callback()
profile_callback <- luz_callback_profile()
profile_callback <- luz_callback_profile()
Responsible for printing progress during training.
luz_callback_progress()
luz_callback_progress()
A luz_callback
In general you don't need to use these callback by yourself because it's always
included by default in fit.luz_module_generator()
.
Printing can be disabled by passing verbose=FALSE
to fit.luz_module_generator()
.
Other luz_callbacks:
luz_callback_auto_resume()
,
luz_callback_csv_logger()
,
luz_callback_early_stopping()
,
luz_callback_interrupt()
,
luz_callback_keep_best_model()
,
luz_callback_lr_scheduler()
,
luz_callback_metrics()
,
luz_callback_mixed_precision()
,
luz_callback_mixup()
,
luz_callback_model_checkpoint()
,
luz_callback_profile()
,
luz_callback_resume_from_checkpoint()
,
luz_callback_train_valid()
,
luz_callback()
Allow resume model training from a specific checkpoint
luz_callback_resume_from_checkpoint( path, ..., restore_model_state = TRUE, restore_records = FALSE, restore_optimizer_state = FALSE, restore_callbacks_state = FALSE )
luz_callback_resume_from_checkpoint( path, ..., restore_model_state = TRUE, restore_records = FALSE, restore_optimizer_state = FALSE, restore_callbacks_state = FALSE )
path |
Path to the checkpoint that you want to resume. |
... |
currently unused. |
restore_model_state |
Wether to restore the model state from the checkpoint. |
restore_records |
Wether to restore records from the checkpoint. |
restore_optimizer_state |
Wether to restore the optimizer state from the checkpoint. |
restore_callbacks_state |
Wether to restore the callbacks state from the checkpoint. |
Read the checkpointing article in the pkgdown website for more information.
luz_callback_model_checkpoint()
Other luz_callbacks:
luz_callback()
,
luz_callback_auto_resume()
,
luz_callback_csv_logger()
,
luz_callback_early_stopping()
,
luz_callback_interrupt()
,
luz_callback_keep_best_model()
,
luz_callback_lr_scheduler()
,
luz_callback_metrics()
,
luz_callback_mixed_precision()
,
luz_callback_mixup()
,
luz_callback_model_checkpoint()
,
luz_callback_profile()
,
luz_callback_progress()
,
luz_callback_train_valid()
Logs metrics and other model information in the tfevents file format. Assuming tensorboard is installed, result can be visualized with
luz_callback_tfevents(logdir = "logs", histograms = FALSE, ...)
luz_callback_tfevents(logdir = "logs", histograms = FALSE, ...)
logdir |
A directory to where log will be written to. |
histograms |
A boolean specifying if histograms of model weights should
be logged. It can also be a character vector specifying the name of the parameters
that should be logged (names are the same as |
... |
Currently not used. For future expansion. |
tensorboard --logdir=logs
if (torch::torch_is_installed()) { library(torch) x <- torch_randn(1000, 10) y <- torch_randn(1000, 1) model <- nn_linear %>% setup(loss = nnf_mse_loss, optimizer = optim_adam) %>% set_hparams(in_features = 10, out_features = 1) %>% set_opt_hparams(lr = 1e-4) tmp <- tempfile() model %>% fit(list(x, y), valid_data = 0.2, callbacks = list( luz_callback_tfevents(tmp, histograms = TRUE) )) }
if (torch::torch_is_installed()) { library(torch) x <- torch_randn(1000, 10) y <- torch_randn(1000, 1) model <- nn_linear %>% setup(loss = nnf_mse_loss, optimizer = optim_adam) %>% set_hparams(in_features = 10, out_features = 1) %>% set_opt_hparams(lr = 1e-4) tmp <- tempfile() model %>% fit(list(x, y), valid_data = 0.2, callbacks = list( luz_callback_tfevents(tmp, histograms = TRUE) )) }
Switches important flags for training and evaluation modes.
luz_callback_train_valid()
luz_callback_train_valid()
It takes care of the three ctx
attributes:
ctx$model
: Responsible for calling ctx$model$train()
and ctx$model$eval()
,
when appropriate.
ctx$training
: Sets this flag to TRUE
when training and FALSE
when in
validation mode.
ctx$loss
: Resets the loss
attribute to list()
when finished training/ or
validating.
A luz_callback
In general you won't need to explicitly use the train_valid callback as it's
used by default in fit.luz_module_generator()
.
Other luz_callbacks:
luz_callback()
,
luz_callback_auto_resume()
,
luz_callback_csv_logger()
,
luz_callback_early_stopping()
,
luz_callback_interrupt()
,
luz_callback_keep_best_model()
,
luz_callback_lr_scheduler()
,
luz_callback_metrics()
,
luz_callback_mixed_precision()
,
luz_callback_mixup()
,
luz_callback_model_checkpoint()
,
luz_callback_profile()
,
luz_callback_progress()
,
luz_callback_resume_from_checkpoint()
Loads a fitted model. See documentation in luz_save()
.
luz_load(path)
luz_load(path)
path |
path in file system to the object. |
Other luz_save:
luz_save()
Works with checkpoints created typically with luz_callback_model_checkpoint()
.
luz_load_checkpoint(obj, path, ...)
luz_load_checkpoint(obj, path, ...)
obj |
Object to which we want to load the checkpoint. |
path |
Path of the checkpoint on disk. |
... |
unused. Is there to allow future extensions. |
This can be useful when you have saved model checkpoints during training and want to reload the best checkpoint in the end.
luz_load_model_weights(obj, path, ...) luz_save_model_weights(obj, path)
luz_load_model_weights(obj, path, ...) luz_save_model_weights(obj, path)
obj |
luz object to which you want to copy the new weights. |
path |
path to saved model in disk. |
... |
other arguments passed to |
Returns NULL
invisibly.
luz_save_model_weights
operates inplace, ie modifies the model object to contain the
new weights.
Creates a new luz metric
luz_metric( name = NULL, ..., private = NULL, active = NULL, parent_env = parent.frame(), inherit = NULL )
luz_metric( name = NULL, ..., private = NULL, active = NULL, parent_env = parent.frame(), inherit = NULL )
name |
string naming the new metric. |
... |
named list of public methods. You should implement at least
|
private |
An optional list of private members, which can be functions and non-functions. |
active |
An optional list of active binding functions. |
parent_env |
An environment to use as the parent of newly-created objects. |
inherit |
A R6ClassGenerator object to inherit from; in other words, a
superclass. This is captured as an unevaluated expression which is
evaluated in |
In order to implement a new luz_metric
we need to implement 3 methods:
initialize
: defines the metric initial state. This function is
called for each epoch for both training and validation loops.
update
: updates the metric internal state. This function is called
at every training and validation step with the predictions obtained by
the model and the target values obtained from the dataloader.
compute
: uses the internal state to compute metric values. This
function is called whenever we need to obtain the current metric
value. Eg, it’s called every training step for metrics displayed in
the progress bar, but only called once per epoch to record it’s value
when the progress bar is not displayed.
Optionally, you can implement an abbrev
field that gives the metric an
abbreviation that will be used when displaying metric information in the
console or tracking record. If no abbrev
is passed, the class name
will be used.
Let’s take a look at the implementation of luz_metric_accuracy
so you
can see how to implement a new one:
luz_metric_accuracy <- luz_metric( # An abbreviation to be shown in progress bars, or # when printing progress abbrev = "Acc", # Initial setup for the metric. Metrics are initialized # every epoch, for both training and validation initialize = function() { self$correct <- 0 self$total <- 0 }, # Run at every training or validation step and updates # the internal state. The update function takes `preds` # and `target` as parameters. update = function(preds, target) { pred <- torch::torch_argmax(preds, dim = 2) self$correct <- self$correct + (pred == target)$ to(dtype = torch::torch_float())$ sum()$ item() self$total <- self$total + pred$numel() }, # Use the internal state to query the metric value compute = function() { self$correct/self$total } )
Note: It’s good practice that the compute
metric returns regular R
values instead of torch tensors and other parts of luz will expect that.
Returns new luz metric.
Other luz_metrics:
luz_metric_accuracy()
,
luz_metric_binary_accuracy_with_logits()
,
luz_metric_binary_accuracy()
,
luz_metric_binary_auroc()
,
luz_metric_mae()
,
luz_metric_mse()
,
luz_metric_multiclass_auroc()
,
luz_metric_rmse()
luz_metric_accuracy <- luz_metric( # An abbreviation to be shown in progress bars, or # when printing progress abbrev = "Acc", # Initial setup for the metric. Metrics are initialized # every epoch, for both training and validation initialize = function() { self$correct <- 0 self$total <- 0 }, # Run at every training or validation step and updates # the internal state. The update function takes `preds` # and `target` as parameters. update = function(preds, target) { pred <- torch::torch_argmax(preds, dim = 2) self$correct <- self$correct + (pred == target)$ to(dtype = torch::torch_float())$ sum()$ item() self$total <- self$total + pred$numel() }, # Use the internal state to query the metric value compute = function() { self$correct/self$total } )
luz_metric_accuracy <- luz_metric( # An abbreviation to be shown in progress bars, or # when printing progress abbrev = "Acc", # Initial setup for the metric. Metrics are initialized # every epoch, for both training and validation initialize = function() { self$correct <- 0 self$total <- 0 }, # Run at every training or validation step and updates # the internal state. The update function takes `preds` # and `target` as parameters. update = function(preds, target) { pred <- torch::torch_argmax(preds, dim = 2) self$correct <- self$correct + (pred == target)$ to(dtype = torch::torch_float())$ sum()$ item() self$total <- self$total + pred$numel() }, # Use the internal state to query the metric value compute = function() { self$correct/self$total } )
Computes accuracy for multi-class classification problems.
luz_metric_accuracy()
luz_metric_accuracy()
This metric expects to take logits or probabilities at every update. It will then take the columnwise argmax and compare to the target.
Returns new luz metric.
Other luz_metrics:
luz_metric_binary_accuracy_with_logits()
,
luz_metric_binary_accuracy()
,
luz_metric_binary_auroc()
,
luz_metric_mae()
,
luz_metric_mse()
,
luz_metric_multiclass_auroc()
,
luz_metric_rmse()
,
luz_metric()
if (torch::torch_is_installed()) { library(torch) metric <- luz_metric_accuracy() metric <- metric$new() metric$update(torch_randn(100, 10), torch::torch_randint(1, 10, size = 100)) metric$compute() }
if (torch::torch_is_installed()) { library(torch) metric <- luz_metric_accuracy() metric <- metric$new() metric$update(torch_randn(100, 10), torch::torch_randint(1, 10, size = 100)) metric$compute() }
Computes the accuracy for binary classification problems where the
model returns probabilities. Commonly used when the loss is torch::nn_bce_loss()
.
luz_metric_binary_accuracy(threshold = 0.5)
luz_metric_binary_accuracy(threshold = 0.5)
threshold |
value used to classifiy observations between 0 and 1. |
Returns new luz metric.
Other luz_metrics:
luz_metric_accuracy()
,
luz_metric_binary_accuracy_with_logits()
,
luz_metric_binary_auroc()
,
luz_metric_mae()
,
luz_metric_mse()
,
luz_metric_multiclass_auroc()
,
luz_metric_rmse()
,
luz_metric()
if (torch::torch_is_installed()) { library(torch) metric <- luz_metric_binary_accuracy(threshold = 0.5) metric <- metric$new() metric$update(torch_rand(100), torch::torch_randint(0, 1, size = 100)) metric$compute() }
if (torch::torch_is_installed()) { library(torch) metric <- luz_metric_binary_accuracy(threshold = 0.5) metric <- metric$new() metric$update(torch_rand(100), torch::torch_randint(0, 1, size = 100)) metric$compute() }
Computes accuracy for binary classification problems where the model
return logits. Commonly used together with torch::nn_bce_with_logits_loss()
.
luz_metric_binary_accuracy_with_logits(threshold = 0.5)
luz_metric_binary_accuracy_with_logits(threshold = 0.5)
threshold |
value used to classifiy observations between 0 and 1. |
Probabilities are generated using torch::nnf_sigmoid()
and threshold
is used to
classify between 0 or 1.
Returns new luz metric.
Other luz_metrics:
luz_metric_accuracy()
,
luz_metric_binary_accuracy()
,
luz_metric_binary_auroc()
,
luz_metric_mae()
,
luz_metric_mse()
,
luz_metric_multiclass_auroc()
,
luz_metric_rmse()
,
luz_metric()
if (torch::torch_is_installed()) { library(torch) metric <- luz_metric_binary_accuracy_with_logits(threshold = 0.5) metric <- metric$new() metric$update(torch_randn(100), torch::torch_randint(0, 1, size = 100)) metric$compute() }
if (torch::torch_is_installed()) { library(torch) metric <- luz_metric_binary_accuracy_with_logits(threshold = 0.5) metric <- metric$new() metric$update(torch_randn(100), torch::torch_randint(0, 1, size = 100)) metric$compute() }
To avoid storing all predictions and targets for an epoch we compute confusion matrices across a range of pre-established thresholds.
luz_metric_binary_auroc( num_thresholds = 200, thresholds = NULL, from_logits = FALSE )
luz_metric_binary_auroc( num_thresholds = 200, thresholds = NULL, from_logits = FALSE )
num_thresholds |
Number of thresholds used to compute confusion matrices.
In that case, thresholds are created by getting |
thresholds |
(optional) If threshold are passed, then those are used to compute the
confusion matrices and |
from_logits |
Boolean indicating if predictions are logits, in that case we use sigmoid to put them in the unit interval. |
Other luz_metrics:
luz_metric_accuracy()
,
luz_metric_binary_accuracy_with_logits()
,
luz_metric_binary_accuracy()
,
luz_metric_mae()
,
luz_metric_mse()
,
luz_metric_multiclass_auroc()
,
luz_metric_rmse()
,
luz_metric()
if (torch::torch_is_installed()){ library(torch) actual <- c(1, 1, 1, 0, 0, 0) predicted <- c(0.9, 0.8, 0.4, 0.5, 0.3, 0.2) y_true <- torch_tensor(actual) y_pred <- torch_tensor(predicted) m <- luz_metric_binary_auroc(thresholds = predicted) m <- m$new() m$update(y_pred[1:2], y_true[1:2]) m$update(y_pred[3:4], y_true[3:4]) m$update(y_pred[5:6], y_true[5:6]) m$compute() }
if (torch::torch_is_installed()){ library(torch) actual <- c(1, 1, 1, 0, 0, 0) predicted <- c(0.9, 0.8, 0.4, 0.5, 0.3, 0.2) y_true <- torch_tensor(actual) y_pred <- torch_tensor(predicted) m <- luz_metric_binary_auroc(thresholds = predicted) m <- m$new() m$update(y_pred[1:2], y_true[1:2]) m$update(y_pred[3:4], y_true[3:4]) m$update(y_pred[5:6], y_true[5:6]) m$compute() }
Computes the mean absolute error.
luz_metric_mae()
luz_metric_mae()
Returns new luz metric.
Other luz_metrics:
luz_metric_accuracy()
,
luz_metric_binary_accuracy_with_logits()
,
luz_metric_binary_accuracy()
,
luz_metric_binary_auroc()
,
luz_metric_mse()
,
luz_metric_multiclass_auroc()
,
luz_metric_rmse()
,
luz_metric()
if (torch::torch_is_installed()) { library(torch) metric <- luz_metric_mae() metric <- metric$new() metric$update(torch_randn(100), torch_randn(100)) metric$compute() }
if (torch::torch_is_installed()) { library(torch) metric <- luz_metric_mae() metric <- metric$new() metric$update(torch_randn(100), torch_randn(100)) metric$compute() }
Computes the mean squared error
luz_metric_mse()
luz_metric_mse()
A luz_metric object.
Other luz_metrics:
luz_metric_accuracy()
,
luz_metric_binary_accuracy_with_logits()
,
luz_metric_binary_accuracy()
,
luz_metric_binary_auroc()
,
luz_metric_mae()
,
luz_metric_multiclass_auroc()
,
luz_metric_rmse()
,
luz_metric()
The same definition as Keras
is used by default. This is equivalent to the 'micro'
method in SciKit Learn
too. See docs.
luz_metric_multiclass_auroc( num_thresholds = 200, thresholds = NULL, from_logits = FALSE, average = c("micro", "macro", "weighted", "none") )
luz_metric_multiclass_auroc( num_thresholds = 200, thresholds = NULL, from_logits = FALSE, average = c("micro", "macro", "weighted", "none") )
num_thresholds |
Number of thresholds used to compute confusion matrices.
In that case, thresholds are created by getting |
thresholds |
(optional) If threshold are passed, then those are used to compute the
confusion matrices and |
from_logits |
If |
average |
The averaging method:
|
Note that class imbalance can affect this metric unlike the AUC for binary classification.
Currently the AUC is approximated using the 'interpolation' method described in Keras.
Other luz_metrics:
luz_metric_accuracy()
,
luz_metric_binary_accuracy_with_logits()
,
luz_metric_binary_accuracy()
,
luz_metric_binary_auroc()
,
luz_metric_mae()
,
luz_metric_mse()
,
luz_metric_rmse()
,
luz_metric()
if (torch::torch_is_installed()) { library(torch) actual <- c(1, 1, 1, 0, 0, 0) + 1L predicted <- c(0.9, 0.8, 0.4, 0.5, 0.3, 0.2) predicted <- cbind(1-predicted, predicted) y_true <- torch_tensor(as.integer(actual)) y_pred <- torch_tensor(predicted) m <- luz_metric_multiclass_auroc(thresholds = as.numeric(predicted), average = "micro") m <- m$new() m$update(y_pred[1:2,], y_true[1:2]) m$update(y_pred[3:4,], y_true[3:4]) m$update(y_pred[5:6,], y_true[5:6]) m$compute() }
if (torch::torch_is_installed()) { library(torch) actual <- c(1, 1, 1, 0, 0, 0) + 1L predicted <- c(0.9, 0.8, 0.4, 0.5, 0.3, 0.2) predicted <- cbind(1-predicted, predicted) y_true <- torch_tensor(as.integer(actual)) y_pred <- torch_tensor(predicted) m <- luz_metric_multiclass_auroc(thresholds = as.numeric(predicted), average = "micro") m <- m$new() m$update(y_pred[1:2,], y_true[1:2]) m$update(y_pred[3:4,], y_true[3:4]) m$update(y_pred[5:6,], y_true[5:6]) m$compute() }
Computes the root mean squared error.
luz_metric_rmse()
luz_metric_rmse()
Returns new luz metric.
Other luz_metrics:
luz_metric_accuracy()
,
luz_metric_binary_accuracy_with_logits()
,
luz_metric_binary_accuracy()
,
luz_metric_binary_auroc()
,
luz_metric_mae()
,
luz_metric_mse()
,
luz_metric_multiclass_auroc()
,
luz_metric()
A metric set can be used to specify metrics that are only evaluated during training, validation or both.
luz_metric_set(metrics = NULL, train_metrics = NULL, valid_metrics = NULL)
luz_metric_set(metrics = NULL, train_metrics = NULL, valid_metrics = NULL)
metrics |
A list of luz_metrics that are meant to be used in both training and validation. |
train_metrics |
A list of luz_metrics that are only used during training. |
valid_metrics |
A list of luz_metrics that are only sued for validation. |
Allows saving luz fitted models to the disk. Objects can be loaded back with
luz_load()
.
luz_save(obj, path, ...)
luz_save(obj, path, ...)
obj |
an object of class 'luz_module_fitted' as returned by
|
path |
path in file system to the object. |
... |
currently unused. |
The ctx is naively serialized. Ie, we only use saveRDS()
to serialize it.
Don't expect luz_save
to work correctly if you have unserializable objects
in the ctx like torch_tensor
s and external pointers in general.
Objects are saved as plain .rds
files but obj$model
is serialized
with torch_save
before saving it.
Other luz_save:
luz_load()
callbacks_mixup()
.In the training phase, computes individual losses with regard to two targets, weights them item-wise, and averages the linear combinations to yield the mean batch loss. For validation and testing, defers to the passed-in loss.
nn_mixup_loss(loss)
nn_mixup_loss(loss)
loss |
the underlying loss |
It should be used together with luz_callback_mixup()
.
Logic underlying luz_callback_mixup()
.
nnf_mixup(x, y, weight)
nnf_mixup(x, y, weight)
x |
an input batch |
y |
a target batch |
weight |
weighting coefficient to be used by |
Based on the passed-in input and target batches, as well as applicable mixing weights, we return new tensors intended to replace the current batch. The new input batch is a weighted linear combination of input batch items, while the new target batch bundles the original targets, as well as the mixing weights, in a nested list.
A list
of:
x
, the new, mixed-up input batch
y
, a list
of:
ys
, a list
of:
y1
, the original target y1
y2
, the mixed-in target y2
weight
, the mixing weights
if (torch::torch_is_installed()) { batch_x <- torch::torch_randn(c(10, 768)) batch_y <- torch::torch_randn(10) weight <- torch::torch_tensor(rep(0.9, 10))$view(c(10, 1)) nnf_mixup(batch_x, batch_y, weight) }
if (torch::torch_is_installed()) { batch_x <- torch::torch_randn(c(10, 768)) batch_y <- torch::torch_randn(10) weight <- torch::torch_tensor(rep(0.9, 10))$view(c(10, 1)) nnf_mixup(batch_x, batch_y, weight) }
Create predictions for a fitted model
## S3 method for class 'luz_module_fitted' predict( object, newdata, ..., callbacks = list(), accelerator = NULL, verbose = NULL, dataloader_options = NULL )
## S3 method for class 'luz_module_fitted' predict( object, newdata, ..., callbacks = list(), accelerator = NULL, verbose = NULL, dataloader_options = NULL )
object |
(fitted model) the fitted model object returned from |
newdata |
(dataloader, dataset, list or array) returning a list with at least 1 element. The other elements aren't used. |
... |
Currently unused. |
callbacks |
(list, optional) A list of callbacks defined with
|
accelerator |
(accelerator, optional) An optional |
verbose |
(logical, optional) An optional boolean value indicating if
the fitting procedure should emit output to the console during training.
By default, it will produce output if |
dataloader_options |
Options used when creating a dataloader. See
|
Other training:
evaluate()
,
fit.luz_module_generator()
,
setup()
This function is used to define hyper-parameters before calling fit
for
luz_modules
.
set_hparams(module, ...)
set_hparams(module, ...)
module |
An |
... |
The parameters set here will be used to initialize the |
The same luz module
Other set_hparam:
set_opt_hparams()
This function is used to define hyper-parameters for the optimizer initialization method.
set_opt_hparams(module, ...)
set_opt_hparams(module, ...)
module |
An |
... |
The parameters passed here will be used to initialize the optimizers.
For example, if your optimizer is |
The same luz module
Other set_hparam:
set_hparams()
nn_module
to use with luzThe setup function is used to set important attributes and method for nn_modules
to be used with luz.
setup(module, loss = NULL, optimizer = NULL, metrics = NULL, backward = NULL)
setup(module, loss = NULL, optimizer = NULL, metrics = NULL, backward = NULL)
module |
( |
loss |
( |
optimizer |
( |
metrics |
( |
backward |
( |
It makes sure the module have all the necessary ingredients in order to be fitted.
A luz module that can be trained with fit()
.
It also adds a device
active field that can be used to query the current
module device
within methods, with eg self$device
. This is useful when
ctx()
is not available, eg, when calling methods from outside the luz
wrappers. Users can override the default by implementing a device
active
method in the input module
.
Other training:
evaluate()
,
fit.luz_module_generator()
,
predict.luz_module_fitted()