Commit 9343c134 authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Added sub batching for RC with NE pairs

parent 6ac7c519
......@@ -26,13 +26,18 @@ class RelDataset(Dataset):
def from_file(path):
dataset = RelDataset()
data = pd.read_csv(path, sep='\t', error_bad_lines=False)
dataset.data = [RelDataset.instance_from_row(row) for _, row in data.iterrows()]
rows = [row for _, row in data.iterrows()]
dataset.data = [x for x in map(RelDataset.instance_from_row, rows) if x is not None]
return dataset
@staticmethod
def instance_from_row(row):
text = row['sentText']
tokens = tokenizer.tokenize(text)
if len(tokens) > NE_TAGS_LEN:
return None # include only texts that can be represented in 126 tokens, filters out 98 texts from NYT train
entities = sorted([em['text'] for em in literal_eval(row['entityMentions'])], key=len, reverse=True)
data_relations = [(m['em1Text'], m['em2Text']) for m in literal_eval(row['relationMentions'])]
relations = {(entity1, entity2): 1 if (entity1, entity2) in data_relations
......
......@@ -88,18 +88,26 @@ class RelBertNet(nn.Module):
rc_attn_mask[i][slice1] = 1
rc_attn_mask[i][slice2] = 1
bert_rc_output = torch.repeat_interleave(bert_context_output, n_combinations_by_instance, dim=0)
extended_rc_attn_mask = rc_attn_mask[:, None, None, :]
for layer in self.bert2_layers:
bert_rc_output, = layer(bert_rc_output, attention_mask=extended_rc_attn_mask)
bert_context_expanded = torch.repeat_interleave(bert_context_output, n_combinations_by_instance, dim=0)
sub_batch_outputs = []
for sub_batch_idx in range(0, len(entity_combinations), BATCH_SIZE):
sub_batch_end = sub_batch_idx + BATCH_SIZE
bert_rc_output = bert_context_expanded[sub_batch_idx:sub_batch_end]
extended_rc_attn_mask = rc_attn_mask[sub_batch_idx:sub_batch_end][:, None, None, :]
for layer in self.bert2_layers:
bert_rc_output, = layer(bert_rc_output, attention_mask=extended_rc_attn_mask)
# MLP for RC
rc_cls_output = bert_rc_output.narrow(1, 0, 1).squeeze(1) # just CLS token
rc_hidden_layer_output = torch.tanh(self.mlp1(rc_cls_output)) # tanh activation
rc_output = self.mlp2(rc_hidden_layer_output) # softmax activation
sub_batch_outputs.append(rc_output)
# MLP for RC
rc_cls_output = bert_rc_output.narrow(1, 0, 1).squeeze(1) # just CLS token
rc_hidden_layer_output = torch.tanh(self.mlp1(rc_cls_output)) # tanh activation
rc_output = self.mlp2(rc_hidden_layer_output) # softmax activation
total_rc_output = torch.cat(sub_batch_outputs, dim=0)
# Return NER and RC outputs
return ner_output, ner_loss, rc_output, torch.tensor(target_relation_labels, dtype=torch.long)
return ner_output, ner_loss, total_rc_output, torch.tensor(target_relation_labels, dtype=torch.long)
@staticmethod
def bieos_to_entities(tags):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment