Commit 915dd99a authored by  Joel  Oksanen's avatar Joel Oksanen
Browse files

Implemented tag based bert extractor.

parent a7869ec7
import torch
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy
import time
from transformers import BertForTokenClassification, AdamW, get_cosine_schedule_with_warmup
from tagged_rel_dataset import TRAINED_WEIGHTS, MAX_SEQ_LEN, RELATIONS, IGNORE_TAG, TaggedRelDataset, generate_train_batch
train_data_path = 'data/train.json'
test_data_path = 'data/test.json'
trained_model_path = 'trained_bert_tag_extractor.pt'
device = torch.device('cuda')
N_TAGS = 4 * len(RELATIONS) * 2 + 1 # 9 for single relation -> 0: O, 1-4: B/I/E/S-1, 5-8: B/I/E/S-2
# optimizer
DECAY_RATE = 0.01
LEARNING_RATE = 0.00002
EPSILON = 1e-8
MAX_GRAD_NORM = 1.0
# training
N_EPOCHS = 3
BATCH_SIZE = 32
WARM_UP_FRAC = 0.05
VALID_FRAC = 0.1
# loss
LOSS_BIAS_WEIGHT = 10
# as defined by Zheng et. al. (2017)
def biased_loss(output_scores, target_tags):
ce_losses = torch.empty((output_scores.shape[0], output_scores.shape[1]))
for idx in range(ce_losses.shape[0]):
ce_losses[idx] = cross_entropy(output_scores[idx], target_tags[idx], reduction='none', ignore_index=IGNORE_TAG)
bias_matrix = torch.where(target_tags == 0, torch.tensor(1), torch.tensor(10))
biased_losses = ce_losses * bias_matrix
return torch.sum(biased_losses)
class BertTagExtractor:
def __init__(self):
self.net = None
@staticmethod
def default():
ext = BertTagExtractor()
ext.load_saved(trained_model_path)
return ext
def load_saved(self, path):
self.net = BertForTokenClassification.from_pretrained(TRAINED_WEIGHTS, num_labels=N_TAGS)
self.net.load_state_dict(torch.load(path))
self.net.eval()
@staticmethod
def new_trained_with_file(file_path):
extractor = BertTagExtractor()
extractor.train_with_file(file_path)
return extractor
def train_with_file(self, file_path, size=None):
# load training data
train_data = TaggedRelDataset.from_file(file_path, size=size)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
collate_fn=generate_train_batch)
# initialise BERT
self.net = BertForTokenClassification.from_pretrained(TRAINED_WEIGHTS, num_labels=N_TAGS)
self.net.cuda()
# set up optimizer with weight decay
params = list(self.net.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
grouped_params = [
{'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'weight_decay_rate': DECAY_RATE},
{'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
]
optimiser = AdamW(grouped_params, lr=LEARNING_RATE, eps=EPSILON)
# set up scheduler for lr
n_training_steps = len(train_loader)*N_EPOCHS
scheduler = get_cosine_schedule_with_warmup(
optimiser,
num_warmup_steps=int(WARM_UP_FRAC*n_training_steps),
num_training_steps=n_training_steps
)
start = time.time()
for epoch_idx in range(N_EPOCHS):
self.net.train()
batch_loss = 0.0
for batch_idx, batch in enumerate(train_loader):
# send batch to gpu
texts, target_tags = tuple(i.to(device) for i in batch)
# zero param gradients
self.net.zero_grad()
# forward pass
output_scores = self.net(**texts)[0]
# backward pass
l = biased_loss(output_scores, target_tags)
l.backward()
# clip gradient norm
clip_grad_norm_(parameters=self.net.parameters(), max_norm=MAX_GRAD_NORM)
# optimise
optimiser.step()
# update lr
scheduler.step()
# print interim stats every 20 batches
batch_loss += l.item()
if batch_idx % 20 == 19:
batch_no = batch_idx + 1
print('epoch:', epoch_idx + 1, '--progress: {:.4f}'.format(batch_no / len(train_loader)),
'-- batch:', batch_no, '-- avg loss:', batch_loss / 20)
batch_loss = 0.0
print('epoch done')
end = time.time()
print('Training took', end - start, 'seconds')
torch.save(self.net.state_dict(), trained_model_path)
import torch.nn as nn
from transformers import *
TRAINED_WEIGHTS = 'bert-base-cased' # cased works better for NER
class SingleRelBertNet(nn.Module):
def __init__(self):
super().__init__()
self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS)
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
import pandas as pd
from collections import defaultdict
TRAINED_WEIGHTS = 'bert-base-cased' # cased works better for NER
RELATIONS = ['/location/location/contains']
MAX_SEQ_LEN = 128
MAX_TOKENS = MAX_SEQ_LEN - 2
IGNORE_TAG = -1
tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)
def generate_train_batch(instances):
texts = tokenizer.batch_encode_plus([instance.tokens for instance in instances], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
target_tags = torch.tensor([instance.tags for instance in instances])
return texts, target_tags
# Based on Zheng et. al. (2017)
class TaggedRelDataset(Dataset):
def __init__(self):
self.df = None
self.relations = RELATIONS
@staticmethod
def from_file(path, valid_frac=None, size=None):
dataset = TaggedRelDataset()
dataset.df = pd.read_json(path, lines=True)
# sample data if a size is specified
if size is not None and size < len(dataset):
dataset.df = dataset.df.sample(size, replace=False)
if valid_frac is None:
return dataset
else:
validset = TaggedRelDataset()
split_idx = int(len(dataset) * (1 - valid_frac))
dataset.df, validset.df = dataset.df[:split_idx, :], dataset.df[split_idx:, :]
return dataset, validset
def instance_from_row(self, row):
text = row['sentText']
tokens = tokenizer.tokenize(text)[:MAX_TOKENS]
tag_map = self.map_for_relation_mentions(row['relationMentions'])
sorted_entities = sorted(tag_map.keys(), key=len, reverse=True)
tags = [IGNORE_TAG] # ignore CLS token
i = 0
while i < len(tokens):
found = False
for entity in sorted_entities:
match_length = TaggedRelDataset.token_entity_match(i, entity, tokens)
if match_length is not None:
tags += TaggedRelDataset.ne_tags_for_len(match_length, tag_map[entity])
found = True
i += match_length
break
if not found:
tags += [0]
i += 1
tags += [IGNORE_TAG] * (MAX_SEQ_LEN - len(tags)) # pad to MAX_SEQ_LEN
return TaggedRelInstance.from_tokens(tokens, tags)
# NOTE: if entity is present in more than one relation, only one (the one with the highest count) is mapped
def map_for_relation_mentions(self, relation_mentions):
m = defaultdict(list)
for rm in relation_mentions:
if rm['label'] in self.relations:
tag_base_idx = 1 + 8 * self.relations.index(rm['label']) # idx 0 reserved for tag O
m[rm['em1Text']].append(tag_base_idx)
m[rm['em2Text']].append(tag_base_idx + 4)
return {e: max(set(tags), key=tags.count) for e, tags in m.items()}
@staticmethod
def token_entity_match(first_token_idx, entity, tokens):
token_idx = first_token_idx
remaining_entity = entity
while remaining_entity:
if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity:
# start of new word
remaining_entity = remaining_entity.lstrip()
if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]:
remaining_entity = remaining_entity[len(tokens[token_idx]):]
token_idx += 1
else:
break
else:
# continuing same word
if (token_idx < len(tokens) and tokens[token_idx].startswith('##')
and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]):
remaining_entity = remaining_entity[len(tokens[token_idx][2:]):]
token_idx += 1
else:
break
if remaining_entity:
return None
else:
return token_idx - first_token_idx
@staticmethod
def ne_tags_for_len(n, base):
assert n > 0
return [base + 3] if n == 1 else [base] + [base + 1] * (n - 2) + [base]
def __len__(self):
return len(self.df.index)
def __getitem__(self, idx):
return self.instance_from_row(self.df.iloc[idx])
class TaggedRelInstance:
def __init__(self):
self.tokens = None
self.tags = None
@staticmethod
def from_tokens(tokens, tags):
i = TaggedRelInstance()
i.tokens = tokens
i.tags = tags
return i
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment