Commit efbad354 authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Fixed some bugs in tag extractor code, testing now.

parent 915dd99a
......@@ -19,7 +19,7 @@ EPSILON = 1e-8
MAX_GRAD_NORM = 1.0
# training
N_EPOCHS = 3
N_EPOCHS = 4
BATCH_SIZE = 32
WARM_UP_FRAC = 0.05
VALID_FRAC = 0.1
......@@ -30,10 +30,10 @@ LOSS_BIAS_WEIGHT = 10
# as defined by Zheng et. al. (2017)
def biased_loss(output_scores, target_tags):
ce_losses = torch.empty((output_scores.shape[0], output_scores.shape[1]))
ce_losses = torch.empty((output_scores.shape[0], output_scores.shape[1])).to(device)
for idx in range(ce_losses.shape[0]):
ce_losses[idx] = cross_entropy(output_scores[idx], target_tags[idx], reduction='none', ignore_index=IGNORE_TAG)
bias_matrix = torch.where(target_tags == 0, torch.tensor(1), torch.tensor(10))
bias_matrix = torch.where(target_tags == 0, torch.tensor(1).to(device), torch.tensor(10).to(device))
biased_losses = ce_losses * bias_matrix
return torch.sum(biased_losses)
......@@ -55,9 +55,9 @@ class BertTagExtractor:
self.net.eval()
@staticmethod
def new_trained_with_file(file_path):
def new_trained_with_file(file_path, size=None):
extractor = BertTagExtractor()
extractor.train_with_file(file_path)
extractor.train_with_file(file_path, size=size)
return extractor
def train_with_file(self, file_path, size=None):
......@@ -95,13 +95,13 @@ class BertTagExtractor:
for batch_idx, batch in enumerate(train_loader):
# send batch to gpu
texts, target_tags = tuple(i.to(device) for i in batch)
input_ids, attn_mask, target_tags = tuple(i.to(device) for i in batch)
# zero param gradients
self.net.zero_grad()
# forward pass
output_scores = self.net(**texts)[0]
output_scores = self.net(input_ids=input_ids, attention_mask=attn_mask)[0]
# backward pass
l = biased_loss(output_scores, target_tags)
......@@ -132,8 +132,7 @@ class BertTagExtractor:
torch.save(self.net.state_dict(), trained_model_path)
BertTagExtractor.new_trained_with_file(train_data_path, size=200000)
......
......@@ -13,11 +13,13 @@ tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)
def generate_train_batch(instances):
texts = tokenizer.batch_encode_plus([instance.tokens for instance in instances], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
encoded = tokenizer.batch_encode_plus([instance.tokens for instance in instances], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
target_tags = torch.tensor([instance.tags for instance in instances])
return texts, target_tags
return input_ids, attn_mask, target_tags
# Based on Zheng et. al. (2017)
......@@ -108,7 +110,7 @@ class TaggedRelDataset(Dataset):
@staticmethod
def ne_tags_for_len(n, base):
assert n > 0
return [base + 3] if n == 1 else [base] + [base + 1] * (n - 2) + [base]
return [base + 3] if n == 1 else [base] + [base + 1] * (n - 2) + [base+2]
def __len__(self):
return len(self.df.index)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment