Class Imbalance, Outliers, and Distribution Shift
This lecture covers three common problems in real-world ML data: class imbalance, outliers, and distribution shift.
Class imbalance
Many real-world classification problems have the property that certain classes are more prevalent than others. For example:
- COVID infection: among all patients, only 10% might have COVID
- Fraud detection: among all credit card transactions, fraud might make up 0.2% of the transactions
- Manufacturing defect classification: different types of manufacturing defects might have different prevalence
- Self-driving car object detection: different types of objects have different prevalence (cars vs trucks vs pedestrians)
Evaluation metrics
If you’re splitting a dataset into train/test splits, make sure to use stratified data splitting to ensure that the train distribution matches the test distribution (otherwise, you’re creating a distribution shift) problem.
With imbalanced data, standard metrics like accuracy might not make sense. For example, a classifier that always predicts “NOT FRAUD” would have 99.8% accuracy in detecting credit card fraud.
There is no one-size-fits-all solution for choosing an evaluation metric: the choice should depend on the problem. For example, an evaluation metric for credit card fraud detection might be a weighted average of the precision and recall scores (the F-beta score), with the weights determined by weighing the relative costs of failing to block a fraudulent transaction and incorrectly blocking a genuine transaction:
\[ F_\beta = \left(1 + \beta^2 \right) \cdot \frac{\mathrm{precision} \cdot \mathrm{recall}}{\beta^2 \cdot \mathrm{precision} + \mathrm{recall}} \]
Training models on imbalanced data
Once an evaluation metric has been chosen, you can try training a model in the standard way. If training a model on the true distribution works well, i.e., the model scores highly on the evaluation metric over a held-out test set that matches the real-world distribution, then you’re done!
If not, there are techniques you can use to try to improve model performance on the minority classes.
Sample weights. Many models can be fit to a dataset with per-sample weights. Instead of optimizing an objective function that’s a uniform average of per-datapoint losses, this optimizes a weighted average of losses, putting more emphasis on certain datapoints. While simple and conceptually appealing, this often does not work well in practice. For classifiers trained using mini-batches, using sample weights results in varying the effective learning rate between mini-batches, which can make learning unstable.
Over-sampling. Related to sample weights, you can simply replicate datapoints in the minority class, even multiple times, to make the dataset more balanced. In simpler settings (e.g., least-squares regression, this might be equivalent to sample weights), in other settings (e.g., training a neural network with mini-batch gradient descent), this is not equivalent and often performs better than sample weights. This solution is often unstable, and it can result in overfitting.
Under-sampling. Another way to balance a dataset is to remove datapoints from the majority class. While discarding data might seem unintuitive, this approach can work surprisingly well in practice. In some situations, it can result in throwing away a lot of data when you have highly imbalanced datasets, resulting in poor performance.
SMOTE (Synthetic Minority Oversampling TEchnique). Rather than over-sampling by copying datapoints, you can use dataset augmentation to create new examples of minority classes by combining or perturbing minority examples. The SMOTE algorithm is sensible for certain data types, where interpolation in feature space makes sense, but doesn’t make sense for certain other data types: averaging pixel values of one picture of a dog with another picture of a dog is unlikely to produce a picture of a dog. Depending on the application, other data augmentation methods could work better.
Balanced mini-batch training. For models trained with mini-batches, like neural networks, when assembling the random subset of data for each mini-batch, you can include datapoints from minority classes with higher probability, such that the mini-batch is balanced. This approach is similar to over-sampling, and it does not throw away data.
These techniques can be combined. For example, the SMOTE authors note that the combination of SMOTE and under-sampling performs better than plain under-sampling.
References
- imbalanced-learn Python package
- SMOTE tutorial
- Experimental perspectives on learning from imbalanced data (paper)
- Tour of evaluation metrics for imbalanced classification
Outliers
An example of an outlier. Two classes in two-dimensional feature space are shown, labeled as "+" and "-", and an outlier is circled in red.
Outliers are datapoints that differ significantly from other datapoints. Causes include errors in measurement (e.g., a damaged air quality sensor), bad data collection (e.g., missing fields in a tabular dataset), malicious inputs (e.g., adversarial examples), and rare events (statistical outliers, e.g., an albino animal in an image classification dataset).
Outlier identification is of interest because outliers can cause issues during model training, at inference time, or when applying statistical techniques to a dataset. Outliers can harm model training, and certain machine learning models (e.g., vanilla SVM) can be particularly sensitive to outliers in the training set. A model, at deployment time, may not produce reasonable output if given outlier data as input (a form of distribution shift). If data has outliers, data analysis techniques might yield bad results.
Once found, what do you do with outliers? It depends. For example, if you find outliers in the training set, you don’t want to blindly discard them: they might be rare events rather than invalid data points. You could, for example, have a domain expert manually review outliers to check whether they are rare data or bad data.
Problem setup
Being a bit more formal with terminology, here are two tasks of interest:
Outlier detection. In this task, we are not given a clean dataset containing only in-distribution examples. Instead, we get a single un-labeled dataset, and the goal is to detect outliers in the dataset, datapoints that are unlike the others. This task comes up, for example, when cleaning a training dataset that is to be used for ML.
Anomaly detection. In this task, we are given an un-labled dataset of only in-distribution examples. Given a new datapoint not in the dataset, the goal is to identify whether it belongs to the same distribution as the dataset. This task comes up, for example, when trying to identify whether a datapoint, at inference time, is drawn from the same distribution as a model’s training set.
Identifying outliers
Outlier detection is a heavily studied field, with many algorithms and lots of published research. Here, we cover a couple selected techniques.
Tukey’s fences. A simple method for scalar real-valued data. If \(Q_1\) and \(Q_3\) are the lower and upper quartiles, then this test says that any observation outside the following range is considered an outlier: \([Q_1 - k(Q_1 - Q_3), Q_3 + k(Q_3 - Q_1)]\). A multiplier of \(k=1.5\) was proposed by John Tukey.
Z-score. The Z-score is the number of standard deviations by which a value is above or below the mean. For one-dimensional or low-dimensional data, assuming a Gaussian distribution of data: calculate the Z-score as \(z_i = \frac{x_i - \mu}{\sigma}\), where \(\mu\) is the mean of all the data and \(\sigma\) is the standard deviation. An outlier is a data point that has a high-magnitude Z-score, \(\| z_i \| > z_{thr}\). A commonly used threshold is \(z_{thr} = 3\). You can apply this technique to individual features as well.
Isolation forest. This technique is related to decision trees. Intuitively, the method creates a “random decision tree” and scores data points according to how many nodes are required to isolate them. The algorithm recursively divides (a subset of) a dataset by randomly selecting a feature and a split value until the subset has only one instance. The idea is that outlier data points will require fewer splits to become isolated.
KNN distance. In-distribution data is likely to be closer to its neighbors. You can use the mean distance (choosing an appropriate distance metric, like cosine distance) to a datapoint’s k nearest neighbors as a score. For high-dimensional data like images, you can use embeddings from a trained model and do KNN in the embedding space.
Reconstruction-based methods. Autoencoders are generative models that are trained to compress high-dimensional data into a low-dimensional representation and then reconstruct the original data. If an autoencoder learns a data distribution, then it should be able to encode and then decode an in-distribution data point back into a data point that is close to the original input data. However, for out-of-distribution data, the reconstruction will be worse, so you can use reconstruction loss as a score for detecting outliers.
You’ll notice that many outlier detection techniques involve computing a score for every datapoint and then thresholding to select outliers. Outlier detection methods can be evaluated by looking at the ROC curve, or if you want a single summary number to compare methods, looking at the AUROC.
References
Distribution shift
An example of extreme distribution shift (in particular, covariate shift / data shift) in a hand-written digit classification task. A classifier is trained on Arabic numerals labeled 0–9, while it is evaluated on the Roman numerals 0–9. It is likely to have extremely poor performance.
Distribution shift is a challenging problem that occurs when the joint distribution of inputs and outputs differs between training and test stages, i.e., \(p_\mathrm{train}(\mathbf{x}, y) \neq p_\mathrm{test}(\mathbf{x}, y)\). This issue is present, to varying degrees, in nearly every practical ML application, in part because it is hard to perfectly reproduce testing conditions at training time.
Types of distribution shift
Covariate shift / data shift
Covariate shift occurs when \(p(\mathbf{x})\) changes between train and test, but \(p(y \mid \mathbf{x})\) does not. In other words, the distribution of inputs changes between train and test, but the relationship between inputs and outputs does not change.
When the distribution of training data and test data differ significantly, a learned model can fit training data well but perform poorly on test data.
Examples of covariate shift:
- Self-driving car trained on the sunny streets of San Francisco and deployed in the snowy streets of Boston
- Speech recognition model trained on native English speakers and then deployed for all English speakers
- Diabetes prediction model trained on hospital data from Boston and deployed in India
Concept shift
Concept shift occurs when \(p(y \mid \mathbf{x})\) changes between train and test, but \(p(\mathbf{x})\) does not. In other words, the input distribution does not change, but the relationship between inputs and outputs does. This can be one of the most difficult types of distribution shift to detect and correct.
Concept shift in a two-class dataset with two-dimensional features. Data points, drawn as "x"s are color-coded by class (red/green), and the decision boundary is shown in purple. The input distribution is exactly identical between train and test, but the relationship between input and output has changed.
It is tricky to come up with real-world examples of concept shift where there is absolutely no change in \(p(\mathbf{x})\). Here are some examples of concept shift in the real world:
- Predicting a stock price based on company fundamentals, trained on data from 1975 and deployed in 2023. Company fundamentals include statistics like earnings per share. While these numbers (\(p(\mathbf{x})\)) themselves did change over time, so did the relationship between these numbers and valuation. The P/E ratio (ratio of stock price, \(y\), to earnings per share, \(\mathbf{x}\)) changed significantly over time. The S&P500 P/E ratio was 8.30 in 1975, while by 2023, it has risen to about 20. This is concept shift, where \(p(y \mid \mathbf{x})\) has changed: people are valuing a company more highly (by more than a factor of 2x) for the same earnings per share.
- Making purchase recommendations based on web browsing behavior, trained on pre-pandemic data and deployed in March 2020. While web browsing behavior (\(\mathbf{x}\)) did not change much (e.g., most individuals browsed the same websites, watched the same YouTube videos, etc.) before the pandemic vs during the pandemic, the relationship between browsing behavior and purchases did (e.g., someone who watched lots of travel videos on YouTube before the pandemic might buy plane or hotel tickets, while during the pandemic they might pay for nature documentary movies).
Prior probability shift / label shift
Prior probability shift appears only in \(y \rightarrow \mathbf{x}\) problems (when we believe \(y\) causes \(\mathbf{x}\)). It occurs when \(p(y)\) changes between train and test, but \(p(\mathbf{x} \mid y)\) does not. You can think of it as the converse of covariate shift.
To understand prior probability shift, consider the example of spam classification, where a commonly-used model is Naive Bayes. If the model is trained on a balanced dataset of 50% spam and 50% non-spam emails, and then it’s deployed in a real-world setting where 90% of emails are spam, that is an example of prior probability shift.
Another example is when training a classifier to predict diagnoses given symptoms, as the relative prevalence of diseases is changing over time. Prior probability shift shift (rather than covariate shift) is the appropriate assumption to make here, because diseases cause symptoms.
Detecting and addressing distribution shift
Some ways you can detect distribution shift in deployments:
- Monitor the performance of your model. Monitor accuracy, precision, statistical measures, or other evaluation metrics. If these change over time, it may be due to distribution shift.
- Monitor your data. You can detect data shift by comparing statistical properties of training data and data seen in a deployment.
At a high level, distribution shift can be addressed by fixing the data and re-training the model. In some situations, the best solution is to collect a better training set.
If unlabeled testing data are available while training, then one way to address covariate shift is to assign individual sample weights to training datapoints to weigh their feature-distribution such that the weighted distribution resembles the feature-distribution of test data. In this setting, even though test labels are unknown, label shift can similarly be addressed by employing shared sample weights for all training examples with the same class label, in order to make the weighted feature-distribution in training data resemble the feature distribution in the test data. However, concept shift cannot be addressed without knowledge of its form in this setting, because there is no way to quantify it from unlabeled test data.
References
Lab
The lab assignment for this lecture is to implement and compare different methods for identifying outliers.
For this lab, we’ve focused on anomaly detection. You are given a clean training dataset consisting of many pictures of dogs, and an evaluation dataset that contains outliers (non-dogs). Your task is to implement and compare various methods for detecting these outliers. You may implement some of the ideas presented in today’s lecture, or you can look up other outlier detection algorithms in the linked references or online.
The lab assignments for the course are available in the dcai-lab repository.
Remember to run a git pull
before doing every lab, to make sure you have the latest version of the labs.
The lab assignment for this class is in outliers/Lab - Outliers.ipynb
.
Licensed under CC BY-NC-SA.