Commit 1a7440c3 authored by Joel Oksanen's avatar Joel Oksanen

Implemented one way pair extractor

parent 4f691588
......@@ -11,7 +11,7 @@ from transformers import get_linear_schedule_with_warmup
from agent.target_extraction.BERT.pair_rel_dataset import PairRelDataset, generate_batch, generate_production_batch
from agent.target_extraction.BERT.pairbertnet import NUM_CLASSES, PairBertNet
trained_model_path = 'trained_bert_rel_extractor_camera_and_backpack_with_nan.pt'
trained_model_path = 'trained_bert_rel_extractor_camera_backpack_laptop_pair.pt'
device = torch.device('cuda')
loss_criterion = CrossEntropyLoss()
......@@ -22,14 +22,14 @@ MAX_GRAD_NORM = 1.0
# training
N_EPOCHS = 3
BATCH_SIZE = 32
BATCH_SIZE = 16
WARM_UP_FRAC = 0.05
class BertRelExtractor:
def __init__(self):
self.net = None
self.net = PairBertNet()
@staticmethod
def load_saved(path):
......@@ -58,7 +58,6 @@ class BertRelExtractor:
collate_fn=generate_batch)
# initialise BERT
self.net = PairBertNet()
self.net.cuda()
# set up optimizer with weight decay
......@@ -80,13 +79,13 @@ class BertRelExtractor:
for batch_idx, batch in enumerate(train_loader):
# send batch to gpu
input_ids, attn_mask, target_labels = tuple(i.to(device) for i in batch)
input_ids, attn_mask, prod_indices, feat_indices, target_labels = tuple(i.to(device) for i in batch)
# zero param gradients
optimiser.zero_grad()
# forward pass
output_scores = self.net(input_ids, attn_mask)
output_scores = self.net(input_ids, attn_mask, prod_indices, feat_indices)
# backward pass
loss = loss_criterion(output_scores, target_labels)
......@@ -140,10 +139,10 @@ class BertRelExtractor:
with torch.no_grad():
for batch in test_loader:
# send batch to gpu
input_ids, attn_mask, target_labels = tuple(i.to(device) for i in batch)
input_ids, attn_mask, prod_indices, feat_indices, target_labels = tuple(i.to(device) for i in batch)
# forward pass
output_scores = self.net(input_ids, attn_mask)
output_scores = self.net(input_ids, attn_mask, prod_indices, feat_indices)
_, output_labels = torch.max(output_scores.data, 1)
outputs += output_labels.tolist()
......@@ -170,17 +169,18 @@ class BertRelExtractor:
def extract_single_relation(self, text, e1, e2):
ins = PairRelDataset.get_instance(text, e1, e2)
input_ids, attn_mask, instances = generate_production_batch([ins])
input_ids, attn_mask, prod_indices, feat_indices, instances = generate_production_batch([ins])
self.net.cuda()
self.net.eval()
with torch.no_grad():
# send batch to gpu
input_ids, attn_mask = tuple(i.to(device) for i in [input_ids, attn_mask])
input_ids, attn_mask, prod_indices, feat_indices = tuple(i.to(device) for i in [input_ids, attn_mask,
prod_indices, feat_indices])
# forward pass
output_scores = softmax(self.net(input_ids, attn_mask), dim=1)
output_scores = softmax(self.net(input_ids, attn_mask, prod_indices, feat_indices), dim=1)
_, output_labels = torch.max(output_scores.data, 1)
print(instances[0].get_relation_for_label(output_labels[0]))
......@@ -203,12 +203,15 @@ class BertRelExtractor:
outputs = []
with torch.no_grad():
for input_ids, attn_mask, instances in loader:
for input_ids, attn_mask, prod_indices, feat_indices, instances in loader:
# send batch to gpu
input_ids, attn_mask = tuple(i.to(device) for i in [input_ids, attn_mask])
input_ids, attn_mask, prod_indices, feat_indices = tuple(i.to(device) for i in [input_ids,
attn_mask,
prod_indices,
feat_indices])
# forward pass
output_scores = softmax(self.net(input_ids, attn_mask), dim=1)
output_scores = softmax(self.net(input_ids, attn_mask, prod_indices, feat_indices), dim=1)
_, output_labels = torch.max(output_scores.data, 1)
outputs += map(lambda x: x[0].get_relation_for_label(x[1]), zip(instances, output_labels.tolist()))
......@@ -222,5 +225,5 @@ class BertRelExtractor:
return outputs
extr: BertRelExtractor = BertRelExtractor.load_saved('trained_bert_rel_extractor_camera_and_backpack_with_nan.pt')
extr.evaluate('data/annotated_camera_review_pairs.tsv', size=10000)
extr: BertRelExtractor = BertRelExtractor.load_saved('trained_bert_rel_extractor_camera_backpack_laptop_pair.pt')
extr.extract_relations(file_path='data/annotated_acoustic_guitar_review_pairs.tsv', size=50)
......@@ -9,10 +9,8 @@ from agent.target_extraction.BERT.pairbertnet import TRAINED_WEIGHTS, HIDDEN_OUT
MAX_SEQ_LEN = 128
RELATIONS = ['/has_feature', '/no_relation']
RELATION_LABEL_MAP = {None: None, '/has_feature': 1, '/no_relation': 0}
PROD_TOKEN = '[MASK]'
FEAT_TOKEN = '[MASK]'
MASK_TOKEN = '[MASK]'
tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)
tokenizer.add_tokens([PROD_TOKEN, FEAT_TOKEN])
def generate_batch(batch):
......@@ -23,7 +21,10 @@ def generate_batch(batch):
attn_mask = encoded['attention_mask']
labels = torch.tensor([instance.label for instance in batch])
return input_ids, attn_mask, labels
both_ranges = [(instance.prod_range, instance.feat_range) for instance in batch]
prod_indices, feat_indices = map(indices_for_entity_ranges, zip(*both_ranges))
return input_ids, attn_mask, prod_indices, feat_indices, labels
def generate_production_batch(batch):
......@@ -33,7 +34,18 @@ def generate_production_batch(batch):
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
return input_ids, attn_mask, batch
both_ranges = [(instance.prod_range, instance.feat_range) for instance in batch]
prod_indices, feat_indices = map(indices_for_entity_ranges, zip(*both_ranges))
return input_ids, attn_mask, prod_indices, feat_indices, batch
def indices_for_entity_ranges(ranges):
max_e_len = max(end - start for start, end in ranges)
indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES
for t in range(start, start + max_e_len + 1)]
for start, end in ranges])
return indices
class PairRelDataset(Dataset):
......@@ -89,7 +101,7 @@ class PairRelDataset(Dataset):
i = 0
found_entities = []
ranges = []
ranges = {}
while i < len(tokens):
match = False
for entity in [prod, feat]:
......@@ -97,10 +109,10 @@ class PairRelDataset(Dataset):
if match_length is not None:
if entity in found_entities:
return None # raise AttributeError('Entity {} appears twice in text {}'.format(entity, text))
tokens[i:i + match_length] = [PROD_TOKEN if entity == prod else FEAT_TOKEN] * match_length
tokens[i:i + match_length] = [MASK_TOKEN] * match_length
match = True
found_entities.append(entity)
ranges.append((i, i + match_length - 1))
ranges[entity] = (i + 1, i + match_length) # + 1 taking into account the [CLS] token
i += match_length
break
if not match:
......@@ -109,7 +121,7 @@ class PairRelDataset(Dataset):
if len(found_entities) != 2:
return None # raise AttributeError('Could not find entities {} and {} in {}. Found entities {}'.format(e1, e2, text, found_entities))
return PairRelInstance(tokens, prod, feat, tuple(ranges), RELATION_LABEL_MAP[relation], text)
return PairRelInstance(tokens, prod, feat, ranges[prod], ranges[feat], RELATION_LABEL_MAP[relation], text)
@staticmethod
def token_entity_match(first_token_idx, entity, tokens):
......@@ -146,11 +158,12 @@ class PairRelDataset(Dataset):
class PairRelInstance:
def __init__(self, tokens, prod, feat, entity_ranges, label, text):
def __init__(self, tokens, prod, feat, prod_range, feat_range, label, text):
self.tokens = tokens
self.prod = prod
self.feat = feat
self.entity_ranges = entity_ranges
self.prod_range = prod_range
self.feat_range = feat_range
self.label = label
self.text = text
......
......@@ -13,11 +13,21 @@ class PairBertNet(nn.Module):
super(PairBertNet, self).__init__()
config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES)
self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES * 2, NUM_CLASSES)
def forward(self, input_ids, attn_mask):
def forward(self, input_ids, attn_mask, prod_indices, feat_indices):
# BERT
_, pooler_output = self.bert_base(input_ids=input_ids, attention_mask=attn_mask)
bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask)
# max pooling at entity locations
prod_outputs = torch.gather(bert_output, dim=1, index=prod_indices)
feat_outputs = torch.gather(bert_output, dim=1, index=feat_indices)
prod_pooled_output, _ = torch.max(prod_outputs, dim=1)
feat_pooled_output, _ = torch.max(feat_outputs, dim=1)
# concat pooled outputs from prod and feat entities
combined = torch.cat((prod_pooled_output, feat_pooled_output), dim=1)
# fc layer (softmax activation done in loss function)
x = self.fc(pooler_output)
x = self.fc(combined)
return x
......@@ -235,11 +235,12 @@ class EntityAnnotator:
return {'em1Text': m[0], 'em2Text': m[1], 'label': '/no_relation'}
def pair_relations_for_text(self, text, nan_entities):
tokens = self.phraser[word_tokenize(text)]
single_tokens = word_tokenize(text)
all_tokens = set().union(*[single_tokens, self.phraser[single_tokens]])
entity_mentions = []
for n in PreOrderIter(self.root):
cont, mention = self.mention_in_text(text, tokens, node=n)
cont, mention = self.mention_in_text(all_tokens, node=n)
if not cont:
# many mentions of same entity
return None
......@@ -255,7 +256,7 @@ class EntityAnnotator:
if len(entity_mentions) == 1:
nan_mention = None
for term in nan_entities:
cont, mention = self.mention_in_text(text, tokens, term=term)
cont, mention = self.mention_in_text(all_tokens, term=term)
if not cont:
# many mentions of term
return None
......@@ -272,7 +273,7 @@ class EntityAnnotator:
# returns True, (synonym of node / term / None) if there is exactly one or zero such occurrence,
# otherwise False, None
def mention_in_text(self, text, tokens, node=None, term=None):
def mention_in_text(self, tokens, node=None, term=None):
mention = None
for syn in ({syn.lower() for syn in self.synset[node]} if node is not None else {term}):
n_matches = sum(1 for token in tokens if syn.lower() == token.lower().replace('_', ' '))
......@@ -320,4 +321,4 @@ class EntityAnnotator:
ann: EntityAnnotator = EntityAnnotator.load_saved('acoustic_guitar_annotator.pickle')
ann.save_annotated_pairs('BERT/data/annotated_acoustic_guitar_review_pairs.tsv')
\ No newline at end of file
ann.save_annotated_pairs('BERT/data/annotated_acoustic_guitar_review_pairs.tsv')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment