To use the code in this article, you will need to install the following packages: discrim, klaR, readr, ROSE, themis, and tidymodels.
Subsampling a training set, either undersampling or oversampling the appropriate class or classes, can be a helpful approach to dealing with classification data where one or more classes occur very infrequently. In such a situation (without compensating for it), most models will overfit to the majority class and produce very good statistics for the class containing the frequently occurring classes while the minority classes have poor performance.
This article describes subsampling for dealing with class imbalances. For better understanding, some knowledge of classification metrics like sensitivity, specificity, and receiver operating characteristic curves is required. See Section 3.2.2 in Kuhn and Johnson (2019) for more information on these metrics.
Consider a two-class problem where the first class has a very low rate of occurrence. The data were simulated and can be imported into R using the code below:
imbal_data <- readr::read_csv("https://bit.ly/imbal_data") %>% mutate(Class = factor(Class)) dim(imbal_data) #>  1200 16 table(imbal_data$Class) #> #> Class1 Class2 #> 60 1140
If “Class1” is the event of interest, it is very likely that a classification model would be able to achieve very good specificity since almost all of the data are of the second class. Sensitivity, however, would likely be poor since the models will optimize accuracy (or other loss functions) by predicting everything to be the majority class.
One result of class imbalance when there are two classes is that the default probability cutoff of 50% is inappropriate; a different cutoff that is more extreme might be able to achieve good performance.
Subsampling the data
One way to alleviate this issue is to subsample the data. There are a number of ways to do this but the most simple one is to sample down (undersample) the majority class data until it occurs with the same frequency as the minority class. While it may seem counterintuitive, throwing out a large percentage of your data can be effective at producing a useful model that can recognize both the majority and minority classes. In some cases, this even means that the overall performance of the model is better (e.g. improved area under the ROC curve). However, subsampling almost always produces models that are better calibrated, meaning that the distributions of the class probabilities are more well behaved. As a result, the default 50% cutoff is much more likely to produce better sensitivity and specificity values than they would otherwise.
Let’s explore subsampling using
themis::step_rose() in a recipe for the simulated data. It uses the ROSE (random over sampling examples) method from Menardi, G. and Torelli, N. (2014). This is an example of an oversampling strategy, rather than undersampling.
In terms of workflow:
- It is extremely important that subsampling occurs inside of resampling. Otherwise, the resampling process can produce poor estimates of model performance.
- The subsampling process should only be applied to the analysis set. The assessment set should reflect the event rates seen “in the wild” and, for this reason, the
step_downsample()and other subsampling recipes steps has a default of
Here is a simple recipe implementing oversampling:
library(tidymodels) library(themis) imbal_rec <- recipe(Class ~ ., data = imbal_data) %>% step_rose(Class)
For a model, let’s use a quadratic discriminant analysis (QDA) model. From the discrim package, this model can be specified using:
library(discrim) qda_mod <- discrim_regularized(frac_common_cov = 0, frac_identity = 0) %>% set_engine("klaR")
To keep these objects bound together, they can be combined in a workflow:
qda_rose_wflw <- workflow() %>% add_model(qda_mod) %>% add_recipe(imbal_rec) qda_rose_wflw #> ══ Workflow ═════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════ #> Preprocessor: Recipe #> Model: discrim_regularized() #> #> ── Preprocessor ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────── #> 1 Recipe Step #> #> ● step_rose() #> #> ── Model ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── #> Regularized Discriminant Model Specification (classification) #> #> Main Arguments: #> frac_common_cov = 0 #> frac_identity = 0 #> #> Computational engine: klaR
Stratified, repeated 10-fold cross-validation is used to resample the model:
set.seed(5732) cv_folds <- vfold_cv(imbal_data, strata = "Class", repeats = 5)
To measure model performance, let’s use two metrics:
- The area under the ROC curve is an overall assessment of performance across all cutoffs. Values near one indicate very good results while values near 0.05 would imply that the model is very poor.
- The J index (a.k.a. Youden’s J statistic) is
sensitivity + specificity - 1. Values near one are once again best.
If a model is poorly calibrated, the ROC curve value might not show diminished performance. However, the J index would be lower for models with pathological distributions for the class probabilities. The yardstick package will be used to compute these metrics.
cls_metrics <- metric_set(roc_auc, j_index)
Now, we train the models and generate the results using
set.seed(2180) qda_rose_res <- fit_resamples( qda_rose_wflw, resamples = cv_folds, metrics = cls_metrics ) collect_metrics(qda_rose_res) #> # A tibble: 2 x 5 #> .metric .estimator mean n std_err #> <chr> <chr> <dbl> <int> <dbl> #> 1 j_index binary 0.780 50 0.0212 #> 2 roc_auc binary 0.950 50 0.00528
What do the results look like without using ROSE? We can create another workflow and fit the QDA model along the same resamples:
qda_wflw <- workflow() %>% add_model(qda_mod) %>% add_formula(Class ~ .) set.seed(2180) qda_only_res <- fit_resamples(qda_wflw, resamples = cv_folds, metrics = cls_metrics) collect_metrics(qda_only_res) #> # A tibble: 2 x 5 #> .metric .estimator mean n std_err #> <chr> <chr> <dbl> <int> <dbl> #> 1 j_index binary 0.250 50 0.0288 #> 2 roc_auc binary 0.953 50 0.00479
It looks like ROSE helped a lot, especially with the J-index. Class imbalance sampling methods tend to greatly improve metrics based on the hard class predictions (i.e., the categorical predictions) because the default cutoff tends to be a better balance of sensitivity and specificity.
Let’s plot the metrics for each resample to see how the individual results changed.
no_sampling <- qda_only_res %>% collect_metrics(summarize = FALSE) %>% dplyr::select(-.estimator) %>% mutate(sampling = "no_sampling") with_sampling <- qda_rose_res %>% collect_metrics(summarize = FALSE) %>% dplyr::select(-.estimator) %>% mutate(sampling = "rose") bind_rows(no_sampling, with_sampling) %>% mutate(label = paste(id2, id)) %>% ggplot(aes(x = sampling, y = .estimate, group = label)) + geom_line(alpha = .4) + facet_wrap(~ .metric, scales = "free_y")
This visually demonstrates that the subsampling mostly affects metrics that use the hard class predictions.
#> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 3.6.2 (2019-12-12) #> os macOS Mojave 10.14.6 #> system x86_64, darwin15.6.0 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz America/Denver #> date 2020-04-17 #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date lib source #> broom * 0.5.5 2020-02-29  CRAN (R 3.6.0) #> dials * 0.0.6 2020-04-03  CRAN (R 3.6.2) #> discrim * 0.0.2 2020-04-09  CRAN (R 3.6.2) #> dplyr * 0.8.5 2020-03-07  CRAN (R 3.6.0) #> ggplot2 * 3.3.0 2020-03-05  CRAN (R 3.6.0) #> infer * 0.5.1 2019-11-19  CRAN (R 3.6.0) #> klaR * 0.6-15 2020-02-19  CRAN (R 3.6.0) #> parsnip * 0.1.0 2020-04-09  CRAN (R 3.6.2) #> purrr * 0.3.3 2019-10-18  CRAN (R 3.6.0) #> readr * 1.3.1 2018-12-21  CRAN (R 3.6.0) #> recipes * 0.1.10 2020-03-18  CRAN (R 3.6.0) #> rlang 0.4.5 2020-03-01  CRAN (R 3.6.0) #> ROSE * 0.0-3 2014-07-15  CRAN (R 3.6.0) #> rsample * 0.0.6 2020-03-31  CRAN (R 3.6.2) #> themis * 0.1.0 2020-01-13  CRAN (R 3.6.0) #> tibble * 2.1.3 2019-06-06  CRAN (R 3.6.2) #> tidymodels * 0.1.0 2020-02-16  CRAN (R 3.6.0) #> tune * 0.1.0 2020-04-02  CRAN (R 3.6.2) #> workflows * 0.1.1 2020-03-17  CRAN (R 3.6.0) #> yardstick * 0.0.6 2020-03-17  CRAN (R 3.6.0) #> #>  /Library/Frameworks/R.framework/Versions/3.6/Resources/library