Skip to contents

Calculates mean squared error (MSE) for continuous features and binary cross-entropy (BCE) for features explicitly marked as binary, comparing model-imputed validation values against ground-truth validation data.

Usage

performance_by_cluster(
  res,
  clusters = NULL,
  group_col = NULL,
  feature_cols = NULL,
  binary_features = character(0),
  by_group = TRUE,
  by_cluster = TRUE,
  cols_ignore = NULL,
  eps = 1e-07
)

Arguments

res

A list containing CISS-VAE run outputs. Must include:

  • res$val_data: validation data frame (with NA for non-validation cells)

  • res$val_imputed: model-imputed validation predictions

  • res$clusters: cluster labels for each row

clusters

Optional vector (same length as rows in val_data) of cluster labels. If NULL, res$clusters will be used.

group_col

Optional character string naming a column in val_data used for grouping (e.g., sex, treatment group, etc.). If supplied, summaries can be computed per group and group-by-cluster.

feature_cols

Character vector specifying which feature columns to evaluate. Defaults to all numeric columns except group_col and those in cols_ignore.

binary_features

Character vector naming those columns (subset of feature_cols) that should use BCE instead of MSE.

by_group

Logical; if TRUE, compute summaries by group_col. Ignored if group_col is NULL.

by_cluster

Logical; if TRUE, compute summaries by cluster.

cols_ignore

Character vector of column names to exclude from scoring (e.g., IDs).

eps

Numeric. Small constant used for clipping probabilities in BCE calculation. Default is 1e-7.

Value

A named list containing:

  • overall: overall validation metrics (MSE, BCE, total)

  • per_cluster: metrics summarized by cluster (if by_cluster = TRUE)

  • per_group: metrics summarized by group (if by_group = TRUE)

  • group_by_cluster: metrics summarized by group and cluster (if both by_group and by_cluster are TRUE)

Each summary contains:

  • mse: mean squared error across continuous validation cells

  • bce: mean binary cross-entropy across binary validation cells

  • imputation_error: mse + bce

Details

Validation loss is computed at the cell level and then aggregated to produce overall, per-cluster, per-group, and group-by-cluster summaries.

For features listed in binary_features, performance is binary cross-entropy (BCE): $$-[y\log(p) + (1-y)\log(1-p)]$$ where \(p\) is the predicted probability.

For other numeric features, performance is mean squared error (MSE): $$(y - \hat{y})^2$$.

Losses are computed at the individual cell level using only validation entries (non-NA in val_data), then aggregated.

Examples

data_complete <- data.frame(
  id = 1:10,
  group = sample(c("A", "B"), 10, replace = TRUE),
  x1 = rnorm(10),
  x2 = rnorm(10)
)

missing_mask <- matrix(
  sample(c(TRUE, FALSE), 20, replace = TRUE),
  nrow = 10
)

val_data <- data_complete
val_data[which(missing_mask, arr.ind = TRUE)] <- NA

val_imputed <- data.frame(
  id = data_complete$id,
  group = data_complete$group,
  x1 = mean(data_complete$x1),
  x2 = mean(data_complete$x2)
)

val_imputed[which(missing_mask, arr.ind = TRUE)] <- NA

result <- list(
  val_data = val_data,
  val_imputed = val_imputed,
  clusters = sample(c(0, 1), 10, replace = TRUE)
)

performance_by_cluster(
  res = result,
  group_col = "group",
  binary_features = character(0),
  by_group = TRUE,
  by_cluster = TRUE,
  cols_ignore = "id"
)
#> $overall
#>         mse bce imputation_error
#> 1 0.7640603  NA        0.7640603
#> 
#> $per_cluster
#>   cluster       mse bce imputation_error
#> 1       0 1.1924671  NA        1.1924671
#> 2       1 0.4784558  NA        0.4784558
#> 
#> $per_group
#>   group       mse bce imputation_error
#> 1     A 0.4385655  NA        0.4385655
#> 2     B 0.9086065  NA        0.9086065
#>