Performs hyperparameter optimization for CISS-VAE using Optuna with support for both tunable and fixed parameters.
Usage
autotune_cissvae(
data,
index_col = NULL,
val_proportion = 0.1,
replacement_value = 0,
columns_ignore = NULL,
imputable_matrix = NULL,
binary_feature_mask = NULL,
clusters,
save_model_path = NULL,
save_search_space_path = NULL,
n_trials = 20,
study_name = "vae_autotune",
device_preference = "cuda",
show_progress = FALSE,
optuna_dashboard_db = NULL,
load_if_exists = TRUE,
seed = 42,
verbose = FALSE,
constant_layer_size = FALSE,
evaluate_all_orders = FALSE,
max_exhaustive_orders = 100,
num_hidden_layers = c(1, 4),
hidden_dims = c(64, 512),
latent_dim = c(10, 100),
latent_shared = c(TRUE, FALSE),
output_shared = c(TRUE, FALSE),
lr = c(1e-04, 0.001),
decay_factor = c(0.9, 0.999),
weight_decay = 0.001,
beta = 0.01,
num_epochs = 500,
batch_size = 4000,
num_shared_encode = c(0, 1, 3),
num_shared_decode = c(0, 1, 3),
encoder_shared_placement = c("at_end", "at_start", "alternating", "random"),
decoder_shared_placement = c("at_start", "at_end", "alternating", "random"),
refit_patience = 2,
refit_loops = 100,
epochs_per_loop = 500,
reset_lr_refit = c(TRUE, FALSE),
debug = FALSE
)Arguments
- data
Data frame or matrix containing the input data
- index_col
String name of index column to preserve (optional)
- val_proportion
Proportion of non-missing data to hold out for validation.
- replacement_value
Numeric value used to replace missing entries before model input.
- columns_ignore
Character vector of column names to exclude from imputation scoring.
- imputable_matrix
Logical matrix indicating entries allowed to be imputed.
- binary_feature_mask
Logical vector marking which columns are binary.
- clusters
Integer vector specifying cluster assignments for each row.
- save_model_path
Optional path to save the best model's state_dict
- save_search_space_path
Optional path to save search space configuration
- n_trials
Number of Optuna trials to run
- study_name
Name identifier for the Optuna study
- device_preference
Preferred device ("cuda", "mps", "cpu")
- show_progress
Whether to display Rich progress bars during training
- optuna_dashboard_db
RDB storage URL/file for Optuna dashboard
- load_if_exists
Whether to load existing study from storage
- seed
Base random seed for reproducible results
- verbose
Whether to print detailed diagnostic information
- constant_layer_size
Whether all hidden layers use same dimension
- evaluate_all_orders
Whether to test all possible layer arrangements
- max_exhaustive_orders
Max arrangements to test when evaluate_all_orders = TRUE
Numeric(2) vector: (min, max) for number of hidden layers
Numeric vector: hidden layer dimensions to test
- latent_dim
Numeric(2) vector: (min, max) for latent dimension
Logical vector: whether latent space is shared across clusters
Logical vector: whether output layer is shared across clusters
- lr
Numeric(2) vector: (min, max) learning rate range
- decay_factor
Numeric(2) vector: (min, max) LR decay factor range
- weight_decay
Weight decay (L2 penalty) used in Adam optimizer.
- beta
Numeric: KL divergence weight (fixed or range)
- num_epochs
Integer: number of initial training epochs (fixed or range)
- batch_size
Integer: mini-batch size (fixed or range)
Numeric vector: numbers of shared encoder layers to test
Numeric vector: numbers of shared decoder layers to test
Character vector: placement strategies for encoder shared layers
Character vector: placement strategies for decoder shared layers
- refit_patience
Integer: early stopping patience for refit loops
- refit_loops
Integer: maximum number of refit loops
- epochs_per_loop
Integer: epochs per refit loop
- reset_lr_refit
Logical vector: whether to reset LR before refit
- debug
Logical; if TRUE, additional metadata is returned for debugging.
Value
A named list with the following components:
- imputed_dataset
A data frame containing the imputed values.
- model
The fitted CISS-VAE model object
- cluster_dataset
The ClusterDataset object used
- clusters
The vector of cluster assignments
- study
An optuna study object containing the trial results
- results
A data frame of trial results
- val_data
Validation dataset used
- val_imputed
Imputed values of validation dataset
Tips
Use
cluster_on_missing()orcluster_on_missing_prop()for cluster assignments.Use GPU computation when available; call
check_devices()to see available devices.Adjust
batch_sizebased on memory (larger is faster but uses more memory).Set
verbose = TRUEorshow_progress = TRUEto monitor training.Explore the
optuna-dashboard(see vignetteoptunadb) for hyperparameter importance.For binary features, set
names(binary_feature_mask) <- colnames(data).
Examples
# \donttest{
## Requires a working Python environment via reticulate
## Examples are wrapped in try() to avoid failures on CRAN check systems
try({
reticulate::use_virtualenv("cissvae_environment", required = TRUE)
data(df_missing)
data(clusters)
## Run autotuning
aut <- autotune_cissvae(
data = df_missing,
index_col = "index",
clusters = clusters$clusters,
n_trials = 3,
study_name = "comprehensive_vae_autotune",
device_preference = "cpu",
seed = 42,
## Hyperparameter search space
num_hidden_layers = c(2, 5),
hidden_dims = c(64, 512),
latent_dim = c(10, 100),
latent_shared = c(TRUE, FALSE),
output_shared = c(TRUE, FALSE),
lr = c(0.01, 0.1),
decay_factor = c(0.99, 1.0),
beta = c(0.01, 0.1),
num_epochs = c(5, 20),
batch_size = c(1000, 4000),
num_shared_encode = c(0, 1, 2),
num_shared_decode = c(0, 1, 2),
## Placement strategies
encoder_shared_placement = c(
"at_end", "at_start",
"alternating", "random"
),
decoder_shared_placement = c(
"at_start", "at_end",
"alternating", "random"
),
refit_patience = 2,
refit_loops = 10,
epochs_per_loop = 5,
reset_lr_refit = c(TRUE, FALSE)
)
## Visualize architecture
plot_vae_architecture(
aut$model,
title = "Optimized CISSVAE Architecture"
)
})
#> Error in py_module_import(module, convert = convert) :
#> ModuleNotFoundError: No module named 'ciss_vae'
#> Run `reticulate::py_last_error()` for details.
# }