Compute per-cluster and per-group performance metrics (MSE, BCE)
Source:R/performance.R
performance_by_cluster.RdCalculates mean squared error (MSE) for continuous features and binary cross-entropy (BCE) for features explicitly marked as binary, comparing model-imputed validation values against ground-truth validation data.
Usage
performance_by_cluster(
res,
clusters = NULL,
group_col = NULL,
feature_cols = NULL,
binary_features = character(0),
by_group = TRUE,
by_cluster = TRUE,
cols_ignore = NULL,
eps = 1e-07
)Arguments
- res
A list containing CISS-VAE run outputs. Must include:
res$val_data: validation data frame (withNAfor non-validation cells)res$val_imputed: model-imputed validation predictionsres$clusters: cluster labels for each row
- clusters
Optional vector (same length as rows in
val_data) of cluster labels. IfNULL,res$clusterswill be used.- group_col
Optional character string naming a column in
val_dataused for grouping (e.g., sex, treatment group, etc.). If supplied, summaries can be computed per group and group-by-cluster.- feature_cols
Character vector specifying which feature columns to evaluate. Defaults to all numeric columns except
group_coland those incols_ignore.- binary_features
Character vector naming those columns (subset of
feature_cols) that should use BCE instead of MSE.- by_group
Logical; if
TRUE, compute summaries bygroup_col. Ignored ifgroup_colisNULL.- by_cluster
Logical; if
TRUE, compute summaries by cluster.- cols_ignore
Character vector of column names to exclude from scoring (e.g., IDs).
- eps
Numeric. Small constant used for clipping probabilities in BCE calculation. Default is
1e-7.
Value
A named list containing:
overall: overall validation metrics (MSE, BCE, total)per_cluster: metrics summarized by cluster (ifby_cluster = TRUE)per_group: metrics summarized by group (ifby_group = TRUE)group_by_cluster: metrics summarized by group and cluster (if bothby_groupandby_clusterareTRUE)
Each summary contains:
mse: mean squared error across continuous validation cellsbce: mean binary cross-entropy across binary validation cellsimputation_error:mse + bce
Details
Validation loss is computed at the cell level and then aggregated to produce overall, per-cluster, per-group, and group-by-cluster summaries.
For features listed in binary_features, performance is binary
cross-entropy (BCE):
$$-[y\log(p) + (1-y)\log(1-p)]$$
where \(p\) is the predicted probability.
For other numeric features, performance is mean squared error (MSE): $$(y - \hat{y})^2$$.
Losses are computed at the individual cell level using only validation
entries (non-NA in val_data), then aggregated.
Examples
data_complete <- data.frame(
id = 1:10,
group = sample(c("A", "B"), 10, replace = TRUE),
x1 = rnorm(10),
x2 = rnorm(10)
)
missing_mask <- matrix(
sample(c(TRUE, FALSE), 20, replace = TRUE),
nrow = 10
)
val_data <- data_complete
val_data[which(missing_mask, arr.ind = TRUE)] <- NA
val_imputed <- data.frame(
id = data_complete$id,
group = data_complete$group,
x1 = mean(data_complete$x1),
x2 = mean(data_complete$x2)
)
val_imputed[which(missing_mask, arr.ind = TRUE)] <- NA
result <- list(
val_data = val_data,
val_imputed = val_imputed,
clusters = sample(c(0, 1), 10, replace = TRUE)
)
performance_by_cluster(
res = result,
group_col = "group",
binary_features = character(0),
by_group = TRUE,
by_cluster = TRUE,
cols_ignore = "id"
)
#> $overall
#> mse bce imputation_error
#> 1 0.7640603 NA 0.7640603
#>
#> $per_cluster
#> cluster mse bce imputation_error
#> 1 0 1.1924671 NA 1.1924671
#> 2 1 0.4784558 NA 0.4784558
#>
#> $per_group
#> group mse bce imputation_error
#> 1 A 0.4385655 NA 0.4385655
#> 2 B 0.9086065 NA 0.9086065
#>