Skip to content
Snippets Groups Projects
Commit b84c4eb6 authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Experimentation with further improving feature and relation extraction

parent 71a5c117
No related branches found
No related tags found
No related merge requests found
Showing
with 934 additions and 215 deletions
......@@ -5,6 +5,7 @@ server/agent/amazon_data/
server/agent/SA/data/
server/agent/target_extraction/data/
server/agent/target_extraction/BERT/data/
server/agent/target_extraction/eval/qa/
.DS_Store
*.pickle
*.wv
\ No newline at end of file
......@@ -10,7 +10,7 @@ from sklearn import metrics
import statistics
from transformers import get_linear_schedule_with_warmup
from agent.target_extraction.BERT.entity_extractor.entity_dataset import EntityDataset, generate_batch, generate_production_batch
from agent.target_extraction.BERT.entity_extractor.entitybertnet import NUM_CLASSES, EntityBertNet
from agent.target_extraction.BERT.entity_extractor.entitybertnet import NUM_CLASSES, EntityBertNet, BATCH_SIZE
device = torch.device('cuda')
......@@ -21,7 +21,6 @@ MAX_GRAD_NORM = 1.0
# training
N_EPOCHS = 3
BATCH_SIZE = 32
WARM_UP_FRAC = 0.05
# loss
......@@ -61,8 +60,7 @@ class BertEntityExtractor:
else:
train_size = int(size * (1 - valid_frac)) if size is not None else None
train_data, _ = EntityDataset.from_file(file_path, size=train_size)
valid_size = int(size * valid_frac) if size is not None else int(len(train_data) * valid_frac)
valid_data, _ = EntityDataset.from_file(valid_file_path, size=valid_size)
valid_data, _ = EntityDataset.from_file(valid_file_path)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
collate_fn=generate_batch)
......@@ -119,11 +117,11 @@ class BertEntityExtractor:
print('epoch done')
torch.save(self.net.state_dict(), '{}_epoch_{}.pt'.format(save_file, epoch_idx + 1))
if valid_data is not None:
self.evaluate(data=valid_data)
torch.save(self.net.state_dict(), '{}.pt'.format(save_file))
end = time.time()
print('Training took', end - start, 'seconds')
......@@ -207,3 +205,7 @@ class BertEntityExtractor:
probs[ins.entity].append(score)
return {t: statistics.mean(t_probs) if len(t_probs) > 0 else None for t, t_probs in probs.items()}
BertEntityExtractor.train_and_validate('all_reviews_features.tsv', 'feature_extractor',
valid_file_path='annotated_watch_review_features.tsv')
......@@ -8,58 +8,22 @@ import os.path
from agent.target_extraction.BERT.relation_extractor.pairbertnet import TRAINED_WEIGHTS, HIDDEN_OUTPUT_FEATURES
MAX_SEQ_LEN = 128
LABELS = ['ASPECT', 'NAN']
LABEL_MAP = {'ASPECT': 1, 'NAN': 0, None: None}
MASK_TOKEN = '[MASK]'
tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)
def generate_batch(batch):
encoded = tokenizer.batch_encode_plus([instance.tokens for instance in batch], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
labels = torch.tensor([instance.label for instance in batch])
entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch])
return input_ids, attn_mask, entity_indices, labels
def generate_production_batch(batch):
encoded = tokenizer.batch_encode_plus([instance.tokens for instance in batch], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch])
return input_ids, attn_mask, entity_indices, batch
def indices_for_entity_ranges(ranges):
max_e_len = max(end - start for start, end in ranges)
indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES
for t in range(start, start + max_e_len + 1)]
for start, end in ranges])
return indices
class EntityDataset(Dataset):
def __init__(self, df, size=None):
# filter inapplicable rows
self.df = df[df.apply(lambda x: EntityDataset.instance_from_row(x) is not None, axis=1)]
def __init__(self, df, training=True, size=None):
self.df = df
self.training = training
# sample data if a size is specified
if size is not None and size < len(self):
self.df = self.df.sample(size, replace=False)
@staticmethod
def from_df(df, size=None):
dataset = EntityDataset(df, size=size)
def for_extraction(df):
dataset = EntityDataset(df, training=False)
print('Obtained dataset of size', len(dataset))
return dataset
......@@ -83,80 +47,60 @@ class EntityDataset(Dataset):
print('Obtained train set of size', len(dataset), 'and validation set of size', len(validset))
return dataset, validset
@staticmethod
def instance_from_row(row):
unpacked_arr = literal_eval(row['entityMentions']) if type(row['entityMentions']) is str else row['entityMentions']
rms = [rm for rm in unpacked_arr if 'label' not in rm or rm['label'] in LABELS]
if len(rms) == 1:
entity, label = rms[0]['text'], (rms[0]['label'] if 'label' in rms[0] else None)
else:
return None # raise AttributeError('Instances must have exactly one relation')
text = row['sentText']
return EntityDataset.get_instance(text, entity, label=label)
@staticmethod
def get_instance(text, entity, label=None):
tokens = tokenizer.tokenize(text)
i = 0
found_entity = False
entity_range = None
while i < len(tokens):
match_length = EntityDataset.token_entity_match(i, entity.lower(), tokens)
if match_length is not None:
if found_entity:
return None # raise AttributeError('Entity {} appears twice in text {}'.format(entity, text))
found_entity = True
tokens[i:i + match_length] = [MASK_TOKEN] * match_length
entity_range = (i + 1, i + match_length) # + 1 taking into account the [CLS] token
i += match_length
else:
i += 1
if found_entity:
return PairRelInstance(tokens, entity, entity_range, LABEL_MAP[label], text)
def instance_from_row(self, row):
if self.training:
return EntityInstance(literal_eval(row['tokens']),
row['entity_idx'],
label=row['label'])
else:
return None
@staticmethod
def token_entity_match(first_token_idx, entity, tokens):
token_idx = first_token_idx
remaining_entity = entity
while remaining_entity:
if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity:
# start of new word
remaining_entity = remaining_entity.lstrip()
if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]:
remaining_entity = remaining_entity[len(tokens[token_idx]):]
token_idx += 1
else:
break
else:
# continuing same word
if (token_idx < len(tokens) and tokens[token_idx].startswith('##')
and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]):
remaining_entity = remaining_entity[len(tokens[token_idx][2:]):]
token_idx += 1
else:
break
if remaining_entity:
return None
else:
return token_idx - first_token_idx
return EntityInstance(row['tokens'],
row['entity_idx'],
entity=row['entity'])
def __len__(self):
return len(self.df.index)
def __getitem__(self, idx):
return EntityDataset.instance_from_row(self.df.iloc[idx])
return self.instance_from_row(self.df.iloc[idx])
class PairRelInstance:
class EntityInstance:
def __init__(self, tokens, entity, entity_range, label, text):
def __init__(self, tokens, entity_idx, label=None, entity=None):
self.tokens = tokens
self.entity = entity
self.entity_range = entity_range
self.entity_idx = entity_idx
self.label = label
self.text = text
self.entity = entity
def generate_batch(instances: [EntityInstance]):
encoded = tokenizer.batch_encode_plus([instance.tokens for instance in instances], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
entity_indices = torch.tensor([instance.entity_idx for instance in instances])
labels = torch.tensor([instance.label for instance in instances])
return input_ids, attn_mask, entity_indices, labels
def generate_production_batch(instances: [EntityInstance]):
encoded = tokenizer.batch_encode_plus([instance.tokens for instance in instances], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
entity_indices = torch.tensor([instance.entity_idx for instance in instances])
return input_ids, attn_mask, entity_indices, instances
# def indices_for_entity_ranges(ranges):
# max_e_len = max(end - start for start, end in ranges)
# indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES
# for t in range(start, start + max_e_len + 1)]
# for start, end in ranges])
# return indices
......@@ -5,6 +5,7 @@ from transformers import *
HIDDEN_OUTPUT_FEATURES = 768
TRAINED_WEIGHTS = 'bert-base-uncased'
NUM_CLASSES = 2 # entity, not entity
BATCH_SIZE = 32
class EntityBertNet(nn.Module):
......@@ -20,14 +21,9 @@ class EntityBertNet(nn.Module):
bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask)
# max pooling at entity locations
entity_pooled_output = EntityBertNet.pooled_output(bert_output, entity_indices)
entity_pooled_output = bert_output[torch.arange(0, bert_output.shape[0]), entity_indices]
# fc layer (softmax activation done in loss function)
x = self.fc(entity_pooled_output)
return x
@staticmethod
def pooled_output(bert_output, indices):
outputs = torch.gather(bert_output, dim=1, index=indices)
pooled_output, _ = torch.max(outputs, dim=1)
return pooled_output
......@@ -8,8 +8,10 @@ import time
import numpy as np
from sklearn import metrics
from transformers import get_linear_schedule_with_warmup
from agent.target_extraction.BERT.relation_extractor.pair_rel_dataset import PairRelDataset, generate_batch, generate_production_batch
from agent.target_extraction.BERT.relation_extractor.pairbertnet import NUM_CLASSES, PairBertNet
# from agent.target_extraction.BERT.relation_extractor.pair_rel_dataset import PairRelDataset, generate_batch, generate_production_batch
from agent.target_extraction.BERT.relation_extractor.rel_dataset import PairRelDataset, generate_batch, generate_production_batch, RelInstance
# from agent.target_extraction.BERT.relation_extractor.pairbertnet import NUM_CLASSES, PairBertNet
from agent.target_extraction.BERT.relation_extractor.relbertnet import NUM_CLASSES, RelBertNet
device = torch.device('cuda')
......@@ -30,12 +32,12 @@ loss_criterion = CrossEntropyLoss()
class BertRelExtractor:
def __init__(self):
self.net = PairBertNet()
self.net = RelBertNet()
@staticmethod
def load_saved(path):
extr = BertRelExtractor()
extr.net = PairBertNet()
extr.net = RelBertNet()
extr.net.load_state_dict(torch.load(path))
extr.net.eval()
return extr
......@@ -60,8 +62,7 @@ class BertRelExtractor:
else:
train_size = int(size * (1 - valid_frac)) if size is not None else None
train_data, _ = PairRelDataset.from_file(file_path, size=train_size)
valid_size = int(size * valid_frac) if size is not None else int(len(train_data) * valid_frac)
valid_data, _ = PairRelDataset.from_file(valid_file_path, size=valid_size)
valid_data, _ = PairRelDataset.from_file(valid_file_path)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
collate_fn=generate_batch)
......@@ -87,16 +88,16 @@ class BertRelExtractor:
for batch_idx, batch in enumerate(train_loader):
# send batch to gpu
input_ids, attn_mask, masked_indices, fst_indices, snd_indices, target_labels = tuple(i.to(device) for i in batch)
input_ids, attn_mask, entity_indices, entity_mask, labels = tuple(i.to(device) for i in batch)
# zero param gradients
optimiser.zero_grad()
# forward pass
output_scores = self.net(input_ids, attn_mask, masked_indices, fst_indices, snd_indices)
output_scores = self.net(input_ids, attn_mask, entity_indices, entity_mask)
# backward pass
loss = loss_criterion(output_scores, target_labels)
loss = loss_criterion(output_scores, labels)
loss.backward()
# clip gradient norm
......@@ -117,12 +118,11 @@ class BertRelExtractor:
batch_loss = 0.0
print('epoch done')
torch.save(self.net.state_dict(), '{}_epoch_{}.pt'.format(save_file, epoch_idx + 1))
if valid_data is not None:
self.evaluate(data=valid_data)
torch.save(self.net.state_dict(), '{}.pt'.format(save_file))
end = time.time()
print('Training took', end - start, 'seconds')
......@@ -147,15 +147,14 @@ class BertRelExtractor:
with torch.no_grad():
for batch in test_loader:
# send batch to gpu
input_ids, attn_mask, masked_indices, fst_indices, snd_indices, target_labels = tuple(i.to(device)
for i in batch)
input_ids, attn_mask, entity_indices, entity_mask, labels = tuple(i.to(device) for i in batch)
# forward pass
output_scores = self.net(input_ids, attn_mask, masked_indices, fst_indices, snd_indices)
output_scores = self.net(input_ids, attn_mask, entity_indices, entity_mask)
_, output_labels = torch.max(output_scores.data, 1)
outputs += output_labels.tolist()
targets += target_labels.tolist()
targets += labels.tolist()
assert len(outputs) == len(targets)
......@@ -176,25 +175,24 @@ class BertRelExtractor:
recall = metrics.recall_score(targets, outputs, average=None)
print('recall:', recall)
def extract_single_relation(self, text, e1, e2):
ins = PairRelDataset.get_instance(text, e1, e2)
input_ids, attn_mask, masked_indices, prod_indices, feat_indices, instances = generate_production_batch([ins])
def extract_single_relation(self, text, entities):
ins = RelInstance.from_sentence(text, entities)
input_ids, attn_mask, entity_indices, entity_mask, _ = generate_production_batch([ins])
self.net.cuda()
self.net.eval()
with torch.no_grad():
# send batch to gpu
input_ids, attn_mask, masked_indices, prod_indices, feat_indices = tuple(i.to(device) for i in
[input_ids, attn_mask,
masked_indices, prod_indices,
feat_indices])
input_ids, attn_mask, entity_indices, entity_mask = tuple(i.to(device) for i in [input_ids, attn_mask,
entity_indices,
entity_mask])
# forward pass
output_scores = softmax(self.net(input_ids, attn_mask, masked_indices, prod_indices, feat_indices), dim=1)
output_scores = softmax(self.net(input_ids, attn_mask, entity_indices, entity_mask), dim=1)
_, output_labels = torch.max(output_scores.data, 1)
print(instances[0].get_relation_for_label(output_labels[0]))
ins.print_results_for_labels(output_labels)
def extract_relations(self, n_aspects, aspect_index_map, aspect_counts, file_path=None, dataset=None, size=None):
# load data
......@@ -215,15 +213,14 @@ class BertRelExtractor:
count_matrix = np.zeros((n_aspects, n_aspects))
with torch.no_grad():
for input_ids, attn_mask, masked_indices, prod_indices, feat_indices, instances in loader:
for input_ids, attn_mask, prod_indices, feat_indices, instances in loader:
# send batch to gpu
input_ids, attn_mask, masked_indices, prod_indices, feat_indices = tuple(i.to(device) for i in
[input_ids, attn_mask,
masked_indices, prod_indices,
feat_indices])
input_ids, attn_mask, prod_indices, feat_indices = tuple(i.to(device) for i in [input_ids, attn_mask,
prod_indices,
feat_indices])
# forward pass
output_scores = softmax(self.net(input_ids, attn_mask, masked_indices, prod_indices, feat_indices), dim=1)
output_scores = softmax(self.net(input_ids, attn_mask, prod_indices, feat_indices), dim=1)
rel_scores = output_scores.narrow(1, 1, 2)
for ins, scores in zip(instances, rel_scores.tolist()):
......@@ -236,4 +233,38 @@ class BertRelExtractor:
return prob_matrix, count_matrix
def extract_relations2(self, n_aspects, dataset):
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
collate_fn=generate_production_batch)
self.net.cuda()
self.net.eval()
prob_matrix = np.zeros((n_aspects, n_aspects))
count_matrix = np.zeros((n_aspects, n_aspects))
with torch.no_grad():
for input_ids, attn_mask, entity_indices, combination_indices, instances in loader:
# send batch to gpu
input_ids, attn_mask, entity_indices, combination_indices = tuple(i.to(device) for i in
[input_ids, attn_mask,
entity_indices, combination_indices])
# forward pass
output_scores = softmax(self.net(input_ids, attn_mask, entity_indices, combination_indices), dim=1)
rel_scores = output_scores.narrow(1, 1, 2).tolist()
entity_pairs = [ep for instance in instances for ep in instance.entity_pairs]
for ep, scores in zip(entity_pairs, rel_scores):
forward_score, backward_score = scores
prob_matrix[ep.snd.idx][ep.fst.idx] += forward_score
prob_matrix[ep.fst.idx][ep.snd.idx] += backward_score
count_matrix[ep.snd.idx][ep.fst.idx] += 1
count_matrix[ep.fst.idx][ep.snd.idx] += 1
return prob_matrix, count_matrix
# extr: BertRelExtractor = BertRelExtractor.load_saved('multi_extractor_5_products_epoch_1.pt')
# extr.extract_single_relation('The mixer comes with a stainless steel bowl.',
# ['mixer', 'stainless steel', 'bowl'])
......@@ -4,7 +4,7 @@ from transformers import *
HIDDEN_OUTPUT_FEATURES = 768
TRAINED_WEIGHTS = 'bert-base-uncased'
NUM_CLASSES = 3 # no relation, fst hasFeature snd, snd hasFeature fst
NUM_CLASSES = 4 # no relation, fst hasFeature snd, snd hasFeature fst, siblings
HIDDEN_ENTITY_FEATURES = 6 # lower -> more general but less informative entity representations
......@@ -18,18 +18,7 @@ class PairBertNet(nn.Module):
self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES * 2, NUM_CLASSES)
def forward(self, input_ids, attn_mask, masked_indices, fst_indices, snd_indices):
# embeddings = self.bert_base.get_input_embeddings()
# input_embeddings = embeddings(input_ids)
#
# # get partially masked input_embeddings for entity terms
# unmasked_entity_embeddings = input_embeddings[masked_indices[:, 0], masked_indices[:, 1]]
# hidden_entity_repr = torch.tanh(self.entity_fc1(unmasked_entity_embeddings))
# masked_entity_embeddings = torch.repeat_interleave(hidden_entity_repr, 128, dim=1) # 768 / 12 = 64
#
# # replace input_embeddings with partially masked ones for entities
# input_embeddings[masked_indices[:, 0], masked_indices[:, 1]] = masked_entity_embeddings
def forward(self, input_ids, attn_mask, fst_indices, snd_indices):
# BERT
bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask)
......
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
import pandas as pd
import numpy as np
from ast import literal_eval
from agent.target_extraction.BERT.relation_extractor.relbertnet import TRAINED_WEIGHTS, MAX_SEQ_LEN, MAX_ENTITIES
import os
MASK_TOKEN = '[MASK]'
tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)
def generate_batch(batch):
encoded = tokenizer.batch_encode_plus([instance.tokens for instance in batch], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
entity_indices = torch.tensor(list(map(indices_for_instance, batch)))
entity_mask = torch.tensor([[n < instance.get_count() for n in range(MAX_ENTITIES)] for instance in batch])
labels = torch.tensor([e.label for instance in batch for e in instance.entities])
return input_ids, attn_mask, entity_indices, entity_mask, labels
def generate_production_batch(batch):
encoded = tokenizer.batch_encode_plus([instance.tokens for instance in batch], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
entity_indices = torch.tensor(list(map(indices_for_instance, batch)))
entity_mask = torch.tensor([[n < instance.get_count() for n in range(MAX_ENTITIES)] for instance in batch])
return input_ids, attn_mask, entity_indices, entity_mask, batch
def indices_for_instance(instance):
indices = [[instance.entities[n].rng[0] if i < instance.entities[n].rng[0] else min(instance.entities[n].rng[1], i)
for i in range(MAX_SEQ_LEN)]
if n < len(instance.entities) else [0] * MAX_SEQ_LEN
for n in range(MAX_ENTITIES)]
return indices
class PairRelDataset(Dataset):
def __init__(self, df, training=True, size=None):
self.df = df
self.training = training
# sample data if a size is specified
if size is not None and size < len(self):
self.df = self.df.sample(size, replace=False)
@staticmethod
def for_extraction(df):
dataset = PairRelDataset(df, training=False)
print('Obtained dataset of size', len(dataset))
return dataset
@staticmethod
def from_file(file_name, valid_frac=None, size=None):
f = open(os.path.dirname(__file__) + '/../data/' + file_name)
dataset = PairRelDataset(pd.read_csv(f, sep='\t', error_bad_lines=False), size=size)
if valid_frac is None:
print('Obtained dataset of size', len(dataset))
return dataset, None
else:
split_idx = int(len(dataset) * (1 - valid_frac))
dataset.df, valid_df = np.split(dataset.df, [split_idx], axis=0)
validset = PairRelDataset(valid_df)
print('Obtained train set of size', len(dataset), 'and validation set of size', len(validset))
return dataset, validset
def instance_from_row(self, row):
if self.training:
return RelInstance(literal_eval(row['tokens']),
literal_eval(row['entity_ranges']),
true_labels=literal_eval(row['labels']))
else:
return RelInstance(row['tokens'],
row['entity_ranges'],
entity_labels=row['entity_labels'])
def __len__(self):
return len(self.df.index)
def __getitem__(self, idx):
return self.instance_from_row(self.df.iloc[idx])
class RelInstance:
def __init__(self, tokens, entity_ranges, true_labels=None, entity_labels=None, entity_texts=None):
self.tokens = tokens
self.entities = [Entity(rng,
label=(true_labels[n] if true_labels else None),
idx=(entity_labels[n] if entity_labels else None),
text=(entity_texts[n] if entity_texts else None))
for n, rng in enumerate(entity_ranges)]
print(self.tokens)
print(entity_ranges)
def get_count(self):
return len(self.entities)
def print_results_for_labels(self, labels):
assert len(labels) == len(self.entities)
label_map = ['not an aspect', 'aspect', 'sub-feature']
for e, l in zip(self.entities, labels):
print('{}: {}'.format(e.text, label_map[l]))
@staticmethod
def from_sentence(text, entities):
def token_entity_match(first_token_idx, entity, tokens):
token_idx = first_token_idx
remaining_entity = entity
while remaining_entity:
if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity:
# start of new word
remaining_entity = remaining_entity.lstrip()
if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]:
remaining_entity = remaining_entity[len(tokens[token_idx]):]
token_idx += 1
else:
break
else:
# continuing same word
if (token_idx < len(tokens) and tokens[token_idx].startswith('##')
and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]):
remaining_entity = remaining_entity[len(tokens[token_idx][2:]):]
token_idx += 1
else:
break
if remaining_entity or (token_idx < len(tokens) and tokens[token_idx].startswith('##')):
return None
else:
return token_idx - first_token_idx
tokens = tokenizer.tokenize(text)
i = 0
entity_ranges = []
while i < len(tokens):
match = False
# check for aspects
for e in entities:
match_length = token_entity_match(i, e.lower(), tokens)
if match_length is not None:
entity_ranges.append((e, (i + 1, i + match_length))) # + 1 taking into account the [CLS] token
match = True
i += match_length
break
if not match:
i += 1
if len(entity_ranges) == 0 or len(entity_ranges) > 3:
return None
# mask entity mentions
for _, (start, end) in entity_ranges:
tokens[(start - 1):end] = ['[MASK]'] * (end - (start - 1))
texts, ranges = zip(*entity_ranges)
return RelInstance(tokens, ranges, entity_texts=texts)
class Entity:
def __init__(self, rng, label=None, idx=None, text=None):
self.rng = rng
self.label = label
self.idx = idx
self.text = text
import torch
import torch.nn as nn
from transformers import *
TRAINED_WEIGHTS = 'bert-base-uncased'
HIDDEN_OUTPUT_FEATURES = 768
MAX_SEQ_LEN = 128
NUM_CLASSES = 3 # no relation, fst hasFeature snd, snd hasFeature fst
MAX_ENTITIES = 3
class RelBertNet(nn.Module):
def __init__(self):
super(RelBertNet, self).__init__()
config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES)
def forward(self, input_ids, attn_mask, entity_indices, entity_mask):
# BERT
bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask)
# obtain entity combinations
combinations = RelBertNet.entity_combinations(bert_output, entity_indices, entity_mask)
# fc layer (softmax activation done in loss function)
x = self.fc(combinations)
return x
@staticmethod
def entity_combinations(bert_output, entity_indices, entity_mask):
# pool outputs
bert_output_exp = bert_output.unsqueeze(1).repeat(1, MAX_ENTITIES, 1, 1)
indices_exp = entity_indices.unsqueeze(3).repeat(1, 1, 1, HIDDEN_OUTPUT_FEATURES)
outputs = torch.gather(bert_output_exp, dim=2, index=indices_exp)
pooled_outputs, _ = torch.max(outputs, dim=2)
# pooled_outputs = torch.flatten(pooled_outputs, start_dim=0, end_dim=1)
return pooled_outputs[entity_mask]
# b_output = torch.randn((2, MAX_SEQ_LEN, HIDDEN_OUTPUT_FEATURES))
# e_indices = torch.tensor([[[0, 1, 1, 1, 1],
# [2, 2, 2, 2, 2],
# [3, 3, 3, 3, 3]],
#
# [[0, 1, 1, 1, 1],
# [2, 2, 2, 2, 2],
# [4, 4, 4, 4, 4]]])
#
# entity_mask = torch.tensor([[True, False, False], [True, True, True]])
#
# print(RelBertNet.entity_combinations(b_output, e_indices, entity_mask))
......@@ -12,11 +12,17 @@ import readchar
import random
from sty import fg, bg
from anytree import Node, RenderTree, LevelOrderIter, PreOrderIter
from itertools import combinations, repeat
from pathos.multiprocessing import ProcessingPool as Pool
from transformers import BertTokenizer
from agent.target_extraction.BERT.relation_extractor.relbertnet import TRAINED_WEIGHTS, MAX_ENTITIES
PHRASE_THRESHOLD = 4
ROW_CHARACTER_COUNT = 100
stop_words = stopwords.words('english')
ann_bgs = [bg.blue, bg.red] # child, parent
pool = Pool(4)
tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)
class EntityAnnotator:
......@@ -212,24 +218,55 @@ class EntityAnnotator:
texts = [text for _, par in reviews['reviewText'].items() if not pd.isnull(par)
for text in sent_tokenize(par)]
pair_texts = [t for t in map(lambda t: self.pair_relations_for_text(t), texts)
pair_texts = [t for t in map(lambda t: self.pair_relation_for_text(t), texts)
if t is not None]
df = pd.DataFrame(pair_texts, columns=['sentText', 'relationMentions'])
df.to_csv(save_path, sep='\t', index=False)
def save_annotated_entities(self, save_path):
def save_annotated_pairs2(self, save_path, n):
reviews = pd.read_csv(self.text_file_path, sep='\t', error_bad_lines=False)
texts = [line for _, par in reviews['reviewText'].items() if not pd.isnull(par)
for sent in sent_tokenize(par) for line in sent.splitlines()]
instances = []
idx = 0
while len(instances) < n and idx <= len(texts):
texts_sub = texts[idx:idx+20000]
idx += 20000
instances += filter(lambda i: i is not None, pool.map(relation_instances_for_text,
repeat(tokenizer, len(texts_sub)),
repeat(self.root, len(texts_sub)),
repeat(self.synset, len(texts_sub)),
texts_sub))
print(len(instances))
instances = instances[:n]
df = pd.DataFrame(instances, columns=['tokens', 'entity_ranges', 'labels'])
df.to_csv(save_path, sep='\t', index=False)
def save_annotated_entities(self, save_path, n):
reviews = pd.read_csv(self.text_file_path, sep='\t', error_bad_lines=False)
texts = [text for _, par in reviews['reviewText'].items() if not pd.isnull(par)
for text in sent_tokenize(par)]
all_entities = {(e, True) for e in self.get_annotated_entities()}.union(
{(e, False) for e in self.get_nan_entities()})
entity_texts = [t for t in map(lambda t: self.entity_mentions_in_text(t, all_entities), texts)
if t is not None]
df = pd.DataFrame(entity_texts, columns=['sentText', 'entityMentions'])
product_entities = {s.lower() for s in self.synset[self.root]}
other_entities = {(e.lower(), True) for e in self.get_annotated_features()}.union(
{(e.lower(), False) for e in self.get_nan_entities()})
instances = []
idx = 0
while len(instances) < n and idx <= len(texts):
texts_sub = texts[idx:idx + 20000]
idx += 20000
instances += filter(lambda i: i is not None, pool.map(entity_instances_for_text,
repeat(tokenizer, len(texts_sub)),
repeat(product_entities, len(texts_sub)),
repeat(other_entities, len(texts_sub)),
texts_sub))
print(len(instances))
df = pd.DataFrame(instances, columns=['tokens', 'entity_idx', 'label'])
df.to_csv(save_path, sep='\t', index=False)
@staticmethod
......@@ -247,7 +284,7 @@ class EntityAnnotator:
random.shuffle(m)
return {'em1Text': m[0], 'em2Text': m[1], 'label': '/no_relation'}
def pair_relations_for_text(self, text, nan_entities=None):
def pair_relation_for_text(self, text, nan_entities=None):
single_tokens = word_tokenize(text)
tagged_single = pos_tag(single_tokens)
tagged_all = set().union(*[tagged_single, pos_tag(self.phraser[single_tokens])])
......@@ -327,6 +364,9 @@ class EntityAnnotator:
def get_annotated_entities(self):
return {syn.lower() for n in PreOrderIter(self.root) for syn in self.synset[n]}
def get_annotated_features(self):
return {syn.lower() for n in self.root.descendants for syn in self.synset[n]}
def get_nan_entities(self):
annotated = self.get_annotated_entities()
return {t.replace('_', ' ').lower() for t, _ in self.counter.most_common(self.n_annotated)
......@@ -358,6 +398,220 @@ class EntityAnnotator:
return text, rels
ea: EntityAnnotator = EntityAnnotator.load_saved('annotators/watch_annotator.pickle')
ea.save_annotated_pairs('BERT/data/annotated_watch_review_pairs.tsv')
ea.save_annotated_entities('BERT/data/annotated_watch_review_entities.tsv')
def relation_instances_for_text(tokenizer, root, synset, text):
def joined_tokens(tokens, entity_ranges):
joined = []
j_token = tokens[0]
start = 0
for idx, t in enumerate(tokens):
if idx == 0:
continue
if t.startswith('##'):
# continuing same word
j_token += t[2:]
elif any(idx in r and idx-1 in r for r in entity_ranges):
# continuing same multi-word entity
j_token = j_token + " " + t
else:
# new word
joined.append((j_token, start, idx))
j_token = t
start = idx
if j_token:
joined.append((j_token, start, len(tokens)))
return joined
def noun_entity_mentions(tokens, entity_mentions):
entity_ranges = [range(em[2][0]-1, em[2][1]) for em in entity_mentions]
joined = joined_tokens(tokens, entity_ranges)
tags = [tag for _, tag in pos_tag([t for t, _, _ in joined])]
noun_ranges = [range(start, end) for idx, (_, start, end) in enumerate(joined) if tags[idx].startswith('NN')]
return [em for idx, em in enumerate(entity_mentions) if entity_ranges[idx] in noun_ranges]
def token_entity_match(first_token_idx, entity, tokens):
token_idx = first_token_idx
remaining_entity = entity
while remaining_entity:
if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity:
# start of new word
remaining_entity = remaining_entity.lstrip()
if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]:
remaining_entity = remaining_entity[len(tokens[token_idx]):]
token_idx += 1
else:
break
else:
# continuing same word
if (token_idx < len(tokens) and tokens[token_idx].startswith('##')
and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]):
remaining_entity = remaining_entity[len(tokens[token_idx][2:]):]
token_idx += 1
else:
break
if remaining_entity or (token_idx < len(tokens) and tokens[token_idx].startswith('##')):
return None
else:
return token_idx - first_token_idx
def get_rel_label(fst_m, snd_m):
fst_n, _, _ = fst_m
snd_n, _, _ = snd_m
if snd_n in fst_n.descendants:
return 1
elif fst_n in snd_n.descendants:
return 2
else:
return 0
tokens = tokenizer.tokenize(text)
i = 0
entity_mentions = []
while i < len(tokens):
match = False
for n in PreOrderIter(root):
for syn in synset[n]:
match_length = token_entity_match(i, syn.lower(), tokens)
if match_length is not None:
if any(em[0] == n for em in entity_mentions):
# sentence cannot mention same aspect twice
return None
entity_mentions.append((n, syn, (i + 1, i + match_length))) # + 1 taking into account the [CLS] token
match = True
i += match_length
break
if match:
break
if not match:
i += 1
if len(entity_mentions) < 2:
return None
# filter out non-nouns
entity_mentions = noun_entity_mentions(tokens, entity_mentions)
if len(entity_mentions) < 2 or len(entity_mentions) > MAX_ENTITIES:
return None
# mask entity mentions
for _, _, (start, end) in entity_mentions:
tokens[(start-1):end] = ['[MASK]'] * (end-(start-1))
entity_mentions = sorted(entity_mentions, key=lambda em: em[2])
entity_ranges = [em[2] for em in entity_mentions]
labels = {(i, j): get_rel_label(entity_mentions[i], entity_mentions[j])
for i, j in combinations(range(len(entity_mentions)), 2)}
return tokens, entity_ranges, labels
def entity_instances_for_text(tokenizer, product_entities, other_entities, text):
def joined_tokens(tokens, entity_ranges):
joined = []
j_token = tokens[0]
start = 0
for idx, t in enumerate(tokens):
if idx == 0:
continue
if t.startswith('##'):
# continuing same word
j_token += t[2:]
elif any(idx in r and idx-1 in r for r in entity_ranges):
# continuing same multi-word entity
j_token = j_token + " " + t
else:
# new word
joined.append((j_token, start, idx))
j_token = t
start = idx
if j_token:
joined.append((j_token, start, len(tokens)))
return joined
def noun_entity_mentions(tokens, entity_mentions):
entity_ranges = [range(em[0][0]-1, em[0][1]) for em in entity_mentions]
joined = joined_tokens(tokens, entity_ranges)
tags = [tag for _, tag in pos_tag([t for t, _, _ in joined])]
noun_ranges = [range(start, end) for idx, (_, start, end) in enumerate(joined) if tags[idx].startswith('NN')]
return [em for idx, em in enumerate(entity_mentions) if entity_ranges[idx] in noun_ranges]
def token_entity_match(first_token_idx, entity, tokens):
token_idx = first_token_idx
remaining_entity = entity
while remaining_entity:
if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity:
# start of new word
remaining_entity = remaining_entity.lstrip()
if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]:
remaining_entity = remaining_entity[len(tokens[token_idx]):]
token_idx += 1
else:
break
else:
# continuing same word
if (token_idx < len(tokens) and tokens[token_idx].startswith('##')
and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]):
remaining_entity = remaining_entity[len(tokens[token_idx][2:]):]
token_idx += 1
else:
break
if remaining_entity or (token_idx < len(tokens) and tokens[token_idx].startswith('##')):
return None
else:
return token_idx - first_token_idx
def mask_tokens(tokens, mask_ranges):
for (start, _), m in mask_ranges:
tokens[start-1] = m
return [t for idx, t in enumerate(tokens)
if not any(idx in range(start, end) for (start, end), _ in mask_ranges)]
tokens = tokenizer.tokenize(text)
entity_mentions = []
product_mentions = []
for i in range(len(tokens)):
for entity, is_aspect in other_entities:
match_length = token_entity_match(i, entity, tokens)
if match_length is not None:
entity_mentions.append(((i + 1, i + match_length), is_aspect)) # + 1 taking into account the [CLS] token
for entity in product_entities:
match_length = token_entity_match(i, entity, tokens)
if match_length is not None:
product_mentions.append(((i + 1, i + match_length), True)) # + 1 taking into account the [CLS] token
if len(entity_mentions) != 1:
return None
# filter out non-nouns
entity_mentions = noun_entity_mentions(tokens, entity_mentions)
# filter intersecting product mentions
product_mentions = list(filter(lambda pm: not any(pm2 != pm and pm2[0][0] <= pm[0][0] and pm2[0][1] >= pm[0][1] for pm2 in product_mentions), product_mentions))
if len(entity_mentions) != 1:
return None
(e_start, e_end), is_aspect = entity_mentions[0]
# mask entity mentions
tokens = mask_tokens(tokens, [((e_start, e_end), '[MASK]')] + [((start, end), 'product') for (start, end), _ in product_mentions])
return tokens, e_start, 1 if is_aspect else 0
# ann: EntityAnnotator = EntityAnnotator.load_saved('annotators/acoustic_guitar_annotator.pickle')
# ann.save_annotated_entities('BERT/data/annotated_acoustic_guitar_review_features.tsv', 37000)
ann: EntityAnnotator = EntityAnnotator.load_saved('annotators/backpack_entity_annotator.pickle')
ann.save_annotated_entities('BERT/data/annotated_backpack_review_features.tsv', 37000)
ann: EntityAnnotator = EntityAnnotator.load_saved('annotators/cardigan_entity_annotator.pickle')
ann.save_annotated_entities('BERT/data/annotated_cardigan_review_features.tsv', 37000)
ann: EntityAnnotator = EntityAnnotator.load_saved('annotators/laptop_entity_annotator.pickle')
ann.save_annotated_entities('BERT/data/annotated_laptop_review_features.tsv', 37000)
ann: EntityAnnotator = EntityAnnotator.load_saved('annotators/camera_entity_annotator.pickle')
ann.save_annotated_entities('BERT/data/annotated_camera_review_features.tsv', 37000)
ann: EntityAnnotator = EntityAnnotator.load_saved('annotators/watch_annotator.pickle')
ann.save_annotated_entities('BERT/data/annotated_watch_review_features.tsv', 37000)
# ann: EntityAnnotator = EntityAnnotator.load_saved('annotators/acoustic_guitar_annotator.pickle')
# print(ann.synset[ann.root])
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment