Skip to content
Snippets Groups Projects
Commit fc85aa0a authored by mm2320's avatar mm2320
Browse files

Implement hyperparameter tuning

parent 197a902f
No related branches found
No related tags found
No related merge requests found
from typing import Dict, ForwardRef
import evaluate
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, \
DataCollatorWithPadding
......@@ -10,6 +12,8 @@ import os
import time
import torch
Env = ForwardRef('Env')
class Metric(Enum):
F1 = "f1"
......@@ -21,18 +25,45 @@ class ModelType(Enum):
DEBERTA = "microsoft/deberta-base"
class TuningToggle:
def __init__(self, env: Env):
self.env = env
self.old_training_args = self.env.training_args
def __enter__(self):
self.env.training_args = self.training_args = TrainingArguments(
output_dir="./model",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=1,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="no",
load_best_model_at_end=False,
push_to_hub=False,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
)
def __exit__(self, *args):
self.env.training_args = self.old_training_args
class Env:
def __init__(self, metric: Metric = Metric.F1, model_name: ModelType = ModelType.DEBERTA):
self.device = "cuda:0" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
def __init__(self, model_name: ModelType, metric: Metric = Metric.F1):
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
id2label = {0: "False", 1: "True"}
label2id = {"False": 0, "True": 1}
self.tokeniser = AutoTokenizer.from_pretrained(model_name.value)
self.metric = evaluate.load(metric.value)
self.model_name = model_name.value
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name.value, num_labels=2, id2label=id2label, label2id=label2id)
self.data_collator = DataCollatorWithPadding(tokenizer=self.tokeniser)
self.training_args = TrainingArguments(
output_dir=f"model/{model_name.value}_{metric.value}_{int(time.time())}",
learning_rate=2e-5,
......@@ -50,28 +81,55 @@ class Env:
self.model = self.model.to(self.device)
self.tuning_toggle = TuningToggle(env=self)
def model_init(self, trial):
return AutoModelForSequenceClassification.from_pretrained(
self.model_name
)
def compute_metrics(self, eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return self.metric.compute(predictions=predictions, references=labels)
def compute_objective(self, metrics: Dict[str, float]) -> float:
return metrics["eval_f1"]
def optuna_hp_space(self, trial: any) -> dict[str, any]:
return {
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
"per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [16, 32, 64]),
"weight_decay": trial.suggest_float("weight_decay", 0.005, 0.05, log=True)
}
def initialise() -> Env:
def initialise(model_name: ModelType = ModelType.DEBERTA) -> Env:
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
pd.set_option('display.max_columns', None)
return Env()
return Env(model_name=model_name)
def preprocess_data(path: str, env: Env, is_train: bool = False, upscale_factor: int = 7) -> Dataset:
def preprocess_train_data(path: str, env: Env, upscale_factor: int = 7) -> tuple[Dataset, Dataset]:
df = pd.read_csv(path, sep=",", escapechar="\\")
df['label'] = df['is_patronising'].apply(lambda x: 1 if x else 0)
dataset = Dataset.from_pandas(df[['label', 'text']]).map(lambda d: env.tokeniser(str(d['text']), truncation=True), batched=False)
dataset = Dataset.from_pandas(df[['label', 'text']]).map(lambda d: env.tokeniser(str(d['text']), truncation=True),
batched=False)
dataset_0_split = dataset.filter(lambda x: x['label'] == 0).train_test_split(test_size=0.2)
dataset_1_split = dataset.filter(lambda x: x['label'] == 1).train_test_split(test_size=0.2)
dataset_train = ds.concatenate_datasets([dataset_0_split['train']] + [dataset_1_split['train']] * upscale_factor)
dataset_test = ds.concatenate_datasets([dataset_0_split['test'], dataset_1_split['test']])
if is_train:
dataset_0 = dataset.filter(lambda x: x['label'] == 0)
dataset_1 = dataset.filter(lambda x: x['label'] == 1)
dataset = ds.concatenate_datasets([dataset_0] + [dataset_1] * upscale_factor)
return dataset_train, dataset_test
def preprocess_test_data(path: str, env: Env) -> Dataset:
df = pd.read_csv(path, sep=",", escapechar="\\")
df['label'] = df['is_patronising'].apply(lambda x: 1 if x else 0)
dataset = Dataset.from_pandas(df[['label', 'text']]).map(lambda d: env.tokeniser(str(d['text']), truncation=True),
batched=False)
return dataset
......@@ -79,11 +137,14 @@ def preprocess_data(path: str, env: Env, is_train: bool = False, upscale_factor:
def train(env: Env) -> None:
print(f"Train device = {env.device}")
train_data, test_data = preprocess_train_data("data/train.csv", env, upscale_factor=7)
validate_data = preprocess_test_data("data/dev.csv", env)
trainer = Trainer(
model=env.model,
args=env.training_args,
train_dataset=preprocess_data("data/train.csv", env, is_train=True),
eval_dataset=preprocess_data("data/dev.csv", env),
train_dataset=train_data,
eval_dataset=test_data,
tokenizer=env.tokeniser,
data_collator=env.data_collator,
compute_metrics=env.compute_metrics,
......@@ -91,6 +152,38 @@ def train(env: Env) -> None:
trainer.train()
# Hyperparameter training
def get_best_hyperparams(env: Env) -> dict[str: any]:
print(f"Train device = {env.device}")
train_data, test_data = preprocess_train_data("data/train.csv", env, upscale_factor=7)
# validate_data = preprocess_test_data("data/dev.csv", env)
with env.tuning_toggle:
trainer = Trainer(
model=None,
args=env.training_args,
train_dataset=train_data,
eval_dataset=test_data,
tokenizer=env.tokeniser,
data_collator=env.data_collator,
compute_metrics=env.compute_metrics,
model_init=env.model_init
)
best_run = trainer.hyperparameter_search(
direction=["maximize"],
backend="optuna",
hp_space=env.optuna_hp_space,
n_trials=20,
compute_objective=env.compute_objective,
)
# With single objective, best_run should be a single BestRun object
return best_run.hyperparameters
if __name__ == '__main__':
env: Env = initialise()
train(env)
# train(env)
print(get_best_hyperparams(env))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment