Skip to content
Snippets Groups Projects
main.py 6.48 KiB
from typing import Dict, ForwardRef

import evaluate
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, \
    DataCollatorWithPadding
from datasets import Dataset
import datasets as ds
import numpy as np
import pandas as pd
from enum import Enum
import os
import time
import torch

Env = ForwardRef('Env')


class Metric(Enum):
    F1 = "f1"
    ACCURACY = "accuracy"


class ModelType(Enum):
    DISTILBERT = "distilbert-base-uncased"
    DEBERTA = "microsoft/deberta-base"


class TuningToggle:
    def __init__(self, env: Env):
        self.env = env
        self.old_training_args = self.env.training_args

    def __enter__(self):
        self.env.training_args = self.training_args = TrainingArguments(
            output_dir="./model",
            learning_rate=2e-5,
            per_device_train_batch_size=16,
            per_device_eval_batch_size=16,
            num_train_epochs=1,
            weight_decay=0.01,
            evaluation_strategy="epoch",
            save_strategy="no",
            load_best_model_at_end=False,
            push_to_hub=False,
            gradient_checkpointing=True,
            gradient_checkpointing_kwargs={"use_reentrant": False},
        )

    def __exit__(self, *args):
        self.env.training_args = self.old_training_args


class Env:
    def __init__(self, model_name: ModelType, metric: Metric = Metric.F1):
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

        id2label = {0: "False", 1: "True"}
        label2id = {"False": 0, "True": 1}

        self.tokeniser = AutoTokenizer.from_pretrained(model_name.value)
        self.metric = evaluate.load(metric.value)
        self.model_name = model_name.value
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name.value, num_labels=2, id2label=id2label, label2id=label2id)
        self.data_collator = DataCollatorWithPadding(tokenizer=self.tokeniser)

        self.training_args = TrainingArguments(
            output_dir=f"model/{model_name.value}_{metric.value}_{int(time.time())}",
            learning_rate=2e-5,
            per_device_train_batch_size=16,
            per_device_eval_batch_size=16,
            num_train_epochs=1,
            weight_decay=0.01,
            evaluation_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
            push_to_hub=False,
            gradient_checkpointing=True,
            gradient_checkpointing_kwargs={"use_reentrant": False},
        )

        self.model = self.model.to(self.device)

        self.tuning_toggle = TuningToggle(env=self)

    def model_init(self, trial):
        return AutoModelForSequenceClassification.from_pretrained(
            self.model_name
        )

    def compute_metrics(self, eval_pred):
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)
        return self.metric.compute(predictions=predictions, references=labels)

    def compute_objective(self, metrics: Dict[str, float]) -> float:
        return metrics["eval_f1"]

    def optuna_hp_space(self, trial: any) -> dict[str, any]:
        return {
            "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
            "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [16, 32, 64]),
            "weight_decay": trial.suggest_float("weight_decay", 0.005, 0.05, log=True)
        }


def initialise(model_name: ModelType = ModelType.DEBERTA) -> Env:
    os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
    os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
    pd.set_option('display.max_columns', None)
    return Env(model_name=model_name)


def preprocess_train_data(path: str, env: Env, upscale_factor: int = 7) -> tuple[Dataset, Dataset]:
    df = pd.read_csv(path, sep=",", escapechar="\\")
    df['label'] = df['is_patronising'].apply(lambda x: 1 if x else 0)
    dataset = Dataset.from_pandas(df[['label', 'text']]).map(lambda d: env.tokeniser(str(d['text']), truncation=True),
                                                             batched=False)

    dataset_0_split = dataset.filter(lambda x: x['label'] == 0).train_test_split(test_size=0.2)
    dataset_1_split = dataset.filter(lambda x: x['label'] == 1).train_test_split(test_size=0.2)
    dataset_train = ds.concatenate_datasets([dataset_0_split['train']] + [dataset_1_split['train']] * upscale_factor)
    dataset_test = ds.concatenate_datasets([dataset_0_split['test'], dataset_1_split['test']])

    return dataset_train, dataset_test


def preprocess_test_data(path: str, env: Env) -> Dataset:
    df = pd.read_csv(path, sep=",", escapechar="\\")
    df['label'] = df['is_patronising'].apply(lambda x: 1 if x else 0)
    dataset = Dataset.from_pandas(df[['label', 'text']]).map(lambda d: env.tokeniser(str(d['text']), truncation=True),
                                                             batched=False)

    return dataset


def train(env: Env) -> None:
    print(f"Train device = {env.device}")

    train_data, test_data = preprocess_train_data("data/train.csv", env, upscale_factor=7)
    validate_data = preprocess_test_data("data/dev.csv", env)

    trainer = Trainer(
        model=env.model,
        args=env.training_args,
        train_dataset=train_data,
        eval_dataset=test_data,
        tokenizer=env.tokeniser,
        data_collator=env.data_collator,
        compute_metrics=env.compute_metrics,
    )
    trainer.train()


# Hyperparameter training
def get_best_hyperparams(env: Env) -> dict[str: any]:
    print(f"Train device = {env.device}")

    train_data, test_data = preprocess_train_data("data/train.csv", env, upscale_factor=7)
    # validate_data = preprocess_test_data("data/dev.csv", env)

    with env.tuning_toggle:
        trainer = Trainer(
            model=None,
            args=env.training_args,
            train_dataset=train_data,
            eval_dataset=test_data,
            tokenizer=env.tokeniser,
            data_collator=env.data_collator,
            compute_metrics=env.compute_metrics,
            model_init=env.model_init
        )

    best_run = trainer.hyperparameter_search(
        direction=["maximize"],
        backend="optuna",
        hp_space=env.optuna_hp_space,
        n_trials=20,
        compute_objective=env.compute_objective,
    )

    # With single objective, best_run should be a single BestRun object
    return best_run.hyperparameters


if __name__ == '__main__':
    env: Env = initialise()
    # train(env)
    print(get_best_hyperparams(env))