Commit a28029dd authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Started working on data preparations for net for NYT data

parent 886f11a6
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from sklearn import metrics
import time
from rel_dataset import RelInstance, RelDataset, generate_batch
from relbertnet import RelBertNet
from relbertnet import RelBertNet, NUM_RELS
train_set_path = 'data/location_contains_train_set.tsv'
trained_model_path = 'bert_extractor.pt'
BATCH_SIZE = 32
MAX_EPOCHS = 6
LEARNING_RATE = 0.00002
loss_criterion = nn.CrossEntropyLoss()
def loss(ner_output, rc_output, ner_loss, true_rels):
return torch.sum(ner_loss)
class BertExtractor:
pass
@staticmethod
def default():
sa = BertExtractor()
sa.load_saved(trained_model_path)
return sa
def load_saved(self, path):
self.net = RelBertNet()
self.net.load_state_dict(torch.load(path))
self.net.eval()
def train(self, data_file):
train_data = RelDataset.from_file(data_file)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
collate_fn=generate_batch)
self.net = RelBertNet()
optimiser = optim.Adam(self.net.parameters(), lr=LEARNING_RATE)
start = time.time()
for epoch in range(MAX_EPOCHS):
batch_loss = 0.0
for idx, (batch, true_ner_tags, true_rels) in enumerate(train_loader):
# zero param gradients
optimiser.zero_grad()
# forward pass
ner_output, rc_output, ner_loss = self.net(batch, true_ner_tags)
# backward pass
l = loss(ner_output, rc_output, ner_loss, true_rels)
l.backward()
# optimise
optimiser.step()
# print interim stats every 10 batches
batch_loss += l.item()
if idx % 10 == 9:
print('epoch:', epoch + 1, '-- batch:', idx + 1, '-- avg loss:', batch_loss / 10)
batch_loss = 0.0
end = time.time()
print('Training took', end - start, 'seconds')
torch.save(self.net.state_dict(), trained_model_path)
def evaluate(self, data_file):
test_data = RelDataset.from_file(data_file)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
collate_fn=generate_batch)
predicted_ner_tags = []
predicted_rels = []
true_ner_tags = []
true_rels = []
with torch.no_grad():
for batch, b_true_ner_tags, b_true_rels in test_loader:
ner_output, rc_output, _ = self.net(batch)
_, rc_max = torch.max(rc_output.data, 1)
predicted_ner_tags += ner_output.tolist()
predicted_rels += rc_max.tolist()
true_rels += b_true_rels.tolist()
true_ner_tags += b_true_ner_tags.tolist()
for pred, truth in [(predicted_ner_tags, true_ner_tags), (predicted_rels, true_rels)]:
correct = (np.array(pred) == np.array(truth))
accuracy = correct.sum() / correct.size
print('accuracy:', accuracy)
cm = metrics.confusion_matrix(true_rels, predicted_rels, labels=range(NUM_RELS))
print('confusion matrix:')
print(cm)
dataset = RelDataset.from_texts(['Testing if this works.', 'Since I\'m not sure.'])
loader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=4, collate_fn=generate_batch)
batch = next(iter(loader))
f1 = metrics.f1_score(true_rels, predicted_rels, labels=range(NUM_RELS), average='macro')
print('macro F1:', f1)
i = RelInstance('Testing if this works.')
net = RelBertNet()
net(batch)
# dataset = RelDataset.from_texts(['Testing if this works.', 'Since I\'m not sure.'])
# loader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=4, collate_fn=generate_batch)
# batch = next(iter(loader))
#
# i = RelInstance('Testing if this works.')
#
# net = RelBertNet()
# net(batch)
\ No newline at end of file
from transformers import BertTokenizer
from torch.utils.data import Dataset
import pandas as pd
from relbertnet import TRAINED_WEIGHTS
MAX_SEQ_LEN = 128
......@@ -10,7 +11,6 @@ def generate_batch(batch):
texts = tokenizer.batch_encode_plus([tokens for tokens in batch], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
return texts
......@@ -19,10 +19,42 @@ class RelDataset(Dataset):
def __init__(self):
self.data = []
@staticmethod
def from_data(path):
dataset = RelDataset()
data = pd.read_csv(path, sep='\t', error_bad_lines=False)
dataset.data = [RelDataset.instance_from_row(row) for _, row in data.iterrows()]
return dataset
@staticmethod
def instance_from_row(row):
tokens = row['sentText'].split(' ')
entities = sorted([em['text'].split(' ') for em in row['entityMentions']], key=len, reverse=True)
relations = [(m['em1Text'], m['em2Text']) for m in row['relationMentions']]
ne_tags = []
i = 0
while i < len(tokens):
found = False
for entity in entities:
if tokens[i:i+len(entity)] == entity:
ne_tags += RelDataset.ne_tags_for_len(len(entity))
found = True
i += len(entity)
break
if not found:
ne_tags += [3]
i += 1
return RelInstance.from_tokens(tokens, ne_tags, relations)
@staticmethod
def ne_tags_for_len(n):
assert n > 0
return [4] if n == 1 else [1] + [2] * (n-2) + [3]
@staticmethod
def from_texts(texts):
dataset = RelDataset()
dataset.data = [RelInstance(text) for text in texts]
dataset.data = [RelInstance.from_text(text) for text in texts]
return dataset
def __len__(self):
......@@ -35,15 +67,30 @@ class RelDataset(Dataset):
class RelInstance:
def __init__(self, text):
self.text = text
def __init__(self):
self.tokens = None
self.ne_tags = None
self.relations = None
@staticmethod
def from_text(text):
i = RelInstance()
i.tokens = tokenizer.tokenize(text)
return i
@staticmethod
def from_tokens(tokens, ne_tags, relations):
i = RelInstance()
i.tokens = tokens
i.ne_tags = ne_tags
i.relations = relations
return i
def get(self):
tokens = tokenizer.tokenize(self.text)
return tokens
return self.tokens
def to_tensor(self):
tokens = self.get()
encoded = tokenizer.encode_plus(tokens, add_special_tokens=True, max_length=MAX_SEQ_LEN,
is_pretokenized=True, return_tensors='pt')
return encoded
\ No newline at end of file
return encoded
......@@ -37,7 +37,7 @@ class RelBertNet(nn.Module):
self.mlp1 = nn.Linear(HIDDEN_OUTPUT_FEATURES, MLP_HIDDEN_LAYER_NODES)
self.mlp2 = nn.Linear(MLP_HIDDEN_LAYER_NODES, NUM_RELS)
def forward(self, encoded_text, ne_tags=None):
def forward(self, encoded_text, ner_tags=None):
attn_mask = encoded_text['attention_mask']
# BERT1 with MASKall for context
......@@ -55,7 +55,7 @@ class RelBertNet(nn.Module):
emissions = self.ner_linear(bert_ner_output)
ner_output = self.crf.decode(emissions, mask=crf_attn_mask)
# calculate loss if tags provided
ner_loss = -self.crf(emissions, ne_tags, mask=crf_attn_mask) if ne_tags is not None else None
ner_loss = -self.crf(emissions, ner_tags, mask=crf_attn_mask, reduction='mean') if ner_tags else None
# obtain pairs of entities
entities_by_instance = [RelBertNet.bieos_to_entities(tags) for tags in ner_output]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment