Commit a28029dd authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Started working on data preparations for net for NYT data

parent 886f11a6
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import numpy as np
from sklearn import metrics
import time
from rel_dataset import RelInstance, RelDataset, generate_batch from rel_dataset import RelInstance, RelDataset, generate_batch
from relbertnet import RelBertNet from relbertnet import RelBertNet, NUM_RELS
train_set_path = 'data/location_contains_train_set.tsv'
trained_model_path = 'bert_extractor.pt'
BATCH_SIZE = 32
MAX_EPOCHS = 6
LEARNING_RATE = 0.00002
loss_criterion = nn.CrossEntropyLoss()
def loss(ner_output, rc_output, ner_loss, true_rels):
return torch.sum(ner_loss)
class BertExtractor: class BertExtractor:
pass
@staticmethod
def default():
sa = BertExtractor()
sa.load_saved(trained_model_path)
return sa
def load_saved(self, path):
self.net = RelBertNet()
self.net.load_state_dict(torch.load(path))
self.net.eval()
def train(self, data_file):
train_data = RelDataset.from_file(data_file)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
collate_fn=generate_batch)
self.net = RelBertNet()
optimiser = optim.Adam(self.net.parameters(), lr=LEARNING_RATE)
start = time.time()
for epoch in range(MAX_EPOCHS):
batch_loss = 0.0
for idx, (batch, true_ner_tags, true_rels) in enumerate(train_loader):
# zero param gradients
optimiser.zero_grad()
# forward pass
ner_output, rc_output, ner_loss = self.net(batch, true_ner_tags)
# backward pass
l = loss(ner_output, rc_output, ner_loss, true_rels)
l.backward()
# optimise
optimiser.step()
# print interim stats every 10 batches
batch_loss += l.item()
if idx % 10 == 9:
print('epoch:', epoch + 1, '-- batch:', idx + 1, '-- avg loss:', batch_loss / 10)
batch_loss = 0.0
end = time.time()
print('Training took', end - start, 'seconds')
torch.save(self.net.state_dict(), trained_model_path)
def evaluate(self, data_file):
test_data = RelDataset.from_file(data_file)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
collate_fn=generate_batch)
predicted_ner_tags = []
predicted_rels = []
true_ner_tags = []
true_rels = []
with torch.no_grad():
for batch, b_true_ner_tags, b_true_rels in test_loader:
ner_output, rc_output, _ = self.net(batch)
_, rc_max = torch.max(rc_output.data, 1)
predicted_ner_tags += ner_output.tolist()
predicted_rels += rc_max.tolist()
true_rels += b_true_rels.tolist()
true_ner_tags += b_true_ner_tags.tolist()
for pred, truth in [(predicted_ner_tags, true_ner_tags), (predicted_rels, true_rels)]:
correct = (np.array(pred) == np.array(truth))
accuracy = correct.sum() / correct.size
print('accuracy:', accuracy)
cm = metrics.confusion_matrix(true_rels, predicted_rels, labels=range(NUM_RELS))
print('confusion matrix:')
print(cm)
dataset = RelDataset.from_texts(['Testing if this works.', 'Since I\'m not sure.']) f1 = metrics.f1_score(true_rels, predicted_rels, labels=range(NUM_RELS), average='macro')
loader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=4, collate_fn=generate_batch) print('macro F1:', f1)
batch = next(iter(loader))
i = RelInstance('Testing if this works.')
net = RelBertNet() # dataset = RelDataset.from_texts(['Testing if this works.', 'Since I\'m not sure.'])
net(batch) # loader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=4, collate_fn=generate_batch)
# batch = next(iter(loader))
#
# i = RelInstance('Testing if this works.')
#
# net = RelBertNet()
# net(batch)
\ No newline at end of file
from transformers import BertTokenizer from transformers import BertTokenizer
from torch.utils.data import Dataset from torch.utils.data import Dataset
import pandas as pd
from relbertnet import TRAINED_WEIGHTS from relbertnet import TRAINED_WEIGHTS
MAX_SEQ_LEN = 128 MAX_SEQ_LEN = 128
...@@ -10,7 +11,6 @@ def generate_batch(batch): ...@@ -10,7 +11,6 @@ def generate_batch(batch):
texts = tokenizer.batch_encode_plus([tokens for tokens in batch], add_special_tokens=True, texts = tokenizer.batch_encode_plus([tokens for tokens in batch], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True, max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt') return_tensors='pt')
return texts return texts
...@@ -19,10 +19,42 @@ class RelDataset(Dataset): ...@@ -19,10 +19,42 @@ class RelDataset(Dataset):
def __init__(self): def __init__(self):
self.data = [] self.data = []
@staticmethod
def from_data(path):
dataset = RelDataset()
data = pd.read_csv(path, sep='\t', error_bad_lines=False)
dataset.data = [RelDataset.instance_from_row(row) for _, row in data.iterrows()]
return dataset
@staticmethod
def instance_from_row(row):
tokens = row['sentText'].split(' ')
entities = sorted([em['text'].split(' ') for em in row['entityMentions']], key=len, reverse=True)
relations = [(m['em1Text'], m['em2Text']) for m in row['relationMentions']]
ne_tags = []
i = 0
while i < len(tokens):
found = False
for entity in entities:
if tokens[i:i+len(entity)] == entity:
ne_tags += RelDataset.ne_tags_for_len(len(entity))
found = True
i += len(entity)
break
if not found:
ne_tags += [3]
i += 1
return RelInstance.from_tokens(tokens, ne_tags, relations)
@staticmethod
def ne_tags_for_len(n):
assert n > 0
return [4] if n == 1 else [1] + [2] * (n-2) + [3]
@staticmethod @staticmethod
def from_texts(texts): def from_texts(texts):
dataset = RelDataset() dataset = RelDataset()
dataset.data = [RelInstance(text) for text in texts] dataset.data = [RelInstance.from_text(text) for text in texts]
return dataset return dataset
def __len__(self): def __len__(self):
...@@ -35,15 +67,30 @@ class RelDataset(Dataset): ...@@ -35,15 +67,30 @@ class RelDataset(Dataset):
class RelInstance: class RelInstance:
def __init__(self, text): def __init__(self):
self.text = text self.tokens = None
self.ne_tags = None
self.relations = None
@staticmethod
def from_text(text):
i = RelInstance()
i.tokens = tokenizer.tokenize(text)
return i
@staticmethod
def from_tokens(tokens, ne_tags, relations):
i = RelInstance()
i.tokens = tokens
i.ne_tags = ne_tags
i.relations = relations
return i
def get(self): def get(self):
tokens = tokenizer.tokenize(self.text) return self.tokens
return tokens
def to_tensor(self): def to_tensor(self):
tokens = self.get() tokens = self.get()
encoded = tokenizer.encode_plus(tokens, add_special_tokens=True, max_length=MAX_SEQ_LEN, encoded = tokenizer.encode_plus(tokens, add_special_tokens=True, max_length=MAX_SEQ_LEN,
is_pretokenized=True, return_tensors='pt') is_pretokenized=True, return_tensors='pt')
return encoded return encoded
\ No newline at end of file
...@@ -37,7 +37,7 @@ class RelBertNet(nn.Module): ...@@ -37,7 +37,7 @@ class RelBertNet(nn.Module):
self.mlp1 = nn.Linear(HIDDEN_OUTPUT_FEATURES, MLP_HIDDEN_LAYER_NODES) self.mlp1 = nn.Linear(HIDDEN_OUTPUT_FEATURES, MLP_HIDDEN_LAYER_NODES)
self.mlp2 = nn.Linear(MLP_HIDDEN_LAYER_NODES, NUM_RELS) self.mlp2 = nn.Linear(MLP_HIDDEN_LAYER_NODES, NUM_RELS)
def forward(self, encoded_text, ne_tags=None): def forward(self, encoded_text, ner_tags=None):
attn_mask = encoded_text['attention_mask'] attn_mask = encoded_text['attention_mask']
# BERT1 with MASKall for context # BERT1 with MASKall for context
...@@ -55,7 +55,7 @@ class RelBertNet(nn.Module): ...@@ -55,7 +55,7 @@ class RelBertNet(nn.Module):
emissions = self.ner_linear(bert_ner_output) emissions = self.ner_linear(bert_ner_output)
ner_output = self.crf.decode(emissions, mask=crf_attn_mask) ner_output = self.crf.decode(emissions, mask=crf_attn_mask)
# calculate loss if tags provided # calculate loss if tags provided
ner_loss = -self.crf(emissions, ne_tags, mask=crf_attn_mask) if ne_tags is not None else None ner_loss = -self.crf(emissions, ner_tags, mask=crf_attn_mask, reduction='mean') if ner_tags else None
# obtain pairs of entities # obtain pairs of entities
entities_by_instance = [RelBertNet.bieos_to_entities(tags) for tags in ner_output] entities_by_instance = [RelBertNet.bieos_to_entities(tags) for tags in ner_output]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment