Skip to contents

Creates a horizontal schematic diagram of the CISS-VAE architecture, showing shared and cluster-specific layers. This function wraps the Python plot_vae_architecture function from the ciss_vae package.

Usage

plot_vae_architecture(
  model,
  title = NULL,
  color_shared = "skyblue",
  color_unshared = "lightcoral",
  color_latent = "gold",
  color_input = "lightgreen",
  color_output = "lightgreen",
  figsize = c(16, 8),
  save_path = NULL,
  dpi = 300,
  return_plot = FALSE,
  display_plot = TRUE
)

Arguments

model

A trained CISSVAE model object (Python object)

title

Title of the plot. If NULL, no title is displayed. Default NULL.

color_shared

Color for shared hidden layers. Default "skyblue".

color_unshared

Color for unshared (cluster-specific) hidden layers. Default "lightcoral".

color_latent

Color for latent layer. Default "gold".

color_input

Color for input layer. Default "lightgreen".

color_output

Color for output layer. Default "lightgreen".

figsize

Size of the matplotlib figure as c(width, height). Default c(16, 8).

save_path

Optional path to save the plot as PNG. If NULL, plot is displayed. Default NULL.

dpi

Resolution for saved PNG file. Default 300.

return_plot

Logical; if TRUE, returns the plot as an R object using reticulate. Default FALSE.

display_plot

Logical; if TRUE, displays the plot. Set to FALSE when only saving. Default TRUE.

Value

If return_plot is TRUE, returns a Python matplotlib figure object that can be further manipulated. Otherwise returns NULL invisibly.

Tips

  • If you get a TCL or TK error, run: reticulate::py_run_string("import matplotlib; matplotlib.use('Agg')") to change the matplotlib backend to use 'Agg' instead.

Examples

## Requires a working Python environment via reticulate
## Examples are wrapped in try() to avoid failures on CRAN check systems
# \donttest{
try({
  # Train a model first
  result <- run_cissvae(my_data, return_model = TRUE)

  # Basic plot
  plot_vae_architecture(result$model)

  # Save plot to file
  plot_vae_architecture(
    model = result$model,
    title = "CISS-VAE Architecture",
    save_path = "vae_architecture.png",
    dpi = 300
  )

  # Return plot object for further manipulation
  fig <- plot_vae_architecture(
    model = result$model,
    return_plot = TRUE,
    display_plot = FALSE
  )
})
#> Error in eval(expr, envir) : object 'my_data' not found
# }