Newer
Older
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, \
DataCollatorWithPadding
from datasets import Dataset
import numpy as np
import pandas as pd
from enum import Enum
import time
class Metric(Enum):
F1 = "f1"
ACCURACY = "accuracy"
class Env:
def __init__(self, metric: Metric = Metric.F1):
id2label = {0: "False", 1: "True"}
label2id = {"False": 0, "True": 1}
self.tokeniser = AutoTokenizer.from_pretrained("distilbert-base-uncased")
self.metric = evaluate.load(metric.value)
self.model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id)
self.data_collator = DataCollatorWithPadding(tokenizer=self.tokeniser)
self.training_args = TrainingArguments(
output_dir=f"model/test_model_{metric.value}_{int(time.time())}",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=False
)
def compute_metrics(self, eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return self.metric.compute(predictions=predictions, references=labels)
def initialise() -> Env:
pd.set_option('display.max_columns', None)
return Env()
def preprocess_data(path: str, env: Env) -> Dataset:
df = pd.read_csv(path, sep=",", escapechar="\\")
df['label'] = df['is_patronising'].apply(lambda x: 1 if x else 0)
dataset = Dataset.from_pandas(df[['label', 'text']])
return dataset.map(lambda d: env.tokeniser(str(d['text']), truncation=True), batched=False)
def train(env: Env) -> None:
trainer = Trainer(
model=env.model,
args=env.training_args,
train_dataset=preprocess_data("data/train.csv", env),
eval_dataset=preprocess_data("data/dev.csv", env),
tokenizer=env.tokeniser,
data_collator=env.data_collator,
compute_metrics=env.compute_metrics,
)
trainer.train()