9

Consider a dataframe that contains two columns, text and label. I can very easily create a stratified train-test split using sklearn.model_selection.train_test_split. The only thing I have to do is to set the column I want to use for the stratification (in this case label).

Now, consider a dataframe that contains three columns, text, subreddit, and label. I would like to make a stratified train-test split using the label column, but I also want to make sure that there is no bias in terms of the subreddit column. E.g., it's possible that the test set has way more comments coming from subreddit X while the train set does not.

How can I do this in Python?

Aventinus
  • 213
  • 1
  • 3
  • 7

1 Answers1

9

One option would be to feed an array of both variables to the stratify parameter which accepts multidimensional arrays too. Here's the description from the scikit documentation:

stratify array-like, default=None

If not None, data is split in a stratified fashion, using this as the class labels.


Here is an example:

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

create dummy data with unbalanced feature value distribution

X = pd.DataFrame(np.concatenate((np.random.randint(0, 3, 500), np.random.randint(0, 10, 500)),axis=0).reshape((500, 2)), columns=["text", "subreddit"]) y = pd.DataFrame(np.random.randint(0,2, 500).reshape((500, 1)), columns=["label"])

split stratified to target variable and subreddit col

X_train, X_test, y_train, y_test = train_test_split( X, pd.concat([X["subreddit"], y], axis=1), stratify=pd.concat([X["subreddit"], y], axis=1))

remove subreddit cols from target variable arrays

y_train = y_train.drop(["subreddit"], axis=1) y_test = y_test.drop(["subreddit"], axis=1)

As you can see the split is stratified to subreddit too:

Train data shares for subreddits

X_train.groupby("subreddit").count()/len(X_train)

gives

text
subreddit   
0   0.232000
1   0.232000
2   0.213333
3   0.034667
4   0.037333
5   0.045333
6   0.056000
7   0.056000
8   0.048000
9   0.045333

Test data shares for subreddits

X_test.groupby("subreddit").count()/len(X_test)

gives

text
subreddit   
0   0.232
1   0.240
2   0.208
3   0.032
4   0.032
5   0.048
6   0.056
7   0.056
8   0.048
9   0.048

Naturally, this only works if you have sufficient data to stratify to subreddit and the target variable at the same time. Otherwise scikit learn will throw an exception.

Jonathan
  • 5,605
  • 1
  • 11
  • 23