--- title: "TorchScript" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{TorchScript} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>", eval = identical(Sys.getenv("TORCH_TEST", unset = "0"), "1"), purl = FALSE ) ``` ```{r setup} library(torch) ``` [TorchScript](https://pytorch.org/docs/stable/jit_language_reference.html#language-reference) is a statically typed subset of Python that can be interpreted by LibTorch without any Python dependency. The torch R package provides interfaces to create, serialize, load and execute TorchScript programs. Advantages of using TorchScript are: - TorchScript code can be invoked in its own interpreter, which is basically a restricted Python interpreter. This interpreter does not acquire the Global Interpreter Lock, and so many requests can be processed on the same instance simultaneously. - This format allows us to save the whole model to disk and load it into another environment, such as on server written in a language other than R. - TorchScript gives us a representation in which we can do compiler optimizations on the code to make execution more efficient. - TorchScript allows us to interface with many backend/device runtimes that require a broader view of the program than individual operators. ## Creating TorchScript programs ### Tracing TorchScript programs can be created from R using tracing. When using tracing, code is automatically converted into this subset of Python by recording only the actual operators on tensors and simply executing and discarding the other surrounding R code. Currently tracing is the only supported way to create TorchScript programs from R code. For example, let's use the `jit_trace` function to create a TorchScript program. We pass a regular R function and example inputs. ```{r} fn <- function(x) { torch_relu(x) } traced_fn <- jit_trace(fn, torch_tensor(c(-1, 0, 1))) ``` The `jit_trace` function has executed the R function with the example input and recorded all torch operations that occurred during execution to create a *graph*. *graph* is how we call the intermediate representation of TorchScript programs, and it can be inspected with: ```{r} traced_fn$graph ``` The traced function can now be invoked as a regular R function: ```{r} traced_fn(torch_randn(3)) ``` It's also possible to trace `nn_modules()` defined in R, for example: ```{r} module <- nn_module( initialize = function() { self$linear1 <- nn_linear(10, 10) self$linear2 <- nn_linear(10, 1) }, forward = function(x) { x %>% self$linear1() %>% nnf_relu() %>% self$linear2() } ) traced_module <- jit_trace(module(), torch_randn(10, 10)) ``` When using `jit_trace` with a `nn_module` only the `forward` method is traced. You can use the `jit_trace_module` function to pass example inputs to other methods. Traced modules look like normal `nn_modules()`, and can be called the same way: ```{r} traced_module(torch_randn(3, 10)) ``` #### Limitations of tracing 1. Tracing will not record any control flow like if-statements or loops. When this control flow is constant across your module, this is fine and it often inlines the control flow decisions. But sometimes the control flow is actually part of the model itself. For instance, a recurrent network is a loop over the (possibly dynamic) length of an input sequence. For example: ```{r, error=TRUE} # fn does does an operation for each dimension of a tensor fn <- function(x) { x %>% torch_unbind(dim = 1) %>% lapply(function(x) x$sum()) %>% torch_stack(dim = 1) } # we trace using as an example a tensor with size (10, 5, 5) traced_fn <- jit_trace(fn, torch_randn(10, 5, 5)) # applying it with a tensor with different size returns an error. traced_fn(torch_randn(11, 5, 5)) ``` 2. In the returned `ScriptModule`, operations that have different behaviors in training and eval modes will always behave as if it were in the mode it was in during tracing, no matter which mode the `ScriptModule` is in. For example: ```{r} traced_dropout <- jit_trace(nn_dropout(), torch_ones(5,5)) traced_dropout(torch_ones(3,3)) traced_dropout$eval() # even after setting to eval mode, dropout is applied traced_dropout(torch_ones(3,3)) ``` 3. Tracing proegrams can only take tensors and lists of tensors as input and return tensors and lists of tensors. For example: ```{r, error = TRUE} fn <- function(x, y) { x + y } jit_trace(fn, torch_tensor(1), 1) ``` ### Compiling TorchScript It's also possible to create TorchScript programs by compiling TorchScript code. TorchScript code looks a lot like standard python code. For example: ```{r} tr <- jit_compile(" def fn (x: Tensor): return torch.relu(x) ") tr$fn(torch_tensor(c(-1, 0, 1))) ``` ## Serializing and loading TorchScript programs can be serialized using the `jit_save` function and loaded back from disk with `jit_load`. For example: ```{r} fn <- function(x) { torch_relu(x) } tr_fn <- jit_trace(fn, torch_tensor(1)) jit_save(tr_fn, "path.pt") loaded <- jit_load("path.pt") ``` Loaded programs can be executed as usual: ```{r} loaded(torch_tensor(c(-1, 0, 1))) ``` **Note** You can load TorchScript programs that were created in libraries different than `torch` for R. Eg, a TorchScript program can be created in PyTorch with `torch.jit.trace` or `torch.jit.script`, and run from R. R objects are automatically converted to their TorchScript counterpart following the Types table in this document. However, sometimes it's necessary to make type annotations with `jit_tuple()` and `jit_scalar()` to disambiguate the conversion. ## Types The following table lists all TorchScript types and how to convert the to and back to R. +---------------------------+---------------------------------------------------------------------------------------+ | TorchScript Type | R Description | +===========================+=======================================================================================+ | `Tensor` | A `torch_tensor` with any shape, dtype or backend. | +---------------------------+---------------------------------------------------------------------------------------+ | `Tuple[T0, T1, ..., TN]` | A `list()` containing subtypes `T0`, `T1`, etc. wrapped with `jit_tuple()` . | +---------------------------+---------------------------------------------------------------------------------------+ | `bool` | A scalar logical value create using `jit_scalar`. | +---------------------------+---------------------------------------------------------------------------------------+ | `int` | A scalar integer value created using `jit_scalar`. | +---------------------------+---------------------------------------------------------------------------------------+ | `float` | A scalar floating value created using `jit_scalar`. | +---------------------------+---------------------------------------------------------------------------------------+ | `str` | A string (ie. character vector of length 1) wrapped in `jit_scalar`. | +---------------------------+---------------------------------------------------------------------------------------+ | `List[T]` | An R list of which all types are type `T` . Or numeric vectors, logical vectors, etc. | +---------------------------+---------------------------------------------------------------------------------------+ | `Optional[T]` | Not yet supported. | +---------------------------+---------------------------------------------------------------------------------------+ | `Dict[str, V]` | A named list with values of type `V` . Only `str` key values are currently supported. | +---------------------------+---------------------------------------------------------------------------------------+ | `T` | Not yet supported. | +---------------------------+---------------------------------------------------------------------------------------+ | `E` | Not yet supported. | +---------------------------+---------------------------------------------------------------------------------------+ | `NamedTuple[T0, T1, ...]` | A named list containing subtypes `T0`, `T1`, etc. wrapped in `jit_tuple()`. | +---------------------------+---------------------------------------------------------------------------------------+