Class Imbalance, Outliers, and Distribution Shift

Click here to view the raw lecture video on Panopto.
An edited version of this video will be posted after the course is over.

This lecture covers three common problems in real-world ML data: class imbalance, outliers, and distribution shift.

Class imbalance

Class imbalance in credit card fraud

Many real-world classification problems have the property that certain classes are more prevalent than others. For example:

Question: what is the difference between class imbalance and underperforming subpopulations, a topic covered in the previous lecture?

Evaluation metrics

If you’re splitting a dataset into train/test splits, make sure to use stratified data splitting to ensure that the train distribution matches the test distribution (otherwise, you’re creating a distribution shift) problem.

With imbalanced data, standard metrics like accuracy might not make sense. For example, a classifier that always predicts “NOT FRAUD” would have 99.8% accuracy in detecting credit card fraud.

There is no one-size-fits-all solution for choosing an evaluation metric: the choice should depend on the problem. For example, an evaluation metric for credit card fraud detection might be a weighted average of the precision and recall scores (the F-beta score), with the weights determined by weighing the relative costs of failing to block a fraudulent transaction and incorrectly blocking a genuine transaction:

precision = what proportion of positive identifications were actually positive = TP / (TP + FP). recall = what proportion of actual positives were identified correctly = TP / (TP + FN).

\[ F_\beta = \left(1 + \beta^2 \right) \cdot \frac{\mathrm{precision} \cdot \mathrm{recall}}{\beta^2 \cdot \mathrm{precision} + \mathrm{recall}} \]

When beta=1, this turns into F1 score, the harmonic mean of precision and recall.

Training models on imbalanced data

Once an evaluation metric has been chosen, you can try training a model in the standard way. If training a model on the true distribution works well, i.e., the model scores highly on the evaluation metric over a held-out test set that matches the real-world distribution, then you’re done!

If not, there are techniques you can use to try to improve model performance on the minority classes.

Sample weights. Many models can be fit to a dataset with per-sample weights. Instead of optimizing an objective function that’s a uniform average of per-datapoint losses, this optimizes a weighted average of losses, putting more emphasis on certain datapoints. While simple and conceptually appealing, this often does not work well in practice. For classifiers trained using mini-batches, using sample weights results in varying the effective learning rate between mini-batches, which can make learning unstable.

Over-sampling. Related to sample weights, you can simply replicate datapoints in the minority class, even multiple times, to make the dataset more balanced. In simpler settings (e.g., least-squares regression, this might be equivalent to sample weights), in other settings (e.g., training a neural network with mini-batch gradient descent), this is not equivalent and often performs better than sample weights. This solution is often unstable, and it can result in overfitting.

Under-sampling. Another way to balance a dataset is to remove datapoints from the majority class. While discarding data might seem unintuitive, this approach can work surprisingly well in practice. In some situations, it can result in throwing away a lot of data when you have highly imbalanced datasets, resulting in poor performance.

SMOTE (Synthetic Minority Oversampling TEchnique). Rather than over-sampling by copying datapoints, you can use dataset augmentation to create new examples of minority classes by combining or perturbing minority examples. The SMOTE algorithm is sensible for certain data types, where interpolation in feature space makes sense, but doesn’t make sense for certain other data types: averaging pixel values of one picture of a dog with another picture of a dog is unlikely to produce a picture of a dog. Depending on the application, other data augmentation methods could work better.

Balanced mini-batch training. For models trained with mini-batches, like neural networks, when assembling the random subset of data for each mini-batch, you can include datapoints from minority classes with higher probability, such that the mini-batch is balanced. This approach is similar to over-sampling, and it does not throw away data.

These techniques can be combined. For example, the SMOTE authors note that the combination of SMOTE and under-sampling performs better than plain under-sampling.

References

Outliers

Outlier

An example of an outlier. Two classes in two-dimensional feature space are shown, labeled as "+" and "-", and an outlier is circled in red.

Outliers are datapoints that differ significantly from other datapoints. Causes include errors in measurement (e.g., a damaged air quality sensor), bad data collection (e.g., missing fields in a tabular dataset), malicious inputs (e.g., adversarial examples), and rare events (statistical outliers, e.g., an albino animal in an image classification dataset).

Outlier identification is of interest because outliers can cause issues during model training, at inference time, or when applying statistical techniques to a dataset. Outliers can harm model training, and certain machine learning models (e.g., vanilla SVM) can be particularly sensitive to outliers in the training set. A model, at deployment time, may not produce reasonable output if given outlier data as input (a form of distribution shift). If data has outliers, data analysis techniques might yield bad results.

In 6.036, you learned some model-centric techniques to deal with outliers. For example, using L1 loss over L2 loss to be less sensitive to outliers. In this course, taking a data-centric view, we'll focus on identifying outliers.

Once found, what do you do with outliers? It depends. For example, if you find outliers in the training set, you don’t want to blindly discard them: they might be rare events rather than invalid data points. You could, for example, have a domain expert manually review outliers to check whether they are rare data or bad data.

Problem setup

Being a bit more formal with terminology, here are two tasks of interest:

Outlier detection. In this task, we are not given a clean dataset containing only in-distribution examples. Instead, we get a single un-labeled dataset, and the goal is to detect outliers in the dataset, datapoints that are unlike the others. This task comes up, for example, when cleaning a training dataset that is to be used for ML.

Anomaly detection. In this task, we are given an un-labled dataset of only in-distribution examples. Given a new datapoint not in the dataset, the goal is to identify whether it belongs to the same distribution as the dataset. This task comes up, for example, when trying to identify whether a datapoint, at inference time, is drawn from the same distribution as a model’s training set.

Question: what makes anomaly different from a standard supervised learning classification problem (classify as anomaly or not)?

Identifying outliers

Outlier detection is a heavily studied field, with many algorithms and lots of published research. Here, we cover a couple selected techniques.

Tukey’s fences. A simple method for scalar real-valued data. If \(Q_1\) and \(Q_3\) are the lower and upper quartiles, then this test says that any observation outside the following range is considered an outlier: \([Q_1 - k(Q_1 - Q_3), Q_3 + k(Q_3 - Q_1)]\). A multiplier of \(k=1.5\) was proposed by John Tukey.

Z-score. The Z-score is the number of standard deviations by which a value is above or below the mean. For one-dimensional or low-dimensional data, assuming a Gaussian distribution of data: calculate the Z-score as \(z_i = \frac{x_i - \mu}{\sigma}\), where \(\mu\) is the mean of all the data and \(\sigma\) is the standard deviation. An outlier is a data point that has a high-magnitude Z-score, \(\| z_i \| > z_{thr}\). A commonly used threshold is \(z_{thr} = 3\). You can apply this technique to individual features as well.

Isolation forest. This technique is related to decision trees. Intuitively, the method creates a “random decision tree” and scores data points according to how many nodes are required to isolate them. The algorithm recursively divides (a subset of) a dataset by randomly selecting a feature and a split value until the subset has only one instance. The idea is that outlier data points will require fewer splits to become isolated.

KNN distance. In-distribution data is likely to be closer to its neighbors. You can use the mean distance (choosing an appropriate distance metric, like cosine distance) to a datapoint’s k nearest neighbors as a score. For high-dimensional data like images, you can use embeddings from a trained model and do KNN in the embedding space.

Suppose we want to use KNN for outlier detection. What's the setup? How does it change for anomaly detection?

Reconstruction-based methods. Autoencoders are generative models that are trained to compress high-dimensional data into a low-dimensional representation and then reconstruct the original data. If an autoencoder learns a data distribution, then it should be able to encode and then decode an in-distribution data point back into a data point that is close to the original input data. However, for out-of-distribution data, the reconstruction will be worse, so you can use reconstruction loss as a score for detecting outliers.

You’ll notice that many outlier detection techniques involve computing a score for every datapoint and then thresholding to select outliers. Outlier detection methods can be evaluated by looking at the ROC curve, or if you want a single summary number to compare methods, looking at the AUROC.

References

Distribution shift

Distribution shift

An example of extreme distribution shift (in particular, covariate shift / data shift) in a hand-written digit classification task. A classifier is trained on Arabic numerals labeled 0–9, while it is evaluated on the Roman numerals 0–9. It is likely to have extremely poor performance.

Distribution shift is a challenging problem that occurs when the joint distribution of inputs and outputs differs between training and test stages, i.e., \(p_\mathrm{train}(\mathbf{x}, y) \neq p_\mathrm{test}(\mathbf{x}, y)\). This issue is present, to varying degrees, in nearly every practical ML application, in part because it is hard to perfectly reproduce testing conditions at training time.

Types of distribution shift

Covariate shift / data shift

Covariate shift occurs when \(p(\mathbf{x})\) changes between train and test, but \(p(y \mid \mathbf{x})\) does not. In other words, the distribution of inputs changes between train and test, but the relationship between inputs and outputs does not change.

Covariate shift

When the distribution of training data and test data differ significantly, a learned model can fit training data well but perform poorly on test data.

Examples of covariate shift:

Concept shift

Concept shift occurs when \(p(y \mid \mathbf{x})\) changes between train and test, but \(p(\mathbf{x})\) does not. In other words, the input distribution does not change, but the relationship between inputs and outputs does. This can be one of the most difficult types of distribution shift to detect and correct.

Concept shift

Concept shift in a two-class dataset with two-dimensional features. Data points, drawn as "x"s are color-coded by class (red/green), and the decision boundary is shown in purple. The input distribution is exactly identical between train and test, but the relationship between input and output has changed.

It is tricky to come up with real-world examples of concept shift where there is absolutely no change in \(p(\mathbf{x})\). Here are some examples of concept shift in the real world:

Prior probability shift / label shift

Prior probability shift appears only in \(y \rightarrow \mathbf{x}\) problems (when we believe \(y\) causes \(\mathbf{x}\)). It occurs when \(p(y)\) changes between train and test, but \(p(\mathbf{x} \mid y)\) does not. You can think of it as the converse of covariate shift.

To understand prior probability shift, consider the example of spam classification, where a commonly-used model is Naive Bayes. If the model is trained on a balanced dataset of 50% spam and 50% non-spam emails, and then it’s deployed in a real-world setting where 90% of emails are spam, that is an example of prior probability shift.

Another example is when training a classifier to predict diagnoses given symptoms, as the relative prevalence of diseases is changing over time. Prior probability shift shift (rather than covariate shift) is the appropriate assumption to make here, because diseases cause symptoms.

Detecting and addressing distribution shift

Some ways you can detect distribution shift in deployments:

At a high level, distribution shift can be addressed by fixing the data and re-training the model. In some situations, the best solution is to collect a better training set.

If unlabeled testing data are available while training, then one way to address covariate shift is to assign individual sample weights to training datapoints to weigh their feature-distribution such that the weighted distribution resembles the feature-distribution of test data. In this setting, even though test labels are unknown, label shift can similarly be addressed by employing shared sample weights for all training examples with the same class label, in order to make the weighted feature-distribution in training data resemble the feature distribution in the test data. However, concept shift cannot be addressed without knowledge of its form in this setting, because there is no way to quantify it from unlabeled test data.

References

Lab

The lab assignment for this lecture is to implement and compare different methods for identifying outliers.

Outlier meme

For this lab, we’ve focused on anomaly detection. You are given a clean training dataset consisting of many pictures of dogs, and an evaluation dataset that contains outliers (non-dogs). Your task is to implement and compare various methods for detecting these outliers. You may implement some of the ideas presented in today’s lecture, or you can look up other outlier detection algorithms in the linked references or online.

The lab assignments for the course are available in the dcai-lab repository.

Remember to run a git pull before doing every lab, to make sure you have the latest version of the labs.

The lab assignment for this class is in outliers/Lab - Outliers.ipynb.


Edit this page.

Licensed under CC BY-NC-SA.