Commit 66460561 authored by Se Park's avatar Se Park

Latest working version

parent d80c5a59
......@@ -73,5 +73,6 @@ class LoadData(Dataset):
tokens_ids = torch.tensor(tokens_ids, dtype=torch.long)
segment_ids = torch.tensor(segment_ids, dtype=torch.long)
attn_mask = torch.tensor(attn_mask, dtype=torch.long)
score = torch.tensor(score, dtype=torch.float)
return tokens_ids, attn_mask, segment_ids, score
......@@ -21,7 +21,7 @@ def evaluate(model, loss_fn, dataloader, device):
for token_ids, segment_ids, attn_masks, labels in dataloader:
token_ids, segment_ids, attn_masks, labels = token_ids.to(device), segment_ids.to(device), attn_masks.to(device), labels.to(device)
qe_scores = model(token_ids, segment_ids, attn_masks)
loss = loss_fn(qe_scores.view(-1), labels.float())
loss = loss_fn(qe_scores.view(-1), labels.view(-1))
qe_scores = qe_scores.detach().cpu().numpy()
qe_scores = qe_scores.reshape((qe_scores.shape[0],))
......@@ -29,7 +29,9 @@ def evaluate(model, loss_fn, dataloader, device):
pred = np.concatenate((pred, qe_scores))
ref = np.concatenate((ref, labels))
print (f'pred: {pred}')
print (f'ref: {ref}')
eval_loss += loss.item()
count += 1
......@@ -52,7 +54,7 @@ def train(model, loss_fn, optimizer, train_loader, val_loader, num_epoch, device
# Obtaining scores from the model
qe_scores = model(token_ids, segment_ids, attn_masks)
# Computing loss
loss = loss_fn(qe_scores.view(-1), labels.float())
loss = loss_fn(qe_scores.view(-1), labels.view(-1))
# Backpropagating the gradients
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
......@@ -74,7 +76,6 @@ if __name__ == "__main__":
PATH = Path("/vol/bitbucket/shp2918/nlp")
use_cuda = torch.cuda.is_available()
# use_cuda = False
device = torch.device('cuda' if use_cuda else 'cpu')
print("Using GPU: {}".format(use_cuda))
......@@ -83,7 +84,7 @@ if __name__ == "__main__":
model.cuda()
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=5e-5)
optimizer = optim.AdamW(model.parameters(), lr=2e-5)
MAX_LEN = 64
train_set = LoadData(src_file=PATH/'data/train.ende.src', mt_file=PATH/'data/train.ende.mt', score_file=PATH/'data/train.ende.scores', maxlen=MAX_LEN)
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertConfig
class QualityEstimation(nn.Module):
......@@ -17,21 +18,23 @@ class QualityEstimation(nn.Module):
self.lstm = nn.LSTM(input_size=768, hidden_size=self.hidden_dim,
num_layers=1, batch_first=True,
dropout=0, bidirectional=False)
self.fc1 = nn.Linear(self.hidden_dim, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, 1)
self.fc1 = nn.Linear(self.hidden_dim, 1)
nn.init.kaiming_normal_(self.fc1.weight)
# self.fc2 = nn.Linear(self.hidden_dim, 1)
# nn.init.kaiming_normal_(self.fc2.weight)
def forward(self, token_ids, segment_ids=None, attention_mask=None):
# Feeding the input to BERT model to obtain contextualized representations
flat_token_ids = token_ids.view(-1, token_ids.size(-1))
flat_segment_ids = segment_ids.view(-1, segment_ids.size(-1))
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
# flat_token_ids = token_ids.view(-1, token_ids.size(-1))
# flat_segment_ids = segment_ids.view(-1, segment_ids.size(-1))
# flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
encoded_layers, _ = self.bert(flat_token_ids, flat_segment_ids, flat_attention_mask)
# encoded_layers, _ = self.bert(input_ids=token_ids, token_type_ids=segment_ids, attention_mask=attention_mask)
# encoded_layers, _ = self.bert(flat_token_ids, flat_segment_ids, flat_attention_mask)
encoded_layers, _ = self.bert(input_ids=token_ids, token_type_ids=segment_ids, attention_mask=attention_mask)
encoded_layers = self.dropout(encoded_layers)
output, _ = self.lstm(encoded_layers)
output = torch.tanh(self.fc1(output[:,-1,:]))
qe_scores = torch.sigmoid(self.fc2(output))
# output = torch.tanh(self.fc1(output[:,-1,:]))
qe_scores = self.fc1(output[:,-1,:])
return qe_scores
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment