Commit 1c13559e authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Fixed bug with very large batch sizes in RC due to NER warming up

parent d79bbd1a
......@@ -12,13 +12,16 @@ from relbertnet import RelBertNet, NUM_RELS
train_data_path = 'data/location_contains_train_set.tsv'
trained_model_path = 'bert_extractor.pt'
BATCH_SIZE = 8
BATCH_SIZE = 2
MAX_EPOCHS = 6
LEARNING_RATE = 0.00002
def loss(ner_loss: torch.Tensor, rc_output, n_combinations_by_instance, entity_ranges, instances):
correct_labels = torch.zeros(len(rc_output), dtype=torch.uint8)
if rc_output is None:
return ner_loss
correct_labels = torch.zeros(len(rc_output), dtype=torch.long)
idx_ins = 0
n = n_combinations_by_instance[idx_ins]
......@@ -50,7 +53,7 @@ class BertExtractor:
def train(self, data_file):
train_data = RelDataset.from_file(data_file)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=1,
collate_fn=generate_batch)
self.net = RelBertNet()
......@@ -76,7 +79,7 @@ class BertExtractor:
# print interim stats every 10 batches
batch_loss += l.item()
if idx % 5 == 4:
if idx % 10 == 9:
print('epoch:', epoch + 1, '-- batch:', idx + 1, '-- avg loss:', batch_loss / 10)
batch_loss = 0.0
......
......@@ -4,6 +4,7 @@ import torch.nn.functional as F
from transformers import *
from torchcrf import CRF
import itertools
from random import sample
K = 4 # number of hidden layers in Bert2
HIDDEN_OUTPUT_FEATURES = 768
......@@ -12,6 +13,7 @@ TRAINED_WEIGHTS = 'bert-base-cased' # cased works better for NER
NUM_NE_TAGS = 5 # BIEOS 0-4: [Begin Inside End Outside Single]
NUM_RELS = 3 # 0-2: [no relation, e1 featureOf e2, e2 featureOf e1]
MLP_HIDDEN_LAYER_NODES = 84
MAX_ENTITIES_PER_SENTENCE = 8
# Based on Xue et. al. (2019) with some modifications
......@@ -86,8 +88,6 @@ class RelBertNet(nn.Module):
rc_hidden_layer_output = torch.tanh(self.mlp1(rc_cls_output)) # tanh activation
rc_output = self.mlp2(rc_hidden_layer_output) # softmax activation
# entity_masks = entity_masks.narrow(2, 1, MAX_SEQ_LEN - 2).type(torch.uint8) # without CLS and SEP tokens
# Return NER and RC outputs
return ner_output, ner_loss, rc_output, n_combinations_by_instance, entity_ranges
......@@ -106,4 +106,6 @@ class RelBertNet(nn.Module):
if tag == 4: # Single
entities.append(slice(idx+1, idx+2)) # +1 comes from CLS token
b = None
return entities
# take at max MAX_ENTITIES per instace in order not to overwhelm RC at the start when NER is warming up
return entities if len(entities) < MAX_ENTITIES_PER_SENTENCE else sample(entities, MAX_ENTITIES_PER_SENTENCE)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment