Commit 886f11a6 authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Finished relbertnet (untested)

parent 2e639f30
......@@ -4,6 +4,7 @@ __pycache__/
server/agent/amazon_data/
server/agent/target_extraction/data/
server/agent/target_extraction/stanford-corenlp-full-2018-10-05
server/agent/target_extraction/BERT/data/
.DS_Store
*.pickle
*.wv
\ No newline at end of file
from rel_dataset import RelInstance
from torch.utils.data import DataLoader
from rel_dataset import RelInstance, RelDataset, generate_batch
from relbertnet import RelBertNet
......@@ -6,6 +7,11 @@ class BertExtractor:
pass
dataset = RelDataset.from_texts(['Testing if this works.', 'Since I\'m not sure.'])
loader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=4, collate_fn=generate_batch)
batch = next(iter(loader))
i = RelInstance('Testing if this works.')
net = RelBertNet()
net(i.to_tensor())
\ No newline at end of file
net(batch)
from transformers import BertTokenizer
from torch.utils.data import Dataset
from relbertnet import TRAINED_WEIGHTS
MAX_SEQ_LEN = 128
tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)
def generate_batch(batch):
texts = tokenizer.batch_encode_plus([tokens for tokens in batch], add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
return_tensors='pt')
return texts
class RelDataset(Dataset):
def __init__(self):
self.data = []
@staticmethod
def from_texts(texts):
dataset = RelDataset()
dataset.data = [RelInstance(text) for text in texts]
return dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
instance = self.data[idx]
return instance.get()
class RelInstance:
def __init__(self, text):
......@@ -18,4 +46,4 @@ class RelInstance:
tokens = self.get()
encoded = tokenizer.encode_plus(tokens, add_special_tokens=True, max_length=MAX_SEQ_LEN,
is_pretokenized=True, return_tensors='pt')
return encoded
return encoded
\ No newline at end of file
......@@ -2,12 +2,20 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import *
from torchcrf import CRF
import itertools
K = 4 # number of hidden layers in Bert2
HIDDEN_OUTPUT_FEATURES = 768
MAX_SEQ_LEN = 128
TRAINED_WEIGHTS = 'bert-base-uncased'
NUM_NE_TAGS = 5 # BIEOS 0-4: [Begin Inside End Outside Single]
NUM_RELS = 3 # 0-2: [e1 featureOf e2, e2 featureOf e1, no relation]
MLP_HIDDEN_LAYER_NODES = 84
# Based on Xue et. al. (2019) with some modifications
# Directional tagging scheme from Zheng et. al. (2017)
class RelBertNet(nn.Module):
def __init__(self):
......@@ -15,38 +23,83 @@ class RelBertNet(nn.Module):
# Load pretrained BERT weights
config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
self.bert1 = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
self.bert1.train()
# Divide BERT encoder layers into two parts
self.bert2_layers = self.bert1.encoder.layer[-K:]
self.bert1.encoder.layer = self.bert1.encoder.layer[:-K]
self.n = config.num_hidden_layers
def forward(self, encoded_text):
self.ner_linear = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_NE_TAGS)
self.crf = CRF(NUM_NE_TAGS, batch_first=True)
self.crf.train()
self.mlp1 = nn.Linear(HIDDEN_OUTPUT_FEATURES, MLP_HIDDEN_LAYER_NODES)
self.mlp2 = nn.Linear(MLP_HIDDEN_LAYER_NODES, NUM_RELS)
def forward(self, encoded_text, ne_tags=None):
attn_mask = encoded_text['attention_mask']
# BERT1 with MASKall for context
bert_context_output, _ = self.bert1(**encoded_text)
# BERT2 with MASKall for NER
bert_ner_output = bert_context_output
extended_attn_mask = attn_mask[:, None, None, :]
for layer in self.bert2_layers:
bert_ner_output, = layer(bert_ner_output, attention_mask=None)
# without CLS token
bert_ner_output = bert_ner_output.narrow(1, 1, bert_ner_output.size()[1]-1)
print(bert_ner_output.size())
bert_ner_output, = layer(bert_ner_output, attention_mask=extended_attn_mask)
# CRF for NER
ner_output = None
bert_ner_output = bert_ner_output.narrow(1, 1, bert_ner_output.size()[1] - 1)
crf_attn_mask = attn_mask.narrow(1, 1, attn_mask.size()[1] - 1).type(torch.uint8) # mask out CLS token
emissions = self.ner_linear(bert_ner_output)
ner_output = self.crf.decode(emissions, mask=crf_attn_mask)
# calculate loss if tags provided
ner_loss = -self.crf(emissions, ne_tags, mask=crf_attn_mask) if ne_tags is not None else None
# For each pair of named entities recognized, perform BERT2 with MASKrc for RC
# c = torch.combinations(ner_output)
# pairs = torch.cat((c, torch.flip(c, dims=(1,))), dim=0)
# bert_rc_output = bert_context_output
# for layer in self.bert2_layers:
# bert_rc_output, = layer(bert_rc_output, attention_mask=None)
# obtain pairs of entities
entities_by_instance = [RelBertNet.bieos_to_entities(tags) for tags in ner_output]
combinations_by_instance = [list(itertools.combinations(ent, 2)) for ent in entities_by_instance]
n_combinations_by_instance = torch.tensor([len(combs) for combs in combinations_by_instance])
flat_combinations = [comb for combs in combinations_by_instance for comb in combs]
# if no entity pairs, cannot find relations so return
if not any(n > 2 for n in n_combinations_by_instance) == 0:
return ner_output, None, ner_loss
# for each pair of named entities recognized, perform BERT2 with MASKrc for RC
rc_attn_mask = torch.zeros((len(flat_combinations), MAX_SEQ_LEN), dtype=torch.long)
for i, (slice1, slice2) in enumerate(flat_combinations):
rc_attn_mask[i][0] = 1
rc_attn_mask[i][slice1] = 1
rc_attn_mask[i][slice2] = 1
bert_rc_output = torch.repeat_interleave(bert_context_output, n_combinations_by_instance, dim=0)
extended_rc_attn_mask = rc_attn_mask[:, None, None, :]
for layer in self.bert2_layers:
bert_rc_output, = layer(bert_rc_output, attention_mask=extended_rc_attn_mask)
# MLP for RC
rc_output = None
rc_cls_output = bert_rc_output.narrow(1, 0, 1).squeeze(1) # just CLS token
rc_hidden_layer_output = torch.tanh(self.mlp1(rc_cls_output)) # tanh activation
rc_output = F.softmax(self.mlp2(rc_hidden_layer_output), 1) # softmax activation
# Return NER and RC outputs
return ner_output, rc_output
return ner_output, rc_output, ner_loss
@staticmethod
def bieos_to_entities(tags):
entities = []
b = None
for idx, tag in enumerate(tags):
if tag == 0: # Begin
b = idx
if tag == 2 and b is not None: # End
entities.append(slice(b+1, idx+2)) # +1 comes from CLS token
b = None
if tag == 3: # Outside
b = None
if tag == 4: # Single
entities.append(slice(idx+1, idx+2)) # +1 comes from CLS token
b = None
return entities
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment