diff --git a/models/conditional_gpt2_model.py b/models/conditional_gpt2_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..848d1cb8eb98eb5e3950d839ff46ce65afb8319f
--- /dev/null
+++ b/models/conditional_gpt2_model.py
@@ -0,0 +1,228 @@
+import os, sys
+os.chdir('/data-imperial')
+sys.path.append(os.path.abspath('/data-imperial'))
+from models.gpt2_model import Manager
+from transformers import GPT2Tokenizer, GPT2LMHeadModel
+from tqdm import tqdm
+from helpers.custom_data import *
+from torch.utils.data import DataLoader
+
+import torch
+import argparse
+import copy
+
+class Conditional_GPT2(Manager):
+
+    '''
+    Conditional_GPT2 class uses ConditionalGPT2 model during training and inference time.
+    
+    - For training, use of the train() function 
+    - For inference, use of the inference() function
+
+    For interence, 4 response generation strategies are implemented: 
+    - greedy_approach
+    - beam_search
+    - top_k_sampling
+    - nucleus_sampling
+    '''
+
+    def __init__(self, mode, ckpt_name=None, lr = 5e-4, decoding=None, age=None, gender=None, topic=None):
+        '''
+        Inputs:
+        - mode : train or inference
+        - ckpt_name : the name of the checkpoint file to load the model
+        - lr : the learning rate (by default 5e-4)
+        - decoding : the decoding strategy for response generation (greedy, beam search, top_k or nucleus sampling)
+        - age : 'under 18' or 'over 18'
+        - gender : 'male', 'female' or 'other'
+        - topic : conversation topic (suicide, anxiety, etc.)  
+        '''
+
+        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+        
+        # Tokenizer & Vocab
+        self.tokenizer = GPT2Tokenizer.from_pretrained(f"data/gpt-2")
+        self.special_tokens = {
+            'bos_token': "<bos>",
+            'eos_token': "<eos>",
+            'pad_token': "<pad>",
+            'additional_special_tokens': ["[texter]", "[volunteer]"]
+        }
+        self.num_new_tokens = self.tokenizer.add_special_tokens(self.special_tokens)
+        self.vocab = self.tokenizer.get_vocab()
+        self.vocab_size = len(self.vocab)
+        self.bos_id = self.vocab["<bos>"]
+        self.eos_id = self.vocab["<eos>"]
+        self.pad_id = self.vocab["<pad>"]
+        self.speaker1_id = self.vocab["[texter]"]
+        self.speaker2_id = self.vocab["[volunteer]"]
+
+        self.decoding = 'nucleus' if decoding is None else decoding
+        self.age = 'under 18' if age is None else age
+        self.gender = 'female' if gender is None else gender
+        self.topic = 'depressed' if topic is None else topic
+        assert self.gender in ['male', 'female', 'other'], "Please check the gender."
+        assert self.age in ['over 18', 'under 18'], "Please check the age"
+        assert self.topic in ['other', 'isolated', 'self_harm', 'anxiety', 'suicide',
+                        'bereavement', 'relationship', 'gender', 'bully', 'depressed',
+                        'eating', 'none', 'abuse_sexual', '3rd_party', 'abuse_physical',
+                        'abuse_emotional', 'covid_19', 'prank', 'substance',
+                        'abuse_unspecified', 'abuse_domestic', 'testing', 'abuse_child',
+                        'abuse_other'], "Please check the topic"
+        
+        # Represent the persona part of the inputs
+        self.cond_labels = f'I am a {self.gender}. I am {self.age}. I want to talk about {self.topic}.'
+        print('Decoding: ', self.decoding)
+        print('Persona profile: ', self.cond_labels)
+
+        # Represent the number of text messages considered (Number of previous messages considered + the reply)
+        self.max_time = 2
+
+        # The number of tokens to be fed to the model
+        self.max_len = 512 
+
+        # The maximum length of a text message
+        self.utter_len = (self.max_len-self.max_time-2) // self.max_time
+        
+        # Load model    
+        print("Loading the model...", flush=True)
+        self.model = GPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
+        self.model.resize_token_embeddings(len(self.tokenizer))
+            
+        if mode == 'train':            
+            # Load optimizer
+            print("Loading the optimizer...", flush=True)
+            self.optim = torch.optim.AdamW(self.model.parameters(), lr=lr)
+            self.best_loss = sys.float_info.max
+            
+            # Load train & valid dataset
+            print("Loading train & valid data...", flush=True)
+            train_set = ConditionalDataset('train')
+            valid_set = ConditionalDataset('valid')
+
+            batch_size = 4
+            self.train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
+            self.valid_loader = DataLoader(valid_set, shuffle=True, batch_size=batch_size)
+            
+            if not os.path.exists("saved_models"):
+                os.mkdir("saved_models")
+        
+        if ckpt_name is not None:
+            if os.path.exists(f"{'saved_models'}/{ckpt_name}.tar"):
+                print("Loading the trained checkpoint...", flush=True)
+                checkpoint = torch.load(f"{'saved_models'}/{ckpt_name}.tar")
+                self.model.load_state_dict(checkpoint['model_state_dict'])
+                
+                if mode == 'train':
+                    print("The training restarts with the specifed checkpoint.", flush=True)
+                    self.optim.load_state_dict(checkpoint['optim_state_dict'])
+                    self.best_loss = checkpoint['loss']
+                    self.ckpt_name = ckpt_name
+            else:
+                assert mode == 'train', "Please check if the checkpoint name exists."
+                
+                print(f"The checkpoint named '{ckpt_name}' does not exist. This becomes the best checkpoint name from now on.", flush=True)
+                self.ckpt_name = ckpt_name
+        else:
+            print("You did not specify the checkpoint name.", flush=True)
+            print(f"The default name '{'best_ckpt'}' is set.", flush=True)
+            self.ckpt_name = "best_ckpt"      
+              
+        print("Setting finished.")
+
+    def inference(self):
+        print("Let's start!")
+        print(f"If you want to quit the conversation, please type Abort!")
+        self.model.eval()
+        
+        with torch.no_grad():
+            cur_speaker = 2
+            input_ids_list = []
+            token_type_ids_list = []
+            t = 0
+            output_id = None
+            
+            while True:
+                if t == 0:
+                    cond_labels = [self.bos_id] + self.tokenizer.encode(self.cond_labels)
+                    cond_len = len(cond_labels)
+                    token_type_id = [self.speaker1_id] * cond_len
+                    input_ids_list.append(cond_labels)
+                    token_type_ids_list.append(token_type_id)
+
+                if cur_speaker == 2:
+                    cur_speaker_id = self.speaker2_id
+                    utter = input("You: ")
+                    
+                    if utter == "Abort!":
+                        print("Bot: Good bye.")
+                        break
+                    
+                    input_id = [cur_speaker_id] + self.tokenizer.encode(utter)
+                    
+                else:
+                    cur_speaker_id = self.speaker1_id
+                    input_id = copy.deepcopy(output_id)
+                    
+                token_type_id = [cur_speaker_id] * len(input_id)
+                
+                if input_id[-1] == self.eos_id:
+                    input_id = input_id[:-1]
+                    token_type_id = token_type_id[:-1] 
+                
+                input_ids_list.append(input_id)
+                token_type_ids_list.append(token_type_id)
+                
+                if t >= self.max_time:
+                    input_ids_list = input_ids_list[:1] + input_ids_list[2:]
+                    token_type_ids_list = token_type_ids_list[:1] + token_type_ids_list[2:]
+                
+                next_speaker = (cur_speaker % 2) + 1
+                if next_speaker == 1:
+                    next_speaker_id = self.speaker1_id
+                else:
+                    next_speaker_id = self.speaker2_id
+                if cur_speaker == 2:
+                    if self.decoding == 'nucleus':
+                        output_id = self.nucleus_sampling(input_ids_list, token_type_ids_list, next_speaker_id)
+                    elif self.decoding == 'greedy':
+                        output_id = self.greedy_approach(input_ids_list, token_type_ids_list, next_speaker_id)
+                    elif self.decoding == 'top_k':
+                        output_id = self.top_k_sampling(input_ids_list, token_type_ids_list, next_speaker_id)
+                    elif self.decoding == 'beam':
+                        output_id = self.beam_search(input_ids_list, token_type_ids_list, next_speaker_id)
+                    else:
+                        raise ValueError('No decoding strategy')
+                    res = self.tokenizer.decode(output_id, skip_special_tokens=True)
+
+                    print(f"Bot: {res}")
+                
+                cur_speaker = copy.deepcopy(next_speaker)
+                t += 1
+
+
+if __name__=='__main__':
+    parser = argparse.ArgumentParser()
+    #parser.add_argument('--config_path', required=True, default='config.json', help="The path to configuration file.")
+    parser.add_argument('--mode', required=True, help="Train or inference?")
+    parser.add_argument('--ckpt_name', required=False, help="Best checkpoint file.")
+    parser.add_argument('--decoding', required=False, help="Decoding strategy (nucleus, greedy, top_k, beam)")
+    parser.add_argument('--gender', required=False, help='male, female, other')
+    parser.add_argument('--age', required=False, help='over 18, under 18')
+    parser.add_argument('--topic', required=False, help="suicide, anxiety, depressed, relationship, isolated, self-harm, etc.")
+
+    args = parser.parse_args()
+    
+    assert args.mode == 'train' or args.mode=='inference', print("Please specify a correct mode name, 'train' or 'inference'.")
+              
+    if args.mode == 'train':
+        manager = Conditional_GPT2(args.mode, ckpt_name=args.ckpt_name)
+
+        manager.train()
+        
+    elif args.mode == 'inference':
+        #assert args.ckpt_name is not None, "Please specify the trained model checkpoint."
+        
+        manager = Conditional_GPT2(args.mode, ckpt_name=args.ckpt_name, decoding=args.decoding, age=args.age, gender=args.gender, topic=args.topic)
+        
+        manager.inference()
\ No newline at end of file
diff --git a/models/gpt2_model.py b/models/gpt2_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..13dee85dd2c029201d7b33c98d9c86ded851c7f0
--- /dev/null
+++ b/models/gpt2_model.py
@@ -0,0 +1,443 @@
+import os, sys
+os.chdir('/data-imperial')
+sys.path.append(os.path.abspath('/data-imperial'))
+from transformers import GPT2Tokenizer, GPT2LMHeadModel
+from helpers.custom_data import *
+from tqdm import tqdm
+from torch.utils.data import DataLoader
+from torch.nn import functional as F
+from itertools import chain
+
+import torch
+import numpy as np
+import argparse
+import copy
+
+class Manager():
+    '''
+    Manager class uses GPT2LMHeadModel during training and inference time.
+    
+    - For training, use of the train() function 
+    - For inference, use of the inference() function
+
+    For interence, 4 response generation strategies are implemented: 
+    - greedy_approach
+    - beam_search
+    - top_k_sampling
+    - nucleus_sampling
+    '''
+
+    def __init__(self, mode, ckpt_name=None, lr = 5e-4, decoding=None, max_time=2):
+        '''
+        Inputs:
+        - mode : train or inference
+        - ckpt_name : the name of the checkpoint file to load the model
+        - lr : the learning rate (by default 5e-4)
+        - decoding : the decoding strategy for response generation (greedy, beam search, top_k or nucleus sampling)
+        - max_time : the number of text messages we consider for the history and the replies 
+        (Number of previous messages considered + the reply) 
+        '''
+        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+        
+        # Tokenizer & Vocab
+        self.tokenizer = GPT2Tokenizer.from_pretrained(f"data/gpt-2")
+        self.special_tokens = {
+            'bos_token': "<bos>",
+            'eos_token': "<eos>",
+            'pad_token': "<pad>",
+            'additional_special_tokens': ["[texter]", "[volunteer]"]
+        }
+        self.num_new_tokens = self.tokenizer.add_special_tokens(self.special_tokens)
+        self.vocab = self.tokenizer.get_vocab()
+        self.vocab_size = len(self.vocab)
+        self.bos_id = self.vocab["<bos>"]
+        self.eos_id = self.vocab["<eos>"]
+        self.pad_id = self.vocab["<pad>"]
+        self.speaker1_id = self.vocab["[texter]"]
+        self.speaker2_id = self.vocab["[volunteer]"]
+        self.decoding = 'nucleus' if decoding is None else decoding
+        print('Decoding strategy: ', self.decoding)
+        
+        # The number of text messages considered in the history and the reply
+        self.max_time = 2 if max_time is None else int(max_time)
+        print('Max time = ', self.max_time)
+        # Represents the maximum number of tokens fed into the model
+        self.max_len = 512 #1024
+
+        # The maximum length of a text message
+        self.utter_len = (self.max_len-self.max_time-2) // self.max_time
+        
+        # Load model    
+        print("Loading the model...", flush=True)
+        self.model = GPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
+        self.model.resize_token_embeddings(len(self.tokenizer))
+            
+        if mode == 'train':            
+            # Load optimizer
+            print("Loading the optimizer...", flush=True)
+            self.optim = torch.optim.AdamW(self.model.parameters(), lr=lr)
+            self.best_loss = sys.float_info.max
+            
+            # Load train & valid dataset
+            print("Loading train & valid data...", flush=True)
+            train_set = CustomDataset('train')
+            valid_set = CustomDataset('valid')
+
+            batch_size = 4
+            self.train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
+            self.valid_loader = DataLoader(valid_set, shuffle=True, batch_size=batch_size)
+            
+            if not os.path.exists("saved_models"):
+                os.mkdir("saved_models")
+        
+        if ckpt_name is not None:
+            if os.path.exists(f"{'saved_models'}/{ckpt_name}.tar"):
+                print("Loading the trained checkpoint...", flush=True)
+                checkpoint = torch.load(f"{'saved_models'}/{ckpt_name}.tar")
+                self.model.load_state_dict(checkpoint['model_state_dict'])
+                
+                if mode == 'train':
+                    print("The training restarts with the specifed checkpoint.", flush=True)
+                    self.optim.load_state_dict(checkpoint['optim_state_dict'])
+                    self.best_loss = checkpoint['loss']
+                    self.ckpt_name = ckpt_name
+            else:
+                assert mode == 'train', "Please check if the checkpoint name exists."
+                
+                print(f"The checkpoint named '{ckpt_name}' does not exist. This becomes the best checkpoint name from now on.", flush=True)
+                self.ckpt_name = ckpt_name
+        else:
+            print("You did not specify the checkpoint name.", flush=True)
+            print(f"The default name '{'best_ckpt'}' is set.", flush=True)
+            self.ckpt_name = "best_ckpt"      
+              
+        print("Setting finished.", flush=True)
+              
+    def train(self):
+        print("Training starts.", flush=True)
+
+        max_epochs = 5 #10      
+        for epoch in range(1, max_epochs+1):
+            self.model.train()
+            
+            print(f"#################### Epoch: {epoch} ####################", flush=True)
+            train_losses = []
+            train_ppls = []
+            for i, batch in enumerate(tqdm(self.train_loader)):
+                input_ids, token_type_ids, lm_labels = batch
+                input_ids, token_type_ids, lm_labels = \
+                    input_ids.to(self.device), token_type_ids.to(self.device), lm_labels.to(self.device)
+                
+                outputs = self.model(
+                    input_ids=input_ids,
+                    token_type_ids = token_type_ids,
+                    labels = lm_labels
+                )
+                
+                loss, logits = outputs[0], outputs[1]
+                
+                self.optim.zero_grad()
+                loss.backward()
+                self.optim.step()
+                
+                train_losses.append(loss.item())
+                train_ppls.append(torch.exp(loss).item())
+            
+            train_loss = np.mean(train_losses)
+            train_ppl = np.mean(train_ppls)
+            print(f"Train loss: {train_loss} || Train perplexity: {train_ppl}", flush=True)
+            
+            valid_loss, valid_ppl = self.validation()
+              
+            if valid_loss < self.best_loss:
+                self.best_loss = valid_loss
+                state_dict = {
+                    'model_state_dict': self.model.state_dict(),
+                    'optim_state_dict': self.optim.state_dict(),
+                    'loss': self.best_loss,
+                }
+              
+                torch.save(state_dict, f"{'saved_models'}/{self.ckpt_name}.tar")
+                print(f"***** Current best checkpoint is saved. *****", flush=True)
+              
+            print(f"Best valid loss: {self.best_loss}", flush=True)
+            print(f"Valid loss: {valid_loss} || Valid perplexity: {valid_ppl}", flush=True)
+              
+        print("Training finished!")
+    
+    def validation(self):
+        print("Validation processing...", flush=True)
+        self.model.eval()
+              
+        valid_losses = []
+        valid_ppls = []
+        with torch.no_grad():
+            for i, batch in enumerate(tqdm(self.valid_loader)):
+                input_ids, token_type_ids, lm_labels = batch
+                input_ids, token_type_ids, lm_labels = \
+                    input_ids.to(self.device), token_type_ids.to(self.device), lm_labels.to(self.device)
+                
+                outputs = self.model(
+                    input_ids=input_ids,
+                    token_type_ids = token_type_ids,
+                    labels = lm_labels
+                )
+                
+                loss, logits = outputs[0], outputs[1]
+                
+                valid_losses.append(loss.item())
+                valid_ppls.append(torch.exp(loss).item())
+              
+            valid_loss = np.mean(valid_losses)
+            valid_ppl = np.mean(valid_ppls)
+              
+        return valid_loss, valid_ppl
+        
+              
+    def inference(self):
+        print("Let's start!")
+        print(f"If you want to quit the conversation, please type Abort!")
+        self.model.eval()
+        
+        with torch.no_grad():
+            cur_speaker = 2
+            input_ids_list = []
+            token_type_ids_list = []
+            t = 0
+            output_id = None
+            
+            while True:
+                if cur_speaker == 2:
+                    cur_speaker_id = self.speaker2_id
+                    utter = input("You: ")
+                    
+                    if utter == "Abort!":
+                        print("Bot: Good bye.")
+                        break
+                    
+                    input_id = [cur_speaker_id] + self.tokenizer.encode(utter)
+                    
+                    if t == 0:
+                        input_id = [self.bos_id] + input_id
+                else:
+                    cur_speaker_id = self.speaker1_id
+                    input_id = copy.deepcopy(output_id)
+                    
+                token_type_id = [cur_speaker_id] * len(input_id)
+                
+                if input_id[-1] == self.eos_id:
+                    input_id = input_id[:-1]
+                    token_type_id = token_type_id[:-1] 
+                
+                input_ids_list.append(input_id)
+                token_type_ids_list.append(token_type_id)
+                
+                if t >= self.max_time:
+                    input_ids_list = input_ids_list[1:]
+                    token_type_ids_list = token_type_ids_list[1:]
+                
+                next_speaker = (cur_speaker % 2) + 1
+                if next_speaker == 1:
+                    next_speaker_id = self.speaker1_id
+                else:
+                    next_speaker_id = self.speaker2_id
+                if cur_speaker == 2:
+                    if self.decoding == 'nucleus':
+                        output_id = self.nucleus_sampling(input_ids_list, token_type_ids_list, next_speaker_id)
+                    elif self.decoding == 'greedy':
+                        output_id = self.greedy_approach(input_ids_list, token_type_ids_list, next_speaker_id)
+                    elif self.decoding == 'top_k':
+                        output_id = self.top_k_sampling(input_ids_list, token_type_ids_list, next_speaker_id)
+                    elif self.decoding == 'beam':
+                        output_id = self.beam_search(input_ids_list, token_type_ids_list, next_speaker_id)
+                    else:
+                        raise ValueError('No decoding strategy')
+                    res = self.tokenizer.decode(output_id, skip_special_tokens=True)
+
+                    print(f"Bot: {res}")
+                
+                cur_speaker = copy.deepcopy(next_speaker)
+                t += 1
+
+    def greedy_approach(self, input_ids_list, token_type_ids_list, next_speaker_id):
+        output_id = []
+        res_id = [next_speaker_id]
+        res_type_id = [next_speaker_id]
+        for pos in range(self.utter_len):
+            input_ids = list(chain.from_iterable(input_ids_list)) + res_id
+            token_type_ids = list(chain.from_iterable(token_type_ids_list)) + res_type_id
+            input_len = len(input_ids)
+            
+            left = self.max_len - len(input_ids)
+            input_ids += [self.pad_id] * left
+            token_type_ids += [self.pad_id] * left
+
+            assert len(input_ids) == len(token_type_ids), "There is something wrong in dialogue process."
+            
+            input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device)  # (1, L)
+            token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device)  # (1, L)
+            
+            output = self.model(input_ids=input_ids, token_type_ids=token_type_ids).logits[:, input_len-1]  # (1, vocab_size)
+            
+            idx = torch.argmax(output)
+            
+            if len(output_id) == self.utter_len or idx == self.eos_id:
+                break
+            else:
+                output_id.append(idx)
+                res_id.append(idx)
+                res_type_id.append(next_speaker_id)
+                
+        return output_id
+
+
+    def beam_search(self, input_ids_list, token_type_ids_list, next_speaker_id, top_B = 2):
+        sequences = [[list(), 0.0]]
+        #output_id = []
+        res_id = [next_speaker_id]
+        res_type_id = [next_speaker_id]
+        for pos in range(self.utter_len):
+            input_ids = list(chain.from_iterable(input_ids_list)) + res_id
+            token_type_ids = list(chain.from_iterable(token_type_ids_list)) + res_type_id
+            input_len = len(input_ids)
+            
+            left = self.max_len - len(input_ids)
+            input_ids += [self.pad_id] * left
+            token_type_ids += [self.pad_id] * left
+
+            assert len(input_ids) == len(token_type_ids), "There is something wrong in dialogue process."
+            
+            input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device)  # (1, L)
+            token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device)  # (1, L)
+            
+            output = self.model(input_ids=input_ids, token_type_ids=token_type_ids).logits[:, input_len-1]  # (1, vocab_size)
+            probs = F.softmax(output, dim=-1)
+
+            all_candidates = list()
+            for i in range(len(sequences)):
+                seq, score = sequences[i]
+                for j in range(output.shape[1]):
+                    candidate = [seq + [j], score - np.log(probs[0,j].item())]
+                    all_candidates.append(candidate)
+            # order all candidates by score
+            ordered = sorted(all_candidates, key=lambda tup:tup[1])
+            # select k best
+            sequences = ordered[:top_B]
+            # Select the second best option
+            output_id = sequences[1][0]
+            idx = output_id[-1]
+            if len(output_id) == self.utter_len or idx == self.eos_id:
+                break
+            else:
+                #output_id.append(idx)
+                res_id.append(idx)
+                res_type_id.append(next_speaker_id)
+                
+        return output_id
+
+    
+
+    def top_k_sampling(self, input_ids_list, token_type_ids_list, next_speaker_id):
+        output_id = []
+        res_id = [next_speaker_id]
+        res_type_id = [next_speaker_id]
+        for pos in range(self.utter_len):
+            input_ids = list(chain.from_iterable(input_ids_list)) + res_id
+            token_type_ids = list(chain.from_iterable(token_type_ids_list)) + res_type_id
+            input_len = len(input_ids)
+            
+            left = self.max_len - len(input_ids)
+            input_ids += [self.pad_id] * left
+            token_type_ids += [self.pad_id] * left
+
+            assert len(input_ids) == len(token_type_ids), "There is something wrong in dialogue process."
+            
+            input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device)  # (1, L)
+            token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device)  # (1, L)
+            
+            output = self.model(input_ids=input_ids, token_type_ids=token_type_ids).logits[:, input_len-1]  # (1, vocab_size)
+            top_k = 10
+            filter_value = -float('Inf')
+            idx_remove = output < torch.topk(output, top_k)[0][..., -1, None]
+            #print(idx_remove.shape)
+            output[idx_remove] = filter_value
+            
+            # Random sampling
+            probs = F.softmax(output, dim=-1)  # (1, vocab_size)
+            idx = torch.multinomial(probs, 1).squeeze(-1).squeeze(0).item()
+            
+            if len(output_id) == self.utter_len or idx == self.eos_id:
+                break
+            else:
+                output_id.append(idx)
+                res_id.append(idx)
+                res_type_id.append(next_speaker_id)
+                
+        return output_id
+
+    def nucleus_sampling(self, input_ids_list, token_type_ids_list, next_speaker_id):
+        output_id = []
+        res_id = [next_speaker_id]
+        res_type_id = [next_speaker_id]
+        for pos in range(self.utter_len):
+            input_ids = list(chain.from_iterable(input_ids_list)) + res_id
+            token_type_ids = list(chain.from_iterable(token_type_ids_list)) + res_type_id
+            input_len = len(input_ids)
+            
+            left = self.max_len - len(input_ids)
+            input_ids += [self.pad_id] * left
+            token_type_ids += [self.pad_id] * left
+
+            assert len(input_ids) == len(token_type_ids), "There is something wrong in dialogue process."
+            
+            input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device)  # (1, L)
+            token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device)  # (1, L)
+            
+            output = self.model(input_ids=input_ids, token_type_ids=token_type_ids).logits[:, input_len-1]  # (1, vocab_size)
+            output = F.softmax(output, dim=-1)  # (1, vocab_size)
+            
+            sorted_probs, sorted_idxs = torch.sort(output, descending=True)
+            cumsum_probs = torch.cumsum(sorted_probs, dim=-1)  # (1, vocab_size)
+            nucleus_p = 0.9
+            idx_remove = cumsum_probs > nucleus_p
+            sorted_probs[idx_remove] = 1e-8
+            sorted_probs /= torch.sum(sorted_probs, dim=-1, keepdim=True)  # (1, vocab_size)
+            
+            # Random sampling
+            probs = torch.zeros(output.shape).to(self.device).scatter_(-1, sorted_idxs, sorted_probs)  # (1, vocab_size)
+            idx = torch.multinomial(probs, 1).squeeze(-1).squeeze(0).item()
+            
+            if len(output_id) == self.utter_len or idx == self.eos_id:
+                break
+            else:
+                output_id.append(idx)
+                res_id.append(idx)
+                res_type_id.append(next_speaker_id)
+                
+        return output_id
+                    
+
+if __name__=='__main__':
+    parser = argparse.ArgumentParser()
+    #parser.add_argument('--config_path', required=True, default='config.json', help="The path to configuration file.")
+    parser.add_argument('--mode', required=True, help="Train or inference?")
+    parser.add_argument('--ckpt_name', required=False, help="Best checkpoint file.")
+    parser.add_argument('--decoding', required=False, help="Decoding strategy (nucleus, greedy, top_k, beam)")
+    parser.add_argument('--max_time', required=False, help="Number of text messages considered in the history and the reply")
+
+              
+    args = parser.parse_args()
+    
+    assert args.mode == 'train' or args.mode=='inference', print("Please specify a correct mode name, 'train' or 'inference'.")
+              
+    if args.mode == 'train':
+        manager = Manager(args.mode, ckpt_name=args.ckpt_name)
+
+        manager.train()
+        
+    elif args.mode == 'inference':
+        #assert args.ckpt_name is not None, "Please specify the trained model checkpoint."
+        
+        manager = Manager(args.mode, ckpt_name=args.ckpt_name, decoding=args.decoding, max_time=args.max_time)
+        
+        manager.inference()
diff --git a/models/style_tokens_model/modules.py b/models/style_tokens_model/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..764759551c7240be93f8d1050bc242538d247986
--- /dev/null
+++ b/models/style_tokens_model/modules.py
@@ -0,0 +1,185 @@
+import torch
+from typing import Optional, Tuple
+from collections import OrderedDict
+from dataclasses import fields
+from typing import Any
+import numpy as np
+
+
+def is_tensor(x):
+    """
+    Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor`, obj:`jaxlib.xla_extension.DeviceArray` or
+    :obj:`np.ndarray`.
+    """
+    if isinstance(x, torch.Tensor):
+        return True
+
+    return isinstance(x, np.ndarray)
+
+
+class ModelOutput(OrderedDict):
+    """
+    Base class for all model outputs as dataclass. Has a ``__getitem__`` that allows indexing by integer or slice (like
+    a tuple) or strings (like a dictionary) that will ignore the ``None`` attributes. Otherwise behaves like a regular
+    python dictionary.
+    .. warning::
+        You can't unpack a :obj:`ModelOutput` directly. Use the :meth:`~transformers.file_utils.ModelOutput.to_tuple`
+        method to convert it to a tuple before.
+    """
+
+    def __post_init__(self):
+        class_fields = fields(self)
+
+        # Safety and consistency checks
+        assert len(class_fields), f"{self.__class__.__name__} has no fields."
+        assert all(
+            field.default is None for field in class_fields[1:]
+        ), f"{self.__class__.__name__} should not have more than one required field."
+
+        first_field = getattr(self, class_fields[0].name)
+        other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
+
+        if other_fields_are_none and not is_tensor(first_field):
+            if isinstance(first_field, dict):
+                iterator = first_field.items()
+                first_field_iterator = True
+            else:
+                try:
+                    iterator = iter(first_field)
+                    first_field_iterator = True
+                except TypeError:
+                    first_field_iterator = False
+
+            # if we provided an iterator as first field and the iterator is a (key, value) iterator
+            # set the associated fields
+            if first_field_iterator:
+                for element in iterator:
+                    if (
+                        not isinstance(element, (list, tuple))
+                        or not len(element) == 2
+                        or not isinstance(element[0], str)
+                    ):
+                        break
+                    setattr(self, element[0], element[1])
+                    if element[1] is not None:
+                        self[element[0]] = element[1]
+            elif first_field is not None:
+                self[class_fields[0].name] = first_field
+        else:
+            for field in class_fields:
+                v = getattr(self, field.name)
+                if v is not None:
+                    self[field.name] = v
+
+    def __delitem__(self, *args, **kwargs):
+        raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+    def setdefault(self, *args, **kwargs):
+        raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+    def pop(self, *args, **kwargs):
+        raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+    def update(self, *args, **kwargs):
+        raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+    def __getitem__(self, k):
+        if isinstance(k, str):
+            inner_dict = {k: v for (k, v) in self.items()}
+            return inner_dict[k]
+        else:
+            return self.to_tuple()[k]
+
+    def __setattr__(self, name, value):
+        if name in self.keys() and value is not None:
+            # Don't call self.__setitem__ to avoid recursion errors
+            super().__setitem__(name, value)
+        super().__setattr__(name, value)
+
+    def __setitem__(self, key, value):
+        # Will raise a KeyException if needed
+        super().__setitem__(key, value)
+        # Don't call self.__setattr__ to avoid recursion errors
+        super().__setattr__(key, value)
+
+    def to_tuple(self) -> Tuple[Any]:
+        """
+        Convert self to a tuple containing all the attributes/keys that are not ``None``.
+        """
+        return tuple(self[k] for k in self.keys())
+
+
+class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
+    """
+    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+    Args:
+        last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+            If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
+            1, hidden_size)` is output.
+        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
+            Tuple of :obj:`tuple(torch.FloatTensor)` of length :obj:`config.n_layers`, with each tuple having 2 tensors
+            of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
+            ``config.is_encoder_decoder=True`` 2 additional tensors of shape :obj:`(batch_size, num_heads,
+            encoder_sequence_length, embed_size_per_head)`.
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+            ``config.is_encoder_decoder=True`` in the cross-attention blocks) that can be used (see
+            :obj:`past_key_values` input) to speed up sequential decoding.
+        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
+            sequence_length, sequence_length)`.
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``):
+            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
+            sequence_length, sequence_length)`.
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+    """
+
+    last_hidden_state: torch.FloatTensor = None
+    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+    attentions: Optional[Tuple[torch.FloatTensor]] = None
+    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+class CausalLMOutputWithCrossAttentions(ModelOutput):
+    """
+    Base class for causal language model (or autoregressive) outputs.
+    Args:
+        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
+            Language modeling loss (for next-token prediction).
+        logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
+            sequence_length, sequence_length)`.
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
+            sequence_length, sequence_length)`.
+            Cross attentions weights after the attention softmax, used to compute the weighted average in the
+            cross-attention heads.
+        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
+            Tuple of :obj:`torch.FloatTensor` tuples of length :obj:`config.n_layers`, with each tuple containing the
+            cached key, value states of the self-attention and the cross-attention layers if model is used in
+            encoder-decoder setting. Only relevant if ``config.is_decoder = True``.
+            Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+            :obj:`past_key_values` input) to speed up sequential decoding.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: torch.FloatTensor = None
+    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+    attentions: Optional[Tuple[torch.FloatTensor]] = None
+    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
\ No newline at end of file
diff --git a/models/style_tokens_model/run_styleGPT2.py b/models/style_tokens_model/run_styleGPT2.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b353a45d37398b5fdc2f83a358b6fc2ad6faf06
--- /dev/null
+++ b/models/style_tokens_model/run_styleGPT2.py
@@ -0,0 +1,463 @@
+from style_token_layer import *
+from styleGPT2_model import *
+
+from transformers import GPT2Tokenizer
+from tqdm import tqdm
+from torch.utils.data import DataLoader
+
+import torch
+import os, sys
+import numpy as np
+import argparse
+warnings.filterwarnings('ignore')
+
+os.chdir('/data-imperial')
+import sys
+sys.path.append(os.path.abspath('/data-imperial'))
+from models.gpt2_model import *
+from helpers.custom_data import *
+
+import torch
+torch.cuda.empty_cache()
+
+# Global Style Token model
+class GST_model(Manager):
+    '''
+    GST_model class uses StyleGPT2 model during training and inference time.
+    
+    - For training, use of the train() function 
+    - For inference, use of the inference() function
+
+    For interence, 4 response generation strategies are implemented: 
+    - greedy_approach
+    - beam_search
+    - top_k_sampling
+    - nucleus_sampling
+    '''
+
+    def __init__(self, mode, ckpt_name=None, lr = 5e-4, decoding=None, style_label=None):
+        '''
+        Inputs:
+        - mode : train or inference
+        - ckpt_name : the name of the checkpoint file to load the model
+        - lr : the learning rate (by default 5e-4)
+        - decoding : the decoding strategy for response generation (greedy, beam search, top_k or nucleus sampling)
+        - style_label : the index of the style token we consider (1, 3 or 4)
+        '''
+
+        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+        
+        # Tokenizer & Vocab
+        self.tokenizer = GPT2Tokenizer.from_pretrained(f"data/gpt-2")
+        self.vocab = self.tokenizer.get_vocab()
+        self.vocab_size = len(self.vocab)
+        self.bos_id = self.vocab["<bos>"]
+        self.eos_id = self.vocab["<eos>"]
+        self.pad_id = self.vocab["<pad>"]
+        self.speaker1_id = self.vocab["[texter]"]
+        self.speaker2_id = self.vocab["[volunteer]"]
+        
+        # Represent the number of text messages considered (Number of previous messages considered + the reply)
+        self.max_time = 2
+
+        # The number of tokens to be fed to the model
+        self.max_len = 512 
+
+        # The maximum length of a text message
+        self.utter_len = (self.max_len-self.max_time-2) // self.max_time
+
+        # Decoding strategy
+        self.decoding = 'nucleus' if decoding is None else decoding
+        print('Decoding: ', self.decoding)
+
+        # Index of the style token we consider
+        self.style_label = 1 if style_label is None else int(style_label)
+        assert self.style_label in [1,3,4], "Please check the style label"
+        style = {1 : '14-17 years old, anxiety', 3: '25-34 years old, depressed', 4: '18-21 years old, anxiety'}
+        print('Style token profile: ', style[self.style_label])
+
+        # Load model    
+        print("Loading the StyleGPT2 model...", flush=True)
+        # Add style token layer
+        self.style_token_layer = GST().to(self.device)
+        # Add the gpt2 part
+        self.transformer = CustomGPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
+        self.transformer.resize_token_embeddings(len(self.tokenizer))
+
+        self.model = StyleGPT2(self.transformer, self.style_token_layer).to(self.device)
+        
+        if mode == 'train':            
+            # Load optimizer
+            print("Loading the optimizer...", flush=True)
+            self.optim = torch.optim.AdamW(self.model.parameters(), lr=lr)
+            self.best_loss = sys.float_info.max
+            
+            # Load train & valid dataset from StyleDataset
+            print("Loading train & valid data...", flush=True)
+            train_set = StyleDataset('train')
+            valid_set = StyleDataset('valid')
+
+            batch_size = 3
+            self.train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
+            self.valid_loader = DataLoader(valid_set, shuffle=True, batch_size=batch_size)
+            
+            if not os.path.exists("saved_models"):
+                os.mkdir("saved_models")
+        
+        if ckpt_name is not None:
+            if os.path.exists(f"{'saved_models'}/{ckpt_name}.tar"):
+                print("Loading the trained checkpoint...", flush=True)
+                checkpoint = torch.load(f"{'saved_models'}/{ckpt_name}.tar")
+                self.model.load_state_dict(checkpoint['model_state_dict'])
+                
+                if mode == 'train':
+                    print("The training restarts with the specifed checkpoint.", flush=True)
+                    self.optim.load_state_dict(checkpoint['optim_state_dict'])
+                    self.best_loss = checkpoint['loss']
+                    self.ckpt_name = ckpt_name
+            else:
+                assert mode == 'train', "Please check if the checkpoint name exists."
+                
+                print(f"The checkpoint named '{ckpt_name}' does not exist. This becomes the best checkpoint name from now on.", flush=True)
+                self.ckpt_name = ckpt_name
+        else:
+            print("You did not specify the checkpoint name.", flush=True)
+            print(f"The default name '{'best_ckpt'}' is set.", flush=True)
+            self.ckpt_name = "best_ckpt"      
+              
+        print("Setting finished.", flush=True)
+
+    def train(self):
+        print("Training starts.", flush=True)
+
+        max_epochs = 5 #10      
+        for epoch in range(1, max_epochs+1):
+            self.model.train()
+            
+            print(f"#################### Epoch: {epoch} ####################", flush=True)
+            train_losses = []
+            train_ppls = []
+            for i, batch in enumerate(tqdm(self.train_loader)):
+                input_ids, token_type_ids, lm_labels, style_tokens = batch
+                input_ids, token_type_ids, lm_labels, style_tokens = \
+                    input_ids.to(self.device), token_type_ids.to(self.device), lm_labels.to(self.device), style_tokens.to(self.device)
+
+                outputs = self.model(
+                    input_ids=input_ids,
+                    token_type_ids = token_type_ids,
+                    labels = lm_labels,
+                    style_tokens = style_tokens
+                )
+                
+                loss, logits = outputs[0], outputs[1]
+                
+                self.optim.zero_grad()
+                loss.backward()
+                self.optim.step()
+                
+                train_losses.append(loss.item())
+                train_ppls.append(torch.exp(loss).item())
+            
+            train_loss = np.mean(train_losses)
+            train_ppl = np.mean(train_ppls)
+            print(f"Train loss: {train_loss} || Train perplexity: {train_ppl}", flush=True)
+            
+            valid_loss, valid_ppl = self.validation()
+            
+            if valid_loss < self.best_loss:
+                self.best_loss = valid_loss
+                
+                state_dict = {
+                    'model_state_dict': self.model.state_dict(),
+                    'optim_state_dict': self.optim.state_dict(),
+                    'loss': self.best_loss,
+                }
+                
+                torch.save(state_dict, f"{'saved_models'}/{self.ckpt_name}.tar")
+                print(f"***** Current best checkpoint is saved. *****", flush=True)
+            
+            if not os.path.exists(f"{'saved_models'}/{self.ckpt_name}"):
+                os.mkdir(f"{'saved_models'}/{self.ckpt_name}")
+
+            print(f"Best valid loss: {self.best_loss}", flush=True)
+            print(f"Valid loss: {valid_loss} || Valid perplexity: {valid_ppl}", flush=True)
+              
+        print("Training finished!")
+    
+    def validation(self):
+        print("Validation processing...", flush=True)
+        self.model.eval()
+              
+        valid_losses = []
+        valid_ppls = []
+        with torch.no_grad():
+            for i, batch in enumerate(tqdm(self.valid_loader)):
+                input_ids, token_type_ids, lm_labels, style_tokens = batch
+                input_ids, token_type_ids, lm_labels, style_tokens = \
+                    input_ids.to(self.device), token_type_ids.to(self.device), lm_labels.to(self.device), style_tokens.to(self.device)
+                
+                outputs = self.model(
+                    input_ids=input_ids,
+                    token_type_ids = token_type_ids,
+                    labels = lm_labels,
+                    style_tokens = style_tokens
+                )
+                
+                loss, logits = outputs[0], outputs[1]
+                
+                valid_losses.append(loss.item())
+                valid_ppls.append(torch.exp(loss).item())
+              
+            valid_loss = np.mean(valid_losses)
+            valid_ppl = np.mean(valid_ppls)
+              
+        return valid_loss, valid_ppl
+
+    def inference(self):
+        print("Let's start!")
+        print(f"If you want to quit the conversation, please type Abort!")
+        self.model.eval()
+        
+        with torch.no_grad():
+            cur_speaker = 2
+            input_ids_list = []
+            token_type_ids_list = []
+            t = 0
+            output_id = None
+            
+            style_tokens = torch.load(f"models/style_tokens_model/style_tokens.torch")[self.style_label]#.to(device)
+            
+            while True:
+                if cur_speaker == 2:
+                    cur_speaker_id = self.speaker2_id
+                    utter = input("You: ")
+                    
+                    if utter == "Abort!":
+                        print("Bot: Good bye.")
+                        break
+                    
+                    input_id = [cur_speaker_id] + self.tokenizer.encode(utter)
+                    
+                    if t == 0:
+                        input_id = [self.bos_id] + input_id
+                else:
+                    cur_speaker_id = self.speaker1_id
+                    input_id = copy.deepcopy(output_id)
+                    
+                token_type_id = [cur_speaker_id] * len(input_id)
+                
+                if input_id[-1] == self.eos_id:
+                    input_id = input_id[:-1]
+                    token_type_id = token_type_id[:-1] 
+                
+                input_ids_list.append(input_id)
+                token_type_ids_list.append(token_type_id)
+                
+                if t >= self.max_time:
+                    input_ids_list = input_ids_list[1:]
+                    token_type_ids_list = token_type_ids_list[1:]
+                
+                next_speaker = (cur_speaker % 2) + 1
+                if next_speaker == 1:
+                    next_speaker_id = self.speaker1_id
+                else:
+                    next_speaker_id = self.speaker2_id
+                if cur_speaker == 2:
+                    if self.decoding == 'nucleus':
+                        output_id = self.nucleus_sampling(input_ids_list, token_type_ids_list, next_speaker_id, style_tokens)
+                    elif self.decoding == 'greedy':
+                        output_id = self.greedy_approach(input_ids_list, token_type_ids_list, next_speaker_id, style_tokens)
+                    elif self.decoding == 'top_k':
+                        output_id = self.top_k_sampling(input_ids_list, token_type_ids_list, next_speaker_id, style_tokens)
+                    elif self.decoding == 'beam':
+                        output_id = self.beam_search(input_ids_list, token_type_ids_list, next_speaker_id, style_tokens)
+                    else:
+                        raise ValueError('No decoding strategy')
+                    res = self.tokenizer.decode(output_id, skip_special_tokens=True)
+
+                    print(f"Bot: {res}")
+                
+                cur_speaker = copy.deepcopy(next_speaker)
+                t += 1
+
+    def greedy_approach(self, input_ids_list, token_type_ids_list, next_speaker_id, style_tokens):
+        output_id = []
+        res_id = [next_speaker_id]
+        res_type_id = [next_speaker_id]
+        for pos in range(self.utter_len):
+            input_ids = list(chain.from_iterable(input_ids_list)) + res_id
+            token_type_ids = list(chain.from_iterable(token_type_ids_list)) + res_type_id
+            input_len = len(input_ids)
+            
+            left = self.max_len - len(input_ids)
+            input_ids += [self.pad_id] * left
+            token_type_ids += [self.pad_id] * left
+
+            assert len(input_ids) == len(token_type_ids), "There is something wrong in dialogue process."
+            
+            input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device)  # (1, L)
+            token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device)  # (1, L)
+
+            output = self.model(input_ids=input_ids, token_type_ids=token_type_ids, style_tokens=style_tokens).logits[:, input_len-1]  # (1, vocab_size)
+            
+            idx = torch.argmax(output)
+            
+            if len(output_id) == self.utter_len or idx == self.eos_id:
+                break
+            else:
+                output_id.append(idx)
+                res_id.append(idx)
+                res_type_id.append(next_speaker_id)
+                
+        return output_id
+
+
+    def beam_search(self, input_ids_list, token_type_ids_list, next_speaker_id, style_tokens, top_B = 2):
+        sequences = [[list(), 0.0]]
+        #output_id = []
+        res_id = [next_speaker_id]
+        res_type_id = [next_speaker_id]
+        for pos in range(self.utter_len):
+            input_ids = list(chain.from_iterable(input_ids_list)) + res_id
+            token_type_ids = list(chain.from_iterable(token_type_ids_list)) + res_type_id
+            input_len = len(input_ids)
+            
+            left = self.max_len - len(input_ids)
+            input_ids += [self.pad_id] * left
+            token_type_ids += [self.pad_id] * left
+
+            assert len(input_ids) == len(token_type_ids), "There is something wrong in dialogue process."
+            
+            input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device)  # (1, L)
+            token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device)  # (1, L)
+            
+            output = self.model(input_ids=input_ids, token_type_ids=token_type_ids, style_tokens=style_tokens).logits[:, input_len-1]  # (1, vocab_size)
+            probs = F.softmax(output, dim=-1)
+
+            all_candidates = list()
+            for i in range(len(sequences)):
+                seq, score = sequences[i]
+                for j in range(output.shape[1]):
+                    candidate = [seq + [j], score - np.log(probs[0,j].item())]
+                    all_candidates.append(candidate)
+            # order all candidates by score
+            ordered = sorted(all_candidates, key=lambda tup:tup[1])
+            # select k best
+            sequences = ordered[:top_B]
+            # Select the second best option
+            output_id = sequences[1][0]
+            idx = output_id[-1]
+            if len(output_id) == self.utter_len or idx == self.eos_id:
+                break
+            else:
+                #output_id.append(idx)
+                res_id.append(idx)
+                res_type_id.append(next_speaker_id)
+                
+        return output_id
+
+    
+
+    def top_k_sampling(self, input_ids_list, token_type_ids_list, next_speaker_id, style_tokens):
+        output_id = []
+        res_id = [next_speaker_id]
+        res_type_id = [next_speaker_id]
+        for pos in range(self.utter_len):
+            input_ids = list(chain.from_iterable(input_ids_list)) + res_id
+            token_type_ids = list(chain.from_iterable(token_type_ids_list)) + res_type_id
+            input_len = len(input_ids)
+            
+            left = self.max_len - len(input_ids)
+            input_ids += [self.pad_id] * left
+            token_type_ids += [self.pad_id] * left
+
+            assert len(input_ids) == len(token_type_ids), "There is something wrong in dialogue process."
+            
+            input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device)  # (1, L)
+            token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device)  # (1, L)
+            
+            output = self.model(input_ids=input_ids, token_type_ids=token_type_ids, style_tokens=style_tokens).logits[:, input_len-1]  # (1, vocab_size)
+            top_k = 10
+            filter_value = -float('Inf')
+            idx_remove = output < torch.topk(output, top_k)[0][..., -1, None]
+            output[idx_remove] = filter_value
+            
+            # Random sampling
+            probs = F.softmax(output, dim=-1)  # (1, vocab_size)
+            idx = torch.multinomial(probs, 1).squeeze(-1).squeeze(0).item()
+            
+            if len(output_id) == self.utter_len or idx == self.eos_id:
+                break
+            else:
+                output_id.append(idx)
+                res_id.append(idx)
+                res_type_id.append(next_speaker_id)
+                
+        return output_id
+
+    def nucleus_sampling(self, input_ids_list, token_type_ids_list, next_speaker_id, style_tokens):
+        output_id = []
+        res_id = [next_speaker_id]
+        res_type_id = [next_speaker_id]
+        for pos in range(self.utter_len):
+            input_ids = list(chain.from_iterable(input_ids_list)) + res_id
+            token_type_ids = list(chain.from_iterable(token_type_ids_list)) + res_type_id
+            input_len = len(input_ids)
+            
+            left = self.max_len - len(input_ids)
+            input_ids += [self.pad_id] * left
+            token_type_ids += [self.pad_id] * left
+
+            assert len(input_ids) == len(token_type_ids), "There is something wrong in dialogue process."
+            
+            input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device)  # (1, L)
+            token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device)  # (1, L)
+            
+            output = self.model(input_ids=input_ids, token_type_ids=token_type_ids, style_tokens=style_tokens).logits[:, input_len-1]  # (1, vocab_size)
+            output = F.softmax(output, dim=-1)  # (1, vocab_size)
+            
+            sorted_probs, sorted_idxs = torch.sort(output, descending=True)
+            cumsum_probs = torch.cumsum(sorted_probs, dim=-1)  # (1, vocab_size)
+            nucleus_p = 0.9
+            idx_remove = cumsum_probs > nucleus_p
+            sorted_probs[idx_remove] = 1e-8
+            sorted_probs /= torch.sum(sorted_probs, dim=-1, keepdim=True)  # (1, vocab_size)
+            
+            # Random sampling
+            probs = torch.zeros(output.shape).to(self.device).scatter_(-1, sorted_idxs, sorted_probs.float()).float()  # (1, vocab_size)
+            idx = torch.multinomial(probs, 1).squeeze(-1).squeeze(0).item()
+            
+            if len(output_id) == self.utter_len or idx == self.eos_id:
+                break
+            else:
+                output_id.append(idx)
+                res_id.append(idx)
+                res_type_id.append(next_speaker_id)
+                
+        return output_id
+
+
+if __name__=='__main__':
+    parser = argparse.ArgumentParser()
+    #parser.add_argument('--config_path', required=True, default='config.json', help="The path to configuration file.")
+    parser.add_argument('--mode', required=True, help="Train or inference?")
+    parser.add_argument('--ckpt_name', required=False, help="Best checkpoint file.")
+    parser.add_argument('--decoding', required=False, help="Decoding strategy (nucleus, greedy, top_k, beam)")
+    parser.add_argument('--style_label', required=False, help="Style 1 : 14-17, anxiety // Style 2: 25-34, depressed // Style 3 : 18-21, anxiety")
+
+
+    args = parser.parse_args()
+    
+    assert args.mode == 'train' or args.mode=='inference', print("Please specify a correct mode name, 'train' or 'inference'.")
+              
+    if args.mode == 'train':
+        manager = GST_model(args.mode, ckpt_name=args.ckpt_name)
+
+        manager.train()
+        
+    elif args.mode == 'inference':
+        #assert args.ckpt_name is not None, "Please specify the trained model checkpoint."
+        
+        manager = GST_model(args.mode, ckpt_name=args.ckpt_name, decoding=args.decoding, style_label=args.style_label)
+        
+        manager.inference()
\ No newline at end of file
diff --git a/models/style_tokens_model/styleGPT2_model.py b/models/style_tokens_model/styleGPT2_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b451a4eaee1874738f8ec9cd5a8854fb29f47548
--- /dev/null
+++ b/models/style_tokens_model/styleGPT2_model.py
@@ -0,0 +1,342 @@
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+from transformers import GPT2Model, GPT2LMHeadModel
+from modules import *
+
+
+class CustomGPT2Model(GPT2Model):
+    '''
+    Modified GPT2Model that takes style_tokens as input
+    https://huggingface.co/transformers/_modules/transformers/models/gpt2/modeling_gpt2.html#GPT2Model
+    '''
+
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.embed_dim = config.hidden_size
+
+        self.wte = nn.Embedding(config.vocab_size, self.embed_dim).double()
+        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim).double()
+
+        self.drop = nn.Dropout(config.embd_pdrop).double()
+        #self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.num_hidden_layers)]).double()
+        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon).double()
+
+        self.init_weights()
+
+        # Model parallel
+        self.model_parallel = False
+        self.device_map = None
+
+
+    def forward(
+        self,
+        input_ids=None,
+        past_key_values=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        style_tokens=None
+    ):
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+            batch_size = input_ids.shape[0]
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+            batch_size = inputs_embeds.shape[0]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+        if token_type_ids is not None:
+            token_type_ids = token_type_ids.view(-1, input_shape[-1])
+        if position_ids is not None:
+            position_ids = position_ids.view(-1, input_shape[-1])
+
+        if past_key_values is None:
+            past_length = 0
+            past_key_values = tuple([None] * len(self.h))
+        else:
+            past_length = past_key_values[0][0].size(-2)
+        if position_ids is None:
+            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
+            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+
+        # GPT2Attention mask.
+        if attention_mask is not None:
+            assert batch_size > 0, "batch_size has to be defined and > 0"
+            attention_mask = attention_mask.view(batch_size, -1)
+            # We create a 3D attention mask from a 2D tensor mask.
+            # Sizes are [batch_size, 1, 1, to_seq_length]
+            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+            # this attention mask is more simple than the triangular masking of causal attention
+            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+            attention_mask = attention_mask[:, None, None, :]
+
+            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+            # masked positions, this operation will create a tensor which is 0.0 for
+            # positions we want to attend and -10000.0 for masked positions.
+            # Since we are adding it to the raw scores before the softmax, this is
+            # effectively the same as removing these entirely.
+            attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
+            attention_mask = (1.0 - attention_mask) * -10000.0
+
+        # If a 2D ou 3D attention mask is provided for the cross-attention
+        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+        if self.config.add_cross_attention and encoder_hidden_states is not None:
+            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+            if encoder_attention_mask is None:
+                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+            encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+        else:
+            encoder_attention_mask = None
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # head_mask has shape n_layer x batch x n_heads x N x N
+        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+        if inputs_embeds is None:
+            inputs_embeds = self.wte(input_ids)
+        position_embeds = self.wpe(position_ids)
+
+        #Add style tokens to hidden states
+        if style_tokens is None:
+            raise ValueError("You have to specify style_tokens")
+        else:
+            if style_tokens.shape[1] != self.embed_dim:
+                raise ValueError("Problem with the dimensions of style_tokens")
+        
+        dim0, dim1 = style_tokens.size()
+        style_tokens = style_tokens.unsqueeze(1).expand(dim0, inputs_embeds.shape[1], dim1)
+        hidden_states = inputs_embeds + position_embeds + style_tokens
+
+        if token_type_ids is not None:
+            token_type_embeds = self.wte(token_type_ids)
+            hidden_states = hidden_states + token_type_embeds
+
+        hidden_states = self.drop(hidden_states)
+        hidden_states = hidden_states.double()
+
+        output_shape = input_shape + (hidden_states.size(-1),)
+
+        presents = () if use_cache else None
+        all_self_attentions = () if output_attentions else None
+        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+        all_hidden_states = () if output_hidden_states else None
+        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+
+            # Model parallel
+            if self.model_parallel:
+                torch.cuda.set_device(hidden_states.device)
+                # Ensure layer_past is on same device as hidden_states (might not be correct)
+                if layer_past is not None:
+                    layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
+                # Ensure that attention_mask is always on the same device as hidden_states
+                if attention_mask is not None:
+                    attention_mask = attention_mask.to(hidden_states.device)
+                if isinstance(head_mask, torch.Tensor):
+                    head_mask = head_mask.to(hidden_states.device)
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            if getattr(self.config, "gradient_checkpointing", False) and self.training:
+
+                if use_cache:
+                    use_cache = False
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        # None for past_key_value
+                        return module(*inputs, use_cache, output_attentions)
+
+                    return custom_forward
+
+                outputs = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(block),
+                    hidden_states,
+                    None,
+                    attention_mask,
+                    head_mask[i],
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                )
+            else:
+                outputs = block(
+                    hidden_states.double(),
+                    layer_past=layer_past,
+                    attention_mask=attention_mask,
+                    head_mask=head_mask[i],
+                    encoder_hidden_states=encoder_hidden_states,
+                    encoder_attention_mask=encoder_attention_mask,
+                    use_cache=use_cache,
+                    output_attentions=output_attentions,
+                )
+
+            hidden_states = outputs[0]
+            if use_cache is True:
+                presents = presents + (outputs[1],)
+
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+                if self.config.add_cross_attention:
+                    all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
+
+            # Model Parallel: If it's the last layer for that device, put things on the next device
+            if self.model_parallel:
+                for k, v in self.device_map.items():
+                    if i == v[-1] and "cuda:" + str(k) != self.last_device:
+                        hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+        hidden_states = self.ln_f(hidden_states)
+
+        hidden_states = hidden_states.view(*output_shape)
+        # Add last hidden state
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=presents,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+            cross_attentions=all_cross_attentions,
+        )
+
+class CustomGPT2LMHeadModel(GPT2LMHeadModel):
+    '''
+    Modified GPT2LMHeadModel that takes style_tokens as input
+    https://huggingface.co/transformers/_modules/transformers/models/gpt2/modeling_gpt2.html#GPT2LMHeadModel
+    '''
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.transformer = CustomGPT2Model(config).double()
+        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False).double()
+
+        self.init_weights()
+
+        # Model parallel
+        self.model_parallel = False
+        self.device_map = None
+
+    def forward(
+        self,
+        input_ids=None,
+        past_key_values=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        labels=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        style_tokens=None
+    ):
+        r"""
+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+            ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
+            ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        transformer_outputs = self.transformer(
+            input_ids,
+            past_key_values=past_key_values,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            style_tokens=style_tokens
+        )
+        hidden_states = transformer_outputs[0] 
+
+        # Set device for model parallelism
+        if self.model_parallel:
+            torch.cuda.set_device(self.transformer.first_device)
+            hidden_states = hidden_states.to(self.lm_head.weight.device)
+
+        lm_logits = self.lm_head(hidden_states)
+
+        loss = None
+        if labels is not None:
+            # Shift so that tokens < n predict n
+            shift_logits = lm_logits[..., :-1, :].contiguous()
+            shift_labels = labels[..., 1:].contiguous()
+            # Flatten the tokens
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+        if not return_dict:
+            output = (lm_logits,) + transformer_outputs[1:]
+            return ((loss,) + output) if loss is not None else output
+
+        return CausalLMOutputWithCrossAttentions(
+            loss=loss,
+            logits=lm_logits,
+            past_key_values=transformer_outputs.past_key_values,
+            hidden_states=transformer_outputs.hidden_states,
+            attentions=transformer_outputs.attentions,
+            cross_attentions=transformer_outputs.cross_attentions,
+        )
+
+
+class StyleGPT2(nn.Module):
+    '''
+    The StyleGPT2 model is composed of 2 modules:
+    - a style token layer 
+    - a CustomGPT2LMHeadModel where we can take style tokens as inputs
+    '''
+
+    def __init__(self, gpt2, gst):
+        super(StyleGPT2, self).__init__()
+        self.transformer = gpt2
+        self.gst = gst
+
+    def forward(self, input_ids, token_type_ids, style_tokens, labels=None):
+        style_embed = self.gst(style_tokens)
+        x = self.transformer(input_ids=input_ids, token_type_ids=token_type_ids, labels=labels, style_tokens=style_embed)
+        return x
+
+
+
diff --git a/models/style_tokens_model/style_token_layer.py b/models/style_tokens_model/style_token_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e53ced45f4dbab2a6c6ccca4f514aadaff416cb
--- /dev/null
+++ b/models/style_tokens_model/style_token_layer.py
@@ -0,0 +1,102 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+import torch.nn.functional as F
+import os
+import warnings
+warnings.filterwarnings('ignore')
+
+os.chdir('/data-imperial')
+import sys
+sys.path.append(os.path.abspath('/data-imperial'))
+
+device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+class GST(nn.Module):
+    '''
+    GST : Global Style Token 
+
+    This module is used to compute the style embedding using the style token layer (STL)
+    inputs : represent a conversation embedding
+    '''
+    def __init__(self):
+        super().__init__()
+        self.stl = STL()
+
+    def forward(self, inputs):
+        enc_out = inputs.to(device)
+        if len(enc_out.size()) == 1:
+            enc_out = enc_out.unsqueeze(0)
+        style_embed = self.stl(enc_out)
+
+        return style_embed
+
+
+class STL(nn.Module):
+    '''
+    STL: Style Token Layer
+    
+    The style token layer uses attention to compute a similarity measure between the inputs
+    which represent conversation embeddings and the set of style tokens.
+    '''
+
+    def __init__(self):
+
+        super().__init__()
+
+        # set of 3 style tokens
+        self.embed = torch.load(f"models/style_tokens_model/style_tokens.torch").to(device)
+        self.E = self.embed.size(1)
+        d_q = self.E 
+        d_k = self.E
+        num_units = self.E
+        self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=num_units, num_heads=1)
+
+    def forward(self, inputs):
+        N = inputs.size(0)
+        query = inputs.unsqueeze(1)  # [N, 1, E]
+        keys = F.tanh(self.embed).unsqueeze(0).expand(N, -1, -1)  # [N, token_num, E // num_heads]
+
+        style_embed = self.attention(query, keys)
+
+        return style_embed
+
+
+class MultiHeadAttention(nn.Module):
+    '''
+    This is the attention module we use in the style token layer to compute a similarity measure
+    between the conversation embedding and the set of style tokens. 
+
+    We use Query, Keys and Values to compute a similarity score.
+    '''
+
+    def __init__(self, query_dim, key_dim, num_units, num_heads):
+
+        super().__init__()
+        self.num_units = num_units
+        self.num_heads = num_heads
+        self.key_dim = key_dim
+
+        self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False).double().to(device)
+        self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False).double().to(device)
+        self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False).double().to(device)
+
+    def forward(self, query, key):
+        querys = self.W_query(query)  # [N, T_q, num_units]
+        keys = self.W_key(key)  # [N, T_k, num_units]
+        values = self.W_value(key)
+
+        split_size = self.num_units // self.num_heads
+        querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0)  # [h, N, T_q, num_units/h]
+        keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0)  # [h, N, T_k, num_units/h]
+        values = torch.stack(torch.split(values, split_size, dim=2), dim=0)  # [h, N, T_k, num_units/h]
+
+        scores = torch.matmul(querys, keys.transpose(2, 3))  # [h, N, T_q, T_k]
+        scores = scores / (self.key_dim ** 0.5)
+        scores = F.softmax(scores, dim=3)
+
+        out = torch.matmul(scores, values)  # [h, N, T_q, num_units/h]
+        out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0)  # [N, T_q, num_units]
+
+        return out.squeeze(1) #[N, num_units]s
\ No newline at end of file
diff --git a/models/style_tokens_model/style_tokens.torch b/models/style_tokens_model/style_tokens.torch
new file mode 100644
index 0000000000000000000000000000000000000000..e4d36971c060daf98be43e379aba56da5d6e23a7
Binary files /dev/null and b/models/style_tokens_model/style_tokens.torch differ