Commit a7869ec7 authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Allow using string arrays as data into bert extractor.

parent 7f31ce78
......@@ -6,12 +6,12 @@ from torch.utils.data import DataLoader
import numpy as np
from sklearn import metrics
import time
from rel_dataset import RelInstance, RelDataset, generate_batch, tokenizer
from rel_dataset import RelInstance, RelDataset, generate_train_batch, generate_batch, tokenizer
from relbertnet import RelBertNet, NUM_RELS, BATCH_SIZE
train_data_path = 'data/train.json'
test_data_path = 'data/test.json'
trained_model_path = 'bert_extractor.pt'
trained_model_path = 'bert_extractor2.pt'
MAX_EPOCHS = 3
LEARNING_RATE = 0.00002
......@@ -40,7 +40,7 @@ class BertExtractor:
def train(self, data_file):
train_data = RelDataset.from_file(data_file, n_instances=40000)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
collate_fn=generate_batch)
collate_fn=generate_train_batch)
self.net = RelBertNet()
optimiser = optim.Adam(self.net.parameters(), lr=LEARNING_RATE)
......@@ -54,7 +54,7 @@ class BertExtractor:
optimiser.zero_grad()
# forward pass
_, ner_loss, rc_output, target_relation_labels = self.net(batch, instances, true_ner_tags)
_, ner_loss, rc_output, target_relation_labels, _ = self.net(batch, instances, ner_tags=true_ner_tags)
# backward pass
l = loss(ner_loss, rc_output, target_relation_labels)
......@@ -80,7 +80,7 @@ class BertExtractor:
def evaluate(self, data_file):
test_data = RelDataset.from_file(data_file)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
collate_fn=generate_batch)
collate_fn=generate_train_batch)
predicted_ner_tags = []
predicted_rels = []
......@@ -88,9 +88,9 @@ class BertExtractor:
true_rels = []
with torch.no_grad():
for idx, (batch, b_true_ner_tags, instances) in enumerate(test_loader):
for batch, b_true_ner_tags, instances in test_loader:
ner_output, _, rc_output, target_relation_labels = self.net(batch, instances)
ner_output, _, rc_output, target_relation_labels, _ = self.net(batch, instances)
_, predicted_relation_labels = torch.max(rc_output.data, 1)
predicted_ner_tags += ner_output
......@@ -119,29 +119,37 @@ class BertExtractor:
f1 = metrics.f1_score(true_rels, predicted_rels, labels=range(NUM_RELS), average='macro')
print('RC macro F1:', f1)
def extract_relations(self, texts):
data = RelDataset.from_texts(texts)
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
collate_fn=generate_batch)
# dataset = RelDataset.from_texts(['A giraffe made friends with a lion at the Zoo.', 'Since I\'m not sure.'])
# loader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=4, collate_fn=generate_batch)
# texts, _, _ = next(iter(loader))
#
# print(tokenizer.convert_ids_to_tokens(texts['input_ids'][0]))
# net = RelBertNet()
# ner_output, _, rc_output, n_combinations_by_instance, entity_ranges = net(texts)
# print(ner_output)
# print(rc_output)
# print(n_combinations_by_instance)
# print(entity_ranges)
# train_data = RelDataset.from_file(train_set_path)
# train_loader = DataLoader(train_data, batch_size=1, shuffle=True, num_workers=4,
# collate_fn=generate_batch)
# for i in [1, 1, 1, 1, 1]:
# batch, b_true_ner_tags, b_true_rels, instances = next(iter(train_loader))
# print(b_true_rels[0])
# print(tokenizer.convert_ids_to_tokens(batch['input_ids'][0]))
# print(b_true_ner_tags[0])
# print(instances[0].entities_for_tags(b_true_ner_tags[0]))
# print('---')
extractor = BertExtractor.default()
with torch.no_grad():
for batch, instances in loader:
ner_output, _, rc_output, _, decoded_entity_combinations = self.net(batch,
instances,
in_production=True)
_, rc_output = torch.max(rc_output.data, 1)
output_entities = [instances[i].entities_for_tags(ner_output[i]) for i in range(len(instances))]
output_relations = [r for r in
map(lambda i:
BertExtractor.get_relation(decoded_entity_combinations[i], rc_output[i]),
range(len(decoded_entity_combinations)))
if r is not None]
print(instances[0].text)
print(output_entities)
print(output_relations)
@staticmethod
def get_relation(entity_pair, label):
if label == 0:
return None
if label == 1:
return entity_pair
if label == 2:
return entity_pair[1], entity_pair[0]
extractor = BertExtractor()
extractor.train(train_data_path)
extractor.evaluate(test_data_path)
......@@ -10,7 +10,7 @@ NE_TAGS_LEN = MAX_SEQ_LEN - 2
tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)
def generate_batch(instances):
def generate_train_batch(instances):
texts = tokenizer.batch_encode_plus([instance.tokens for instance in instances], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
......@@ -18,6 +18,13 @@ def generate_batch(instances):
return texts, true_ne_tags, instances
def generate_batch(instances):
texts = tokenizer.batch_encode_plus([instance.tokens for instance in instances], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
return texts, instances
class RelDataset(Dataset):
MAX_RELATIONS = 8
......@@ -36,6 +43,12 @@ class RelDataset(Dataset):
print("Obtained dataset of length", len(dataset.data))
return dataset
@staticmethod
def from_texts(texts):
dataset = RelDataset()
dataset.data = [RelInstance.from_text(text) for text in texts]
return dataset
@staticmethod
def instance_from_row(row):
entities = sorted([em['text'] for em in row['entityMentions']], key=len, reverse=True)
......@@ -136,12 +149,6 @@ class RelDataset(Dataset):
assert n > 0
return [4] if n == 1 else [0] + [1] * (n-2) + [2]
@staticmethod
def from_texts(texts):
dataset = RelDataset()
dataset.data = [RelInstance.from_text(text) for text in texts]
return dataset
def __len__(self):
return len(self.data)
......@@ -156,12 +163,12 @@ class RelInstance:
self.ne_tags = None
self.relations = None
self.text = None
self.token_to_text = None
@staticmethod
def from_text(text):
i = RelInstance()
i.tokens = tokenizer.tokenize(text)
i.text = text
return i
@staticmethod
......@@ -172,12 +179,8 @@ class RelInstance:
i.relations = relations
return i
def get(self):
return self.tokens
def to_tensor(self):
tokens = self.get()
encoded = tokenizer.encode_plus(tokens, add_special_tokens=True, max_length=MAX_SEQ_LEN,
encoded = tokenizer.encode_plus(self.tokens, add_special_tokens=True, max_length=MAX_SEQ_LEN,
is_pretokenized=True, return_tensors='pt')
return encoded
......@@ -195,7 +198,7 @@ class RelInstance:
entity = self.tokens[start]
for idx, t in enumerate(self.tokens[start+1:end]):
if t.startswith('##'):
entity += t[:2]
entity += t[2:]
else:
entity += ' ' + t
return entity
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import *
from torchcrf import CRF
import itertools
......@@ -13,7 +12,7 @@ TRAINED_WEIGHTS = 'bert-base-cased' # cased works better for NER
NUM_NE_TAGS = 5 # BIEOS 0-4: [Begin Inside End Outside Single]
NUM_RELS = 3 # 0-2: [no relation, e1 featureOf e2, e2 featureOf e1]
MLP_HIDDEN_LAYER_NODES = 84
BATCH_SIZE = 32
BATCH_SIZE = 16
MAX_ENTITIES_PER_SENTENCE = 8
......@@ -33,6 +32,12 @@ class RelBertNet(nn.Module):
self.bert1.encoder.layer = self.bert1.encoder.layer[:-K]
self.n = config.num_hidden_layers
for p in self.bert1.parameters():
p.requires_grad = True
for p in self.bert2_layers.parameters():
p.requires_grad = True
self.ner_linear = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_NE_TAGS)
self.crf = CRF(NUM_NE_TAGS, batch_first=True)
self.crf.train()
......@@ -40,7 +45,7 @@ class RelBertNet(nn.Module):
self.mlp1 = nn.Linear(HIDDEN_OUTPUT_FEATURES, MLP_HIDDEN_LAYER_NODES)
self.mlp2 = nn.Linear(MLP_HIDDEN_LAYER_NODES, NUM_RELS)
def forward(self, encoded_text, instances=None, ner_tags=None):
def forward(self, encoded_text, instances, ner_tags=None, in_production=False):
attn_mask = encoded_text['attention_mask']
# BERT1 with MASKall for context
......@@ -65,21 +70,32 @@ class RelBertNet(nn.Module):
entity_combinations = []
target_relation_labels = []
decoded_entity_combinations = []
n_combinations_by_instance = torch.empty(len(entities_by_instance), dtype=torch.long)
for idx_ins, entities in enumerate(entities_by_instance):
n = 0
for slice1, slice2 in list(itertools.combinations(entities, 2)):
relation_label = instances[idx_ins].relation((slice1.start-1, slice1.stop-1),
(slice2.start-1, slice2.stop-1))
if relation_label is not None:
r1 = (slice1.start-1, slice1.stop-1)
r2 = (slice2.start-1, slice2.stop-1)
if in_production:
# production: add all pairs
entity_combinations.append((slice1, slice2))
target_relation_labels.append(relation_label)
n += 1
e1 = instances[idx_ins].entity_for_range(r1)
e2 = instances[idx_ins].entity_for_range(r2)
decoded_entity_combinations.append((e1, e2))
else:
# training: filter out pairs that don't match entities in RC
relation_label = instances[idx_ins].relation(r1, r2)
if relation_label is not None:
entity_combinations.append((slice1, slice2))
target_relation_labels.append(relation_label)
n += 1
n_combinations_by_instance[idx_ins] = n
# if no entity pairs, cannot find relations so return
if sum(n_combinations_by_instance) == 0:
return ner_output, ner_loss, None, None
return ner_output, ner_loss, None, None, None
# for each pair of named entities recognized, perform BERT2 with MASKrc for RC
rc_attn_mask = torch.zeros((len(entity_combinations), MAX_SEQ_LEN), dtype=torch.long)
......@@ -107,7 +123,11 @@ class RelBertNet(nn.Module):
total_rc_output = torch.cat(sub_batch_outputs, dim=0)
# Return NER and RC outputs
return ner_output, ner_loss, total_rc_output, torch.tensor(target_relation_labels, dtype=torch.long)
return (ner_output,
ner_loss,
total_rc_output,
torch.tensor(target_relation_labels, dtype=torch.long),
decoded_entity_combinations)
@staticmethod
def bieos_to_entities(tags):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment