Skip to content
Snippets Groups Projects
main.py 2.33 KiB
Newer Older
  • Learn to ignore specific revisions
  • mmzk1526's avatar
    mmzk1526 committed
    import evaluate
    
    from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, \
        DataCollatorWithPadding
    from datasets import Dataset
    import numpy as np
    import pandas as pd
    from enum import Enum
    import time
    
    
    class Metric(Enum):
        F1 = "f1"
        ACCURACY = "accuracy"
    
    
    class Env:
        def __init__(self, metric: Metric = Metric.F1):
            id2label = {0: "False", 1: "True"}
            label2id = {"False": 0, "True": 1}
            self.tokeniser = AutoTokenizer.from_pretrained("distilbert-base-uncased")
            self.metric = evaluate.load(metric.value)
            self.model = AutoModelForSequenceClassification.from_pretrained(
                "distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id)
            self.data_collator = DataCollatorWithPadding(tokenizer=self.tokeniser)
            self.training_args = TrainingArguments(
                output_dir=f"model/test_model_{metric.value}_{int(time.time())}",
                learning_rate=2e-5,
                per_device_train_batch_size=16,
                per_device_eval_batch_size=16,
                num_train_epochs=2,
                weight_decay=0.01,
                evaluation_strategy="epoch",
                save_strategy="epoch",
                load_best_model_at_end=True,
                push_to_hub=False
            )
    
        def compute_metrics(self, eval_pred):
            predictions, labels = eval_pred
            predictions = np.argmax(predictions, axis=1)
            return self.metric.compute(predictions=predictions, references=labels)
    
    
    def initialise() -> Env:
        pd.set_option('display.max_columns', None)
        return Env()
    
    
    def preprocess_data(path: str, env: Env) -> Dataset:
        df = pd.read_csv(path, sep=",", escapechar="\\")
        df['label'] = df['is_patronising'].apply(lambda x: 1 if x else 0)
        dataset = Dataset.from_pandas(df[['label', 'text']])
        return dataset.map(lambda d: env.tokeniser(str(d['text']), truncation=True), batched=False)
    
    
    def train(env: Env) -> None:
        trainer = Trainer(
            model=env.model,
            args=env.training_args,
            train_dataset=preprocess_data("data/train.csv", env),
            eval_dataset=preprocess_data("data/dev.csv", env),
            tokenizer=env.tokeniser,
            data_collator=env.data_collator,
            compute_metrics=env.compute_metrics,
        )
        trainer.train()
    
    mmzk1526's avatar
    mmzk1526 committed
    
    
    if __name__ == '__main__':
    
        env: Env = initialise()
        train(env)