Skip to content
Snippets Groups Projects
main.py 3.22 KiB
Newer Older
  • Learn to ignore specific revisions
  • mmzk1526's avatar
    mmzk1526 committed
    import evaluate
    
    from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, \
        DataCollatorWithPadding
    from datasets import Dataset
    
    mmzk1526's avatar
    mmzk1526 committed
    import datasets as ds
    
    import numpy as np
    import pandas as pd
    from enum import Enum
    
    mmzk1526's avatar
    mmzk1526 committed
    import os
    
    import time
    
    import torch
    
    
    
    class Metric(Enum):
        F1 = "f1"
        ACCURACY = "accuracy"
    
    
    
    mmzk1526's avatar
    mmzk1526 committed
    class ModelType(Enum):
        DISTILBERT = "distilbert-base-uncased"
        DEBERTA = "microsoft/deberta-base"
    
    
    
    class Env:
    
    mmzk1526's avatar
    mmzk1526 committed
        def __init__(self, metric: Metric = Metric.F1, model_name: ModelType = ModelType.DEBERTA):
    
            self.device = "cuda:0" if torch.cuda.is_available() else ("mps" if torch.cuda.is_available() else "cpu")
    
    
            id2label = {0: "False", 1: "True"}
            label2id = {"False": 0, "True": 1}
    
    mmzk1526's avatar
    mmzk1526 committed
            self.tokeniser = AutoTokenizer.from_pretrained(model_name.value)
    
            self.metric = evaluate.load(metric.value)
            self.model = AutoModelForSequenceClassification.from_pretrained(
    
    mmzk1526's avatar
    mmzk1526 committed
                model_name.value, num_labels=2, id2label=id2label, label2id=label2id)
    
            self.data_collator = DataCollatorWithPadding(tokenizer=self.tokeniser)
            self.training_args = TrainingArguments(
    
    mmzk1526's avatar
    mmzk1526 committed
                output_dir=f"model/{model_name.value}_{metric.value}_{int(time.time())}",
    
                learning_rate=2e-5,
                per_device_train_batch_size=16,
                per_device_eval_batch_size=16,
    
    mmzk1526's avatar
    mmzk1526 committed
                num_train_epochs=1,
    
                weight_decay=0.01,
                evaluation_strategy="epoch",
                save_strategy="epoch",
                load_best_model_at_end=True,
    
    mmzk1526's avatar
    mmzk1526 committed
                push_to_hub=False,
    
                # use_mps_device=(self.device == "mps"),
                gradient_checkpointing=True,
                gradient_checkpointing_kwargs={"use_reentrant": False},
    
            self.model = self.model.to(self.device)
    
    
        def compute_metrics(self, eval_pred):
            predictions, labels = eval_pred
            predictions = np.argmax(predictions, axis=1)
            return self.metric.compute(predictions=predictions, references=labels)
    
    
    def initialise() -> Env:
    
    mmzk1526's avatar
    mmzk1526 committed
        os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
    
        pd.set_option('display.max_columns', None)
        return Env()
    
    
    
    def preprocess_data(path: str, env: Env, is_train: bool = False, upscale_factor: int = 7) -> Dataset:
    
        df = pd.read_csv(path, sep=",", escapechar="\\")
        df['label'] = df['is_patronising'].apply(lambda x: 1 if x else 0)
    
    mmzk1526's avatar
    mmzk1526 committed
        dataset = Dataset.from_pandas(df[['label', 'text']]).map(lambda d: env.tokeniser(str(d['text']), truncation=True), batched=False)
    
        if is_train:
            dataset_0 = dataset.filter(lambda x: x['label'] == 0)
            dataset_1 = dataset.filter(lambda x: x['label'] == 1)
            dataset = ds.concatenate_datasets([dataset_0] + [dataset_1] * upscale_factor)
    
        return dataset
    
    
    
    def train(env: Env) -> None:
    
        print(f"Train device = {env.device}")
    
    
        trainer = Trainer(
            model=env.model,
            args=env.training_args,
    
    mmzk1526's avatar
    mmzk1526 committed
            train_dataset=preprocess_data("data/train.csv", env, is_train=True),
    
            eval_dataset=preprocess_data("data/dev.csv", env),
            tokenizer=env.tokeniser,
            data_collator=env.data_collator,
            compute_metrics=env.compute_metrics,
        )
        trainer.train()
    
    mmzk1526's avatar
    mmzk1526 committed
    
    
    if __name__ == '__main__':
    
        env: Env = initialise()
        train(env)