Commit 6ac7c519 authored by Joel Oksanen's avatar Joel Oksanen
Browse files

RelBertNet now checks which pairs of entities are valid after NER before...

RelBertNet now checks which pairs of entities are valid after NER before continuing to RC while training.
parent 1c13559e
......@@ -7,35 +7,20 @@ import numpy as np
from sklearn import metrics
import time
from rel_dataset import RelInstance, RelDataset, generate_batch, tokenizer
from relbertnet import RelBertNet, NUM_RELS
from relbertnet import RelBertNet, NUM_RELS, BATCH_SIZE
train_data_path = 'data/location_contains_train_set.tsv'
trained_model_path = 'bert_extractor.pt'
BATCH_SIZE = 2
MAX_EPOCHS = 6
MAX_EPOCHS = 4
LEARNING_RATE = 0.00002
def loss(ner_loss: torch.Tensor, rc_output, n_combinations_by_instance, entity_ranges, instances):
def loss(ner_loss, rc_output, target_relation_labels):
if rc_output is None:
return ner_loss
correct_labels = torch.zeros(len(rc_output), dtype=torch.long)
idx_ins = 0
n = n_combinations_by_instance[idx_ins]
for em_idx, (er1, er2) in enumerate(entity_ranges):
while n == 0:
idx_ins += 1
n = n_combinations_by_instance[idx_ins]
relation_label = instances[idx_ins].relation(er1, er2)
correct_labels[em_idx] = relation_label if relation_label is not None else -1
n -= 1
return sum([ner_loss, cross_entropy(rc_output, correct_labels, ignore_index=-1)])
else:
return sum([ner_loss, cross_entropy(rc_output, target_relation_labels)])
class BertExtractor:
......@@ -53,7 +38,7 @@ class BertExtractor:
def train(self, data_file):
train_data = RelDataset.from_file(data_file)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=1,
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
collate_fn=generate_batch)
self.net = RelBertNet()
......@@ -68,10 +53,10 @@ class BertExtractor:
optimiser.zero_grad()
# forward pass
_, ner_loss, rc_output, n_combinations_by_instance, entity_ranges = self.net(batch, true_ner_tags)
_, ner_loss, rc_output, target_relation_labels = self.net(batch, instances, true_ner_tags)
# backward pass
l = loss(ner_loss, rc_output, n_combinations_by_instance, entity_ranges, instances)
l = loss(ner_loss, rc_output, target_relation_labels)
l.backward()
# optimise
......@@ -79,8 +64,10 @@ class BertExtractor:
# print interim stats every 10 batches
batch_loss += l.item()
if idx % 10 == 9:
print('epoch:', epoch + 1, '-- batch:', idx + 1, '-- avg loss:', batch_loss / 10)
if idx % 5 == 4:
batch_no = idx + 1
print('epoch:', epoch + 1, '--progress: {:.4f}'.format(batch_no / len(train_loader)),
'-- batch:', batch_no, '-- avg loss:', batch_loss / 5)
batch_loss = 0.0
end = time.time()
......
......@@ -87,8 +87,6 @@ class RelDataset(Dataset):
else:
return token_idx - first_token_idx
@staticmethod
def join_tokens(tokens):
joined_tokens = []
......
......@@ -13,7 +13,8 @@ TRAINED_WEIGHTS = 'bert-base-cased' # cased works better for NER
NUM_NE_TAGS = 5 # BIEOS 0-4: [Begin Inside End Outside Single]
NUM_RELS = 3 # 0-2: [no relation, e1 featureOf e2, e2 featureOf e1]
MLP_HIDDEN_LAYER_NODES = 84
MAX_ENTITIES_PER_SENTENCE = 8
# MAX_ENTITIES_PER_SENTENCE = 8
BATCH_SIZE = 32
# Based on Xue et. al. (2019) with some modifications
......@@ -39,7 +40,7 @@ class RelBertNet(nn.Module):
self.mlp1 = nn.Linear(HIDDEN_OUTPUT_FEATURES, MLP_HIDDEN_LAYER_NODES)
self.mlp2 = nn.Linear(MLP_HIDDEN_LAYER_NODES, NUM_RELS)
def forward(self, encoded_text, ner_tags=None):
def forward(self, encoded_text, instances=None, ner_tags=None):
attn_mask = encoded_text['attention_mask']
# BERT1 with MASKall for context
......@@ -61,22 +62,31 @@ class RelBertNet(nn.Module):
# obtain pairs of entities
entities_by_instance = [RelBertNet.bieos_to_entities(tags) for tags in ner_output]
combinations_by_instance = [list(itertools.combinations(ent, 2)) for ent in entities_by_instance]
n_combinations_by_instance = torch.tensor([len(combs) for combs in combinations_by_instance])
flat_combinations = [comb for combs in combinations_by_instance for comb in combs]
entity_combinations = []
target_relation_labels = []
n_combinations_by_instance = torch.empty(BATCH_SIZE, dtype=torch.long)
for idx_ins, entities in enumerate(entities_by_instance):
n = 0
for slice1, slice2 in list(itertools.combinations(entities, 2)):
relation_label = instances[idx_ins].relation((slice1.start-1, slice1.stop-1),
(slice2.start-1, slice2.stop-1))
if relation_label is not None:
entity_combinations.append((slice1, slice2))
target_relation_labels.append(relation_label)
n += 1
n_combinations_by_instance[idx_ins] = n
# if no entity pairs, cannot find relations so return
if not any(n > 2 for n in n_combinations_by_instance):
return ner_output, ner_loss, None, None, None
if sum(n_combinations_by_instance) == 0:
return ner_output, ner_loss, None, None
# for each pair of named entities recognized, perform BERT2 with MASKrc for RC
rc_attn_mask = torch.zeros((len(flat_combinations), MAX_SEQ_LEN), dtype=torch.long)
entity_ranges = torch.zeros((len(flat_combinations), 2, 2), dtype=torch.long)
for i, (slice1, slice2) in enumerate(flat_combinations):
rc_attn_mask = torch.zeros((len(entity_combinations), MAX_SEQ_LEN), dtype=torch.long)
for i, (slice1, slice2) in enumerate(entity_combinations):
rc_attn_mask[i][0] = 1
rc_attn_mask[i][slice1] = 1
rc_attn_mask[i][slice2] = 1
entity_ranges[i] = torch.tensor([[slice1.start, slice1.stop], [slice2.start, slice2.stop]]) - 1
bert_rc_output = torch.repeat_interleave(bert_context_output, n_combinations_by_instance, dim=0)
extended_rc_attn_mask = rc_attn_mask[:, None, None, :]
......@@ -89,7 +99,7 @@ class RelBertNet(nn.Module):
rc_output = self.mlp2(rc_hidden_layer_output) # softmax activation
# Return NER and RC outputs
return ner_output, ner_loss, rc_output, n_combinations_by_instance, entity_ranges
return ner_output, ner_loss, rc_output, torch.tensor(target_relation_labels, dtype=torch.long)
@staticmethod
def bieos_to_entities(tags):
......@@ -108,4 +118,4 @@ class RelBertNet(nn.Module):
b = None
# take at max MAX_ENTITIES per instace in order not to overwhelm RC at the start when NER is warming up
return entities if len(entities) < MAX_ENTITIES_PER_SENTENCE else sample(entities, MAX_ENTITIES_PER_SENTENCE)
return entities # if len(entities) < MAX_ENTITIES_PER_SENTENCE else sample(entities, MAX_ENTITIES_PER_SENTENCE)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment