Loads a CISSVAE model previously saved by
save_cissvae_model. When the model was saved with
method = "state_dict" (the default), the architecture is
automatically reconstructed from the paired .config.rds file and
the weights are loaded via load_state_dict. When the model was saved
with method = "full", the entire Python object is deserialised
directly.
Usage
load_cissvae_model(
file,
method = c("state_dict", "full"),
device = "cpu",
python_env = NULL
)Arguments
- file
Character string. Path to the saved model file (
.pt). Formethod = "state_dict", a paired<file>.config.rdsmust exist in the same location (written automatically bysave_cissvae_model).- method
Character string. One of
"state_dict"(default) or"full". Must match the method used when saving.- device
Character string. PyTorch device string passed to
torch.loadasmap_location. Defaults to"cpu". Use"cuda"to load directly onto GPU.- python_env
Optional character string. Name of a virtualenv or conda environment to activate before loading. If
NULL(default), the currently active Python session is used.
Details
The CISSVAE class is imported from the installed
ciss_vae.classes using reticulate::import
reticulate::import).
Examples
# \donttest{
try({
reticulate::use_virtualenv("cissvae_environment", required = TRUE)
pt_file <- tempfile(fileext = ".pt")
# save first
save_cissvae_model(dat$model, file = pt_file)
# reload in same or new session
model <- load_cissvae_model(file = pt_file, device = "cpu")
})
#> Error in reticulate::use_virtualenv("cissvae_environment", required = TRUE) :
#> Directory ~/.virtualenvs/cissvae_environment is not a Python virtualenv
# }