Newer
Older
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, \
DataCollatorWithPadding
from datasets import Dataset
class Metric(Enum):
F1 = "f1"
ACCURACY = "accuracy"
class ModelType(Enum):
DISTILBERT = "distilbert-base-uncased"
DEBERTA = "microsoft/deberta-base"
def __init__(self, metric: Metric = Metric.F1, model_name: ModelType = ModelType.DEBERTA):
self.device = "cuda:0" if torch.cuda.is_available() else ("mps" if torch.cuda.is_available() else "cpu")
id2label = {0: "False", 1: "True"}
label2id = {"False": 0, "True": 1}
self.metric = evaluate.load(metric.value)
self.model = AutoModelForSequenceClassification.from_pretrained(
self.data_collator = DataCollatorWithPadding(tokenizer=self.tokeniser)
self.training_args = TrainingArguments(
output_dir=f"model/{model_name.value}_{metric.value}_{int(time.time())}",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
# use_mps_device=(self.device == "mps"),
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
self.model = self.model.to(self.device)
def compute_metrics(self, eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return self.metric.compute(predictions=predictions, references=labels)
def initialise() -> Env:
pd.set_option('display.max_columns', None)
return Env()
def preprocess_data(path: str, env: Env, is_train: bool = False, upscale_factor: int = 7) -> Dataset:
df = pd.read_csv(path, sep=",", escapechar="\\")
df['label'] = df['is_patronising'].apply(lambda x: 1 if x else 0)
dataset = Dataset.from_pandas(df[['label', 'text']]).map(lambda d: env.tokeniser(str(d['text']), truncation=True), batched=False)
if is_train:
dataset_0 = dataset.filter(lambda x: x['label'] == 0)
dataset_1 = dataset.filter(lambda x: x['label'] == 1)
dataset = ds.concatenate_datasets([dataset_0] + [dataset_1] * upscale_factor)
return dataset
print(f"Train device = {env.device}")
trainer = Trainer(
model=env.model,
args=env.training_args,
eval_dataset=preprocess_data("data/dev.csv", env),
tokenizer=env.tokeniser,
data_collator=env.data_collator,
compute_metrics=env.compute_metrics,
)
trainer.train()