Commit 816e6691 authored by Joel Oksanen's avatar Joel Oksanen

Reverted to bidirectional with fixed bugs

parent 1a7440c3
......@@ -11,7 +11,7 @@ from transformers import get_linear_schedule_with_warmup
from agent.target_extraction.BERT.pair_rel_dataset import PairRelDataset, generate_batch, generate_production_batch
from agent.target_extraction.BERT.pairbertnet import NUM_CLASSES, PairBertNet
trained_model_path = 'trained_bert_rel_extractor_camera_backpack_laptop_pair.pt'
trained_model_path = 'trained_bert_rel_extractor_camera_backpack_laptop_bi_directional.pt'
device = torch.device('cuda')
loss_criterion = CrossEntropyLoss()
......@@ -46,14 +46,18 @@ class BertRelExtractor:
return extractor
@staticmethod
def train_and_validate(file_path, valid_frac, size=None):
def train_and_validate(file_path, size=None, valid_frac=None, valid_file_path=None):
extractor = BertRelExtractor()
extractor.train_with_file(file_path, size=size, valid_frac=valid_frac)
extractor.train_with_file(file_path, size=size, valid_frac=valid_frac, valid_file_path=valid_file_path)
return extractor
def train_with_file(self, file_path, size=None, valid_frac=None):
def train_with_file(self, file_path, size=None, valid_frac=None, valid_file_path=None):
# load training data
train_data, valid_data = PairRelDataset.from_file(file_path, size=size, valid_frac=valid_frac)
if valid_file_path is None:
train_data, valid_data = PairRelDataset.from_file(file_path, size=size, valid_frac=valid_frac)
else:
train_data, _ = PairRelDataset.from_file(file_path, size=int(size * (1 - valid_frac)))
valid_data, _ = PairRelDataset.from_file(valid_file_path, size=int(size * valid_frac))
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
collate_fn=generate_batch)
......@@ -79,13 +83,13 @@ class BertRelExtractor:
for batch_idx, batch in enumerate(train_loader):
# send batch to gpu
input_ids, attn_mask, prod_indices, feat_indices, target_labels = tuple(i.to(device) for i in batch)
input_ids, attn_mask, fst_indices, snd_indices, target_labels = tuple(i.to(device) for i in batch)
# zero param gradients
optimiser.zero_grad()
# forward pass
output_scores = self.net(input_ids, attn_mask, prod_indices, feat_indices)
output_scores = self.net(input_ids, attn_mask, fst_indices, snd_indices)
# backward pass
loss = loss_criterion(output_scores, target_labels)
......@@ -139,10 +143,10 @@ class BertRelExtractor:
with torch.no_grad():
for batch in test_loader:
# send batch to gpu
input_ids, attn_mask, prod_indices, feat_indices, target_labels = tuple(i.to(device) for i in batch)
input_ids, attn_mask, fst_indices, snd_indices, target_labels = tuple(i.to(device) for i in batch)
# forward pass
output_scores = self.net(input_ids, attn_mask, prod_indices, feat_indices)
output_scores = self.net(input_ids, attn_mask, fst_indices, snd_indices)
_, output_labels = torch.max(output_scores.data, 1)
outputs += output_labels.tolist()
......@@ -161,10 +165,10 @@ class BertRelExtractor:
f1 = metrics.f1_score(targets, outputs, labels=range(NUM_CLASSES), average='macro')
print('macro F1:', f1)
precision = metrics.precision_score(targets, outputs, pos_label=1)
precision = metrics.precision_score(targets, outputs, average=None)
print('precision:', precision)
recall = metrics.recall_score(targets, outputs, pos_label=1)
recall = metrics.recall_score(targets, outputs, average=None)
print('recall:', recall)
def extract_single_relation(self, text, e1, e2):
......@@ -175,6 +179,9 @@ class BertRelExtractor:
self.net.eval()
with torch.no_grad():
embeddings = self.net.bert_base.get_input_embeddings()
print(torch.narrow(embeddings(input_ids.to(device)), 1, 1, 4))
# send batch to gpu
input_ids, attn_mask, prod_indices, feat_indices = tuple(i.to(device) for i in [input_ids, attn_mask,
prod_indices, feat_indices])
......@@ -225,5 +232,8 @@ class BertRelExtractor:
return outputs
extr: BertRelExtractor = BertRelExtractor.load_saved('trained_bert_rel_extractor_camera_backpack_laptop_pair.pt')
extr.extract_relations(file_path='data/annotated_acoustic_guitar_review_pairs.tsv', size=50)
extr: BertRelExtractor = BertRelExtractor.train_and_validate('data/camera_backpack_laptop_review_pairs.tsv',
size=5000,
valid_frac=0.05,
valid_file_path='data/annotated_acoustic_guitar_review_pairs.tsv')
......@@ -21,10 +21,10 @@ def generate_batch(batch):
attn_mask = encoded['attention_mask']
labels = torch.tensor([instance.label for instance in batch])
both_ranges = [(instance.prod_range, instance.feat_range) for instance in batch]
prod_indices, feat_indices = map(indices_for_entity_ranges, zip(*both_ranges))
both_ranges = [instance.ranges for instance in batch]
fst_indices, snd_indices = map(indices_for_entity_ranges, zip(*both_ranges))
return input_ids, attn_mask, prod_indices, feat_indices, labels
return input_ids, attn_mask, fst_indices, snd_indices, labels
def generate_production_batch(batch):
......@@ -34,10 +34,10 @@ def generate_production_batch(batch):
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
both_ranges = [(instance.prod_range, instance.feat_range) for instance in batch]
prod_indices, feat_indices = map(indices_for_entity_ranges, zip(*both_ranges))
both_ranges = [instance.ranges for instance in batch]
fst_indices, snd_indices = map(indices_for_entity_ranges, zip(*both_ranges))
return input_ids, attn_mask, prod_indices, feat_indices, batch
return input_ids, attn_mask, fst_indices, snd_indices, batch
def indices_for_entity_ranges(ranges):
......@@ -101,7 +101,7 @@ class PairRelDataset(Dataset):
i = 0
found_entities = []
ranges = {}
ranges = []
while i < len(tokens):
match = False
for entity in [prod, feat]:
......@@ -112,7 +112,7 @@ class PairRelDataset(Dataset):
tokens[i:i + match_length] = [MASK_TOKEN] * match_length
match = True
found_entities.append(entity)
ranges[entity] = (i + 1, i + match_length) # + 1 taking into account the [CLS] token
ranges.append((i + 1, i + match_length)) # + 1 taking into account the [CLS] token
i += match_length
break
if not match:
......@@ -121,7 +121,18 @@ class PairRelDataset(Dataset):
if len(found_entities) != 2:
return None # raise AttributeError('Could not find entities {} and {} in {}. Found entities {}'.format(e1, e2, text, found_entities))
return PairRelInstance(tokens, prod, feat, ranges[prod], ranges[feat], RELATION_LABEL_MAP[relation], text)
if relation is None:
return PairRelInstance(tokens, found_entities[0], found_entities[1], ranges, None, text)
if relation == '/has_feature':
if found_entities == [prod, feat]:
return PairRelInstance(tokens, found_entities[0], found_entities[1], ranges, 1, text)
else:
assert found_entities == [feat, prod]
return PairRelInstance(tokens, found_entities[0], found_entities[1], ranges, 2, text)
assert relation == '/no_relation'
return PairRelInstance(tokens, found_entities[0], found_entities[1], ranges, 0, text)
@staticmethod
def token_entity_match(first_token_idx, entity, tokens):
......@@ -158,17 +169,18 @@ class PairRelDataset(Dataset):
class PairRelInstance:
def __init__(self, tokens, prod, feat, prod_range, feat_range, label, text):
def __init__(self, tokens, fst, snd, ranges, label, text):
self.tokens = tokens
self.prod = prod
self.feat = feat
self.prod_range = prod_range
self.feat_range = feat_range
self.fst = fst
self.snd = snd
self.ranges = ranges
self.label = label
self.text = text
def get_relation_for_label(self, label):
if label == 0:
return self.prod, '/no_relation', self.feat
return self.fst, '/no_relation', self.snd
if label == 1:
return self.prod, '/has_feature', self.feat
return self.fst, '/has_feature', self.snd
if label == 2:
return self.snd, '/has_feature', self.fst
......@@ -4,7 +4,7 @@ from transformers import *
HIDDEN_OUTPUT_FEATURES = 768
TRAINED_WEIGHTS = 'bert-base-uncased'
NUM_CLASSES = 2 # no relation, prod hasFeature feat
NUM_CLASSES = 3 # no relation, fst hasFeature snd, snd hasFeature fst
class PairBertNet(nn.Module):
......@@ -20,13 +20,13 @@ class PairBertNet(nn.Module):
bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask)
# max pooling at entity locations
prod_outputs = torch.gather(bert_output, dim=1, index=prod_indices)
feat_outputs = torch.gather(bert_output, dim=1, index=feat_indices)
prod_pooled_output, _ = torch.max(prod_outputs, dim=1)
feat_pooled_output, _ = torch.max(feat_outputs, dim=1)
fst_outputs = torch.gather(bert_output, dim=1, index=prod_indices)
snd_outputs = torch.gather(bert_output, dim=1, index=feat_indices)
fst_pooled_output, _ = torch.max(fst_outputs, dim=1)
snd_pooled_output, _ = torch.max(snd_outputs, dim=1)
# concat pooled outputs from prod and feat entities
combined = torch.cat((prod_pooled_output, feat_pooled_output), dim=1)
combined = torch.cat((fst_pooled_output, snd_pooled_output), dim=1)
# fc layer (softmax activation done in loss function)
x = self.fc(combined)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment