Skip to contents

Loads a CISSVAE model previously saved by save_cissvae_model. When the model was saved with method = "state_dict" (the default), the architecture is automatically reconstructed from the paired .config.rds file and the weights are loaded via load_state_dict. When the model was saved with method = "full", the entire Python object is deserialised directly.

Usage

load_cissvae_model(
  file,
  method = c("state_dict", "full"),
  device = "cpu",
  python_env = NULL
)

Arguments

file

Character string. Path to the saved model file (.pt). For method = "state_dict", a paired <file>.config.rds must exist in the same location (written automatically by save_cissvae_model).

method

Character string. One of "state_dict" (default) or "full". Must match the method used when saving.

device

Character string. PyTorch device string passed to torch.load as map_location. Defaults to "cpu". Use "cuda" to load directly onto GPU.

python_env

Optional character string. Name of a virtualenv or conda environment to activate before loading. If NULL (default), the currently active Python session is used.

Value

A CISSVAE Python object in evaluation mode (model.eval() has been called).

Details

The CISSVAE class is imported from the installed ciss_vae.classes using reticulate::import reticulate::import).

Examples

# \donttest{
try({
  reticulate::use_virtualenv("cissvae_environment", required = TRUE)

  pt_file <- tempfile(fileext = ".pt")

  # save first
  save_cissvae_model(dat$model, file = pt_file)

  # reload in same or new session
  model <- load_cissvae_model(file = pt_file, device = "cpu")
})
#> Error in reticulate::use_virtualenv("cissvae_environment", required = TRUE) : 
#>   Directory ~/.virtualenvs/cissvae_environment is not a Python virtualenv
# }