Compute per-cluster and per-group performance metrics (MSE, BCE)
Source:R/performance.R
performance_by_cluster.RdCalculates mean squared error (MSE) for continuous features and binary cross-entropy (BCE) for features you explicitly mark as binary, comparing model-imputed validation values against ground-truth validation data.
Usage
performance_by_cluster(
res,
clusters = NULL,
group_col = NULL,
feature_cols = NULL,
binary_features = character(0),
by_group = TRUE,
by_cluster = TRUE,
cols_ignore = NULL
)Arguments
- res
A list containing CISS-VAE run outputs. Must include:
res$val_data: validation data frame (withNAfor non-validation cells)res$val_imputed: model-imputed validation predictionsres$clusters: cluster labels for each row
- clusters
Optional vector (same length as rows in
val_data) of cluster labels. IfNULL, will useres$clusters.- group_col
Optional character, name of the column in
val_datafor grouping.- feature_cols
Character vector specifying which feature columns to evaluate. Defaults to all numeric columns except
group_coland those incols_ignore.- binary_features
Character vector naming those columns (subset of
feature_cols) that should use BCE instead of MSE.- by_group
Logical; if
TRUE(default), summarize bygroup_col.- by_cluster
Logical; if
TRUE(default), summarize by cluster.- cols_ignore
Character vector of column names to exclude from scoring (e.g., “id”).
Value
A named list containing:
overall: overall average metric (MSE for continuous, BCE for binary)per_cluster: summaries by clusterper_group: summaries by groupgroup_by_cluster: summaries by group and clusterper_feature_overall: average per-feature metric
Details
For features listed in binary_features, performance is binary cross-entropy (BCE):
$$-[y\log(p) + (1-y)\log(1-p)]$$.
For other numeric features, performance is mean squared error (MSE).
Examples
library(tidyverse)
library(reticulate)
library(rCISSVAE)
library(kableExtra)
#>
#> Attaching package: ‘kableExtra’
#> The following object is masked from ‘package:dplyr’:
#>
#> group_rows
library(gtsummary)
## Make example results
data_complete = data.frame(
index = 1:10,
x1 = rnorm(10),
x2 = rnorm(10)*rnorm(10, mean = 50, sd=10)
)
missing_mask = matrix(data = c(rep(FALSE, 10),
sample(c(TRUE, FALSE),
size = 20, replace = TRUE,
prob = c(0.7, 0.3))), nrow = 10)
## Example validation dataset
val_data = data_complete
val_data[missing_mask] <- NA
## Example 'imputed' validation dataset
val_imputed = data.frame(index = 1:10, x1 = mean(data_complete$x1), x2 = mean(data_complete$x2))
val_imputed[missing_mask] <- NA
## Example result list
result = list("val_data" = val_data, "val_imputed" = val_imputed)
clusters = sample(c(0, 1), size = 10, replace = TRUE)
## Run the function
performance_by_cluster(res = result,
group_col = NULL,
clusters = clusters,
feature_cols = NULL,
by_group = FALSE,
by_cluster = TRUE,
cols_ignore = c("index")
)
#> $overall
#> metric n
#> 1 1480.499 5
#>
#> $per_cluster
#> cluster mean_imputation_loss n
#> 1 0 148.3505 1
#> 2 1 1813.5365 4
#>
#> $per_feature_overall
#> feature type mean_imputation_loss n
#> 1 x1 continuous 1.291553 2
#> 2 x2 continuous 2466.637818 3
#>