Custom loops with luz

library(torch)
library(luz)

Luz is a higher level API for torch that is designed to be highly flexible by providing a layered API that allows it to be useful no matter the level of control your need for your training loop.

In the getting started vignette we have seen the basics of luz and how to quickly modify parts of the training loop using callbacks and custom metrics. In this document we will describe how luz allows the user to get fine-grained control of the training loop.

Apart from the use of callbacks, there are three more ways that you can use luz (depending on how much control you need):

  • Multiple optimizers or losses: You might be optimizing two loss functions each with its own optimizer, but you still don’t want to modify the backward() - zero_grad() and step() calls. This is common in models like GANs (Generative Adversarial Networks) when you have competing neural networks trained with different losses and optimizers.

  • Fully flexible steps: You might want to be in control of how to call backward(), zero_grad()and step(). You might also want to have more control of gradient computation. For example, you might want to use ‘virtual batch sizes’, where you accumulate the gradients for a few steps before updating the weights.

  • Completely flexible loops: Your training loop can be anything you want but you still want to use luz to handle device placement of the dataloaders, optimizers and models. See vignette("accelerator").

Let’s consider a simplified version of the net that we implemented in the getting started vignette:

net <- nn_module(
  "Net",
  initialize = function() {
    self$fc1 <- nn_linear(100, 50)
    self$fc1 <- nn_linear(50, 10)
  },
  forward = function(x) {
    x %>% 
      self$fc1() %>% 
      nnf_relu() %>% 
      self$fc2()
  }
)

Using the highest level of luz API we would fit it using:

fitted <- net %>%
  setup(
    loss = nn_cross_entropy_loss(),
    optimizer = optim_adam,
    metrics = list(
      luz_metric_accuracy
    )
  ) %>%
  fit(train_dl, epochs = 10, valid_data = test_dl)

Multiple optimizers

Suppose we want to do an experiment where we train the first fully connected layer using a learning rate of 0.1 and the second one using a learning rate of 0.01. We will minimize the same nn_cross_entropy_loss() for both, but for the first layer we want to add L1 regularization on the weights.

In order to use luz for this, we will implement two methods in the net module:

  • set_optimizers: returns a named list of optimizers depending on the ctx.

  • loss: computes the loss depending on the selected optimizer.

Let’s go to the code:

net <- nn_module(
  "Net",
  initialize = function() {
    self$fc1 <- nn_linear(100, 50)
    self$fc1 <- nn_linear(50, 10)
  },
  forward = function(x) {
    x %>% 
      self$fc1() %>% 
      nnf_relu() %>% 
      self$fc2()
  },
  set_optimizers = function(lr_fc1 = 0.1, lr_fc2 = 0.01) {
    list(
      opt_fc1 = optim_adam(self$fc1$parameters, lr = lr_fc1),
      opt_fc2 = optim_adam(self$fc2$parameters, lr = lr_fc2)
    )
  },
  loss = function(input, target) {
    pred <- ctx$model(input)
  
    if (ctx$opt_name == "opt_fc1") 
      nnf_cross_entropy(pred, target) + torch_norm(self$fc1$weight, p = 1)
    else if (ctx$opt_name == "opt_fc2")
      nnf_cross_entropy(pred, target)
  }
)

Notice that the model optimizers will be initialized according to the set_optimizers() method’s return value (a list). In this case, we are initializing the optimizers using different model parameters and learning rates.

The loss() method is responsible for computing the loss that will then be back-propagated to compute gradients and update the weights. This loss() method can access the ctx object that will contain an opt_name field, describing which optimizer is currently being used. Note that this function will be called once for each optimizer for each training and validation step. See help("ctx") for complete information about the context object.

We can finally setup and fit this module, however we no longer need to specify optimizers and loss functions.

fitted <- net %>% 
  setup(metrics = list(luz_metric_accuracy)) %>% 
  fit(train_dl, epochs = 10, valid_data = test_dl)

Now let’s re-implement this same model using the slightly more flexible approach of overriding the training and validation step.

Fully flexible step

Instead of implementing the loss() method, we can implement the step() method. This allows us to flexibly modify what happens when training and validating for each batch in the dataset. You are now responsible for updating the weights by stepping the optimizers and back-propagating the loss.

net <- nn_module(
  "Net",
  initialize = function() {
    self$fc1 <- nn_linear(100, 50)
    self$fc1 <- nn_linear(50, 10)
  },
  forward = function(x) {
    x %>% 
      self$fc1() %>% 
      nnf_relu() %>% 
      self$fc2()
  },
  set_optimizers = function(lr_fc1 = 0.1, lr_fc2 = 0.01) {
    list(
      opt_fc1 = optim_adam(self$fc1$parameters, lr = lr_fc1),
      opt_fc2 = optim_adam(self$fc2$parameters, lr = lr_fc2)
    )
  },
  step = function() {
    ctx$loss <- list()
    for (opt_name in names(ctx$optimizers)) {
    
      ctx$pred <- ctx$model(ctx$input)
      opt <- ctx$optimizers[[opt_name]]
      loss <- nnf_cross_entropy(pred, target)
      
      if (opt_name == "opt_fc1") {
        # we have L1 regularization in layer 1
        loss <- nnf_cross_entropy(pred, target) + 
          torch_norm(self$fc1$weight, p = 1)
      }
        
      if (ctx$training) {
        opt$zero_grad()
        loss$backward()
        opt$step()  
      }
      
      ctx$loss[[opt_name]] <- loss$detach()
    }
  }
)

The important things to notice here are:

  • The step() method is used for both training and validation. You need to be careful to only modify the weights when training. Again, you can get complete information regarding the context object using help("ctx").

  • ctx$optimizers is a named list holding each optimizer that was created when the set_optimizers() method was called.

  • You need to manually track the losses by saving saving them in a named list in ctx$loss. By convention, we use the same name as the optimizer it refers to. It is good practice to detach() them before saving to reduce memory usage.

  • Callbacks that would be called inside the default step() method like on_train_batch_after_pred, on_train_batch_after_loss, etc, won’t be automatically called. You can still cal them manually by adding ctx$call_callbacks("<callback name>") inside your training step. See the code for fit_one_batch() and valid_one_batch to find all the callbacks that won’t be called.

  • If you want luz metrics to work with your custom step() method, you must assign ctx$pred with the model predictions as metrics will always be called with metric$update(ctx$pred, ctx$target).

Next steps

In this article you learned how to customize the step() of your training loop using luz layered functionality.

Luz also allows more flexible modifications of the training loop described in the Accelerator vignette (vignette("accelerator")).

You should now be able to follow the examples marked with the ‘intermediate’ and ‘advanced’ category in the examples gallery.