Commit cbd56512 authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Added evaluation for bert tag extractor.

parent efbad354
......@@ -3,14 +3,14 @@ from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy
import time
from sklearn import metrics
from transformers import BertForTokenClassification, AdamW, get_cosine_schedule_with_warmup
from tagged_rel_dataset import TRAINED_WEIGHTS, MAX_SEQ_LEN, RELATIONS, IGNORE_TAG, TaggedRelDataset, generate_train_batch
from tagged_rel_dataset import TRAINED_WEIGHTS, MAX_SEQ_LEN, RELATIONS, IGNORE_TAG, N_TAGS, TaggedRelDataset, generate_train_batch, generate_eval_batch
train_data_path = 'data/train.json'
test_data_path = 'data/test.json'
trained_model_path = 'trained_bert_tag_extractor.pt'
trained_model_path = 'trained_bert_tag_extractor_2.pt'
device = torch.device('cuda')
N_TAGS = 4 * len(RELATIONS) * 2 + 1 # 9 for single relation -> 0: O, 1-4: B/I/E/S-1, 5-8: B/I/E/S-2
# optimizer
DECAY_RATE = 0.01
......@@ -25,7 +25,7 @@ WARM_UP_FRAC = 0.05
VALID_FRAC = 0.1
# loss
LOSS_BIAS_WEIGHT = 10
LOSS_BIAS_WEIGHT = 1
# as defined by Zheng et. al. (2017)
......@@ -116,12 +116,12 @@ class BertTagExtractor:
# update lr
scheduler.step()
# print interim stats every 20 batches
# print interim stats every 100 batches
batch_loss += l.item()
if batch_idx % 20 == 19:
if batch_idx % 100 == 99:
batch_no = batch_idx + 1
print('epoch:', epoch_idx + 1, '--progress: {:.4f}'.format(batch_no / len(train_loader)),
'-- batch:', batch_no, '-- avg loss:', batch_loss / 20)
print('epoch:', epoch_idx + 1, '-- progress: {:.4f}'.format(batch_no / len(train_loader)),
'-- batch:', batch_no, '-- avg loss:', batch_loss / 100)
batch_loss = 0.0
print('epoch done')
......@@ -131,8 +131,57 @@ class BertTagExtractor:
torch.save(self.net.state_dict(), trained_model_path)
def evaluate(self, file_path):
# load training data
test_data = TaggedRelDataset.from_file(file_path)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
collate_fn=generate_eval_batch)
self.net.cuda()
self.net.eval()
BertTagExtractor.new_trained_with_file(train_data_path, size=200000)
instances = []
outputs = []
targets = []
with torch.no_grad():
for batch_idx, (input_ids, attn_mask, target_tags, b_instances) in enumerate(test_loader):
# send batch to gpu
input_ids, attn_mask = tuple(i.to(device) for i in [input_ids, attn_mask])
# forward pass
output_scores = self.net(input_ids=input_ids, attention_mask=attn_mask)[0]
_, output_tags = torch.max(output_scores.data, 2)
instances += b_instances
outputs += output_tags.tolist()
targets += target_tags.tolist()
assert len(outputs) == len(targets)
# remove CLS, SEP, and PAD tags
cropped_targets = [[t for t in ins if t != IGNORE_TAG] for ins in targets]
cropped_outputs = [outputs[idx][1:len(cropped_targets[idx])+1] for idx in range(len(outputs))]
for ins, o_tags, t_tags in zip(instances, cropped_outputs, cropped_targets):
print('text:', ins.text)
print(ins.tags)
print('output:', ins.relations_for_tags(o_tags))
print('target:', ins.relations_for_tags(t_tags))
print('---')
n_correct = sum(1 for idx in range(len(outputs)) if cropped_outputs[idx] == cropped_targets[idx])
accuracy = n_correct / len(outputs)
print('accuracy:', accuracy)
# cm = metrics.confusion_matrix(targets, outputs, labels=range(N_TAGS))
# print('confusion matrix:')
# print(cm)
#
# f1 = metrics.f1_score(targets, outputs, labels=range(N_TAGS), average='macro')
# print('macro F1:', f1)
extr = BertTagExtractor.new_trained_with_file(train_data_path)
extr.evaluate(test_data_path)
......
......@@ -6,6 +6,7 @@ from collections import defaultdict
TRAINED_WEIGHTS = 'bert-base-cased' # cased works better for NER
RELATIONS = ['/location/location/contains']
N_TAGS = 4 * len(RELATIONS) * 2 + 1
MAX_SEQ_LEN = 128
MAX_TOKENS = MAX_SEQ_LEN - 2
IGNORE_TAG = -1
......@@ -22,6 +23,16 @@ def generate_train_batch(instances):
return input_ids, attn_mask, target_tags
def generate_eval_batch(instances):
encoded = tokenizer.batch_encode_plus([instance.tokens for instance in instances], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
target_tags = torch.tensor([instance.tags for instance in instances])
return input_ids, attn_mask, target_tags, instances
# Based on Zheng et. al. (2017)
class TaggedRelDataset(Dataset):
......@@ -39,6 +50,7 @@ class TaggedRelDataset(Dataset):
dataset.df = dataset.df.sample(size, replace=False)
if valid_frac is None:
print('Obtained dataset of size', len(dataset))
return dataset
else:
validset = TaggedRelDataset()
......@@ -69,7 +81,7 @@ class TaggedRelDataset(Dataset):
i += 1
tags += [IGNORE_TAG] * (MAX_SEQ_LEN - len(tags)) # pad to MAX_SEQ_LEN
return TaggedRelInstance.from_tokens(tokens, tags)
return TaggedRelInstance.from_tokens(tokens, tags, text)
# NOTE: if entity is present in more than one relation, only one (the one with the highest count) is mapped
def map_for_relation_mentions(self, relation_mentions):
......@@ -124,10 +136,93 @@ class TaggedRelInstance:
def __init__(self):
self.tokens = None
self.tags = None
self.text = None
@staticmethod
def from_tokens(tokens, tags):
def from_tokens(tokens, tags, text):
i = TaggedRelInstance()
i.tokens = tokens
i.tags = tags
i.text = text
return i
def relations_for_tags(self, tags):
# find entities
entities = self.entities_for_tags(tags)
# find relations among entities
relations = set()
for entity in entities:
# find nearest matching relation
matches = [other for other in entities - {entity}
if other.rel_idx == entity.rel_idx and other.rel_pos_idx != entity.rel_pos_idx]
if matches:
if len(matches) > 1:
match = min(matches, key=lambda m: entity.distance_to(m))
else:
match = matches[0]
fst, snd = map(self.entity_to_text, sorted([entity, match], key=lambda e: e.rel_pos_idx))
relations.add((fst, RELATIONS[entity.rel_idx], snd))
return relations
def entities_for_tags(self, tags):
assert len(tags) == len(self.tokens)
entities = set()
entity_in_progress = None
for idx, tag in enumerate(tags):
assert tag in range(N_TAGS)
if tag == 0:
entity_in_progress = None
continue
rel_idx, tag = divmod(tag - 1, 8)
rel_pos_idx, tag = divmod(tag, 4)
if tag == 0:
if self.tokens[idx].startswith('##'):
entity_in_progress = None
else:
entity_in_progress = Entity(idx, rel_idx, rel_pos_idx)
if tag == 1 and entity_in_progress:
entity_in_progress = entity_in_progress.add(idx, rel_idx, rel_pos_idx)
if tag == 2 and entity_in_progress:
entity_in_progress = entity_in_progress.add(idx, rel_idx, rel_pos_idx)
if entity_in_progress:
entities.add(entity_in_progress)
entity_in_progress = None
if tag == 3:
entities.add(Entity(idx, rel_idx, rel_pos_idx))
entity_in_progress = None
return entities
def entity_to_text(self, entity):
text = self.tokens[entity.start]
for t in self.tokens[entity.start+1:entity.end+1]:
if t.startswith('##'):
text += t[2:]
else:
text += ' ' + t
return text
class Entity:
def __init__(self, start, rel_idx, rel_pos_idx):
self.start = start
self.end = start
self.rel_idx = rel_idx
self.rel_pos_idx = rel_pos_idx
def add(self, idx, rel_idx, rel_pos_idx):
if rel_idx == self.rel_idx and rel_pos_idx == self.rel_pos_idx:
self.end = idx
return self
else:
return None
def distance_to(self, e2):
return min(abs(self.start - e2.end), abs(e2.start - self.end))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment