Skip to content
Snippets Groups Projects
Commit de805594 authored by mmzk1526's avatar mmzk1526
Browse files

Can toggle multiclass

parent 8fa60afa
No related branches found
No related tags found
No related merge requests found
......@@ -12,7 +12,6 @@ from evaluate import load
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, \
set_seed
Env = ForwardRef('Env')
......@@ -32,12 +31,12 @@ def compute_objective(metrics: Dict[str, float]) -> float:
return metrics["eval_f1"]
def initialise(configs: dict, model_type: ModelType = ModelType.DEBERTA) -> Env:
def initialise(configs: dict, model_type: ModelType = ModelType.DEBERTA, is_multiclass: bool = False) -> Env:
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
pd.set_option('display.max_columns', None)
set_seed(42)
return Env(configs=configs, model_type=model_type)
return Env(configs=configs, model_type=model_type, is_multiclass=is_multiclass)
def label_to_pcl(label: int) -> int:
......@@ -68,23 +67,24 @@ def preprocess_test_data(path: str, env: Env) -> Dataset:
def preprocess_train_data(path: str, env: Env, upsample_factor: int = 1) -> tuple[Dataset, Dataset]:
df = pd.read_csv(path, sep=",", escapechar="\\")
# df['pcl'] = df['is_patronising'].apply(lambda x: 1 if x else 0)
if not env.is_multiclass:
df['label'] = df['is_patronising'].apply(lambda x: 1 if x else 0)
dataset = Dataset.from_pandas(
df[['text', 'label']]).map(lambda d: env.tokeniser(str(d['text']), truncation=True), batched=False)
dataset_0_split = dataset.filter(
lambda x: not label_to_pcl(x['label'])).train_test_split(test_size=0.2)
dataset_1_split = dataset.filter(
lambda x: label_to_pcl(x['label'])).train_test_split(test_size=0.2)
dataset_split = dataset.train_test_split(test_size=0.2)
dataset_train = dataset_split['train']
dataset_test = dataset_split['test']
dataset_train_0 = dataset_train.filter(
lambda x: not label_to_pcl(x['label']))
dataset_train_1 = dataset_train.filter(
lambda x: label_to_pcl(x['label']))
dataset_train = concatenate_datasets(
[dataset_0_split['train']] + [dataset_1_split['train']] * upsample_factor)
dataset_test = concatenate_datasets(
[dataset_0_split['test'], dataset_1_split['test']])
[dataset_train_0] + [dataset_train_1] * upsample_factor)
dataset_train = dataset_train.map(lambda x: {'length': len(x['input_ids'])}, batched=False)
dataset_test = dataset_test.map(lambda x: {'length': len(x['input_ids'])}, batched=False)
return dataset_train, dataset_test
# Environment classes
class TuningToggle:
def __init__(self, env: Env):
......@@ -119,14 +119,11 @@ class TuningToggle:
class Env:
def __init__(self, configs: dict, model_type: ModelType, metric: Metric = Metric.F1):
# id2label = {0: "False", 1: "True"}
# label2id = {"False": 0, "True": 1}
def __init__(self, configs: dict, model_type: ModelType, is_multiclass: bool, metric: Metric = Metric.F1):
self.configs = configs
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.tokeniser = AutoTokenizer.from_pretrained(model_type.value)
self.is_multiclass = is_multiclass
self.data_collator = DataCollatorWithPadding(tokenizer=self.tokeniser)
self.metric = load(metric.value)
......@@ -134,7 +131,7 @@ class Env:
self.model_name = model_type.value
self.model = AutoModelForSequenceClassification.from_pretrained(
model_type.value, num_labels=5)
model_type.value, num_labels=5 if self.is_multiclass else 2)
self.training_args = TrainingArguments(
output_dir=f"model/{self.model_name}_{self.metric_name}_{int(time.time())}",
......
......@@ -33,6 +33,6 @@ if __name__ == '__main__':
# Load initial hyperparameters
with open("configs/hyperparams_optim_yitang_3.json", mode="r") as f:
configs: dict[str, any] = json.load(f)
env: Env = env.initialise(configs=configs, model_type=model_type)
env: Env = env.initialise(configs=configs, model_type=model_type, is_multiclass=True)
train(env)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment