Commit 1c13559e authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Fixed bug with very large batch sizes in RC due to NER warming up

parent d79bbd1a
...@@ -12,13 +12,16 @@ from relbertnet import RelBertNet, NUM_RELS ...@@ -12,13 +12,16 @@ from relbertnet import RelBertNet, NUM_RELS
train_data_path = 'data/location_contains_train_set.tsv' train_data_path = 'data/location_contains_train_set.tsv'
trained_model_path = 'bert_extractor.pt' trained_model_path = 'bert_extractor.pt'
BATCH_SIZE = 8 BATCH_SIZE = 2
MAX_EPOCHS = 6 MAX_EPOCHS = 6
LEARNING_RATE = 0.00002 LEARNING_RATE = 0.00002
def loss(ner_loss: torch.Tensor, rc_output, n_combinations_by_instance, entity_ranges, instances): def loss(ner_loss: torch.Tensor, rc_output, n_combinations_by_instance, entity_ranges, instances):
correct_labels = torch.zeros(len(rc_output), dtype=torch.uint8) if rc_output is None:
return ner_loss
correct_labels = torch.zeros(len(rc_output), dtype=torch.long)
idx_ins = 0 idx_ins = 0
n = n_combinations_by_instance[idx_ins] n = n_combinations_by_instance[idx_ins]
...@@ -50,7 +53,7 @@ class BertExtractor: ...@@ -50,7 +53,7 @@ class BertExtractor:
def train(self, data_file): def train(self, data_file):
train_data = RelDataset.from_file(data_file) train_data = RelDataset.from_file(data_file)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=1,
collate_fn=generate_batch) collate_fn=generate_batch)
self.net = RelBertNet() self.net = RelBertNet()
...@@ -76,7 +79,7 @@ class BertExtractor: ...@@ -76,7 +79,7 @@ class BertExtractor:
# print interim stats every 10 batches # print interim stats every 10 batches
batch_loss += l.item() batch_loss += l.item()
if idx % 5 == 4: if idx % 10 == 9:
print('epoch:', epoch + 1, '-- batch:', idx + 1, '-- avg loss:', batch_loss / 10) print('epoch:', epoch + 1, '-- batch:', idx + 1, '-- avg loss:', batch_loss / 10)
batch_loss = 0.0 batch_loss = 0.0
......
...@@ -4,6 +4,7 @@ import torch.nn.functional as F ...@@ -4,6 +4,7 @@ import torch.nn.functional as F
from transformers import * from transformers import *
from torchcrf import CRF from torchcrf import CRF
import itertools import itertools
from random import sample
K = 4 # number of hidden layers in Bert2 K = 4 # number of hidden layers in Bert2
HIDDEN_OUTPUT_FEATURES = 768 HIDDEN_OUTPUT_FEATURES = 768
...@@ -12,6 +13,7 @@ TRAINED_WEIGHTS = 'bert-base-cased' # cased works better for NER ...@@ -12,6 +13,7 @@ TRAINED_WEIGHTS = 'bert-base-cased' # cased works better for NER
NUM_NE_TAGS = 5 # BIEOS 0-4: [Begin Inside End Outside Single] NUM_NE_TAGS = 5 # BIEOS 0-4: [Begin Inside End Outside Single]
NUM_RELS = 3 # 0-2: [no relation, e1 featureOf e2, e2 featureOf e1] NUM_RELS = 3 # 0-2: [no relation, e1 featureOf e2, e2 featureOf e1]
MLP_HIDDEN_LAYER_NODES = 84 MLP_HIDDEN_LAYER_NODES = 84
MAX_ENTITIES_PER_SENTENCE = 8
# Based on Xue et. al. (2019) with some modifications # Based on Xue et. al. (2019) with some modifications
...@@ -86,8 +88,6 @@ class RelBertNet(nn.Module): ...@@ -86,8 +88,6 @@ class RelBertNet(nn.Module):
rc_hidden_layer_output = torch.tanh(self.mlp1(rc_cls_output)) # tanh activation rc_hidden_layer_output = torch.tanh(self.mlp1(rc_cls_output)) # tanh activation
rc_output = self.mlp2(rc_hidden_layer_output) # softmax activation rc_output = self.mlp2(rc_hidden_layer_output) # softmax activation
# entity_masks = entity_masks.narrow(2, 1, MAX_SEQ_LEN - 2).type(torch.uint8) # without CLS and SEP tokens
# Return NER and RC outputs # Return NER and RC outputs
return ner_output, ner_loss, rc_output, n_combinations_by_instance, entity_ranges return ner_output, ner_loss, rc_output, n_combinations_by_instance, entity_ranges
...@@ -106,4 +106,6 @@ class RelBertNet(nn.Module): ...@@ -106,4 +106,6 @@ class RelBertNet(nn.Module):
if tag == 4: # Single if tag == 4: # Single
entities.append(slice(idx+1, idx+2)) # +1 comes from CLS token entities.append(slice(idx+1, idx+2)) # +1 comes from CLS token
b = None b = None
return entities
# take at max MAX_ENTITIES per instace in order not to overwhelm RC at the start when NER is warming up
return entities if len(entities) < MAX_ENTITIES_PER_SENTENCE else sample(entities, MAX_ENTITIES_PER_SENTENCE)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment