Skip to contents

Calculates mean squared error (MSE) for continuous features and binary cross-entropy (BCE) for features you explicitly mark as binary, comparing model-imputed validation values against ground-truth validation data.

Usage

performance_by_cluster(
  res,
  clusters = NULL,
  group_col = NULL,
  feature_cols = NULL,
  binary_features = character(0),
  by_group = TRUE,
  by_cluster = TRUE,
  cols_ignore = NULL
)

Arguments

res

A list containing CISS-VAE run outputs. Must include:

  • res$val_data: validation data frame (with NA for non-validation cells)

  • res$val_imputed: model-imputed validation predictions

  • res$clusters: cluster labels for each row

clusters

Optional vector (same length as rows in val_data) of cluster labels. If NULL, will use res$clusters.

group_col

Optional character, name of the column in val_data for grouping.

feature_cols

Character vector specifying which feature columns to evaluate. Defaults to all numeric columns except group_col and those in cols_ignore.

binary_features

Character vector naming those columns (subset of feature_cols) that should use BCE instead of MSE.

by_group

Logical; if TRUE (default), summarize by group_col.

by_cluster

Logical; if TRUE (default), summarize by cluster.

cols_ignore

Character vector of column names to exclude from scoring (e.g., “id”).

Value

A named list containing:

  • overall: overall average metric (MSE for continuous, BCE for binary)

  • per_cluster: summaries by cluster

  • per_group: summaries by group

  • group_by_cluster: summaries by group and cluster

  • per_feature_overall: average per-feature metric

Details

For features listed in binary_features, performance is binary cross-entropy (BCE): $$-[y\log(p) + (1-y)\log(1-p)]$$. For other numeric features, performance is mean squared error (MSE).

Examples

library(tidyverse)
library(reticulate)
library(rCISSVAE)
library(kableExtra)
#> 
#> Attaching package: ‘kableExtra’
#> The following object is masked from ‘package:dplyr’:
#> 
#>     group_rows
library(gtsummary)

## Make example results
data_complete = data.frame(
index = 1:10,
x1 = rnorm(10),
x2 = rnorm(10)*rnorm(10, mean = 50, sd=10)
 )

missing_mask = matrix(data = c(rep(FALSE, 10), 
sample(c(TRUE, FALSE), 
size = 20, replace = TRUE, 
prob = c(0.7, 0.3))), nrow = 10)

## Example validation dataset
val_data = data_complete
val_data[missing_mask] <- NA

## Example 'imputed' validation dataset
val_imputed = data.frame(index = 1:10, x1 = mean(data_complete$x1), x2 = mean(data_complete$x2))
val_imputed[missing_mask] <- NA

## Example result list
result = list("val_data" = val_data, "val_imputed" = val_imputed)
clusters = sample(c(0, 1), size = 10, replace = TRUE)

## Run the function
performance_by_cluster(res = result, 
  group_col = NULL, 
  clusters = clusters,
  feature_cols = NULL, 
  by_group = FALSE,
  by_cluster = TRUE,
  cols_ignore = c("index") 
)
#> $overall
#>     metric n
#> 1 1480.499 5
#> 
#> $per_cluster
#>   cluster mean_imputation_loss n
#> 1       0             148.3505 1
#> 2       1            1813.5365 4
#> 
#> $per_feature_overall
#>   feature       type mean_imputation_loss n
#> 1      x1 continuous             1.291553 2
#> 2      x2 continuous          2466.637818 3
#>