Saving and Using Saved Models
2026-04-08
Source:vignettes/impute_with_saved_model.rmd
impute_with_saved_model.rmdThis vignette shows how to save a trained CISS-VAE model, reload it in a later R session, and use it to impute new data.
As of rCISSVAE version 0.0.6, the recommended workflow is to save the
model as a PyTorch state_dict together with an
automatically generated architecture config file. This is more portable
across Python and PyTorch and more robust than saving the full live
Python object.
Saving a trained CISS-VAE model
library(reticulate)
library(rCISSVAE)
# Train a model
res <- run_cissvae(data, return_model = TRUE)
# Save the trained model to disk
# by default, this writes two files:
# - trained_vae.pt (the model weights)
# - trained_vae.pt.config.rds (the saved model configuration/architecture needed to rebuild the model when loading. )
save_cissvae_model(res$model, "trained_vae.pt", method="state_dict")
# IMPORTANT
# The Python environment must be active so 'torch' can be imported.There is still the option to save the full model as a .pt file if
desired by setting method='full'.
Loading a saved model and imputing data
library(rCISSVAE)
library(reticulate)
## Activate your Python environment
reticulate::use_virtualenv("cissvae_environment", required = TRUE)
## Load full model object
model <- load_cissvae_model(
file = "trained_vae.pt",
method = "state_dict", ## (or 'full' if full was used when saving)
device = "cpu"
)
## Perform imputation on new data
## Make sure your `data` has valid NAs and `clusters` vector is ready
## `val_proportion`, `categorical_column_map` and `replacement_value` are not needed because we are just imputing
imputed_df <- impute_with_cissvae(
model_py = model,
data = data,
index_col = "index",
columns_ignore = NULL,
clusters = clusters,
imputable_matrix = NULL,
binary_feature_mask = NULL,
batch_size = 4000L,
seed = 42
)
# `imputed_df` is returned to R as a data.frameIf you have binary variables in your dataset, make sure to define the
binary_feature_mask in your impute_with_cissvae() call and
convert the probabilities for the binary variables into {0, 1} values
after imputation using desired thresholding.