diff --git a/models/conditional_gpt2_model.py b/models/conditional_gpt2_model.py new file mode 100644 index 0000000000000000000000000000000000000000..848d1cb8eb98eb5e3950d839ff46ce65afb8319f --- /dev/null +++ b/models/conditional_gpt2_model.py @@ -0,0 +1,228 @@ +import os, sys +os.chdir('/data-imperial') +sys.path.append(os.path.abspath('/data-imperial')) +from models.gpt2_model import Manager +from transformers import GPT2Tokenizer, GPT2LMHeadModel +from tqdm import tqdm +from helpers.custom_data import * +from torch.utils.data import DataLoader + +import torch +import argparse +import copy + +class Conditional_GPT2(Manager): + + ''' + Conditional_GPT2 class uses ConditionalGPT2 model during training and inference time. + + - For training, use of the train() function + - For inference, use of the inference() function + + For interence, 4 response generation strategies are implemented: + - greedy_approach + - beam_search + - top_k_sampling + - nucleus_sampling + ''' + + def __init__(self, mode, ckpt_name=None, lr = 5e-4, decoding=None, age=None, gender=None, topic=None): + ''' + Inputs: + - mode : train or inference + - ckpt_name : the name of the checkpoint file to load the model + - lr : the learning rate (by default 5e-4) + - decoding : the decoding strategy for response generation (greedy, beam search, top_k or nucleus sampling) + - age : 'under 18' or 'over 18' + - gender : 'male', 'female' or 'other' + - topic : conversation topic (suicide, anxiety, etc.) + ''' + + self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + # Tokenizer & Vocab + self.tokenizer = GPT2Tokenizer.from_pretrained(f"data/gpt-2") + self.special_tokens = { + 'bos_token': "<bos>", + 'eos_token': "<eos>", + 'pad_token': "<pad>", + 'additional_special_tokens': ["[texter]", "[volunteer]"] + } + self.num_new_tokens = self.tokenizer.add_special_tokens(self.special_tokens) + self.vocab = self.tokenizer.get_vocab() + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab["<bos>"] + self.eos_id = self.vocab["<eos>"] + self.pad_id = self.vocab["<pad>"] + self.speaker1_id = self.vocab["[texter]"] + self.speaker2_id = self.vocab["[volunteer]"] + + self.decoding = 'nucleus' if decoding is None else decoding + self.age = 'under 18' if age is None else age + self.gender = 'female' if gender is None else gender + self.topic = 'depressed' if topic is None else topic + assert self.gender in ['male', 'female', 'other'], "Please check the gender." + assert self.age in ['over 18', 'under 18'], "Please check the age" + assert self.topic in ['other', 'isolated', 'self_harm', 'anxiety', 'suicide', + 'bereavement', 'relationship', 'gender', 'bully', 'depressed', + 'eating', 'none', 'abuse_sexual', '3rd_party', 'abuse_physical', + 'abuse_emotional', 'covid_19', 'prank', 'substance', + 'abuse_unspecified', 'abuse_domestic', 'testing', 'abuse_child', + 'abuse_other'], "Please check the topic" + + # Represent the persona part of the inputs + self.cond_labels = f'I am a {self.gender}. I am {self.age}. I want to talk about {self.topic}.' + print('Decoding: ', self.decoding) + print('Persona profile: ', self.cond_labels) + + # Represent the number of text messages considered (Number of previous messages considered + the reply) + self.max_time = 2 + + # The number of tokens to be fed to the model + self.max_len = 512 + + # The maximum length of a text message + self.utter_len = (self.max_len-self.max_time-2) // self.max_time + + # Load model + print("Loading the model...", flush=True) + self.model = GPT2LMHeadModel.from_pretrained('gpt2').to(self.device) + self.model.resize_token_embeddings(len(self.tokenizer)) + + if mode == 'train': + # Load optimizer + print("Loading the optimizer...", flush=True) + self.optim = torch.optim.AdamW(self.model.parameters(), lr=lr) + self.best_loss = sys.float_info.max + + # Load train & valid dataset + print("Loading train & valid data...", flush=True) + train_set = ConditionalDataset('train') + valid_set = ConditionalDataset('valid') + + batch_size = 4 + self.train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size) + self.valid_loader = DataLoader(valid_set, shuffle=True, batch_size=batch_size) + + if not os.path.exists("saved_models"): + os.mkdir("saved_models") + + if ckpt_name is not None: + if os.path.exists(f"{'saved_models'}/{ckpt_name}.tar"): + print("Loading the trained checkpoint...", flush=True) + checkpoint = torch.load(f"{'saved_models'}/{ckpt_name}.tar") + self.model.load_state_dict(checkpoint['model_state_dict']) + + if mode == 'train': + print("The training restarts with the specifed checkpoint.", flush=True) + self.optim.load_state_dict(checkpoint['optim_state_dict']) + self.best_loss = checkpoint['loss'] + self.ckpt_name = ckpt_name + else: + assert mode == 'train', "Please check if the checkpoint name exists." + + print(f"The checkpoint named '{ckpt_name}' does not exist. This becomes the best checkpoint name from now on.", flush=True) + self.ckpt_name = ckpt_name + else: + print("You did not specify the checkpoint name.", flush=True) + print(f"The default name '{'best_ckpt'}' is set.", flush=True) + self.ckpt_name = "best_ckpt" + + print("Setting finished.") + + def inference(self): + print("Let's start!") + print(f"If you want to quit the conversation, please type Abort!") + self.model.eval() + + with torch.no_grad(): + cur_speaker = 2 + input_ids_list = [] + token_type_ids_list = [] + t = 0 + output_id = None + + while True: + if t == 0: + cond_labels = [self.bos_id] + self.tokenizer.encode(self.cond_labels) + cond_len = len(cond_labels) + token_type_id = [self.speaker1_id] * cond_len + input_ids_list.append(cond_labels) + token_type_ids_list.append(token_type_id) + + if cur_speaker == 2: + cur_speaker_id = self.speaker2_id + utter = input("You: ") + + if utter == "Abort!": + print("Bot: Good bye.") + break + + input_id = [cur_speaker_id] + self.tokenizer.encode(utter) + + else: + cur_speaker_id = self.speaker1_id + input_id = copy.deepcopy(output_id) + + token_type_id = [cur_speaker_id] * len(input_id) + + if input_id[-1] == self.eos_id: + input_id = input_id[:-1] + token_type_id = token_type_id[:-1] + + input_ids_list.append(input_id) + token_type_ids_list.append(token_type_id) + + if t >= self.max_time: + input_ids_list = input_ids_list[:1] + input_ids_list[2:] + token_type_ids_list = token_type_ids_list[:1] + token_type_ids_list[2:] + + next_speaker = (cur_speaker % 2) + 1 + if next_speaker == 1: + next_speaker_id = self.speaker1_id + else: + next_speaker_id = self.speaker2_id + if cur_speaker == 2: + if self.decoding == 'nucleus': + output_id = self.nucleus_sampling(input_ids_list, token_type_ids_list, next_speaker_id) + elif self.decoding == 'greedy': + output_id = self.greedy_approach(input_ids_list, token_type_ids_list, next_speaker_id) + elif self.decoding == 'top_k': + output_id = self.top_k_sampling(input_ids_list, token_type_ids_list, next_speaker_id) + elif self.decoding == 'beam': + output_id = self.beam_search(input_ids_list, token_type_ids_list, next_speaker_id) + else: + raise ValueError('No decoding strategy') + res = self.tokenizer.decode(output_id, skip_special_tokens=True) + + print(f"Bot: {res}") + + cur_speaker = copy.deepcopy(next_speaker) + t += 1 + + +if __name__=='__main__': + parser = argparse.ArgumentParser() + #parser.add_argument('--config_path', required=True, default='config.json', help="The path to configuration file.") + parser.add_argument('--mode', required=True, help="Train or inference?") + parser.add_argument('--ckpt_name', required=False, help="Best checkpoint file.") + parser.add_argument('--decoding', required=False, help="Decoding strategy (nucleus, greedy, top_k, beam)") + parser.add_argument('--gender', required=False, help='male, female, other') + parser.add_argument('--age', required=False, help='over 18, under 18') + parser.add_argument('--topic', required=False, help="suicide, anxiety, depressed, relationship, isolated, self-harm, etc.") + + args = parser.parse_args() + + assert args.mode == 'train' or args.mode=='inference', print("Please specify a correct mode name, 'train' or 'inference'.") + + if args.mode == 'train': + manager = Conditional_GPT2(args.mode, ckpt_name=args.ckpt_name) + + manager.train() + + elif args.mode == 'inference': + #assert args.ckpt_name is not None, "Please specify the trained model checkpoint." + + manager = Conditional_GPT2(args.mode, ckpt_name=args.ckpt_name, decoding=args.decoding, age=args.age, gender=args.gender, topic=args.topic) + + manager.inference() \ No newline at end of file diff --git a/models/gpt2_model.py b/models/gpt2_model.py new file mode 100644 index 0000000000000000000000000000000000000000..13dee85dd2c029201d7b33c98d9c86ded851c7f0 --- /dev/null +++ b/models/gpt2_model.py @@ -0,0 +1,443 @@ +import os, sys +os.chdir('/data-imperial') +sys.path.append(os.path.abspath('/data-imperial')) +from transformers import GPT2Tokenizer, GPT2LMHeadModel +from helpers.custom_data import * +from tqdm import tqdm +from torch.utils.data import DataLoader +from torch.nn import functional as F +from itertools import chain + +import torch +import numpy as np +import argparse +import copy + +class Manager(): + ''' + Manager class uses GPT2LMHeadModel during training and inference time. + + - For training, use of the train() function + - For inference, use of the inference() function + + For interence, 4 response generation strategies are implemented: + - greedy_approach + - beam_search + - top_k_sampling + - nucleus_sampling + ''' + + def __init__(self, mode, ckpt_name=None, lr = 5e-4, decoding=None, max_time=2): + ''' + Inputs: + - mode : train or inference + - ckpt_name : the name of the checkpoint file to load the model + - lr : the learning rate (by default 5e-4) + - decoding : the decoding strategy for response generation (greedy, beam search, top_k or nucleus sampling) + - max_time : the number of text messages we consider for the history and the replies + (Number of previous messages considered + the reply) + ''' + self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + # Tokenizer & Vocab + self.tokenizer = GPT2Tokenizer.from_pretrained(f"data/gpt-2") + self.special_tokens = { + 'bos_token': "<bos>", + 'eos_token': "<eos>", + 'pad_token': "<pad>", + 'additional_special_tokens': ["[texter]", "[volunteer]"] + } + self.num_new_tokens = self.tokenizer.add_special_tokens(self.special_tokens) + self.vocab = self.tokenizer.get_vocab() + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab["<bos>"] + self.eos_id = self.vocab["<eos>"] + self.pad_id = self.vocab["<pad>"] + self.speaker1_id = self.vocab["[texter]"] + self.speaker2_id = self.vocab["[volunteer]"] + self.decoding = 'nucleus' if decoding is None else decoding + print('Decoding strategy: ', self.decoding) + + # The number of text messages considered in the history and the reply + self.max_time = 2 if max_time is None else int(max_time) + print('Max time = ', self.max_time) + # Represents the maximum number of tokens fed into the model + self.max_len = 512 #1024 + + # The maximum length of a text message + self.utter_len = (self.max_len-self.max_time-2) // self.max_time + + # Load model + print("Loading the model...", flush=True) + self.model = GPT2LMHeadModel.from_pretrained('gpt2').to(self.device) + self.model.resize_token_embeddings(len(self.tokenizer)) + + if mode == 'train': + # Load optimizer + print("Loading the optimizer...", flush=True) + self.optim = torch.optim.AdamW(self.model.parameters(), lr=lr) + self.best_loss = sys.float_info.max + + # Load train & valid dataset + print("Loading train & valid data...", flush=True) + train_set = CustomDataset('train') + valid_set = CustomDataset('valid') + + batch_size = 4 + self.train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size) + self.valid_loader = DataLoader(valid_set, shuffle=True, batch_size=batch_size) + + if not os.path.exists("saved_models"): + os.mkdir("saved_models") + + if ckpt_name is not None: + if os.path.exists(f"{'saved_models'}/{ckpt_name}.tar"): + print("Loading the trained checkpoint...", flush=True) + checkpoint = torch.load(f"{'saved_models'}/{ckpt_name}.tar") + self.model.load_state_dict(checkpoint['model_state_dict']) + + if mode == 'train': + print("The training restarts with the specifed checkpoint.", flush=True) + self.optim.load_state_dict(checkpoint['optim_state_dict']) + self.best_loss = checkpoint['loss'] + self.ckpt_name = ckpt_name + else: + assert mode == 'train', "Please check if the checkpoint name exists." + + print(f"The checkpoint named '{ckpt_name}' does not exist. This becomes the best checkpoint name from now on.", flush=True) + self.ckpt_name = ckpt_name + else: + print("You did not specify the checkpoint name.", flush=True) + print(f"The default name '{'best_ckpt'}' is set.", flush=True) + self.ckpt_name = "best_ckpt" + + print("Setting finished.", flush=True) + + def train(self): + print("Training starts.", flush=True) + + max_epochs = 5 #10 + for epoch in range(1, max_epochs+1): + self.model.train() + + print(f"#################### Epoch: {epoch} ####################", flush=True) + train_losses = [] + train_ppls = [] + for i, batch in enumerate(tqdm(self.train_loader)): + input_ids, token_type_ids, lm_labels = batch + input_ids, token_type_ids, lm_labels = \ + input_ids.to(self.device), token_type_ids.to(self.device), lm_labels.to(self.device) + + outputs = self.model( + input_ids=input_ids, + token_type_ids = token_type_ids, + labels = lm_labels + ) + + loss, logits = outputs[0], outputs[1] + + self.optim.zero_grad() + loss.backward() + self.optim.step() + + train_losses.append(loss.item()) + train_ppls.append(torch.exp(loss).item()) + + train_loss = np.mean(train_losses) + train_ppl = np.mean(train_ppls) + print(f"Train loss: {train_loss} || Train perplexity: {train_ppl}", flush=True) + + valid_loss, valid_ppl = self.validation() + + if valid_loss < self.best_loss: + self.best_loss = valid_loss + state_dict = { + 'model_state_dict': self.model.state_dict(), + 'optim_state_dict': self.optim.state_dict(), + 'loss': self.best_loss, + } + + torch.save(state_dict, f"{'saved_models'}/{self.ckpt_name}.tar") + print(f"***** Current best checkpoint is saved. *****", flush=True) + + print(f"Best valid loss: {self.best_loss}", flush=True) + print(f"Valid loss: {valid_loss} || Valid perplexity: {valid_ppl}", flush=True) + + print("Training finished!") + + def validation(self): + print("Validation processing...", flush=True) + self.model.eval() + + valid_losses = [] + valid_ppls = [] + with torch.no_grad(): + for i, batch in enumerate(tqdm(self.valid_loader)): + input_ids, token_type_ids, lm_labels = batch + input_ids, token_type_ids, lm_labels = \ + input_ids.to(self.device), token_type_ids.to(self.device), lm_labels.to(self.device) + + outputs = self.model( + input_ids=input_ids, + token_type_ids = token_type_ids, + labels = lm_labels + ) + + loss, logits = outputs[0], outputs[1] + + valid_losses.append(loss.item()) + valid_ppls.append(torch.exp(loss).item()) + + valid_loss = np.mean(valid_losses) + valid_ppl = np.mean(valid_ppls) + + return valid_loss, valid_ppl + + + def inference(self): + print("Let's start!") + print(f"If you want to quit the conversation, please type Abort!") + self.model.eval() + + with torch.no_grad(): + cur_speaker = 2 + input_ids_list = [] + token_type_ids_list = [] + t = 0 + output_id = None + + while True: + if cur_speaker == 2: + cur_speaker_id = self.speaker2_id + utter = input("You: ") + + if utter == "Abort!": + print("Bot: Good bye.") + break + + input_id = [cur_speaker_id] + self.tokenizer.encode(utter) + + if t == 0: + input_id = [self.bos_id] + input_id + else: + cur_speaker_id = self.speaker1_id + input_id = copy.deepcopy(output_id) + + token_type_id = [cur_speaker_id] * len(input_id) + + if input_id[-1] == self.eos_id: + input_id = input_id[:-1] + token_type_id = token_type_id[:-1] + + input_ids_list.append(input_id) + token_type_ids_list.append(token_type_id) + + if t >= self.max_time: + input_ids_list = input_ids_list[1:] + token_type_ids_list = token_type_ids_list[1:] + + next_speaker = (cur_speaker % 2) + 1 + if next_speaker == 1: + next_speaker_id = self.speaker1_id + else: + next_speaker_id = self.speaker2_id + if cur_speaker == 2: + if self.decoding == 'nucleus': + output_id = self.nucleus_sampling(input_ids_list, token_type_ids_list, next_speaker_id) + elif self.decoding == 'greedy': + output_id = self.greedy_approach(input_ids_list, token_type_ids_list, next_speaker_id) + elif self.decoding == 'top_k': + output_id = self.top_k_sampling(input_ids_list, token_type_ids_list, next_speaker_id) + elif self.decoding == 'beam': + output_id = self.beam_search(input_ids_list, token_type_ids_list, next_speaker_id) + else: + raise ValueError('No decoding strategy') + res = self.tokenizer.decode(output_id, skip_special_tokens=True) + + print(f"Bot: {res}") + + cur_speaker = copy.deepcopy(next_speaker) + t += 1 + + def greedy_approach(self, input_ids_list, token_type_ids_list, next_speaker_id): + output_id = [] + res_id = [next_speaker_id] + res_type_id = [next_speaker_id] + for pos in range(self.utter_len): + input_ids = list(chain.from_iterable(input_ids_list)) + res_id + token_type_ids = list(chain.from_iterable(token_type_ids_list)) + res_type_id + input_len = len(input_ids) + + left = self.max_len - len(input_ids) + input_ids += [self.pad_id] * left + token_type_ids += [self.pad_id] * left + + assert len(input_ids) == len(token_type_ids), "There is something wrong in dialogue process." + + input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device) # (1, L) + token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device) # (1, L) + + output = self.model(input_ids=input_ids, token_type_ids=token_type_ids).logits[:, input_len-1] # (1, vocab_size) + + idx = torch.argmax(output) + + if len(output_id) == self.utter_len or idx == self.eos_id: + break + else: + output_id.append(idx) + res_id.append(idx) + res_type_id.append(next_speaker_id) + + return output_id + + + def beam_search(self, input_ids_list, token_type_ids_list, next_speaker_id, top_B = 2): + sequences = [[list(), 0.0]] + #output_id = [] + res_id = [next_speaker_id] + res_type_id = [next_speaker_id] + for pos in range(self.utter_len): + input_ids = list(chain.from_iterable(input_ids_list)) + res_id + token_type_ids = list(chain.from_iterable(token_type_ids_list)) + res_type_id + input_len = len(input_ids) + + left = self.max_len - len(input_ids) + input_ids += [self.pad_id] * left + token_type_ids += [self.pad_id] * left + + assert len(input_ids) == len(token_type_ids), "There is something wrong in dialogue process." + + input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device) # (1, L) + token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device) # (1, L) + + output = self.model(input_ids=input_ids, token_type_ids=token_type_ids).logits[:, input_len-1] # (1, vocab_size) + probs = F.softmax(output, dim=-1) + + all_candidates = list() + for i in range(len(sequences)): + seq, score = sequences[i] + for j in range(output.shape[1]): + candidate = [seq + [j], score - np.log(probs[0,j].item())] + all_candidates.append(candidate) + # order all candidates by score + ordered = sorted(all_candidates, key=lambda tup:tup[1]) + # select k best + sequences = ordered[:top_B] + # Select the second best option + output_id = sequences[1][0] + idx = output_id[-1] + if len(output_id) == self.utter_len or idx == self.eos_id: + break + else: + #output_id.append(idx) + res_id.append(idx) + res_type_id.append(next_speaker_id) + + return output_id + + + + def top_k_sampling(self, input_ids_list, token_type_ids_list, next_speaker_id): + output_id = [] + res_id = [next_speaker_id] + res_type_id = [next_speaker_id] + for pos in range(self.utter_len): + input_ids = list(chain.from_iterable(input_ids_list)) + res_id + token_type_ids = list(chain.from_iterable(token_type_ids_list)) + res_type_id + input_len = len(input_ids) + + left = self.max_len - len(input_ids) + input_ids += [self.pad_id] * left + token_type_ids += [self.pad_id] * left + + assert len(input_ids) == len(token_type_ids), "There is something wrong in dialogue process." + + input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device) # (1, L) + token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device) # (1, L) + + output = self.model(input_ids=input_ids, token_type_ids=token_type_ids).logits[:, input_len-1] # (1, vocab_size) + top_k = 10 + filter_value = -float('Inf') + idx_remove = output < torch.topk(output, top_k)[0][..., -1, None] + #print(idx_remove.shape) + output[idx_remove] = filter_value + + # Random sampling + probs = F.softmax(output, dim=-1) # (1, vocab_size) + idx = torch.multinomial(probs, 1).squeeze(-1).squeeze(0).item() + + if len(output_id) == self.utter_len or idx == self.eos_id: + break + else: + output_id.append(idx) + res_id.append(idx) + res_type_id.append(next_speaker_id) + + return output_id + + def nucleus_sampling(self, input_ids_list, token_type_ids_list, next_speaker_id): + output_id = [] + res_id = [next_speaker_id] + res_type_id = [next_speaker_id] + for pos in range(self.utter_len): + input_ids = list(chain.from_iterable(input_ids_list)) + res_id + token_type_ids = list(chain.from_iterable(token_type_ids_list)) + res_type_id + input_len = len(input_ids) + + left = self.max_len - len(input_ids) + input_ids += [self.pad_id] * left + token_type_ids += [self.pad_id] * left + + assert len(input_ids) == len(token_type_ids), "There is something wrong in dialogue process." + + input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device) # (1, L) + token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device) # (1, L) + + output = self.model(input_ids=input_ids, token_type_ids=token_type_ids).logits[:, input_len-1] # (1, vocab_size) + output = F.softmax(output, dim=-1) # (1, vocab_size) + + sorted_probs, sorted_idxs = torch.sort(output, descending=True) + cumsum_probs = torch.cumsum(sorted_probs, dim=-1) # (1, vocab_size) + nucleus_p = 0.9 + idx_remove = cumsum_probs > nucleus_p + sorted_probs[idx_remove] = 1e-8 + sorted_probs /= torch.sum(sorted_probs, dim=-1, keepdim=True) # (1, vocab_size) + + # Random sampling + probs = torch.zeros(output.shape).to(self.device).scatter_(-1, sorted_idxs, sorted_probs) # (1, vocab_size) + idx = torch.multinomial(probs, 1).squeeze(-1).squeeze(0).item() + + if len(output_id) == self.utter_len or idx == self.eos_id: + break + else: + output_id.append(idx) + res_id.append(idx) + res_type_id.append(next_speaker_id) + + return output_id + + +if __name__=='__main__': + parser = argparse.ArgumentParser() + #parser.add_argument('--config_path', required=True, default='config.json', help="The path to configuration file.") + parser.add_argument('--mode', required=True, help="Train or inference?") + parser.add_argument('--ckpt_name', required=False, help="Best checkpoint file.") + parser.add_argument('--decoding', required=False, help="Decoding strategy (nucleus, greedy, top_k, beam)") + parser.add_argument('--max_time', required=False, help="Number of text messages considered in the history and the reply") + + + args = parser.parse_args() + + assert args.mode == 'train' or args.mode=='inference', print("Please specify a correct mode name, 'train' or 'inference'.") + + if args.mode == 'train': + manager = Manager(args.mode, ckpt_name=args.ckpt_name) + + manager.train() + + elif args.mode == 'inference': + #assert args.ckpt_name is not None, "Please specify the trained model checkpoint." + + manager = Manager(args.mode, ckpt_name=args.ckpt_name, decoding=args.decoding, max_time=args.max_time) + + manager.inference() diff --git a/models/style_tokens_model/modules.py b/models/style_tokens_model/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..764759551c7240be93f8d1050bc242538d247986 --- /dev/null +++ b/models/style_tokens_model/modules.py @@ -0,0 +1,185 @@ +import torch +from typing import Optional, Tuple +from collections import OrderedDict +from dataclasses import fields +from typing import Any +import numpy as np + + +def is_tensor(x): + """ + Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor`, obj:`jaxlib.xla_extension.DeviceArray` or + :obj:`np.ndarray`. + """ + if isinstance(x, torch.Tensor): + return True + + return isinstance(x, np.ndarray) + + +class ModelOutput(OrderedDict): + """ + Base class for all model outputs as dataclass. Has a ``__getitem__`` that allows indexing by integer or slice (like + a tuple) or strings (like a dictionary) that will ignore the ``None`` attributes. Otherwise behaves like a regular + python dictionary. + .. warning:: + You can't unpack a :obj:`ModelOutput` directly. Use the :meth:`~transformers.file_utils.ModelOutput.to_tuple` + method to convert it to a tuple before. + """ + + def __post_init__(self): + class_fields = fields(self) + + # Safety and consistency checks + assert len(class_fields), f"{self.__class__.__name__} has no fields." + assert all( + field.default is None for field in class_fields[1:] + ), f"{self.__class__.__name__} should not have more than one required field." + + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) + + if other_fields_are_none and not is_tensor(first_field): + if isinstance(first_field, dict): + iterator = first_field.items() + first_field_iterator = True + else: + try: + iterator = iter(first_field) + first_field_iterator = True + except TypeError: + first_field_iterator = False + + # if we provided an iterator as first field and the iterator is a (key, value) iterator + # set the associated fields + if first_field_iterator: + for element in iterator: + if ( + not isinstance(element, (list, tuple)) + or not len(element) == 2 + or not isinstance(element[0], str) + ): + break + setattr(self, element[0], element[1]) + if element[1] is not None: + self[element[0]] = element[1] + elif first_field is not None: + self[class_fields[0].name] = first_field + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __getitem__(self, k): + if isinstance(k, str): + inner_dict = {k: v for (k, v) in self.items()} + return inner_dict[k] + else: + return self.to_tuple()[k] + + def __setattr__(self, name, value): + if name in self.keys() and value is not None: + # Don't call self.__setitem__ to avoid recursion errors + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + # Will raise a KeyException if needed + super().__setitem__(key, value) + # Don't call self.__setattr__ to avoid recursion errors + super().__setattr__(key, value) + + def to_tuple(self) -> Tuple[Any]: + """ + Convert self to a tuple containing all the attributes/keys that are not ``None``. + """ + return tuple(self[k] for k in self.keys()) + + +class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + Args: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, + 1, hidden_size)` is output. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + Tuple of :obj:`tuple(torch.FloatTensor)` of length :obj:`config.n_layers`, with each tuple having 2 tensors + of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + ``config.is_encoder_decoder=True`` 2 additional tensors of shape :obj:`(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + ``config.is_encoder_decoder=True`` in the cross-attention blocks) that can be used (see + :obj:`past_key_values` input) to speed up sequential decoding. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + +class CausalLMOutputWithCrossAttentions(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Language modeling loss (for next-token prediction). + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + Cross attentions weights after the attention softmax, used to compute the weighted average in the + cross-attention heads. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + Tuple of :obj:`torch.FloatTensor` tuples of length :obj:`config.n_layers`, with each tuple containing the + cached key, value states of the self-attention and the cross-attention layers if model is used in + encoder-decoder setting. Only relevant if ``config.is_decoder = True``. + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + :obj:`past_key_values` input) to speed up sequential decoding. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None \ No newline at end of file diff --git a/models/style_tokens_model/run_styleGPT2.py b/models/style_tokens_model/run_styleGPT2.py new file mode 100644 index 0000000000000000000000000000000000000000..4b353a45d37398b5fdc2f83a358b6fc2ad6faf06 --- /dev/null +++ b/models/style_tokens_model/run_styleGPT2.py @@ -0,0 +1,463 @@ +from style_token_layer import * +from styleGPT2_model import * + +from transformers import GPT2Tokenizer +from tqdm import tqdm +from torch.utils.data import DataLoader + +import torch +import os, sys +import numpy as np +import argparse +warnings.filterwarnings('ignore') + +os.chdir('/data-imperial') +import sys +sys.path.append(os.path.abspath('/data-imperial')) +from models.gpt2_model import * +from helpers.custom_data import * + +import torch +torch.cuda.empty_cache() + +# Global Style Token model +class GST_model(Manager): + ''' + GST_model class uses StyleGPT2 model during training and inference time. + + - For training, use of the train() function + - For inference, use of the inference() function + + For interence, 4 response generation strategies are implemented: + - greedy_approach + - beam_search + - top_k_sampling + - nucleus_sampling + ''' + + def __init__(self, mode, ckpt_name=None, lr = 5e-4, decoding=None, style_label=None): + ''' + Inputs: + - mode : train or inference + - ckpt_name : the name of the checkpoint file to load the model + - lr : the learning rate (by default 5e-4) + - decoding : the decoding strategy for response generation (greedy, beam search, top_k or nucleus sampling) + - style_label : the index of the style token we consider (1, 3 or 4) + ''' + + self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + # Tokenizer & Vocab + self.tokenizer = GPT2Tokenizer.from_pretrained(f"data/gpt-2") + self.vocab = self.tokenizer.get_vocab() + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab["<bos>"] + self.eos_id = self.vocab["<eos>"] + self.pad_id = self.vocab["<pad>"] + self.speaker1_id = self.vocab["[texter]"] + self.speaker2_id = self.vocab["[volunteer]"] + + # Represent the number of text messages considered (Number of previous messages considered + the reply) + self.max_time = 2 + + # The number of tokens to be fed to the model + self.max_len = 512 + + # The maximum length of a text message + self.utter_len = (self.max_len-self.max_time-2) // self.max_time + + # Decoding strategy + self.decoding = 'nucleus' if decoding is None else decoding + print('Decoding: ', self.decoding) + + # Index of the style token we consider + self.style_label = 1 if style_label is None else int(style_label) + assert self.style_label in [1,3,4], "Please check the style label" + style = {1 : '14-17 years old, anxiety', 3: '25-34 years old, depressed', 4: '18-21 years old, anxiety'} + print('Style token profile: ', style[self.style_label]) + + # Load model + print("Loading the StyleGPT2 model...", flush=True) + # Add style token layer + self.style_token_layer = GST().to(self.device) + # Add the gpt2 part + self.transformer = CustomGPT2LMHeadModel.from_pretrained('gpt2').to(self.device) + self.transformer.resize_token_embeddings(len(self.tokenizer)) + + self.model = StyleGPT2(self.transformer, self.style_token_layer).to(self.device) + + if mode == 'train': + # Load optimizer + print("Loading the optimizer...", flush=True) + self.optim = torch.optim.AdamW(self.model.parameters(), lr=lr) + self.best_loss = sys.float_info.max + + # Load train & valid dataset from StyleDataset + print("Loading train & valid data...", flush=True) + train_set = StyleDataset('train') + valid_set = StyleDataset('valid') + + batch_size = 3 + self.train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size) + self.valid_loader = DataLoader(valid_set, shuffle=True, batch_size=batch_size) + + if not os.path.exists("saved_models"): + os.mkdir("saved_models") + + if ckpt_name is not None: + if os.path.exists(f"{'saved_models'}/{ckpt_name}.tar"): + print("Loading the trained checkpoint...", flush=True) + checkpoint = torch.load(f"{'saved_models'}/{ckpt_name}.tar") + self.model.load_state_dict(checkpoint['model_state_dict']) + + if mode == 'train': + print("The training restarts with the specifed checkpoint.", flush=True) + self.optim.load_state_dict(checkpoint['optim_state_dict']) + self.best_loss = checkpoint['loss'] + self.ckpt_name = ckpt_name + else: + assert mode == 'train', "Please check if the checkpoint name exists." + + print(f"The checkpoint named '{ckpt_name}' does not exist. This becomes the best checkpoint name from now on.", flush=True) + self.ckpt_name = ckpt_name + else: + print("You did not specify the checkpoint name.", flush=True) + print(f"The default name '{'best_ckpt'}' is set.", flush=True) + self.ckpt_name = "best_ckpt" + + print("Setting finished.", flush=True) + + def train(self): + print("Training starts.", flush=True) + + max_epochs = 5 #10 + for epoch in range(1, max_epochs+1): + self.model.train() + + print(f"#################### Epoch: {epoch} ####################", flush=True) + train_losses = [] + train_ppls = [] + for i, batch in enumerate(tqdm(self.train_loader)): + input_ids, token_type_ids, lm_labels, style_tokens = batch + input_ids, token_type_ids, lm_labels, style_tokens = \ + input_ids.to(self.device), token_type_ids.to(self.device), lm_labels.to(self.device), style_tokens.to(self.device) + + outputs = self.model( + input_ids=input_ids, + token_type_ids = token_type_ids, + labels = lm_labels, + style_tokens = style_tokens + ) + + loss, logits = outputs[0], outputs[1] + + self.optim.zero_grad() + loss.backward() + self.optim.step() + + train_losses.append(loss.item()) + train_ppls.append(torch.exp(loss).item()) + + train_loss = np.mean(train_losses) + train_ppl = np.mean(train_ppls) + print(f"Train loss: {train_loss} || Train perplexity: {train_ppl}", flush=True) + + valid_loss, valid_ppl = self.validation() + + if valid_loss < self.best_loss: + self.best_loss = valid_loss + + state_dict = { + 'model_state_dict': self.model.state_dict(), + 'optim_state_dict': self.optim.state_dict(), + 'loss': self.best_loss, + } + + torch.save(state_dict, f"{'saved_models'}/{self.ckpt_name}.tar") + print(f"***** Current best checkpoint is saved. *****", flush=True) + + if not os.path.exists(f"{'saved_models'}/{self.ckpt_name}"): + os.mkdir(f"{'saved_models'}/{self.ckpt_name}") + + print(f"Best valid loss: {self.best_loss}", flush=True) + print(f"Valid loss: {valid_loss} || Valid perplexity: {valid_ppl}", flush=True) + + print("Training finished!") + + def validation(self): + print("Validation processing...", flush=True) + self.model.eval() + + valid_losses = [] + valid_ppls = [] + with torch.no_grad(): + for i, batch in enumerate(tqdm(self.valid_loader)): + input_ids, token_type_ids, lm_labels, style_tokens = batch + input_ids, token_type_ids, lm_labels, style_tokens = \ + input_ids.to(self.device), token_type_ids.to(self.device), lm_labels.to(self.device), style_tokens.to(self.device) + + outputs = self.model( + input_ids=input_ids, + token_type_ids = token_type_ids, + labels = lm_labels, + style_tokens = style_tokens + ) + + loss, logits = outputs[0], outputs[1] + + valid_losses.append(loss.item()) + valid_ppls.append(torch.exp(loss).item()) + + valid_loss = np.mean(valid_losses) + valid_ppl = np.mean(valid_ppls) + + return valid_loss, valid_ppl + + def inference(self): + print("Let's start!") + print(f"If you want to quit the conversation, please type Abort!") + self.model.eval() + + with torch.no_grad(): + cur_speaker = 2 + input_ids_list = [] + token_type_ids_list = [] + t = 0 + output_id = None + + style_tokens = torch.load(f"models/style_tokens_model/style_tokens.torch")[self.style_label]#.to(device) + + while True: + if cur_speaker == 2: + cur_speaker_id = self.speaker2_id + utter = input("You: ") + + if utter == "Abort!": + print("Bot: Good bye.") + break + + input_id = [cur_speaker_id] + self.tokenizer.encode(utter) + + if t == 0: + input_id = [self.bos_id] + input_id + else: + cur_speaker_id = self.speaker1_id + input_id = copy.deepcopy(output_id) + + token_type_id = [cur_speaker_id] * len(input_id) + + if input_id[-1] == self.eos_id: + input_id = input_id[:-1] + token_type_id = token_type_id[:-1] + + input_ids_list.append(input_id) + token_type_ids_list.append(token_type_id) + + if t >= self.max_time: + input_ids_list = input_ids_list[1:] + token_type_ids_list = token_type_ids_list[1:] + + next_speaker = (cur_speaker % 2) + 1 + if next_speaker == 1: + next_speaker_id = self.speaker1_id + else: + next_speaker_id = self.speaker2_id + if cur_speaker == 2: + if self.decoding == 'nucleus': + output_id = self.nucleus_sampling(input_ids_list, token_type_ids_list, next_speaker_id, style_tokens) + elif self.decoding == 'greedy': + output_id = self.greedy_approach(input_ids_list, token_type_ids_list, next_speaker_id, style_tokens) + elif self.decoding == 'top_k': + output_id = self.top_k_sampling(input_ids_list, token_type_ids_list, next_speaker_id, style_tokens) + elif self.decoding == 'beam': + output_id = self.beam_search(input_ids_list, token_type_ids_list, next_speaker_id, style_tokens) + else: + raise ValueError('No decoding strategy') + res = self.tokenizer.decode(output_id, skip_special_tokens=True) + + print(f"Bot: {res}") + + cur_speaker = copy.deepcopy(next_speaker) + t += 1 + + def greedy_approach(self, input_ids_list, token_type_ids_list, next_speaker_id, style_tokens): + output_id = [] + res_id = [next_speaker_id] + res_type_id = [next_speaker_id] + for pos in range(self.utter_len): + input_ids = list(chain.from_iterable(input_ids_list)) + res_id + token_type_ids = list(chain.from_iterable(token_type_ids_list)) + res_type_id + input_len = len(input_ids) + + left = self.max_len - len(input_ids) + input_ids += [self.pad_id] * left + token_type_ids += [self.pad_id] * left + + assert len(input_ids) == len(token_type_ids), "There is something wrong in dialogue process." + + input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device) # (1, L) + token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device) # (1, L) + + output = self.model(input_ids=input_ids, token_type_ids=token_type_ids, style_tokens=style_tokens).logits[:, input_len-1] # (1, vocab_size) + + idx = torch.argmax(output) + + if len(output_id) == self.utter_len or idx == self.eos_id: + break + else: + output_id.append(idx) + res_id.append(idx) + res_type_id.append(next_speaker_id) + + return output_id + + + def beam_search(self, input_ids_list, token_type_ids_list, next_speaker_id, style_tokens, top_B = 2): + sequences = [[list(), 0.0]] + #output_id = [] + res_id = [next_speaker_id] + res_type_id = [next_speaker_id] + for pos in range(self.utter_len): + input_ids = list(chain.from_iterable(input_ids_list)) + res_id + token_type_ids = list(chain.from_iterable(token_type_ids_list)) + res_type_id + input_len = len(input_ids) + + left = self.max_len - len(input_ids) + input_ids += [self.pad_id] * left + token_type_ids += [self.pad_id] * left + + assert len(input_ids) == len(token_type_ids), "There is something wrong in dialogue process." + + input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device) # (1, L) + token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device) # (1, L) + + output = self.model(input_ids=input_ids, token_type_ids=token_type_ids, style_tokens=style_tokens).logits[:, input_len-1] # (1, vocab_size) + probs = F.softmax(output, dim=-1) + + all_candidates = list() + for i in range(len(sequences)): + seq, score = sequences[i] + for j in range(output.shape[1]): + candidate = [seq + [j], score - np.log(probs[0,j].item())] + all_candidates.append(candidate) + # order all candidates by score + ordered = sorted(all_candidates, key=lambda tup:tup[1]) + # select k best + sequences = ordered[:top_B] + # Select the second best option + output_id = sequences[1][0] + idx = output_id[-1] + if len(output_id) == self.utter_len or idx == self.eos_id: + break + else: + #output_id.append(idx) + res_id.append(idx) + res_type_id.append(next_speaker_id) + + return output_id + + + + def top_k_sampling(self, input_ids_list, token_type_ids_list, next_speaker_id, style_tokens): + output_id = [] + res_id = [next_speaker_id] + res_type_id = [next_speaker_id] + for pos in range(self.utter_len): + input_ids = list(chain.from_iterable(input_ids_list)) + res_id + token_type_ids = list(chain.from_iterable(token_type_ids_list)) + res_type_id + input_len = len(input_ids) + + left = self.max_len - len(input_ids) + input_ids += [self.pad_id] * left + token_type_ids += [self.pad_id] * left + + assert len(input_ids) == len(token_type_ids), "There is something wrong in dialogue process." + + input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device) # (1, L) + token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device) # (1, L) + + output = self.model(input_ids=input_ids, token_type_ids=token_type_ids, style_tokens=style_tokens).logits[:, input_len-1] # (1, vocab_size) + top_k = 10 + filter_value = -float('Inf') + idx_remove = output < torch.topk(output, top_k)[0][..., -1, None] + output[idx_remove] = filter_value + + # Random sampling + probs = F.softmax(output, dim=-1) # (1, vocab_size) + idx = torch.multinomial(probs, 1).squeeze(-1).squeeze(0).item() + + if len(output_id) == self.utter_len or idx == self.eos_id: + break + else: + output_id.append(idx) + res_id.append(idx) + res_type_id.append(next_speaker_id) + + return output_id + + def nucleus_sampling(self, input_ids_list, token_type_ids_list, next_speaker_id, style_tokens): + output_id = [] + res_id = [next_speaker_id] + res_type_id = [next_speaker_id] + for pos in range(self.utter_len): + input_ids = list(chain.from_iterable(input_ids_list)) + res_id + token_type_ids = list(chain.from_iterable(token_type_ids_list)) + res_type_id + input_len = len(input_ids) + + left = self.max_len - len(input_ids) + input_ids += [self.pad_id] * left + token_type_ids += [self.pad_id] * left + + assert len(input_ids) == len(token_type_ids), "There is something wrong in dialogue process." + + input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device) # (1, L) + token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device) # (1, L) + + output = self.model(input_ids=input_ids, token_type_ids=token_type_ids, style_tokens=style_tokens).logits[:, input_len-1] # (1, vocab_size) + output = F.softmax(output, dim=-1) # (1, vocab_size) + + sorted_probs, sorted_idxs = torch.sort(output, descending=True) + cumsum_probs = torch.cumsum(sorted_probs, dim=-1) # (1, vocab_size) + nucleus_p = 0.9 + idx_remove = cumsum_probs > nucleus_p + sorted_probs[idx_remove] = 1e-8 + sorted_probs /= torch.sum(sorted_probs, dim=-1, keepdim=True) # (1, vocab_size) + + # Random sampling + probs = torch.zeros(output.shape).to(self.device).scatter_(-1, sorted_idxs, sorted_probs.float()).float() # (1, vocab_size) + idx = torch.multinomial(probs, 1).squeeze(-1).squeeze(0).item() + + if len(output_id) == self.utter_len or idx == self.eos_id: + break + else: + output_id.append(idx) + res_id.append(idx) + res_type_id.append(next_speaker_id) + + return output_id + + +if __name__=='__main__': + parser = argparse.ArgumentParser() + #parser.add_argument('--config_path', required=True, default='config.json', help="The path to configuration file.") + parser.add_argument('--mode', required=True, help="Train or inference?") + parser.add_argument('--ckpt_name', required=False, help="Best checkpoint file.") + parser.add_argument('--decoding', required=False, help="Decoding strategy (nucleus, greedy, top_k, beam)") + parser.add_argument('--style_label', required=False, help="Style 1 : 14-17, anxiety // Style 2: 25-34, depressed // Style 3 : 18-21, anxiety") + + + args = parser.parse_args() + + assert args.mode == 'train' or args.mode=='inference', print("Please specify a correct mode name, 'train' or 'inference'.") + + if args.mode == 'train': + manager = GST_model(args.mode, ckpt_name=args.ckpt_name) + + manager.train() + + elif args.mode == 'inference': + #assert args.ckpt_name is not None, "Please specify the trained model checkpoint." + + manager = GST_model(args.mode, ckpt_name=args.ckpt_name, decoding=args.decoding, style_label=args.style_label) + + manager.inference() \ No newline at end of file diff --git a/models/style_tokens_model/styleGPT2_model.py b/models/style_tokens_model/styleGPT2_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b451a4eaee1874738f8ec9cd5a8854fb29f47548 --- /dev/null +++ b/models/style_tokens_model/styleGPT2_model.py @@ -0,0 +1,342 @@ +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F +from transformers import GPT2Model, GPT2LMHeadModel +from modules import * + + +class CustomGPT2Model(GPT2Model): + ''' + Modified GPT2Model that takes style_tokens as input + https://huggingface.co/transformers/_modules/transformers/models/gpt2/modeling_gpt2.html#GPT2Model + ''' + + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim).double() + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim).double() + + self.drop = nn.Dropout(config.embd_pdrop).double() + #self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.num_hidden_layers)]).double() + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon).double() + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + style_tokens=None + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + if attention_mask is not None: + assert batch_size > 0, "batch_size has to be defined and > 0" + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + + #Add style tokens to hidden states + if style_tokens is None: + raise ValueError("You have to specify style_tokens") + else: + if style_tokens.shape[1] != self.embed_dim: + raise ValueError("Problem with the dimensions of style_tokens") + + dim0, dim1 = style_tokens.size() + style_tokens = style_tokens.unsqueeze(1).expand(dim0, inputs_embeds.shape[1], dim1) + hidden_states = inputs_embeds + position_embeds + style_tokens + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + hidden_states = hidden_states.double() + + output_shape = input_shape + (hidden_states.size(-1),) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states.double(), + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(*output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + +class CustomGPT2LMHeadModel(GPT2LMHeadModel): + ''' + Modified GPT2LMHeadModel that takes style_tokens as input + https://huggingface.co/transformers/_modules/transformers/models/gpt2/modeling_gpt2.html#GPT2LMHeadModel + ''' + + def __init__(self, config): + super().__init__(config) + self.transformer = CustomGPT2Model(config).double() + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False).double() + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + style_tokens=None + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to + ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + style_tokens=style_tokens + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + +class StyleGPT2(nn.Module): + ''' + The StyleGPT2 model is composed of 2 modules: + - a style token layer + - a CustomGPT2LMHeadModel where we can take style tokens as inputs + ''' + + def __init__(self, gpt2, gst): + super(StyleGPT2, self).__init__() + self.transformer = gpt2 + self.gst = gst + + def forward(self, input_ids, token_type_ids, style_tokens, labels=None): + style_embed = self.gst(style_tokens) + x = self.transformer(input_ids=input_ids, token_type_ids=token_type_ids, labels=labels, style_tokens=style_embed) + return x + + + diff --git a/models/style_tokens_model/style_token_layer.py b/models/style_tokens_model/style_token_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..3e53ced45f4dbab2a6c6ccca4f514aadaff416cb --- /dev/null +++ b/models/style_tokens_model/style_token_layer.py @@ -0,0 +1,102 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as F +import os +import warnings +warnings.filterwarnings('ignore') + +os.chdir('/data-imperial') +import sys +sys.path.append(os.path.abspath('/data-imperial')) + +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +class GST(nn.Module): + ''' + GST : Global Style Token + + This module is used to compute the style embedding using the style token layer (STL) + inputs : represent a conversation embedding + ''' + def __init__(self): + super().__init__() + self.stl = STL() + + def forward(self, inputs): + enc_out = inputs.to(device) + if len(enc_out.size()) == 1: + enc_out = enc_out.unsqueeze(0) + style_embed = self.stl(enc_out) + + return style_embed + + +class STL(nn.Module): + ''' + STL: Style Token Layer + + The style token layer uses attention to compute a similarity measure between the inputs + which represent conversation embeddings and the set of style tokens. + ''' + + def __init__(self): + + super().__init__() + + # set of 3 style tokens + self.embed = torch.load(f"models/style_tokens_model/style_tokens.torch").to(device) + self.E = self.embed.size(1) + d_q = self.E + d_k = self.E + num_units = self.E + self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=num_units, num_heads=1) + + def forward(self, inputs): + N = inputs.size(0) + query = inputs.unsqueeze(1) # [N, 1, E] + keys = F.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads] + + style_embed = self.attention(query, keys) + + return style_embed + + +class MultiHeadAttention(nn.Module): + ''' + This is the attention module we use in the style token layer to compute a similarity measure + between the conversation embedding and the set of style tokens. + + We use Query, Keys and Values to compute a similarity score. + ''' + + def __init__(self, query_dim, key_dim, num_units, num_heads): + + super().__init__() + self.num_units = num_units + self.num_heads = num_heads + self.key_dim = key_dim + + self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False).double().to(device) + self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False).double().to(device) + self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False).double().to(device) + + def forward(self, query, key): + querys = self.W_query(query) # [N, T_q, num_units] + keys = self.W_key(key) # [N, T_k, num_units] + values = self.W_value(key) + + split_size = self.num_units // self.num_heads + querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h] + keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + + scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k] + scores = scores / (self.key_dim ** 0.5) + scores = F.softmax(scores, dim=3) + + out = torch.matmul(scores, values) # [h, N, T_q, num_units/h] + out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] + + return out.squeeze(1) #[N, num_units]s \ No newline at end of file diff --git a/models/style_tokens_model/style_tokens.torch b/models/style_tokens_model/style_tokens.torch new file mode 100644 index 0000000000000000000000000000000000000000..e4d36971c060daf98be43e379aba56da5d6e23a7 Binary files /dev/null and b/models/style_tokens_model/style_tokens.torch differ