Commit 7f31ce78 authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Trained and evaluated first BERT extractor, NER and RC accuracy 0.772 and 0.658 respectively.

parent 9343c134
......@@ -9,10 +9,11 @@ import time
from rel_dataset import RelInstance, RelDataset, generate_batch, tokenizer
from relbertnet import RelBertNet, NUM_RELS, BATCH_SIZE
train_data_path = 'data/location_contains_train_set.tsv'
train_data_path = 'data/train.json'
test_data_path = 'data/test.json'
trained_model_path = 'bert_extractor.pt'
MAX_EPOCHS = 4
MAX_EPOCHS = 3
LEARNING_RATE = 0.00002
......@@ -37,7 +38,7 @@ class BertExtractor:
self.net.eval()
def train(self, data_file):
train_data = RelDataset.from_file(data_file)
train_data = RelDataset.from_file(data_file, n_instances=40000)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
collate_fn=generate_batch)
......@@ -62,13 +63,14 @@ class BertExtractor:
# optimise
optimiser.step()
# print interim stats every 10 batches
# print interim stats every 20 batches
batch_loss += l.item()
if idx % 5 == 4:
if idx % 20 == 19:
batch_no = idx + 1
print('epoch:', epoch + 1, '--progress: {:.4f}'.format(batch_no / len(train_loader)),
'-- batch:', batch_no, '-- avg loss:', batch_loss / 5)
'-- batch:', batch_no, '-- avg loss:', batch_loss / 20)
batch_loss = 0.0
print('epoch done')
end = time.time()
print('Training took', end - start, 'seconds')
......@@ -86,27 +88,36 @@ class BertExtractor:
true_rels = []
with torch.no_grad():
for batch, b_true_ner_tags, b_true_rels in test_loader:
for idx, (batch, b_true_ner_tags, instances) in enumerate(test_loader):
ner_output, _, rc_output = self.net(batch)
_, rc_max = torch.max(rc_output.data, 1)
ner_output, _, rc_output, target_relation_labels = self.net(batch, instances)
_, predicted_relation_labels = torch.max(rc_output.data, 1)
predicted_ner_tags += ner_output.tolist()
predicted_rels += rc_max.tolist()
true_rels += b_true_rels.tolist()
predicted_ner_tags += ner_output
predicted_rels += predicted_relation_labels.tolist()
true_rels += target_relation_labels.tolist()
true_ner_tags += b_true_ner_tags.tolist()
for pred, truth in [(predicted_ner_tags, true_ner_tags), (predicted_rels, true_rels)]:
correct = (np.array(pred) == np.array(truth))
accuracy = correct.sum() / correct.size
print('accuracy:', accuracy)
# NER
ner_correct = []
for i in range(len(predicted_ner_tags)):
# remove padding from truths and compare
ins_correct = 1 if predicted_ner_tags[i] == true_ner_tags[i][:len(predicted_ner_tags[i])] else 0
ner_correct.append(ins_correct)
ner_accuracy = sum(ner_correct) / len(ner_correct)
print('NER accuracy:', ner_accuracy)
# RC
rc_correct = (np.array(predicted_rels) == np.array(true_rels))
rc_accuracy = rc_correct.sum() / rc_correct.size
print('RC accuracy:', rc_accuracy)
cm = metrics.confusion_matrix(true_rels, predicted_rels, labels=range(NUM_RELS))
print('confusion matrix:')
print('RC confusion matrix:')
print(cm)
f1 = metrics.f1_score(true_rels, predicted_rels, labels=range(NUM_RELS), average='macro')
print('macro F1:', f1)
print('RC macro F1:', f1)
# dataset = RelDataset.from_texts(['A giraffe made friends with a lion at the Zoo.', 'Since I\'m not sure.'])
......@@ -132,5 +143,5 @@ class BertExtractor:
# print(instances[0].entities_for_tags(b_true_ner_tags[0]))
# print('---')
extractor = BertExtractor()
extractor.train(train_data_path)
extractor = BertExtractor.default()
extractor.evaluate(test_data_path)
......@@ -2,8 +2,9 @@ from transformers import BertTokenizer
import torch
from torch.utils.data import Dataset
import pandas as pd
from relbertnet import TRAINED_WEIGHTS, MAX_SEQ_LEN
from relbertnet import TRAINED_WEIGHTS, MAX_SEQ_LEN, MAX_ENTITIES_PER_SENTENCE
from ast import literal_eval
import random
NE_TAGS_LEN = MAX_SEQ_LEN - 2
tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)
......@@ -18,32 +19,37 @@ def generate_batch(instances):
class RelDataset(Dataset):
MAX_RELATIONS = 8
def __init__(self):
self.data = []
@staticmethod
def from_file(path):
def from_file(path, n_instances=None):
dataset = RelDataset()
data = pd.read_csv(path, sep='\t', error_bad_lines=False)
data = pd.read_json(path, lines=True)
rows = [row for _, row in data.iterrows()]
dataset.data = [x for x in map(RelDataset.instance_from_row, rows) if x is not None]
if n_instances is not None:
random.shuffle(dataset.data)
dataset.data = dataset.data[:n_instances]
print("Obtained dataset of length", len(dataset.data))
return dataset
@staticmethod
def instance_from_row(row):
entities = sorted([em['text'] for em in row['entityMentions']], key=len, reverse=True)
relations = {(e1, e2): label for t in row['relationMentions']
for e1, e2, label in RelDataset.relations_from_tuple(t)}
if len(relations) == 0 or len(relations) > RelDataset.MAX_RELATIONS:
return None # include only texts with at least one and no more than MAX_RELATIONS relations
text = row['sentText']
tokens = tokenizer.tokenize(text)
if len(tokens) > NE_TAGS_LEN:
return None # include only texts that can be represented in 126 tokens, filters out 98 texts from NYT train
entities = sorted([em['text'] for em in literal_eval(row['entityMentions'])], key=len, reverse=True)
data_relations = [(m['em1Text'], m['em2Text']) for m in literal_eval(row['relationMentions'])]
relations = {(entity1, entity2): 1 if (entity1, entity2) in data_relations
else (2 if (entity2, entity1) in data_relations else 0)
for entity1 in entities for entity2 in entities}
return None # include only texts that can be represented in 126 tokens
n_entities = 0
ne_tags = []
i = 0
while i < len(tokens):
......@@ -54,11 +60,15 @@ class RelDataset(Dataset):
ne_tags += RelDataset.ne_tags_for_len(match_length)
found = True
i += match_length
n_entities += 1
break
if not found:
ne_tags += [3]
i += 1
if n_entities > MAX_ENTITIES_PER_SENTENCE:
return None # include only texts with at most MAX_ENTITIES_PER_SENTENCE entities (including duplicates)
if len(ne_tags) > NE_TAGS_LEN: # without CLS and SEP tokens
ne_tags = ne_tags[:NE_TAGS_LEN] # trim to length
else:
......@@ -66,6 +76,15 @@ class RelDataset(Dataset):
return RelInstance.from_tokens(tokens, ne_tags, relations)
@staticmethod
def relations_from_tuple(t):
if t['label'] == '/location/location/contains':
return [(t['em1Text'], t['em2Text'], 1), (t['em2Text'], t['em1Text'], 2)]
elif t['label'] == 'None':
return [(t['em1Text'], t['em2Text'], 0)]
else:
return []
@staticmethod
def token_entity_match(first_token_idx, entity, tokens):
token_idx = first_token_idx
......
......@@ -13,8 +13,8 @@ TRAINED_WEIGHTS = 'bert-base-cased' # cased works better for NER
NUM_NE_TAGS = 5 # BIEOS 0-4: [Begin Inside End Outside Single]
NUM_RELS = 3 # 0-2: [no relation, e1 featureOf e2, e2 featureOf e1]
MLP_HIDDEN_LAYER_NODES = 84
# MAX_ENTITIES_PER_SENTENCE = 8
BATCH_SIZE = 32
MAX_ENTITIES_PER_SENTENCE = 8
# Based on Xue et. al. (2019) with some modifications
......@@ -26,7 +26,7 @@ class RelBertNet(nn.Module):
# Load pretrained BERT weights
config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
self.bert1 = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
self.bert1.train()
# self.bert1.train()
# Divide BERT encoder layers into two parts
self.bert2_layers = self.bert1.encoder.layer[-K:]
......@@ -65,7 +65,7 @@ class RelBertNet(nn.Module):
entity_combinations = []
target_relation_labels = []
n_combinations_by_instance = torch.empty(BATCH_SIZE, dtype=torch.long)
n_combinations_by_instance = torch.empty(len(entities_by_instance), dtype=torch.long)
for idx_ins, entities in enumerate(entities_by_instance):
n = 0
for slice1, slice2 in list(itertools.combinations(entities, 2)):
......@@ -125,5 +125,6 @@ class RelBertNet(nn.Module):
entities.append(slice(idx+1, idx+2)) # +1 comes from CLS token
b = None
# take at max MAX_ENTITIES per instace in order not to overwhelm RC at the start when NER is warming up
return entities # if len(entities) < MAX_ENTITIES_PER_SENTENCE else sample(entities, MAX_ENTITIES_PER_SENTENCE)
# take at max MAX_ENTITIES per instace in order not to overwhelm the system at the start when NER is warming up
# (Nentities choose 2) complexity othewise per sentence
return entities if len(entities) < MAX_ENTITIES_PER_SENTENCE else sample(entities, MAX_ENTITIES_PER_SENTENCE)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment