Skip to contents

Calculate performance metrics by cluster and subgroups

Usage

performance_by_cluster(
  original_data,
  model,
  dataset,
  clusters,
  index_col = NULL,
  grouping_vars = NULL,
  device = "cpu",
  metrics = c("mse", "mae", "correlation"),
  only_validation = TRUE,
  verbose = FALSE
)

Arguments

original_data

Original data frame with true values (for grouping variables)

model

Trained Python CISS-VAE model object

dataset

Python ClusterDataset object used for training (contains validation data)

clusters

Vector of cluster assignments

index_col

Optional name of index column to align data

grouping_vars

Optional vector of column names to group performance by (e.g., c("race", "sex"))

device

Device for model inference ("cpu" or "cuda")

metrics

Vector of metrics to compute (default: c("mse", "mae", "correlation"))

only_validation

Logical; if TRUE, only compute metrics on validation (masked) entries

verbose

Logical; if TRUE, prints debug information

Value

Data frame with performance metrics by cluster and optional subgroups