Commit d9b522d6 authored by Joel Oksanen's avatar Joel Oksanen

Implemented masked two-way rel extractor and entity extractor

parent 816e6691
import torch
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from torch.nn.functional import softmax
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import time
import numpy as np
from sklearn import metrics
from transformers import get_linear_schedule_with_warmup
from agent.target_extraction.BERT.entity_extractor.entity_dataset import EntityDataset, generate_batch, generate_production_batch
from agent.target_extraction.BERT.entity_extractor.entitybertnet import NUM_CLASSES, EntityBertNet
device = torch.device('cuda')
# optimizer
DECAY_RATE = 0.01
LEARNING_RATE = 0.00002
MAX_GRAD_NORM = 1.0
# training
N_EPOCHS = 2
BATCH_SIZE = 32
WARM_UP_FRAC = 0.05
# loss
loss_criterion = CrossEntropyLoss()
class BertEntityExtractor:
def __init__(self):
self.net = EntityBertNet()
@staticmethod
def load_saved(path):
extr = BertEntityExtractor()
extr.net = EntityBertNet()
extr.net.load_state_dict(torch.load(path))
extr.net.eval()
return extr
@staticmethod
def new_trained_with_file(file_path, save_path, size=None):
extractor = BertEntityExtractor()
extractor.train_with_file(file_path, save_path, size=size)
return extractor
@staticmethod
def train_and_validate(file_path, save_path, size=None, valid_frac=None, valid_file_path=None):
extractor = BertEntityExtractor()
extractor.train_with_file(file_path, save_path, size=size, valid_frac=valid_frac,
valid_file_path=valid_file_path)
return extractor
def train_with_file(self, file_path, save_path, size=None, valid_frac=None, valid_file_path=None):
# load training data
if valid_file_path is None:
train_data, valid_data = EntityDataset.from_file(file_path, size=size, valid_frac=valid_frac)
else:
train_size = int(size * (1 - valid_frac)) if size is not None else None
train_data, _ = EntityDataset.from_file(file_path, size=train_size)
valid_size = int(size * valid_frac) if size is not None else int(len(train_data) * valid_frac)
valid_data, _ = EntityDataset.from_file(valid_file_path, size=valid_size)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
collate_fn=generate_batch)
# initialise BERT
self.net.cuda()
# set up optimizer with weight decay
optimiser = Adam(self.net.parameters(), lr=LEARNING_RATE)
# set up scheduler for lr
n_training_steps = len(train_loader) * N_EPOCHS
scheduler = get_linear_schedule_with_warmup(
optimiser,
num_warmup_steps=int(WARM_UP_FRAC * n_training_steps),
num_training_steps=n_training_steps
)
start = time.time()
for epoch_idx in range(N_EPOCHS):
self.net.train()
batch_loss = 0.0
for batch_idx, batch in enumerate(train_loader):
# send batch to gpu
input_ids, attn_mask, entity_indices, target_labels = tuple(i.to(device) for i in batch)
# zero param gradients
optimiser.zero_grad()
# forward pass
output_scores = self.net(input_ids, attn_mask, entity_indices)
# backward pass
loss = loss_criterion(output_scores, target_labels)
loss.backward()
# clip gradient norm
clip_grad_norm_(parameters=self.net.parameters(), max_norm=MAX_GRAD_NORM)
# optimise
optimiser.step()
# update lr
scheduler.step()
# print interim stats every 250 batches
batch_loss += loss.item()
if batch_idx % 250 == 249:
batch_no = batch_idx + 1
print('epoch:', epoch_idx + 1, '-- progress: {:.4f}'.format(batch_no / len(train_loader)),
'-- batch:', batch_no, '-- avg loss:', batch_loss / 250)
batch_loss = 0.0
print('epoch done')
if valid_data is not None:
self.evaluate(data=valid_data)
end = time.time()
print('Training took', end - start, 'seconds')
torch.save(self.net.state_dict(), save_path)
def evaluate(self, file_path=None, data=None, size=None):
# load eval data
if file_path is not None:
test_data, _ = EntityDataset.from_file(file_path, size=size)
else:
if data is None:
raise AttributeError('file_path and data cannot both be None')
test_data = data
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
collate_fn=generate_batch)
self.net.cuda()
self.net.eval()
outputs = []
targets = []
with torch.no_grad():
for batch in test_loader:
# send batch to gpu
input_ids, attn_mask, entity_indices, target_labels = tuple(i.to(device) for i in batch)
# forward pass
output_scores = self.net(input_ids, attn_mask, entity_indices)
_, output_labels = torch.max(output_scores.data, 1)
outputs += output_labels.tolist()
targets += target_labels.tolist()
assert len(outputs) == len(targets)
correct = (np.array(outputs) == np.array(targets))
accuracy = correct.sum() / correct.size
print('accuracy:', accuracy)
cm = metrics.confusion_matrix(targets, outputs, labels=range(NUM_CLASSES))
print('confusion matrix:')
print(cm)
f1 = metrics.f1_score(targets, outputs, labels=range(NUM_CLASSES), average='macro')
print('macro F1:', f1)
precision = metrics.precision_score(targets, outputs, average=None)
print('precision:', precision)
recall = metrics.recall_score(targets, outputs, average=None)
print('recall:', recall)
extr: BertEntityExtractor = BertEntityExtractor.train_and_validate('camera_backpack_laptop_review_entities.tsv',
'trained_bert_entity_extractor_camera_backpack_laptop.pt',
valid_frac=0.05,
valid_file_path='annotated_acoustic_guitar_review_entities.tsv')
\ No newline at end of file
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
import pandas as pd
import numpy as np
from ast import literal_eval
import os.path
from agent.target_extraction.BERT.relation_extractor.pairbertnet import TRAINED_WEIGHTS, HIDDEN_OUTPUT_FEATURES
MAX_SEQ_LEN = 128
LABELS = ['ASPECT', 'NAN']
LABEL_MAP = {'ASPECT': 1, 'NAN': 0}
MASK_TOKEN = '[MASK]'
tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)
def generate_batch(batch):
encoded = tokenizer.batch_encode_plus([instance.tokens for instance in batch], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
labels = torch.tensor([instance.label for instance in batch])
entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch])
return input_ids, attn_mask, entity_indices, labels
def generate_production_batch(batch):
encoded = tokenizer.batch_encode_plus([instance.tokens for instance in batch], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
entity_indices = indices_for_entity_ranges([instance.range for instance in batch])
return input_ids, attn_mask, entity_indices, batch
def indices_for_entity_ranges(ranges):
max_e_len = max(end - start for start, end in ranges)
indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES
for t in range(start, start + max_e_len + 1)]
for start, end in ranges])
return indices
class EntityDataset(Dataset):
def __init__(self, df, size=None):
# filter inapplicable rows
self.df = df[df.apply(lambda x: EntityDataset.instance_from_row(x) is not None, axis=1)]
# sample data if a size is specified
if size is not None and size < len(self):
self.df = self.df.sample(size, replace=False)
@staticmethod
def from_df(df, size=None):
dataset = EntityDataset(df, size=size)
print('Obtained dataset of size', len(dataset))
return dataset
@staticmethod
def from_file(file_name, valid_frac=None, size=None):
f = open(os.path.dirname(__file__) + '/../data/' + file_name)
if file_name.endswith('.json'):
dataset = EntityDataset(pd.read_json(f, lines=True), size=size)
elif file_name.endswith('.tsv'):
dataset = EntityDataset(pd.read_csv(f, sep='\t', error_bad_lines=False), size=size)
else:
raise AttributeError('Could not recognize file type')
if valid_frac is None:
print('Obtained dataset of size', len(dataset))
return dataset, None
else:
split_idx = int(len(dataset) * (1 - valid_frac))
dataset.df, valid_df = np.split(dataset.df, [split_idx], axis=0)
validset = EntityDataset(valid_df)
print('Obtained train set of size', len(dataset), 'and validation set of size', len(validset))
return dataset, validset
@staticmethod
def instance_from_row(row):
unpacked_arr = literal_eval(row['entityMentions']) if type(row['entityMentions']) is str else row['entityMentions']
rms = [rm for rm in unpacked_arr if 'label' not in rm or rm['label'] in LABELS]
if len(rms) == 1:
entity, label = rms[0]['text'], (rms[0]['label'] if 'label' in rms[0] else None)
else:
return None # raise AttributeError('Instances must have exactly one relation')
text = row['sentText']
return EntityDataset.get_instance(text, entity, label=label)
@staticmethod
def get_instance(text, entity, label=None):
tokens = tokenizer.tokenize(text)
i = 0
found_entity = False
entity_range = None
while i < len(tokens):
match_length = EntityDataset.token_entity_match(i, entity.lower(), tokens)
if match_length is not None:
if found_entity:
return None # raise AttributeError('Entity {} appears twice in text {}'.format(entity, text))
found_entity = True
tokens[i:i + match_length] = [MASK_TOKEN] * match_length
entity_range = (i + 1, i + match_length) # + 1 taking into account the [CLS] token
i += match_length
else:
i += 1
if found_entity:
return PairRelInstance(tokens, entity, entity_range, LABEL_MAP[label], text)
else:
return None
@staticmethod
def token_entity_match(first_token_idx, entity, tokens):
token_idx = first_token_idx
remaining_entity = entity
while remaining_entity:
if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity:
# start of new word
remaining_entity = remaining_entity.lstrip()
if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]:
remaining_entity = remaining_entity[len(tokens[token_idx]):]
token_idx += 1
else:
break
else:
# continuing same word
if (token_idx < len(tokens) and tokens[token_idx].startswith('##')
and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]):
remaining_entity = remaining_entity[len(tokens[token_idx][2:]):]
token_idx += 1
else:
break
if remaining_entity:
return None
else:
return token_idx - first_token_idx
def __len__(self):
return len(self.df.index)
def __getitem__(self, idx):
return EntityDataset.instance_from_row(self.df.iloc[idx])
class PairRelInstance:
def __init__(self, tokens, entity, entity_range, label, text):
self.tokens = tokens
self.entity = entity
self.entity_range = entity_range
self.label = label
self.text = text
......@@ -4,30 +4,30 @@ from transformers import *
HIDDEN_OUTPUT_FEATURES = 768
TRAINED_WEIGHTS = 'bert-base-uncased'
NUM_CLASSES = 3 # no relation, fst hasFeature snd, snd hasFeature fst
NUM_CLASSES = 2 # entity, not entity
class PairBertNet(nn.Module):
class EntityBertNet(nn.Module):
def __init__(self):
super(PairBertNet, self).__init__()
super(EntityBertNet, self).__init__()
config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES * 2, NUM_CLASSES)
self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES)
def forward(self, input_ids, attn_mask, prod_indices, feat_indices):
def forward(self, input_ids, attn_mask, entity_indices):
# BERT
bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask)
# max pooling at entity locations
fst_outputs = torch.gather(bert_output, dim=1, index=prod_indices)
snd_outputs = torch.gather(bert_output, dim=1, index=feat_indices)
fst_pooled_output, _ = torch.max(fst_outputs, dim=1)
snd_pooled_output, _ = torch.max(snd_outputs, dim=1)
# concat pooled outputs from prod and feat entities
combined = torch.cat((fst_pooled_output, snd_pooled_output), dim=1)
entity_pooled_output = EntityBertNet.pooled_output(bert_output, entity_indices)
# fc layer (softmax activation done in loss function)
x = self.fc(combined)
x = self.fc(entity_pooled_output)
return x
@staticmethod
def pooled_output(bert_output, indices):
outputs = torch.gather(bert_output, dim=1, index=indices)
pooled_output, _ = torch.max(outputs, dim=1)
return pooled_output
......@@ -8,12 +8,10 @@ import time
import numpy as np
from sklearn import metrics
from transformers import get_linear_schedule_with_warmup
from agent.target_extraction.BERT.pair_rel_dataset import PairRelDataset, generate_batch, generate_production_batch
from agent.target_extraction.BERT.pairbertnet import NUM_CLASSES, PairBertNet
from agent.target_extraction.BERT.relation_extractor.pair_rel_dataset import PairRelDataset, generate_batch, generate_production_batch
from agent.target_extraction.BERT.relation_extractor.pairbertnet import NUM_CLASSES, PairBertNet
trained_model_path = 'trained_bert_rel_extractor_camera_backpack_laptop_bi_directional.pt'
device = torch.device('cuda')
loss_criterion = CrossEntropyLoss()
# optimizer
DECAY_RATE = 0.01
......@@ -21,10 +19,13 @@ LEARNING_RATE = 0.00002
MAX_GRAD_NORM = 1.0
# training
N_EPOCHS = 3
N_EPOCHS = 2
BATCH_SIZE = 16
WARM_UP_FRAC = 0.05
# loss
loss_criterion = CrossEntropyLoss()
class BertRelExtractor:
......@@ -40,24 +41,27 @@ class BertRelExtractor:
return extr
@staticmethod
def new_trained_with_file(file_path, size=None):
def new_trained_with_file(file_path, save_path, size=None):
extractor = BertRelExtractor()
extractor.train_with_file(file_path, size=size)
extractor.train_with_file(file_path, save_path, size=size)
return extractor
@staticmethod
def train_and_validate(file_path, size=None, valid_frac=None, valid_file_path=None):
def train_and_validate(file_path, save_path, size=None, valid_frac=None, valid_file_path=None):
extractor = BertRelExtractor()
extractor.train_with_file(file_path, size=size, valid_frac=valid_frac, valid_file_path=valid_file_path)
extractor.train_with_file(file_path, save_path, size=size, valid_frac=valid_frac,
valid_file_path=valid_file_path)
return extractor
def train_with_file(self, file_path, size=None, valid_frac=None, valid_file_path=None):
def train_with_file(self, file_path, save_path, size=None, valid_frac=None, valid_file_path=None):
# load training data
if valid_file_path is None:
train_data, valid_data = PairRelDataset.from_file(file_path, size=size, valid_frac=valid_frac)
else:
train_data, _ = PairRelDataset.from_file(file_path, size=int(size * (1 - valid_frac)))
valid_data, _ = PairRelDataset.from_file(valid_file_path, size=int(size * valid_frac))
train_size = int(size * (1 - valid_frac)) if size is not None else None
train_data, _ = PairRelDataset.from_file(file_path, size=train_size)
valid_size = int(size * valid_frac) if size is not None else int(len(train_data) * valid_frac)
valid_data, _ = PairRelDataset.from_file(valid_file_path, size=valid_size)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
collate_fn=generate_batch)
......@@ -83,13 +87,13 @@ class BertRelExtractor:
for batch_idx, batch in enumerate(train_loader):
# send batch to gpu
input_ids, attn_mask, fst_indices, snd_indices, target_labels = tuple(i.to(device) for i in batch)
input_ids, attn_mask, masked_indices, fst_indices, snd_indices, target_labels = tuple(i.to(device) for i in batch)
# zero param gradients
optimiser.zero_grad()
# forward pass
output_scores = self.net(input_ids, attn_mask, fst_indices, snd_indices)
output_scores = self.net(input_ids, attn_mask, masked_indices, fst_indices, snd_indices)
# backward pass
loss = loss_criterion(output_scores, target_labels)
......@@ -120,7 +124,7 @@ class BertRelExtractor:
end = time.time()
print('Training took', end - start, 'seconds')
torch.save(self.net.state_dict(), trained_model_path)
torch.save(self.net.state_dict(), save_path)
def evaluate(self, file_path=None, data=None, size=None):
# load eval data
......@@ -143,10 +147,11 @@ class BertRelExtractor:
with torch.no_grad():
for batch in test_loader:
# send batch to gpu
input_ids, attn_mask, fst_indices, snd_indices, target_labels = tuple(i.to(device) for i in batch)
input_ids, attn_mask, masked_indices, fst_indices, snd_indices, target_labels = tuple(i.to(device)
for i in batch)
# forward pass
output_scores = self.net(input_ids, attn_mask, fst_indices, snd_indices)
output_scores = self.net(input_ids, attn_mask, masked_indices, fst_indices, snd_indices)
_, output_labels = torch.max(output_scores.data, 1)
outputs += output_labels.tolist()
......@@ -173,21 +178,20 @@ class BertRelExtractor:
def extract_single_relation(self, text, e1, e2):
ins = PairRelDataset.get_instance(text, e1, e2)
input_ids, attn_mask, prod_indices, feat_indices, instances = generate_production_batch([ins])
input_ids, attn_mask, masked_indices, prod_indices, feat_indices, instances = generate_production_batch([ins])
self.net.cuda()
self.net.eval()
with torch.no_grad():
embeddings = self.net.bert_base.get_input_embeddings()
print(torch.narrow(embeddings(input_ids.to(device)), 1, 1, 4))
# send batch to gpu
input_ids, attn_mask, prod_indices, feat_indices = tuple(i.to(device) for i in [input_ids, attn_mask,
prod_indices, feat_indices])
input_ids, attn_mask, masked_indices, prod_indices, feat_indices = tuple(i.to(device) for i in
[input_ids, attn_mask,
masked_indices, prod_indices,
feat_indices])
# forward pass
output_scores = softmax(self.net(input_ids, attn_mask, prod_indices, feat_indices), dim=1)
output_scores = softmax(self.net(input_ids, attn_mask, masked_indices, prod_indices, feat_indices), dim=1)
_, output_labels = torch.max(output_scores.data, 1)
print(instances[0].get_relation_for_label(output_labels[0]))
......@@ -210,21 +214,22 @@ class BertRelExtractor:
outputs = []
with torch.no_grad():
for input_ids, attn_mask, prod_indices, feat_indices, instances in loader:
for input_ids, attn_mask, masked_indices, prod_indices, feat_indices, instances in loader:
# send batch to gpu
input_ids, attn_mask, prod_indices, feat_indices = tuple(i.to(device) for i in [input_ids,
attn_mask,
prod_indices,
feat_indices])
input_ids, attn_mask, masked_indices, prod_indices, feat_indices = tuple(i.to(device) for i in
[input_ids, attn_mask,
masked_indices, prod_indices,
feat_indices])
# forward pass
output_scores = softmax(self.net(input_ids, attn_mask, prod_indices, feat_indices), dim=1)
output_scores = softmax(self.net(input_ids, attn_mask, masked_indices, prod_indices, feat_indices), dim=1)
_, output_labels = torch.max(output_scores.data, 1)
outputs += map(lambda x: x[0].get_relation_for_label(x[1]), zip(instances, output_labels.tolist()))
for ins, scores, out in zip(instances, output_scores.tolist(), output_labels.tolist()):
print(ins.text)
print(ins.tokens)
print(scores)
print(ins.get_relation_for_label(out))
print('---')
......@@ -232,8 +237,10 @@ class BertRelExtractor:
return outputs
extr: BertRelExtractor = BertRelExtractor.train_and_validate('data/camera_backpack_laptop_review_pairs.tsv',
size=5000,
extr: BertRelExtractor = BertRelExtractor.train_and_validate('../data/camera_backpack_laptop_review_pairs_no_nan.tsv',
'trained_bert_rel_extractor_camera_backpack_laptop_no_nan.pt',
size=10000,
valid_frac=0.05,
valid_file_path='data/annotated_acoustic_guitar_review_pairs.tsv')
......@@ -4,7 +4,7 @@ from transformers import BertTokenizer
import pandas as pd
import numpy as np
from ast import literal_eval
from agent.target_extraction.BERT.pairbertnet import TRAINED_WEIGHTS, HIDDEN_OUTPUT_FEATURES
from agent.target_extraction.BERT.relation_extractor.pairbertnet import TRAINED_WEIGHTS, HIDDEN_OUTPUT_FEATURES
MAX_SEQ_LEN = 128
RELATIONS = ['/has_feature', '/no_relation']
......@@ -22,9 +22,11 @@ def generate_batch(batch):
labels = torch.tensor([instance.label for instance in batch])
both_ranges = [instance.ranges for instance in batch]
masked_indices = torch.tensor([[ins_idx, token_idx] for ins_idx, ranges in enumerate(both_ranges)
for start, end in ranges for token_idx in range(start, end + 1)])
fst_indices, snd_indices = map(indices_for_entity_ranges, zip(*both_ranges))
return input_ids, attn_mask, fst_indices, snd_indices, labels
return input_ids, attn_mask, masked_indices, fst_indices, snd_indices, labels
def generate_production_batch(batch):
......@@ -35,9 +37,11 @@ def generate_production_batch(batch):
attn_mask = encoded['attention_mask']