Skip to contents

Performs hyperparameter optimization for CISS-VAE using Optuna with support for both tunable and fixed parameters.

Usage

autotune_cissvae(
  data,
  index_col = NULL,
  val_proportion = 0.1,
  replacement_value = 0,
  columns_ignore = NULL,
  imputable_matrix = NULL,
  binary_feature_mask = NULL,
  clusters,
  save_model_path = NULL,
  save_search_space_path = NULL,
  n_trials = 20,
  study_name = "vae_autotune",
  device_preference = "cuda",
  show_progress = FALSE,
  optuna_dashboard_db = NULL,
  load_if_exists = TRUE,
  seed = 42,
  verbose = FALSE,
  constant_layer_size = FALSE,
  evaluate_all_orders = FALSE,
  max_exhaustive_orders = 100,
  num_hidden_layers = c(1, 4),
  hidden_dims = c(64, 512),
  latent_dim = c(10, 100),
  latent_shared = c(TRUE, FALSE),
  output_shared = c(TRUE, FALSE),
  lr = c(1e-04, 0.001),
  decay_factor = c(0.9, 0.999),
  weight_decay = 0.001,
  beta = 0.01,
  num_epochs = 500,
  batch_size = 4000,
  num_shared_encode = c(0, 1, 3),
  num_shared_decode = c(0, 1, 3),
  encoder_shared_placement = c("at_end", "at_start", "alternating", "random"),
  decoder_shared_placement = c("at_start", "at_end", "alternating", "random"),
  refit_patience = 2,
  refit_loops = 100,
  epochs_per_loop = 500,
  reset_lr_refit = c(TRUE, FALSE),
  debug = FALSE
)

Arguments

data

Data frame or matrix containing the input data

index_col

String name of index column to preserve (optional)

val_proportion

Proportion of non-missing data to hold out for validation.

replacement_value

Numeric value used to replace missing entries before model input.

columns_ignore

Character vector of column names to exclude from imputation scoring.

imputable_matrix

Logical matrix indicating entries allowed to be imputed.

binary_feature_mask

Logical vector marking which columns are binary.

clusters

Integer vector specifying cluster assignments for each row.

save_model_path

Optional path to save the best model's state_dict

save_search_space_path

Optional path to save search space configuration

n_trials

Number of Optuna trials to run

study_name

Name identifier for the Optuna study

device_preference

Preferred device ("cuda", "mps", "cpu")

show_progress

Whether to display Rich progress bars during training

optuna_dashboard_db

RDB storage URL/file for Optuna dashboard

load_if_exists

Whether to load existing study from storage

seed

Base random seed for reproducible results

verbose

Whether to print detailed diagnostic information

constant_layer_size

Whether all hidden layers use same dimension

evaluate_all_orders

Whether to test all possible layer arrangements

max_exhaustive_orders

Max arrangements to test when evaluate_all_orders = TRUE

num_hidden_layers

Numeric(2) vector: (min, max) for number of hidden layers

hidden_dims

Numeric vector: hidden layer dimensions to test

latent_dim

Numeric(2) vector: (min, max) for latent dimension

latent_shared

Logical vector: whether latent space is shared across clusters

output_shared

Logical vector: whether output layer is shared across clusters

lr

Numeric(2) vector: (min, max) learning rate range

decay_factor

Numeric(2) vector: (min, max) LR decay factor range

weight_decay

Weight decay (L2 penalty) used in Adam optimizer.

beta

Numeric: KL divergence weight (fixed or range)

num_epochs

Integer: number of initial training epochs (fixed or range)

batch_size

Integer: mini-batch size (fixed or range)

num_shared_encode

Numeric vector: numbers of shared encoder layers to test

num_shared_decode

Numeric vector: numbers of shared decoder layers to test

encoder_shared_placement

Character vector: placement strategies for encoder shared layers

decoder_shared_placement

Character vector: placement strategies for decoder shared layers

refit_patience

Integer: early stopping patience for refit loops

refit_loops

Integer: maximum number of refit loops

epochs_per_loop

Integer: epochs per refit loop

reset_lr_refit

Logical vector: whether to reset LR before refit

debug

Logical; if TRUE, additional metadata is returned for debugging.

Value

A named list with the following components:

imputed_dataset

A data frame containing the imputed values.

model

The fitted CISS-VAE model object

cluster_dataset

The ClusterDataset object used

clusters

The vector of cluster assignments

study

An optuna study object containing the trial results

results

A data frame of trial results

val_data

Validation dataset used

val_imputed

Imputed values of validation dataset

Tips

  • Use cluster_on_missing() or cluster_on_missing_prop() for cluster assignments.

  • Use GPU computation when available; call check_devices() to see available devices.

  • Adjust batch_size based on memory (larger is faster but uses more memory).

  • Set verbose = TRUE or show_progress = TRUE to monitor training.

  • Explore the optuna-dashboard (see vignette optunadb) for hyperparameter importance.

  • For binary features, set names(binary_feature_mask) <- colnames(data).

Examples

# \donttest{
## Requires a working Python environment via reticulate
## Examples are wrapped in try() to avoid failures on CRAN check systems
try({
reticulate::use_virtualenv("cissvae_environment", required = TRUE)


data(df_missing)
data(clusters)

## Run autotuning
aut <- autotune_cissvae(
  data = df_missing,
  index_col = "index",
  clusters = clusters$clusters,
  n_trials = 3,
  study_name = "comprehensive_vae_autotune",
  device_preference = "cpu",
  seed = 42,

  ## Hyperparameter search space
  num_hidden_layers = c(2, 5),
  hidden_dims = c(64, 512),
  latent_dim = c(10, 100),
  latent_shared = c(TRUE, FALSE),
  output_shared = c(TRUE, FALSE),
  lr = c(0.01, 0.1),
  decay_factor = c(0.99, 1.0),
  beta = c(0.01, 0.1),
  num_epochs = c(5, 20),
  batch_size = c(1000, 4000),
  num_shared_encode = c(0, 1, 2),
  num_shared_decode = c(0, 1, 2),

  ## Placement strategies
  encoder_shared_placement = c(
    "at_end", "at_start",
    "alternating", "random"
  ),
  decoder_shared_placement = c(
    "at_start", "at_end",
    "alternating", "random"
  ),

  refit_patience = 2,
  refit_loops = 10,
  epochs_per_loop = 5,
  reset_lr_refit = c(TRUE, FALSE)
)

## Visualize architecture
plot_vae_architecture(
  aut$model,
  title = "Optimized CISSVAE Architecture"
)
})
#> Error in py_module_import(module, convert = convert) : 
#>   ModuleNotFoundError: No module named 'ciss_vae'
#> Run `reticulate::py_last_error()` for details.
# }