Comment Classification with BERT (Sequential Transfer Learning)¶

For educational purposes, consider a hypothetical online retailer launching an innovative service. This service enables users to collaboratively edit and enhance product descriptions, similar to the contributions found in wiki communities. Customers can suggest changes and comment on others’ edits, and to maintain a respectful and safe environment, the retailer requires a tool that detects toxic comments and routes them for moderation.

Project Overview¶

Objective:

  • Develop a model that classifies comments as either positive or negative using a dataset labeled for toxicity.
  • The model must achieve an F1 score of at least 0.75.

Data Description:

The dataset is stored in the file /datasets/toxic_comments.csv.

  • The column text contains the comment text.
  • The column toxic is the target label indicating toxicity.

Approach:

  1. Load and Prepare the Data.
  2. Evaluation and Conclusions.

Preparation¶

Dependencies¶

In [1]:
import os
import hashlib

from dataclasses import dataclass, field
from enum import Enum

import nltk
import numpy as np
import pandas as pd
import seaborn as sns
import torch


from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from matplotlib import pyplot as plt
from sklearn.metrics import (
    classification_report, confusion_matrix, f1_score, roc_auc_score, roc_curve
)
from sklearn.model_selection import train_test_split
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertForSequenceClassification
from tqdm import tqdm
In [2]:
from google.colab import drive

mount_files = True
if mount_files:
    drive.mount('/content/drive')
Mounted at /content/drive

Constants¶

In [3]:
STL_ID = 3  # Sequential Transfer Learning (start with 0)

SUBSETS = 4
VERSION = 2
ID_STEP = 1
SEED = 42
STOPWORDS_STL_ID = 0

TEST_SIZE = 0.3
PRODUCTION_TEST_SIZE = 0.2

MODEL_CACHE = 'drive/MyDrive/datasets/.cache'
DATASETS_PATH =  'https://code.s3.yandex.net/datasets/toxic_comments.csv'
DATASETS_PATH_LOCAL = 'drive/MyDrive/datasets/toxic_comments.csv'

NLTK_DATA_PATH = ('drive', 'MyDrive', 'datasets', 'nltk_data')

INFERENCE_ONLY = False

PRODUCTION_N_SAMPLES = 2000
PRODUCTION_N_SAMPLES_BOOTSTRAP = 800
BOOTSTRAP_N_SAMPLES = 600


BR = '\n'


f'Version {VERSION}.{ID_STEP*STL_ID}.{SUBSETS}'
Out[3]:
'Version 2.3.4'

Hyperparameters¶

In [4]:
class InfraParameter(Enum):
    BATCH_SIZE = 128
    MAX_LENGTH = 256 + 64
    EPOCHS = 8


class MetaParameter(Enum):
    L_RATE_MIN = 1e-6
    L_RATE = 1e-5
    L_RATE_MAX = 5e-5
    L_RATE_SUBSET_ID_DECREASE = -1e-6


@dataclass
class HyperParameter:
    infra: InfraParameter = InfraParameter
    meta: MetaParameter = MetaParameter
    subset_id: int = ID_STEP * STL_ID
    warm_up: tuple[int, int] = (0, 10)

    def lr(self) -> float:
        addapt = self.meta.L_RATE_SUBSET_ID_DECREASE.value * self.subset_id
        if self.subset_id <= self.warm_up[0]:
            addapt = self.meta.L_RATE_SUBSET_ID_DECREASE.value * self.warm_up[1] * -1
        lr_ = max(self.meta.L_RATE.value + addapt, self.meta.L_RATE_MIN.value)
        return min(lr_, self.meta.L_RATE_MAX.value)

    def epochs(self, threshold: int = 1,  add: int = 1) -> int:
        hp_epochs = self.infra.EPOCHS.value
        return hp_epochs + add if self.subset_id < threshold else hp_epochs

    def lr_development(self) -> None:
        lrs = [HyperParameter(subset_id=i).lr() for i in range(SUBSETS)]
        plt.figure(figsize=(12, 3))
        bar_colors, bar_colors[self.subset_id] = ['gray'] * SUBSETS, 'green'
        plt.bar(range(SUBSETS), lrs, color=bar_colors, alpha=0.4, width=0.5)
        act_lr = f'{lrs[self.subset_id]:.2e}'
        plt.text(self.subset_id, lrs[self.subset_id], act_lr, va='top')
        plt.title('Learning Rate Development')
        plt.xlabel('Subset ID')
        plt.ylabel('Learning Rate')
        plt.show()

Sequential Transfer Learning¶

In [5]:
hp = HyperParameter()


hp.lr_development()
No description has been provided for this image

This project is focused on developing a robust toxic comment classification model using Sequential Transfer Learning (STL) to adapt training on data subsets.

Sequential Transfer Learning (STL) is an approach that enables the model to learn iteratively on smaller, manageable subsets of data, which is especially useful when working with large datasets and accounting for potential data drift over time.

The process begins with a warm start: the first subset is trained without stop-words and with a warm (warm-up learning rate) training speed. This allows the model to adapt to the basic structure of the data while maintaining its generalization capability. For subsequent subsets, the learning rate is gradually decreased, enabling more precise tuning of the model training and reducing the risk of overfitting. Each following subset is built on the knowledge (pre-training) acquired from the previous stage, creating an iterative improvement process that mimics pre-training while incorporating controlled fine-tuning at each stage. This approach ensures a structured and adaptive training process suited to the complexity of large datasets. Moreover, STL offers a systematic method for adapting the model to changing data distributions (for example, data drift), which contributes to the model's improved performance over time. The use of subsets also simplifies the gradual tuning of hyperparameters, such as adjusting the learning rate, allowing for more precise model optimization and enhanced performance at each stage of training.

Auxiliary Classes¶

In [6]:
@dataclass
class Log:
    time_start: pd.Timestamp = None

    def set_time(self):
        self.time_start = pd.Timestamp.now()

    def get_time_h(self, print_: bool = False) -> float:
        t_ = round((pd.Timestamp.now() - self.time_start).seconds / 3600, 2)
        if not print_:
            return t_
        print(f'Elapsed time: {t_} hours')

    @staticmethod
    def value_counts(
        df_cnt: pd.DataFrame, column: str = 'target', cnt: str = 'cnt'
    ) -> None:
        df_cnt = df_cnt[column].value_counts().to_frame()
        df_cnt.columns = [cnt]
        display(df_cnt)

    @staticmethod
    def censored(log_df: pd.DataFrame, cols: str = 'target-text', hd: int = 8) -> None:
        prefix = '*** [ positive ] *** | HASH: '
        log_df_, c0, c1  = log_df.copy(), cols.split('-')[0], cols.split('-')[1]
        log_df_.loc[log_df_[c0] == 1, c1] = log_df_.loc[log_df_[c0] == 1, c1].apply(
            lambda x: prefix + hashlib.sha256(str(x).encode()).hexdigest()
        )
        display(log_df_.head(hd))
In [7]:
@dataclass
class DataPreparation:
    target: str
    seed: int
    size: float
    is_under_sample: bool = False
    schema: tuple[str, str] = ('text', 'target')
    cache: dict = field(default_factory=dict)

    def key_value_structure(self, dfm: pd.DataFrame, km: str, vm: str) -> pd.DataFrame:
        schema = {km: self.schema[0], vm: self.schema[1]}
        self.target, self.cache['target_name'] = self.schema[1], vm
        return dfm[[km, vm]].copy().rename(columns=schema)

    def split(self, *series: pd.Series, **conf) -> tuple[pd.Series, ...]:
        conf['test_size'] = conf.get('test_size', self.size)
        return train_test_split(*series, random_state=self.seed, **conf)

    def sub_samples(self, dfm: pd.DataFrame, n_subsamples: int) -> list[pd.DataFrame]:
        dfm = dfm.copy()
        split_size = len(dfm) // n_subsamples
        subsamples = [
            dfm.iloc[i * split_size: (i + 1) * split_size].reset_index(drop=True)
            for i in range(n_subsamples)
        ]
        if len(dfm) % n_subsamples != 0:
            remaining = dfm.iloc[n_subsamples * split_size:].reset_index(drop=True)
            subsamples[-1] = pd.concat([subsamples[-1], remaining]).reset_index(
                drop=True
            )
        return subsamples

    @staticmethod
    def join(x_: pd.Series, y_: pd.Series) -> pd.DataFrame:
        return pd.concat([x_, y_], axis=1)

    @staticmethod
    def text_target(dfm: pd.DataFrame) -> tuple[pd.Series, pd.Series]:
        dfm = dfm.copy()
        return dfm['text'], dfm['target']

EDA (Unstructured Data)¶

In [8]:
df_raw = pd.read_csv(DATASETS_PATH)


df_raw.info()  # no missing values
Log.censored(df_raw, 'toxic-text')
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 159292 entries, 0 to 159291
Data columns (total 3 columns):
 #   Column      Non-Null Count   Dtype 
---  ------      --------------   ----- 
 0   Unnamed: 0  159292 non-null  int64 
 1   text        159292 non-null  object
 2   toxic       159292 non-null  int64 
dtypes: int64(2), object(1)
memory usage: 3.6+ MB
Unnamed: 0 text toxic
0 0 Explanation\nWhy the edits made under my usern... 0
1 1 D'aww! He matches this background colour I'm s... 0
2 2 Hey man, I'm really not trying to edit war. It... 0
3 3 "\nMore\nI can't make any real suggestions on ... 0
4 4 You, sir, are my hero. Any chance you remember... 0
5 5 "\n\nCongratulations from me as well, use the ... 0
6 6 *** [ positive ] *** | HASH: 6e4d3584d34a8a9e2... 1
7 7 Your vandalism to the Matt Shirvington article... 0
In [9]:
vc_target = df_raw['toxic'].value_counts().to_frame()
vc_target['percent'] =  round(vc_target / vc_target.sum() * 100, 2)


vc_target  # imbalanced target (90% non-toxic)
Out[9]:
count percent
toxic
0 143106 89.84
1 16186 10.16
In [10]:
plt.figure(figsize=(12, 4))
for value, color in [(1, 'red'), (0, 'gray')]:
    label = f'Toxic = {value}'
    df_raw[df_raw['toxic'] == value]['text'].apply(lambda x: len(str(x))).plot(
        kind='hist', bins=200, range=(0, 3000), alpha=0.5, label=label, color=color
    )


plt.title('Distribution of Token (Approximation) Count')
plt.xlabel('Word Count')
plt.ylabel('Frequency')
plt.legend()

plt.show()
No description has been provided for this image

Data Splitting I¶

In [11]:
dp = DataPreparation(target='toxic', seed=SEED, size=TEST_SIZE)


df = dp.key_value_structure(df_raw, 'text', 'toxic')


Log.censored(df)
text target
0 Explanation\nWhy the edits made under my usern... 0
1 D'aww! He matches this background colour I'm s... 0
2 Hey man, I'm really not trying to edit war. It... 0
3 "\nMore\nI can't make any real suggestions on ... 0
4 You, sir, are my hero. Any chance you remember... 0
5 "\n\nCongratulations from me as well, use the ... 0
6 *** [ positive ] *** | HASH: 6e4d3584d34a8a9e2... 1
7 Your vandalism to the Matt Shirvington article... 0
In [12]:
production: dict[str, pd.Series] = {}
df_X_temp, production['X_test'], df_y_temp, production['y_test'] = dp.split(
    *dp.text_target(df), test_size=PRODUCTION_TEST_SIZE
)


# Training, Validation, Test
# ==========================
df = dp.join(df_X_temp, df_y_temp)


# Production Test
# ===============
dp.cache['production'] = production
dp.cache['df_production'] = dp.join(production['X_test'], production['y_test'])


Log.value_counts(df, cnt='count_train_val_test')
Log.censored(df, hd=5)
df.shape, production['X_test'].shape, production['y_test'].shape
count_train_val_test
target
0 114448
1 12985
text target
45155 "\nYou claimed to have ""scavenged the UN and ... 0
60904 "\n\n Please do not vandalize pages, as you di... 0
92242 "\n\n ""largest moon"" \n\nShouldn't it say la... 0
74757 "\n\n Isn't baking cooking? \n\nAccording to t... 0
7198 I am sure the judges smiled too.\n\nWhen you c... 0
Out[12]:
((127433, 2), (31859,), (31859,))

Preprocessing¶

In [13]:
@dataclass
class StopWordsProcessor:
    nltk_path: tuple[str, ...]
    stop_words: set = field(init=False)

    def __post_init__(self):
        nltk_data_path = os.path.join(*self.nltk_path)
        os.makedirs(nltk_data_path, exist_ok=True)
        nltk.data.path.append(nltk_data_path)
        nltk.download('stopwords', download_dir=nltk_data_path)
        nltk.download('punkt_tab', download_dir=nltk_data_path)
        self.stop_words = set(stopwords.words('english'))

    def remove_stop_words(self, text: str) -> str:
        words = word_tokenize(text)
        filtered_words = [word for word in words if word.lower() not in self.stop_words]
        return ' '.join(filtered_words)

    def preprocess_dataframe(self, df: pd.DataFrame, rm: bool = True) -> pd.DataFrame:
        if rm:
            df = df.copy()
            df['text'] = df['text'].apply(self.remove_stop_words)
        return df
In [14]:
stop_words_processor = StopWordsProcessor(NLTK_DATA_PATH)


df = stop_words_processor.preprocess_dataframe(df, rm = STL_ID == STOPWORDS_STL_ID)

Log.censored(df)
[nltk_data] Downloading package stopwords to
[nltk_data]     drive/MyDrive/datasets/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     drive/MyDrive/datasets/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
text target
45155 "\nYou claimed to have ""scavenged the UN and ... 0
60904 "\n\n Please do not vandalize pages, as you di... 0
92242 "\n\n ""largest moon"" \n\nShouldn't it say la... 0
74757 "\n\n Isn't baking cooking? \n\nAccording to t... 0
7198 I am sure the judges smiled too.\n\nWhen you c... 0
22361 this is a government IP used by roughly 3,000 ... 0
132882 12:13, 4 May 2012‎ User:217.217.197.24 0
154232 gentlemen and gentlemen, ProKo has been revert... 0
In [15]:
df_subsets = dp.sub_samples(df, SUBSETS)
df = df_subsets[hp.subset_id]
df = df.sample(frac=1, random_state=SEED).reset_index(drop=True)


Log.value_counts(df, cnt=f'count_train_val_test_with_subset_id_{hp.subset_id}')

Log.censored(df)
count_train_val_test_with_subset_id_3
target
0 28616
1 3243
text target
0 Yes, the Mustang GT500 laptime is real \n\nThe... 0
1 User Sitush \n\nYou seem to be a genuine crusa... 0
2 RfAr notice \n\nYou are mentioned in Wikipedia... 0
3 In popular culture \n\nMossad is seen in many ... 0
4 2010 (UTC)\n\nWelcome\n86.29.137.111 03:50, 1... 0
5 OK, Ngo Thanh Nhan just wrote to me after read... 0
6 *** [ positive ] *** | HASH: 9565b007350481f51... 1
7 Are you brooding about colours nuances, violet... 0

Training¶

Data Splitting II¶

In [16]:
test: dict[str, pd.Series] = {}
X_train_val, test['X'], y_train_val, test['y'] = dp.split(*dp.text_target(df))


print(' ' * 3, 'train/val', 'test', sep='  ')
display(('X', X_train_val.shape, test['X'].shape))
'y', y_train_val.shape, test['y'].shape
     train/val  test
('X', (22301,), (9558,))
Out[16]:
('y', (22301,), (9558,))
In [17]:
train: dict[str, pd.Series] = {}
val: dict[str, pd.Series] = {}
train['X'], val['X'], train['y'], val['y'] = dp.split(X_train_val, y_train_val)


print(' ' * 4, 'train', ' val', sep='  ')
display(('X', train['X'].shape, val['X'].shape))
'y', train['y'].shape, val['y'].shape
      train   val
('X', (15610,), (6691,))
Out[17]:
('y', (15610,), (6691,))

Tokenization and Models¶

In [18]:
# Subset Determination
# ====================
name_or_path = 'bert-base-uncased'
if hp.subset_id > 0:
    prev_subset_id = hp.subset_id - ID_STEP
    name_or_path = f'{MODEL_CACHE}/model_v{VERSION}.{prev_subset_id}.{SUBSETS}'


# Tokenization and Load Model
# ===========================
tokenizer: BertTokenizer = BertTokenizer.from_pretrained(
    name_or_path,
    cache_dir=MODEL_CACHE
)
model = BertForSequenceClassification.from_pretrained(
    name_or_path,
    num_labels=2,
    cache_dir=MODEL_CACHE
)


ml_info = f'Model "{name_or_path}" ready for SUBSET with ID: {hp.subset_id}'
print(BR, BR + ml_info, BR, sep=BR + '*' * 64)
'Total subsets: ' + str(SUBSETS)

****************************************************************
Model "drive/MyDrive/datasets/.cache/model_v2.2.4" ready for SUBSET with ID: 3
****************************************************************

Out[18]:
'Total subsets: 4'
In [19]:
def tokenize_function(text: list[str], hp_: HyperParameter) -> dict[str, torch.Tensor]:
    kw_args = {
        'padding': 'max_length',
        'truncation': True,
        'max_length': hp_.infra.MAX_LENGTH.value,
        'return_tensors': 'pt'
    }
    return tokenizer(text, **kw_args)

Encoding¶

In [20]:
train_encodings = tokenize_function(train['X'].tolist(), hp)
val_encodings = tokenize_function(val['X'].tolist(), hp)
test_encodings = tokenize_function(test['X'].tolist(), hp)


sequence_lengths = [
    len(tokenizer(text, truncation=True)['input_ids']) for text in train['X']
]


print(f'Average sequence length: {sum(sequence_lengths) / len(sequence_lengths)}')
print(BR, 'Percentiles', '=' * 20, sep=BR)
for percentile in [80, 90, 95]:
    print(f'{percentile}th: {int(np.percentile(sequence_lengths, percentile))}')
'Sequence (truncated) length: ' + str(train_encodings['input_ids'].shape[1])
Average sequence length: 90.0457399103139


Percentiles
====================
80th: 128
90th: 211
95th: 325
Out[20]:
'Sequence (truncated) length: 320'

Dataset Setup¶

In [21]:
class ToxicDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}, self.labels[idx]


train_dataset = ToxicDataset(train_encodings, train['y'].tolist())
train_loader = DataLoader(
    train_dataset, batch_size=hp.infra.BATCH_SIZE.value, shuffle=True
)

val_dataset = ToxicDataset(val_encodings, val['y'].tolist())
val_loader = DataLoader(
    val_dataset, batch_size=hp.infra.BATCH_SIZE.value, shuffle=False
)

Training (BERT)¶

Training Functions¶

In [22]:
def train_one_epoch(
    model: torch.nn.Module,
    train_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    loss_fn: torch.nn.Module,
    device: torch.device
) -> tuple[list[int], list[int], float]:

    model.train()
    all_preds, all_labels = [], []
    progress_bar = tqdm(train_loader, desc="Training", leave=False)

    for batch in progress_bar:
        inputs, labels = batch
        inputs = {key: value.to(device) for key, value in inputs.items()}
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(**inputs, labels=labels)
        loss = loss_fn(outputs.logits, labels)
        loss.backward()
        optimizer.step()

        progress_bar.set_postfix({'loss': loss.item()})
        preds = outputs.logits.argmax(dim=-1).detach().cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.cpu().numpy())

    train_f1 = f1_score(all_labels, all_preds, average='binary')
    return all_preds, all_labels, train_f1
In [23]:
def validate_model(
    model: torch.nn.Module, val_loader: DataLoader, device: torch.device
) -> tuple[list[int], list[int], float]:

    model.eval()
    val_preds, val_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            inputs, labels = batch
            inputs = {key: value.to(device) for key, value in inputs.items()}
            labels = labels.to(device)
            outputs = model(**inputs)
            preds = outputs.logits.argmax(dim=-1).detach().cpu().numpy()
            val_preds.extend(preds)
            val_labels.extend(labels.cpu().numpy())

    val_f1 = f1_score(val_labels, val_preds, average='binary')
    return val_preds, val_labels, val_f1

Device¶

In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)


if torch.cuda.is_available():
    print(torch.cuda.device_count())
    print(torch.cuda.get_device_name(0))


device
1
NVIDIA A100-SXM4-40GB
Out[24]:
device(type='cuda')

Loss Function¶

In [25]:
class_counts = np.bincount(train['y'])  # 90% non-toxic, 10% toxic
total_samples = class_counts.sum()
class_weights = torch.tensor(total_samples / class_counts, dtype=torch.float32)

loss_fn = CrossEntropyLoss(weight=class_weights).to(device)

Training Configuration¶

In [26]:
optimizer = torch.optim.AdamW(model.parameters(), lr=hp.lr())

patience = 2  # Number of epochs to wait before stopping
f1_epoch_factor = 0.01
wait = 0
best_val_f1 = float('-inf')
best_model_state = None

log = Log()

Early Stopping Training¶

In [27]:
log.set_time()


for epoch in range(hp.epochs()):
    if INFERENCE_ONLY:
        print('Set "INFERENCE_ONLY = False" for model training')
        break

    print(f'Epoch {epoch + 1} of {hp.epochs()}')
    train_preds, train_labels, train_f1 = train_one_epoch(
        model, train_loader, optimizer, loss_fn, device=device
    )
    print(f'Epoch {epoch + 1} Training F1 Score: {train_f1:.4f}')
    _, _, val_f1 = validate_model(model, val_loader, device=device)
    print(f'Epoch {epoch + 1} Validation F1 Score: {val_f1:.4f}')

    if val_f1 > best_val_f1:
        best_model_state = model.state_dict()
        print(f'Model (state) with validation F1 {val_f1} is updated.')
    if val_f1 >= (best_val_f1 + (f1_epoch_factor * epoch)):
        wait = 0
    else:
        wait += 1
        print(f'No significant improvement in validation F1 for {wait} epoch(s).')
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1

    if wait >= patience:
        print(f'Early stopping triggered after {epoch + 1} epochs.')
        break


log.get_time_h(print_=True)


if best_model_state:
    model.load_state_dict(best_model_state)
    print(f'Loaded best model with Validation F1: {best_val_f1:.4f}')
else:
    print('(!) - No improvement during training. Model not updated.')
Epoch 1 of 8

Epoch 1 Training F1 Score: 0.7505
Epoch 1 Validation F1 Score: 0.7700
Model (state) with validation F1 0.76996996996997 is updated.
Epoch 2 of 8

Epoch 2 Training F1 Score: 0.8171
Epoch 2 Validation F1 Score: 0.7538
No significant improvement in validation F1 for 1 epoch(s).
Epoch 3 of 8

Epoch 3 Training F1 Score: 0.8533
Epoch 3 Validation F1 Score: 0.7903
Model (state) with validation F1 0.7903123008285532 is updated.
Epoch 4 of 8

Epoch 4 Training F1 Score: 0.8926
Epoch 4 Validation F1 Score: 0.7908
Model (state) with validation F1 0.7908455181182454 is updated.
No significant improvement in validation F1 for 1 epoch(s).
Epoch 5 of 8

Epoch 5 Training F1 Score: 0.9173
Epoch 5 Validation F1 Score: 0.8088
Model (state) with validation F1 0.8088235294117647 is updated.
No significant improvement in validation F1 for 2 epoch(s).
Early stopping triggered after 5 epochs.
Elapsed time: 0.26 hours
Loaded best model with Validation F1: 0.8088

Saving the Model¶

In [28]:
path_pretrained = f'{MODEL_CACHE}/model_v{VERSION}.{hp.subset_id}.{SUBSETS}'
if not INFERENCE_ONLY:
    model.save_pretrained(path_pretrained)
    tokenizer.save_pretrained(path_pretrained)


path_pretrained
Out[28]:
'drive/MyDrive/datasets/.cache/model_v2.3.4'

Testing¶

Testing Functions¶

In [29]:
def inference(
    encodings: dict[str, torch.Tensor],
    model: BertForSequenceClassification,
    target: pd.Series,
    batch_size: int = 8,
    threshold: float = 0.5,
) -> tuple[list[int], list[int], list[float]]:

    model.eval()
    device = next(model.parameters()).device
    dataset = ToxicDataset(encodings, target.tolist())
    probabilities, predictions, true_labels = [], [], []

    with torch.no_grad():
        for batch in DataLoader(dataset, batch_size=batch_size):
            inputs, labels = batch
            inputs = {key: value.to(device) for key, value in inputs.items()}
            labels = labels.to(device)
            outputs = model(**inputs)
            logits = outputs.logits
            probs = torch.softmax(logits, dim=1)[:, 1].tolist()
            probabilities.extend(probs)
            batch_predictions = [1 if prob >= threshold else 0 for prob in probs]
            predictions.extend(batch_predictions)
            true_labels.extend(labels.tolist())

    return predictions, true_labels, probabilities
In [30]:
def report(true_labels, proba, threshold: float = 0.5):
    proba_pred = (np.array(proba) >= threshold).astype(int)
    cm = confusion_matrix(true_labels, proba_pred)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted')
    plt.ylabel('True')

    plt.show()

    print(classification_report(true_labels, proba_pred))
In [31]:
def production_preparation(
    text: pd.Series,
    target: pd.Series,
    seed: int,
    sw: StopWordsProcessor,
    n: int = None,
    rm: bool = True
) -> tuple[pd.Series, pd.Series, pd.DataFrame]:

    df_production = pd.concat([text, target], axis=1)
    df_production = sw.preprocess_dataframe(df_production, rm)
    if n is not None:
        df_production = df_production.sample(n=n, random_state=seed)
    df_production = df_production.reset_index(drop=True)
    return df_production['text'], df_production['target'], df_production.copy()


def production_pipeline(
    text: pd.Series,
    target: pd.Series,
    seed: int,
    hp_: HyperParameter,
    sw: StopWordsProcessor,
    n: int = None,
    rm: bool = True,
) -> tuple[list[int], list[int], pd.DataFrame, list[float]]:

    X_text, y_target, df_prod = production_preparation(text, target, seed, sw, n, rm)
    predictions, true_labels, probabilities = inference(
        tokenize_function(X_text.tolist(), hp_),
        model,
        y_target,
        batch_size=hp_.infra.BATCH_SIZE.value,
    )
    return predictions, true_labels, df_prod, probabilities
In [32]:
def f1_distribution(production_pipeline, n_samples, **kwargs):
    np.random.seed(kwargs.get('seed', 42))
    f1_scores, f1_str = [], ''
    for i in range(n_samples):
        kwargs['seed'] = kwargs['seed'] + i
        predictions, true_labels, _, _ = production_pipeline(**kwargs)
        f1 = f1_score(true_labels, predictions)
        f1_scores.append(f1)
        f1_str += ' ' + str(round(f1, 2))
    lower_ci, upper_ci = np.percentile(f1_scores, 2.5), np.percentile(f1_scores, 97.5)
    return f1_scores, np.mean(f1_scores), np.std(f1_scores), lower_ci, upper_ci

Subset Testing¶

In [33]:
predictions_, true_labels_, probas_ = inference(
    test_encodings, model, target=test['y'], batch_size=hp.infra.BATCH_SIZE.value
)

report(true_labels_, probas_)


f1_score(true_labels_, predictions_), hp.subset_id, SUBSETS
No description has been provided for this image
              precision    recall  f1-score   support

           0       0.99      0.97      0.98      8590
           1       0.75      0.90      0.82       968

    accuracy                           0.96      9558
   macro avg       0.87      0.93      0.90      9558
weighted avg       0.96      0.96      0.96      9558

Out[33]:
(0.8185654008438819, 3, 4)

Pipeline-Based Predictions (Test)¶

In [34]:
predictions, true_labels, df_production, proba = production_pipeline(
    text=dp.cache['production']['X_test'],
    target=dp.cache['production']['y_test'],
    seed=SEED+hp.subset_id,
    hp_=hp,
    n=PRODUCTION_N_SAMPLES,
    sw=stop_words_processor,
)


report(true_labels, proba)
No description has been provided for this image
              precision    recall  f1-score   support

           0       0.98      0.97      0.97      1802
           1       0.74      0.81      0.77       198

    accuracy                           0.95      2000
   macro avg       0.86      0.89      0.87      2000
weighted avg       0.95      0.95      0.95      2000

In [35]:
fpr, tpr, _ = roc_curve(true_labels, proba)
roc_auc = roc_auc_score(true_labels, proba)


plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='blue', label=f'ROC Curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid(alpha=0.3)


plt.show()
No description has been provided for this image

Bootstrap Test¶

In [36]:
f1_scores, mean_f1, std_f1, lower_ci, upper_ci = f1_distribution(
    production_pipeline=production_pipeline,
    n_samples=BOOTSTRAP_N_SAMPLES,
    text=dp.cache['production']['X_test'],
    target=dp.cache['production']['y_test'],
    hp_=hp,
    seed=SEED,
    sw=stop_words_processor,
    n=PRODUCTION_N_SAMPLES_BOOTSTRAP,
    rm=STL_ID == STOPWORDS_STL_ID,
)


plt.hist(f1_scores, bins=30, alpha=0.7, edgecolor='black')
plt.axvline(mean_f1, color='red', linestyle='dashed', linewidth=2, label='Mean F1')
plt.axvline(
    lower_ci, color='green', linestyle='dashed', linewidth=2, label='2.5th Percentile'
)
plt.axvline(
    upper_ci, color='green', linestyle='dashed', linewidth=2, label='97.5th Percentile'
)
plt.title("F1 Score Distribution")
plt.xlabel("F1 Score")
plt.ylabel("Frequency")
plt.legend()


plt.show()
No description has been provided for this image

Conclusions¶

The project successfully trained a toxic comment classification model using sequential transfer learning, achieving high performance metrics and demonstrating strong generalization capabilities.

Sequential Transfer Learning (STL) is especially effective in addressing data drift issues because it enables the model to iteratively adapt to smaller, representative data subsets. This ensures that the model maintains accuracy even as the data distribution changes over time. Splitting the dataset into manageable parts makes the training process more computationally efficient and allows for hyperparameter tuning at each stage. In addition, STL facilitates pre-training on the initial subset, followed by targeted learning rate adjustments on subsequent subsets, establishing a controlled and systematic process for improving model performance.

The production pipeline integrates preprocessing, inference, and evaluation, making it ideally suited for real-world deployment. Furthermore, the pipeline design is both scalable and modular, allowing for the integration of new subsets or data to ensure adaptability and long-term viability in dynamic environments.

Model v2.3.4¶

Upon completing training on the 4th subset in the sequential transfer learning process, the model achieved a high F1 score of 0.82, reflecting its ability to effectively balance precision and recall in toxic comment classification. The F1 score distribution, obtained through bootstrap analysis, demonstrated stability and consistency, with a mean value of about 0.82 and confidence intervals ranging from approximately 0.75 (2.5th percentile) to 0.87 (97.5th percentile). This indicates that the model's performance is not only strong but also reliable across different conditions, confirming the robustness of the sequential transfer learning approach at this stage. These results underscore the effectiveness of iterative improvements via subsets while maintaining statistical reliability.

The model delivered outstanding results, achieving an impressive ROC-AUC of 0.96. This highlights its ability to clearly distinguish between toxic and non-toxic comments, a key metric for classification tasks, and confirms its reliability. In addition, the bootstrap analysis of the F1 metric showed stability and consistency in the model's performance. Confidence intervals, such as the 2.5th and 97.5th percentiles, provide a quantitative measure of uncertainty, further confirming the model's robustness under various conditions.

In [ ]: