Commit 4f691588 authored by Joel Oksanen's avatar Joel Oksanen

Implemented prod feat bert extractor

parent 9288c473
......@@ -24,19 +24,20 @@ def get_df(path):
pd.set_option('display.max_colwidth', None)
category = 'Backpacks'
category = 'Acoustic Guitars'
metadata_iter = pd.read_json('amazon_data/meta_Clothing_Shoes_and_Jewelry.json', lines=True, chunksize=1000)
metadata = pd.concat([metadata[metadata['category'].apply(lambda cl: category in cl)] for metadata in metadata_iter])
metadata_iter = pd.read_json('amazon_data/meta_Musical_Instruments.json', lines=True, chunksize=1000)
metadata = pd.concat([metadata[metadata['category'].apply(lambda cl: type(cl) is list and category in cl)]
for metadata in metadata_iter])
print(len(metadata.index))
review_iter = pd.read_json('amazon_data/Clothing_Shoes_and_Jewelry.json', lines=True, chunksize=1000)
review_iter = pd.read_json('amazon_data/Musical_Instruments.json', lines=True, chunksize=1000)
reviews = pd.concat([reviews[reviews['asin'].isin(metadata['asin'])] for reviews in review_iter])
print(len(reviews.index))
reviews.to_csv('target_extraction/data/verified_backpack_reviews.tsv', sep='\t', index=False)
reviews.to_csv('target_extraction/data/verified_acoustic_guitar_reviews.tsv', sep='\t', index=False)
# child_product = 'speaker'
# reviews = pd.read_csv('amazon_data/amazon_reviews_us_Electronics_v1_00.tsv.gz', sep='\t', error_bad_lines=False,
......
import torch
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy
from torch.nn.functional import softmax
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import time
......@@ -65,10 +65,10 @@ class BertRelExtractor:
optimiser = Adam(self.net.parameters(), lr=LEARNING_RATE)
# set up scheduler for lr
n_training_steps = len(train_loader)*N_EPOCHS
n_training_steps = len(train_loader) * N_EPOCHS
scheduler = get_linear_schedule_with_warmup(
optimiser,
num_warmup_steps=int(WARM_UP_FRAC*n_training_steps),
num_warmup_steps=int(WARM_UP_FRAC * n_training_steps),
num_training_steps=n_training_steps
)
......@@ -80,13 +80,13 @@ class BertRelExtractor:
for batch_idx, batch in enumerate(train_loader):
# send batch to gpu
input_ids, attn_mask, fst_indices, snd_indices, target_labels = tuple(i.to(device) for i in batch)
input_ids, attn_mask, target_labels = tuple(i.to(device) for i in batch)
# zero param gradients
optimiser.zero_grad()
# forward pass
output_scores = self.net(input_ids, attn_mask, fst_indices, snd_indices)
output_scores = self.net(input_ids, attn_mask)
# backward pass
loss = loss_criterion(output_scores, target_labels)
......@@ -140,10 +140,10 @@ class BertRelExtractor:
with torch.no_grad():
for batch in test_loader:
# send batch to gpu
input_ids, attn_mask, fst_indices, snd_indices, target_labels = tuple(i.to(device) for i in batch)
input_ids, attn_mask, target_labels = tuple(i.to(device) for i in batch)
# forward pass
output_scores = self.net(input_ids, attn_mask, fst_indices, snd_indices)
output_scores = self.net(input_ids, attn_mask)
_, output_labels = torch.max(output_scores.data, 1)
outputs += output_labels.tolist()
......@@ -162,20 +162,25 @@ class BertRelExtractor:
f1 = metrics.f1_score(targets, outputs, labels=range(NUM_CLASSES), average='macro')
print('macro F1:', f1)
precision = metrics.precision_score(targets, outputs, pos_label=1)
print('precision:', precision)
recall = metrics.recall_score(targets, outputs, pos_label=1)
print('recall:', recall)
def extract_single_relation(self, text, e1, e2):
ins = PairRelDataset.get_instance(text, e1, e2)
input_ids, attn_mask, fst_indices, snd_indices, instances = generate_production_batch([ins])
input_ids, attn_mask, instances = generate_production_batch([ins])
self.net.cuda()
self.net.eval()
with torch.no_grad():
# send batch to gpu
input_ids, attn_mask, fst_indices, snd_indices = tuple(i.to(device) for i in
[input_ids, attn_mask, fst_indices, snd_indices])
input_ids, attn_mask = tuple(i.to(device) for i in [input_ids, attn_mask])
# forward pass
output_scores = self.net(input_ids, attn_mask, fst_indices, snd_indices)
output_scores = softmax(self.net(input_ids, attn_mask), dim=1)
_, output_labels = torch.max(output_scores.data, 1)
print(instances[0].get_relation_for_label(output_labels[0]))
......@@ -198,24 +203,24 @@ class BertRelExtractor:
outputs = []
with torch.no_grad():
for input_ids, attn_mask, fst_indices, snd_indices, instances in loader:
for input_ids, attn_mask, instances in loader:
# send batch to gpu
input_ids, attn_mask, fst_indices, snd_indices = tuple(i.to(device) for i in
[input_ids, attn_mask, fst_indices, snd_indices])
input_ids, attn_mask = tuple(i.to(device) for i in [input_ids, attn_mask])
# forward pass
output_scores = self.net(input_ids, attn_mask, fst_indices, snd_indices)
output_scores = softmax(self.net(input_ids, attn_mask), dim=1)
_, output_labels = torch.max(output_scores.data, 1)
outputs += map(lambda x: x[0].get_relation_for_label(x[1]), zip(instances, output_labels.tolist()))
for ins, out in zip(instances, output_labels.tolist()):
for ins, scores, out in zip(instances, output_scores.tolist(), output_labels.tolist()):
print(ins.text)
print(scores)
print(ins.get_relation_for_label(out))
print('---')
return outputs
extr = BertRelExtractor.load_saved('trained_bert_rel_extractor_camera_and_backpack_with_nan.pt')
extr.evaluate(file_path='data/annotated_laptop_review_pairs_with_nan.tsv', size=10000)
extr: BertRelExtractor = BertRelExtractor.load_saved('trained_bert_rel_extractor_camera_and_backpack_with_nan.pt')
extr.evaluate('data/annotated_camera_review_pairs.tsv', size=10000)
......@@ -7,9 +7,12 @@ from ast import literal_eval
from agent.target_extraction.BERT.pairbertnet import TRAINED_WEIGHTS, HIDDEN_OUTPUT_FEATURES
MAX_SEQ_LEN = 128
MASK_TOKEN = '[MASK]'
RELATIONS = ['/has_feature', '/no_relation', '/nan_relation']
RELATIONS = ['/has_feature', '/no_relation']
RELATION_LABEL_MAP = {None: None, '/has_feature': 1, '/no_relation': 0}
PROD_TOKEN = '[MASK]'
FEAT_TOKEN = '[MASK]'
tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)
tokenizer.add_tokens([PROD_TOKEN, FEAT_TOKEN])
def generate_batch(batch):
......@@ -18,13 +21,9 @@ def generate_batch(batch):
return_tensors='pt')
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
both_ranges = [instance.entity_ranges for instance in batch]
fst_indices, snd_indices = map(indices_for_entity_ranges, zip(*both_ranges))
labels = torch.tensor([instance.label for instance in batch])
return input_ids, attn_mask, fst_indices, snd_indices, labels
return input_ids, attn_mask, labels
def generate_production_batch(batch):
......@@ -34,18 +33,7 @@ def generate_production_batch(batch):
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
both_ranges = [instance.entity_ranges for instance in batch]
fst_indices, snd_indices = map(indices_for_entity_ranges, zip(*both_ranges))
return input_ids, attn_mask, fst_indices, snd_indices, batch
def indices_for_entity_ranges(ranges):
max_e_len = max(end - start for start, end in ranges)
indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES
for t in range(start, start + max_e_len + 1)]
for start, end in ranges])
return indices
return input_ids, attn_mask, batch
class PairRelDataset(Dataset):
......@@ -88,15 +76,15 @@ class PairRelDataset(Dataset):
unpacked_arr = literal_eval(row['relationMentions']) if type(row['relationMentions']) is str else row['relationMentions']
rms = [rm for rm in unpacked_arr if 'label' not in rm or rm['label'] in RELATIONS]
if len(rms) == 1:
e1, e2, label = rms[0]['em1Text'], rms[0]['em2Text'], (rms[0]['label'] if 'label' in rms[0] else None)
prod, feat, rel = rms[0]['em1Text'], rms[0]['em2Text'], (rms[0]['label'] if 'label' in rms[0] else None)
else:
return None # raise AttributeError('Instances must have exactly one relation')
text = row['sentText']
return PairRelDataset.get_instance(text, e1, e2, label=label)
return PairRelDataset.get_instance(text, prod, feat, relation=rel)
@staticmethod
def get_instance(text, e1, e2, label=None):
def get_instance(text, prod, feat, relation=None):
tokens = tokenizer.tokenize(text)
i = 0
......@@ -104,12 +92,12 @@ class PairRelDataset(Dataset):
ranges = []
while i < len(tokens):
match = False
for entity in [e1, e2]:
for entity in [prod, feat]:
match_length = PairRelDataset.token_entity_match(i, entity.lower(), tokens)
if match_length is not None:
if entity in found_entities:
return None # raise AttributeError('Entity {} appears twice in text {}'.format(entity, text))
tokens[i:i + match_length] = [MASK_TOKEN] * match_length
tokens[i:i + match_length] = [PROD_TOKEN if entity == prod else FEAT_TOKEN] * match_length
match = True
found_entities.append(entity)
ranges.append((i, i + match_length - 1))
......@@ -118,20 +106,11 @@ class PairRelDataset(Dataset):
if not match:
i += 1
if found_entities == [e1, e2] and label is None:
return PairRelInstance(tokens, e1, e2, tuple(ranges), None, text)
elif found_entities == [e2, e1] and label is None:
return PairRelInstance(tokens, e2, e1, tuple(ranges), None, text)
if len(found_entities) == 2 and label in ['/no_relation', '/nan_relation']:
return PairRelInstance(tokens, e1, e2, tuple(ranges), 0, text)
elif found_entities == [e1, e2] and label == '/has_feature':
return PairRelInstance(tokens, e1, e2, tuple(ranges), 1, text)
elif found_entities == [e2, e1] and label == '/has_feature':
return PairRelInstance(tokens, e2, e1, tuple(ranges), 2, text)
else:
if len(found_entities) != 2:
return None # raise AttributeError('Could not find entities {} and {} in {}. Found entities {}'.format(e1, e2, text, found_entities))
return PairRelInstance(tokens, prod, feat, tuple(ranges), RELATION_LABEL_MAP[relation], text)
@staticmethod
def token_entity_match(first_token_idx, entity, tokens):
token_idx = first_token_idx
......@@ -167,28 +146,16 @@ class PairRelDataset(Dataset):
class PairRelInstance:
def __init__(self, tokens, fst_e, snd_e, entity_ranges, label, text):
def __init__(self, tokens, prod, feat, entity_ranges, label, text):
self.tokens = tokens
self.fst_e = fst_e
self.snd_e = snd_e
self.prod = prod
self.feat = feat
self.entity_ranges = entity_ranges
self.label = label
self.text = text
def get_relation_for_label(self, label):
if label == 0:
return self.fst_e, '/no_relation', self.snd_e
return self.prod, '/no_relation', self.feat
if label == 1:
return self.fst_e, '/has_feature', self.snd_e
if label == 2:
return self.snd_e, '/has_feature', self.fst_e
# def range_to_entity(self, e_range):
# start, end = e_range
# text = self.tokens[start]
# for t in self.tokens[start + 1:end]:
# if t.startswith('##'):
# text += t[2:]
# else:
# text += ' ' + t
# return text
return self.prod, '/has_feature', self.feat
......@@ -4,7 +4,7 @@ from transformers import *
HIDDEN_OUTPUT_FEATURES = 768
TRAINED_WEIGHTS = 'bert-base-uncased'
NUM_CLASSES = 3 # no relation, e1 featureOf e2, e2 featureOf e1
NUM_CLASSES = 2 # no relation, prod hasFeature feat
class PairBertNet(nn.Module):
......@@ -13,19 +13,11 @@ class PairBertNet(nn.Module):
super(PairBertNet, self).__init__()
config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES) # 2 * n of hidden features, n of output labels
self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES)
def forward(self, input_ids, attn_mask, fst_e_indices, snd_e_indices):
def forward(self, input_ids, attn_mask):
# BERT
_, pooler_output = self.bert_base(input_ids=input_ids, attention_mask=attn_mask)
# max pooling at entity locations
# fst_e_outputs = torch.gather(bert_output, dim=1, index=fst_e_indices)
# snd_e_outputs = torch.gather(bert_output, dim=1, index=snd_e_indices)
# fst_pooled_output, _ = torch.max( , dim=1)
# snd_pooled_output, _ = torch.max(snd_e_outputs, dim=1)
# combined = torch.cat((fst_pooled_output, snd_pooled_output), dim=1)
# fc layer (softmax activation done in loss function)
x = self.fc(pooler_output)
return x
......@@ -317,3 +317,7 @@ class EntityAnnotator:
rels.append({'em1Text': e1.name, 'em2Text': e2.name, 'label': '/has_feature'})
child_entities.append(e2)
return text, rels
ann: EntityAnnotator = EntityAnnotator.load_saved('acoustic_guitar_annotator.pickle')
ann.save_annotated_pairs('BERT/data/annotated_acoustic_guitar_review_pairs.tsv')
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment