Commit 9288c473 authored by Joel Oksanen's avatar Joel Oksanen

Implemented BertRelExtractor and started implementing next version.

parent e183155c
import torch
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import time
import numpy as np
from sklearn import metrics
from transformers import get_linear_schedule_with_warmup
from agent.target_extraction.BERT.pair_rel_dataset import PairRelDataset, generate_batch, generate_production_batch
from agent.target_extraction.BERT.pairbertnet import NUM_CLASSES, PairBertNet
trained_model_path = 'trained_bert_rel_extractor_camera_and_backpack_with_nan.pt'
device = torch.device('cuda')
loss_criterion = CrossEntropyLoss()
# optimizer
DECAY_RATE = 0.01
LEARNING_RATE = 0.00002
MAX_GRAD_NORM = 1.0
# training
N_EPOCHS = 3
BATCH_SIZE = 32
WARM_UP_FRAC = 0.05
class BertRelExtractor:
def __init__(self):
self.net = None
@staticmethod
def load_saved(path):
extr = BertRelExtractor()
extr.net = PairBertNet()
extr.net.load_state_dict(torch.load(path))
extr.net.eval()
return extr
@staticmethod
def new_trained_with_file(file_path, size=None):
extractor = BertRelExtractor()
extractor.train_with_file(file_path, size=size)
return extractor
@staticmethod
def train_and_validate(file_path, valid_frac, size=None):
extractor = BertRelExtractor()
extractor.train_with_file(file_path, size=size, valid_frac=valid_frac)
return extractor
def train_with_file(self, file_path, size=None, valid_frac=None):
# load training data
train_data, valid_data = PairRelDataset.from_file(file_path, size=size, valid_frac=valid_frac)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
collate_fn=generate_batch)
# initialise BERT
self.net = PairBertNet()
self.net.cuda()
# set up optimizer with weight decay
optimiser = Adam(self.net.parameters(), lr=LEARNING_RATE)
# set up scheduler for lr
n_training_steps = len(train_loader)*N_EPOCHS
scheduler = get_linear_schedule_with_warmup(
optimiser,
num_warmup_steps=int(WARM_UP_FRAC*n_training_steps),
num_training_steps=n_training_steps
)
start = time.time()
for epoch_idx in range(N_EPOCHS):
self.net.train()
batch_loss = 0.0
for batch_idx, batch in enumerate(train_loader):
# send batch to gpu
input_ids, attn_mask, fst_indices, snd_indices, target_labels = tuple(i.to(device) for i in batch)
# zero param gradients
optimiser.zero_grad()
# forward pass
output_scores = self.net(input_ids, attn_mask, fst_indices, snd_indices)
# backward pass
loss = loss_criterion(output_scores, target_labels)
loss.backward()
# clip gradient norm
clip_grad_norm_(parameters=self.net.parameters(), max_norm=MAX_GRAD_NORM)
# optimise
optimiser.step()
# update lr
scheduler.step()
# print interim stats every 500 batches
batch_loss += loss.item()
if batch_idx % 500 == 499:
batch_no = batch_idx + 1
print('epoch:', epoch_idx + 1, '-- progress: {:.4f}'.format(batch_no / len(train_loader)),
'-- batch:', batch_no, '-- avg loss:', batch_loss / 500)
batch_loss = 0.0
print('epoch done')
if valid_data is not None:
self.evaluate(data=valid_data)
end = time.time()
print('Training took', end - start, 'seconds')
torch.save(self.net.state_dict(), trained_model_path)
def evaluate(self, file_path=None, data=None, size=None):
# load eval data
if file_path is not None:
test_data, _ = PairRelDataset.from_file(file_path, size=size)
else:
if data is None:
raise AttributeError('file_path and data cannot both be None')
test_data = data
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
collate_fn=generate_batch)
self.net.cuda()
self.net.eval()
outputs = []
targets = []
with torch.no_grad():
for batch in test_loader:
# send batch to gpu
input_ids, attn_mask, fst_indices, snd_indices, target_labels = tuple(i.to(device) for i in batch)
# forward pass
output_scores = self.net(input_ids, attn_mask, fst_indices, snd_indices)
_, output_labels = torch.max(output_scores.data, 1)
outputs += output_labels.tolist()
targets += target_labels.tolist()
assert len(outputs) == len(targets)
correct = (np.array(outputs) == np.array(targets))
accuracy = correct.sum() / correct.size
print('accuracy:', accuracy)
cm = metrics.confusion_matrix(targets, outputs, labels=range(NUM_CLASSES))
print('confusion matrix:')
print(cm)
f1 = metrics.f1_score(targets, outputs, labels=range(NUM_CLASSES), average='macro')
print('macro F1:', f1)
def extract_single_relation(self, text, e1, e2):
ins = PairRelDataset.get_instance(text, e1, e2)
input_ids, attn_mask, fst_indices, snd_indices, instances = generate_production_batch([ins])
self.net.cuda()
self.net.eval()
with torch.no_grad():
# send batch to gpu
input_ids, attn_mask, fst_indices, snd_indices = tuple(i.to(device) for i in
[input_ids, attn_mask, fst_indices, snd_indices])
# forward pass
output_scores = self.net(input_ids, attn_mask, fst_indices, snd_indices)
_, output_labels = torch.max(output_scores.data, 1)
print(instances[0].get_relation_for_label(output_labels[0]))
def extract_relations(self, file_path=None, dataset=None, size=None):
# load data
if file_path is not None:
data, _ = PairRelDataset.from_file(file_path, size=size)
else:
if dataset is None:
raise AttributeError('file_path and data cannot both be None')
data = dataset
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
collate_fn=generate_production_batch)
self.net.cuda()
self.net.eval()
outputs = []
with torch.no_grad():
for input_ids, attn_mask, fst_indices, snd_indices, instances in loader:
# send batch to gpu
input_ids, attn_mask, fst_indices, snd_indices = tuple(i.to(device) for i in
[input_ids, attn_mask, fst_indices, snd_indices])
# forward pass
output_scores = self.net(input_ids, attn_mask, fst_indices, snd_indices)
_, output_labels = torch.max(output_scores.data, 1)
outputs += map(lambda x: x[0].get_relation_for_label(x[1]), zip(instances, output_labels.tolist()))
for ins, out in zip(instances, output_labels.tolist()):
print(ins.text)
print(ins.get_relation_for_label(out))
print('---')
return outputs
extr = BertRelExtractor.load_saved('trained_bert_rel_extractor_camera_and_backpack_with_nan.pt')
extr.evaluate(file_path='data/annotated_laptop_review_pairs_with_nan.tsv', size=10000)
......@@ -122,28 +122,28 @@ class BertTagExtractor:
# update lr
scheduler.step()
# print interim stats every 100 batches
# print interim stats every 500 batches
batch_loss += l.item()
if batch_idx % 100 == 99:
if batch_idx % 500 == 499:
batch_no = batch_idx + 1
print('epoch:', epoch_idx + 1, '-- progress: {:.4f}'.format(batch_no / len(train_loader)),
'-- batch:', batch_no, '-- avg loss:', batch_loss / 100)
'-- batch:', batch_no, '-- avg loss:', batch_loss / 500)
batch_loss = 0.0
print('epoch done')
if valid_data is not None:
self.evaluate(data=valid_data)
end = time.time()
print('Training took', end - start, 'seconds')
torch.save(self.net.state_dict(), trained_model_path)
if valid_data is not None:
self.evaluate(data=valid_data)
def evaluate(self, file_path=None, data=None):
def evaluate(self, file_path=None, data=None, size=None):
# load training data
if file_path is not None:
test_data = TaggedRelDataset.from_file(file_path)
test_data, _ = TaggedRelDataset.from_file(file_path, size=size)
else:
if data is None:
raise AttributeError('file_path and data cannot both be None')
......@@ -195,8 +195,9 @@ class BertTagExtractor:
# print('macro F1:', f1)
BertTagExtractor.train_and_validate('data/annotated_camera_reviews.tsv', 0.05, size=200000)
# BertTagExtractor.train_and_validate('data/annotated_camera_reviews.tsv', 0.05, size=200000)
extr: BertTagExtractor = BertTagExtractor.default()
extr.evaluate(file_path='data/annotated_laptop_reviews.tsv', size=20000)
......
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
import pandas as pd
import numpy as np
from ast import literal_eval
from agent.target_extraction.BERT.pairbertnet import TRAINED_WEIGHTS, HIDDEN_OUTPUT_FEATURES
MAX_SEQ_LEN = 128
MASK_TOKEN = '[MASK]'
RELATIONS = ['/has_feature', '/no_relation', '/nan_relation']
tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)
def generate_batch(batch):
encoded = tokenizer.batch_encode_plus([instance.tokens for instance in batch], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
both_ranges = [instance.entity_ranges for instance in batch]
fst_indices, snd_indices = map(indices_for_entity_ranges, zip(*both_ranges))
labels = torch.tensor([instance.label for instance in batch])
return input_ids, attn_mask, fst_indices, snd_indices, labels
def generate_production_batch(batch):
encoded = tokenizer.batch_encode_plus([instance.tokens for instance in batch], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
both_ranges = [instance.entity_ranges for instance in batch]
fst_indices, snd_indices = map(indices_for_entity_ranges, zip(*both_ranges))
return input_ids, attn_mask, fst_indices, snd_indices, batch
def indices_for_entity_ranges(ranges):
max_e_len = max(end - start for start, end in ranges)
indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES
for t in range(start, start + max_e_len + 1)]
for start, end in ranges])
return indices
class PairRelDataset(Dataset):
def __init__(self, df, size=None):
# filter inapplicable rows
self.df = df[df.apply(lambda x: PairRelDataset.instance_from_row(x) is not None, axis=1)]
# sample data if a size is specified
if size is not None and size < len(self):
self.df = self.df.sample(size, replace=False)
@staticmethod
def from_df(df, size=None):
dataset = PairRelDataset(df, size=size)
print('Obtained dataset of size', len(dataset))
return dataset
@staticmethod
def from_file(path, valid_frac=None, size=None):
if path.endswith('.json'):
dataset = PairRelDataset(pd.read_json(path, lines=True), size=size)
elif path.endswith('.tsv'):
dataset = PairRelDataset(pd.read_csv(path, sep='\t', error_bad_lines=False), size=size)
else:
raise AttributeError('Could not recognize file type')
if valid_frac is None:
print('Obtained dataset of size', len(dataset))
return dataset, None
else:
split_idx = int(len(dataset) * (1 - valid_frac))
dataset.df, valid_df = np.split(dataset.df, [split_idx], axis=0)
validset = PairRelDataset(valid_df)
print('Obtained train set of size', len(dataset), 'and validation set of size', len(validset))
return dataset, validset
@staticmethod
def instance_from_row(row):
unpacked_arr = literal_eval(row['relationMentions']) if type(row['relationMentions']) is str else row['relationMentions']
rms = [rm for rm in unpacked_arr if 'label' not in rm or rm['label'] in RELATIONS]
if len(rms) == 1:
e1, e2, label = rms[0]['em1Text'], rms[0]['em2Text'], (rms[0]['label'] if 'label' in rms[0] else None)
else:
return None # raise AttributeError('Instances must have exactly one relation')
text = row['sentText']
return PairRelDataset.get_instance(text, e1, e2, label=label)
@staticmethod
def get_instance(text, e1, e2, label=None):
tokens = tokenizer.tokenize(text)
i = 0
found_entities = []
ranges = []
while i < len(tokens):
match = False
for entity in [e1, e2]:
match_length = PairRelDataset.token_entity_match(i, entity.lower(), tokens)
if match_length is not None:
if entity in found_entities:
return None # raise AttributeError('Entity {} appears twice in text {}'.format(entity, text))
tokens[i:i + match_length] = [MASK_TOKEN] * match_length
match = True
found_entities.append(entity)
ranges.append((i, i + match_length - 1))
i += match_length
break
if not match:
i += 1
if found_entities == [e1, e2] and label is None:
return PairRelInstance(tokens, e1, e2, tuple(ranges), None, text)
elif found_entities == [e2, e1] and label is None:
return PairRelInstance(tokens, e2, e1, tuple(ranges), None, text)
if len(found_entities) == 2 and label in ['/no_relation', '/nan_relation']:
return PairRelInstance(tokens, e1, e2, tuple(ranges), 0, text)
elif found_entities == [e1, e2] and label == '/has_feature':
return PairRelInstance(tokens, e1, e2, tuple(ranges), 1, text)
elif found_entities == [e2, e1] and label == '/has_feature':
return PairRelInstance(tokens, e2, e1, tuple(ranges), 2, text)
else:
return None # raise AttributeError('Could not find entities {} and {} in {}. Found entities {}'.format(e1, e2, text, found_entities))
@staticmethod
def token_entity_match(first_token_idx, entity, tokens):
token_idx = first_token_idx
remaining_entity = entity
while remaining_entity:
if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity:
# start of new word
remaining_entity = remaining_entity.lstrip()
if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]:
remaining_entity = remaining_entity[len(tokens[token_idx]):]
token_idx += 1
else:
break
else:
# continuing same word
if (token_idx < len(tokens) and tokens[token_idx].startswith('##')
and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]):
remaining_entity = remaining_entity[len(tokens[token_idx][2:]):]
token_idx += 1
else:
break
if remaining_entity:
return None
else:
return token_idx - first_token_idx
def __len__(self):
return len(self.df.index)
def __getitem__(self, idx):
return PairRelDataset.instance_from_row(self.df.iloc[idx])
class PairRelInstance:
def __init__(self, tokens, fst_e, snd_e, entity_ranges, label, text):
self.tokens = tokens
self.fst_e = fst_e
self.snd_e = snd_e
self.entity_ranges = entity_ranges
self.label = label
self.text = text
def get_relation_for_label(self, label):
if label == 0:
return self.fst_e, '/no_relation', self.snd_e
if label == 1:
return self.fst_e, '/has_feature', self.snd_e
if label == 2:
return self.snd_e, '/has_feature', self.fst_e
# def range_to_entity(self, e_range):
# start, end = e_range
# text = self.tokens[start]
# for t in self.tokens[start + 1:end]:
# if t.startswith('##'):
# text += t[2:]
# else:
# text += ' ' + t
# return text
import torch
import torch.nn as nn
from transformers import *
HIDDEN_OUTPUT_FEATURES = 768
TRAINED_WEIGHTS = 'bert-base-uncased'
NUM_CLASSES = 3 # no relation, e1 featureOf e2, e2 featureOf e1
class PairBertNet(nn.Module):
def __init__(self):
super(PairBertNet, self).__init__()
config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES) # 2 * n of hidden features, n of output labels
def forward(self, input_ids, attn_mask, fst_e_indices, snd_e_indices):
# BERT
_, pooler_output = self.bert_base(input_ids=input_ids, attention_mask=attn_mask)
# max pooling at entity locations
# fst_e_outputs = torch.gather(bert_output, dim=1, index=fst_e_indices)
# snd_e_outputs = torch.gather(bert_output, dim=1, index=snd_e_indices)
# fst_pooled_output, _ = torch.max( , dim=1)
# snd_pooled_output, _ = torch.max(snd_e_outputs, dim=1)
# combined = torch.cat((fst_pooled_output, snd_pooled_output), dim=1)
# fc layer (softmax activation done in loss function)
x = self.fc(pooler_output)
return x
......@@ -9,6 +9,7 @@ from collections import Counter
import pickle
import os
import readchar
import random
from sty import fg, bg
from anytree import Node, RenderTree, LevelOrderIter, PreOrderIter
......@@ -20,9 +21,10 @@ ann_bgs = [bg.blue, bg.red] # child, parent
class EntityAnnotator:
def __init__(self, text_file_path, counter, save_path):
def __init__(self, text_file_path, counter, phraser, save_path):
self.text_file_path = text_file_path
self.counter = counter
self.phraser = phraser
self.save_path = save_path
self.root = None
self.synset = {}
......@@ -36,9 +38,9 @@ class EntityAnnotator:
for _, par in df.sample(frac=1)['reviewText'].items() if not pd.isnull(par)
for text in sent_tokenize(par)]
print('obtaining counter...')
counter = EntityAnnotator.count_nouns(texts)
counter, phraser = EntityAnnotator.count_nouns(texts)
print('finished initialising annotator')
ann = EntityAnnotator(file_path, counter, name + '.pickle')
ann = EntityAnnotator(file_path, counter, phraser, name + '.pickle')
ann.save()
return ann
......@@ -89,7 +91,7 @@ class EntityAnnotator:
if idx % 1000 == 0:
print(' {:0.2f} done'.format((idx + 1) / len(texts)))
return Counter(nouns)
return Counter(nouns), phraser
@staticmethod
def is_noun(pos_tagged):
......@@ -205,6 +207,92 @@ class EntityAnnotator:
node.n = i
i += 1
def save_annotated_pairs(self, save_path):
reviews = pd.read_csv(self.text_file_path, sep='\t', error_bad_lines=False)
texts = [text for _, par in reviews['reviewText'].items() if not pd.isnull(par)
for text in sent_tokenize(par)]
nan_entities = self.get_nan_entities()
pair_texts = [t for t in map(lambda t: self.pair_relations_for_text(t, nan_entities), texts)
if t is not None]
df = pd.DataFrame(pair_texts, columns=['sentText', 'relationMentions'])
df.to_csv(save_path, sep='\t', index=False)
@staticmethod
def get_entity_relation(mention1, mention2):
n1, e1 = mention1
n2, e2 = mention2
if n2 in n1.descendants:
return {'em1Text': e1, 'em2Text': e2, 'label': '/has_feature'}
elif n1 in n2.descendants:
return {'em1Text': e2, 'em2Text': e1, 'label': '/has_feature'}
else:
# randomise order of no rel tuple to avoid bias
m = [e1, e2]
random.shuffle(m)
return {'em1Text': m[0], 'em2Text': m[1], 'label': '/no_relation'}
def pair_relations_for_text(self, text, nan_entities):
tokens = self.phraser[word_tokenize(text)]
entity_mentions = []
for n in PreOrderIter(self.root):
cont, mention = self.mention_in_text(text, tokens, node=n)
if not cont:
# many mentions of same entity
return None
if mention is not None:
entity_mentions.append((n, mention))
if len(entity_mentions) > 2:
# text cannot have more than two entity mentions
return None
if len(entity_mentions) == 2:
return text, [EntityAnnotator.get_entity_relation(entity_mentions[0], entity_mentions[1])]
if len(entity_mentions) == 1:
nan_mention = None
for term in nan_entities:
cont, mention = self.mention_in_text(text, tokens, term=term)
if not cont:
# many mentions of term
return None
if mention is not None:
if nan_mention is not None:
# text cannot have more than one nan mention
return None
nan_mention = mention
if nan_mention is not None: