Commit d79bbd1a authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Finished loss functionality for bert_extractor. Testing next.

parent a28029dd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.functional import cross_entropy
from torch.utils.data import DataLoader
import numpy as np
from sklearn import metrics
import time
from rel_dataset import RelInstance, RelDataset, generate_batch
from rel_dataset import RelInstance, RelDataset, generate_batch, tokenizer
from relbertnet import RelBertNet, NUM_RELS
train_set_path = 'data/location_contains_train_set.tsv'
train_data_path = 'data/location_contains_train_set.tsv'
trained_model_path = 'bert_extractor.pt'
BATCH_SIZE = 32
BATCH_SIZE = 8
MAX_EPOCHS = 6
LEARNING_RATE = 0.00002
loss_criterion = nn.CrossEntropyLoss()
def loss(ner_output, rc_output, ner_loss, true_rels):
return torch.sum(ner_loss)
def loss(ner_loss: torch.Tensor, rc_output, n_combinations_by_instance, entity_ranges, instances):
correct_labels = torch.zeros(len(rc_output), dtype=torch.uint8)
idx_ins = 0
n = n_combinations_by_instance[idx_ins]
for em_idx, (er1, er2) in enumerate(entity_ranges):
while n == 0:
idx_ins += 1
n = n_combinations_by_instance[idx_ins]
relation_label = instances[idx_ins].relation(er1, er2)
correct_labels[em_idx] = relation_label if relation_label is not None else -1
n -= 1
return sum([ner_loss, cross_entropy(rc_output, correct_labels, ignore_index=-1)])
class BertExtractor:
......@@ -46,15 +60,15 @@ class BertExtractor:
for epoch in range(MAX_EPOCHS):
batch_loss = 0.0
for idx, (batch, true_ner_tags, true_rels) in enumerate(train_loader):
for idx, (batch, true_ner_tags, instances) in enumerate(train_loader):
# zero param gradients
optimiser.zero_grad()
# forward pass
ner_output, rc_output, ner_loss = self.net(batch, true_ner_tags)
_, ner_loss, rc_output, n_combinations_by_instance, entity_ranges = self.net(batch, true_ner_tags)
# backward pass
l = loss(ner_output, rc_output, ner_loss, true_rels)
l = loss(ner_loss, rc_output, n_combinations_by_instance, entity_ranges, instances)
l.backward()
# optimise
......@@ -62,7 +76,7 @@ class BertExtractor:
# print interim stats every 10 batches
batch_loss += l.item()
if idx % 10 == 9:
if idx % 5 == 4:
print('epoch:', epoch + 1, '-- batch:', idx + 1, '-- avg loss:', batch_loss / 10)
batch_loss = 0.0
......@@ -84,7 +98,7 @@ class BertExtractor:
with torch.no_grad():
for batch, b_true_ner_tags, b_true_rels in test_loader:
ner_output, rc_output, _ = self.net(batch)
ner_output, _, rc_output = self.net(batch)
_, rc_max = torch.max(rc_output.data, 1)
predicted_ner_tags += ner_output.tolist()
......@@ -105,11 +119,28 @@ class BertExtractor:
print('macro F1:', f1)
# dataset = RelDataset.from_texts(['Testing if this works.', 'Since I\'m not sure.'])
# dataset = RelDataset.from_texts(['A giraffe made friends with a lion at the Zoo.', 'Since I\'m not sure.'])
# loader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=4, collate_fn=generate_batch)
# batch = next(iter(loader))
#
# i = RelInstance('Testing if this works.')
# texts, _, _ = next(iter(loader))
#
# print(tokenizer.convert_ids_to_tokens(texts['input_ids'][0]))
# net = RelBertNet()
# net(batch)
\ No newline at end of file
# ner_output, _, rc_output, n_combinations_by_instance, entity_ranges = net(texts)
# print(ner_output)
# print(rc_output)
# print(n_combinations_by_instance)
# print(entity_ranges)
# train_data = RelDataset.from_file(train_set_path)
# train_loader = DataLoader(train_data, batch_size=1, shuffle=True, num_workers=4,
# collate_fn=generate_batch)
# for i in [1, 1, 1, 1, 1]:
# batch, b_true_ner_tags, b_true_rels, instances = next(iter(train_loader))
# print(b_true_rels[0])
# print(tokenizer.convert_ids_to_tokens(batch['input_ids'][0]))
# print(b_true_ner_tags[0])
# print(instances[0].entities_for_tags(b_true_ner_tags[0]))
# print('---')
extractor = BertExtractor()
extractor.train(train_data_path)
from transformers import BertTokenizer
import torch
from torch.utils.data import Dataset
import pandas as pd
from relbertnet import TRAINED_WEIGHTS
from relbertnet import TRAINED_WEIGHTS, MAX_SEQ_LEN
from ast import literal_eval
MAX_SEQ_LEN = 128
NE_TAGS_LEN = MAX_SEQ_LEN - 2
tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)
def generate_batch(batch):
texts = tokenizer.batch_encode_plus([tokens for tokens in batch], add_special_tokens=True,
def generate_batch(instances):
texts = tokenizer.batch_encode_plus([instance.tokens for instance in instances], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
return texts
true_ne_tags = torch.tensor([instance.ne_tags for instance in instances])
return texts, true_ne_tags, instances
class RelDataset(Dataset):
......@@ -20,7 +23,7 @@ class RelDataset(Dataset):
self.data = []
@staticmethod
def from_data(path):
def from_file(path):
dataset = RelDataset()
data = pd.read_csv(path, sep='\t', error_bad_lines=False)
dataset.data = [RelDataset.instance_from_row(row) for _, row in data.iterrows()]
......@@ -28,28 +31,88 @@ class RelDataset(Dataset):
@staticmethod
def instance_from_row(row):
tokens = row['sentText'].split(' ')
entities = sorted([em['text'].split(' ') for em in row['entityMentions']], key=len, reverse=True)
relations = [(m['em1Text'], m['em2Text']) for m in row['relationMentions']]
text = row['sentText']
tokens = tokenizer.tokenize(text)
entities = sorted([em['text'] for em in literal_eval(row['entityMentions'])], key=len, reverse=True)
data_relations = [(m['em1Text'], m['em2Text']) for m in literal_eval(row['relationMentions'])]
relations = {(entity1, entity2): 1 if (entity1, entity2) in data_relations
else (2 if (entity2, entity1) in data_relations else 0)
for entity1 in entities for entity2 in entities}
ne_tags = []
i = 0
while i < len(tokens):
found = False
for entity in entities:
if tokens[i:i+len(entity)] == entity:
ne_tags += RelDataset.ne_tags_for_len(len(entity))
match_length = RelDataset.token_entity_match(i, entity, tokens)
if match_length is not None:
ne_tags += RelDataset.ne_tags_for_len(match_length)
found = True
i += len(entity)
i += match_length
break
if not found:
ne_tags += [3]
i += 1
if len(ne_tags) > NE_TAGS_LEN: # without CLS and SEP tokens
ne_tags = ne_tags[:NE_TAGS_LEN] # trim to length
else:
ne_tags = ne_tags + [3] * (NE_TAGS_LEN - len(ne_tags)) # pad to length
return RelInstance.from_tokens(tokens, ne_tags, relations)
@staticmethod
def token_entity_match(first_token_idx, entity, tokens):
token_idx = first_token_idx
remaining_entity = entity
while remaining_entity:
if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity:
# start of new word
remaining_entity = remaining_entity.lstrip()
if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]:
remaining_entity = remaining_entity[len(tokens[token_idx]):]
token_idx += 1
else:
break
else:
# continuing same word
if (token_idx < len(tokens) and tokens[token_idx].startswith('##')
and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]):
remaining_entity = remaining_entity[len(tokens[token_idx][2:]):]
token_idx += 1
else:
break
if remaining_entity:
return None
else:
return token_idx - first_token_idx
@staticmethod
def join_tokens(tokens):
joined_tokens = []
joint_token = ''
for token in tokens:
if token.startswith('##'):
joint_token += token[2:]
else:
if joint_token:
joined_tokens.append(joint_token)
joint_token = ''
joined_tokens.append(token)
return joined_tokens
@staticmethod
def range_for_token(tokens, text, t_idx):
start = sum(len(t.replace('##', '')) for t in tokens[:t_idx])
end = start + len(tokens[t_idx].replace('##', ''))
return range(start, end + 1)
@staticmethod
def ne_tags_for_len(n):
assert n > 0
return [4] if n == 1 else [1] + [2] * (n-2) + [3]
return [4] if n == 1 else [0] + [1] * (n-2) + [2]
@staticmethod
def from_texts(texts):
......@@ -61,8 +124,7 @@ class RelDataset(Dataset):
return len(self.data)
def __getitem__(self, idx):
instance = self.data[idx]
return instance.get()
return self.data[idx]
class RelInstance:
......@@ -71,6 +133,8 @@ class RelInstance:
self.tokens = None
self.ne_tags = None
self.relations = None
self.text = None
self.token_to_text = None
@staticmethod
def from_text(text):
......@@ -94,3 +158,51 @@ class RelInstance:
encoded = tokenizer.encode_plus(tokens, add_special_tokens=True, max_length=MAX_SEQ_LEN,
is_pretokenized=True, return_tensors='pt')
return encoded
def relation(self, er1, er2):
e1 = self.entity_for_range(er1)
e2 = self.entity_for_range(er2)
if (e1, e2) in self.relations:
return self.relations[(e1, e2)]
else:
return None
def entity_for_range(self, r):
start, end = r
entity = self.tokens[start]
for idx, t in enumerate(self.tokens[start+1:end]):
if t.startswith('##'):
entity += t[:2]
else:
entity += ' ' + t
return entity
def entities_for_tags(self, tags):
assert len(tags) == len(self.tokens)
entities = []
entity_in_progress = None
for idx, tag in enumerate(tags):
assert tag in range(5)
if tag == 0 and not self.tokens[idx].startswith('##'):
entity_in_progress = self.tokens[idx]
if tag == 1 and entity_in_progress:
if self.tokens[idx].startswith('##'):
entity_in_progress = entity_in_progress + self.tokens[idx][2:]
else:
entity_in_progress = entity_in_progress + ' ' + self.tokens[idx]
if tag == 2 and entity_in_progress:
if self.tokens[idx].startswith('##'):
entities.append(entity_in_progress + self.tokens[idx][2:])
else:
entities.append(entity_in_progress + ' ' + self.tokens[idx])
entity_in_progress = None
if tag == 3 and entity_in_progress:
entity_in_progress = None
if tag == 4:
entities.append(self.tokens[idx])
entity_in_progress = None
return entities
......@@ -8,9 +8,9 @@ import itertools
K = 4 # number of hidden layers in Bert2
HIDDEN_OUTPUT_FEATURES = 768
MAX_SEQ_LEN = 128
TRAINED_WEIGHTS = 'bert-base-uncased'
TRAINED_WEIGHTS = 'bert-base-cased' # cased works better for NER
NUM_NE_TAGS = 5 # BIEOS 0-4: [Begin Inside End Outside Single]
NUM_RELS = 3 # 0-2: [e1 featureOf e2, e2 featureOf e1, no relation]
NUM_RELS = 3 # 0-2: [no relation, e1 featureOf e2, e2 featureOf e1]
MLP_HIDDEN_LAYER_NODES = 84
......@@ -50,12 +50,12 @@ class RelBertNet(nn.Module):
bert_ner_output, = layer(bert_ner_output, attention_mask=extended_attn_mask)
# CRF for NER
bert_ner_output = bert_ner_output.narrow(1, 1, bert_ner_output.size()[1] - 1)
crf_attn_mask = attn_mask.narrow(1, 1, attn_mask.size()[1] - 1).type(torch.uint8) # mask out CLS token
bert_ner_output = bert_ner_output.narrow(1, 1, MAX_SEQ_LEN - 2) # remove CLS and last token
crf_attn_mask = attn_mask.narrow(1, 2, attn_mask.size()[1] - 2).type(torch.uint8) # mask out SEP token
emissions = self.ner_linear(bert_ner_output)
ner_output = self.crf.decode(emissions, mask=crf_attn_mask)
# calculate loss if tags provided
ner_loss = -self.crf(emissions, ner_tags, mask=crf_attn_mask, reduction='mean') if ner_tags else None
ner_loss = None if ner_tags is None else -self.crf(emissions, ner_tags, mask=crf_attn_mask, reduction='mean')
# obtain pairs of entities
entities_by_instance = [RelBertNet.bieos_to_entities(tags) for tags in ner_output]
......@@ -64,15 +64,17 @@ class RelBertNet(nn.Module):
flat_combinations = [comb for combs in combinations_by_instance for comb in combs]
# if no entity pairs, cannot find relations so return
if not any(n > 2 for n in n_combinations_by_instance) == 0:
return ner_output, None, ner_loss
if not any(n > 2 for n in n_combinations_by_instance):
return ner_output, ner_loss, None, None, None
# for each pair of named entities recognized, perform BERT2 with MASKrc for RC
rc_attn_mask = torch.zeros((len(flat_combinations), MAX_SEQ_LEN), dtype=torch.long)
entity_ranges = torch.zeros((len(flat_combinations), 2, 2), dtype=torch.long)
for i, (slice1, slice2) in enumerate(flat_combinations):
rc_attn_mask[i][0] = 1
rc_attn_mask[i][slice1] = 1
rc_attn_mask[i][slice2] = 1
entity_ranges[i] = torch.tensor([[slice1.start, slice1.stop], [slice2.start, slice2.stop]]) - 1
bert_rc_output = torch.repeat_interleave(bert_context_output, n_combinations_by_instance, dim=0)
extended_rc_attn_mask = rc_attn_mask[:, None, None, :]
......@@ -82,10 +84,12 @@ class RelBertNet(nn.Module):
# MLP for RC
rc_cls_output = bert_rc_output.narrow(1, 0, 1).squeeze(1) # just CLS token
rc_hidden_layer_output = torch.tanh(self.mlp1(rc_cls_output)) # tanh activation
rc_output = F.softmax(self.mlp2(rc_hidden_layer_output), 1) # softmax activation
rc_output = self.mlp2(rc_hidden_layer_output) # softmax activation
# entity_masks = entity_masks.narrow(2, 1, MAX_SEQ_LEN - 2).type(torch.uint8) # without CLS and SEP tokens
# Return NER and RC outputs
return ner_output, rc_output, ner_loss
return ner_output, ner_loss, rc_output, n_combinations_by_instance, entity_ranges
@staticmethod
def bieos_to_entities(tags):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment