Skip to contents

Saves a trained CISSVAE PyTorch model to disk. The default method (method = "state_dict") saves only the model weights alongside an automatically captured architecture config, which is the recommended approach for portability and long-term reproducibility. A full-object save (method = "full") is also available but is less portable across Python / PyTorch versions.

Usage

save_cissvae_model(
  model,
  file,
  method = c("state_dict", "full"),
  overwrite = FALSE
)

Arguments

model

A CISSVAE Python object (reticulate proxy), e.g., res$model from run_cissvae().

file

Character string. File path for the saved model. For method = "state_dict" two files are written:

  • <file> — the PyTorch state_dict checkpoint (.pt).

  • <file>.config.rds — the architecture config needed to reconstruct the model on load.

For method = "full" only <file> is written.

method

Character string. One of "state_dict" (default) or "full".

"state_dict"

Saves weights and architecture config separately. Recommended. Requires load_cissvae_model() to reload.

"full"

Saves the entire Python object with torch.save. Less portable; may break across PyTorch versions.

overwrite

Logical. If FALSE (default), an error is raised when output files already exist.

Value

NULL, invisibly. Called for its side effects.

Examples

# \donttest{
try({
  reticulate::use_virtualenv("cissvae_environment", required = TRUE)
  data(df_missing)
  data(clusters)

  dat <- run_cissvae(
    data             = df_missing,
    index_col        = "index",
    val_proportion   = 0.1,
    clusters         = clusters$clusters,
    epochs           = 5,
    return_model     = TRUE,
    device           = "cpu",
    layer_order_enc  = c("unshared", "shared", "unshared"),
    layer_order_dec  = c("shared",   "unshared", "shared")
  )

  # default: saves state_dict + config
  save_cissvae_model(dat$model, file = tempfile(fileext = ".pt"))

  # alternative: full object save
  save_cissvae_model(dat$model, file = tempfile(fileext = ".pt"),
                     method = "full")
})
#> Error in reticulate::use_virtualenv("cissvae_environment", required = TRUE) : 
#>   Directory ~/.virtualenvs/cissvae_environment is not a Python virtualenv
# }