relbertnet.py 4.95 KB
Newer Older
1
2
3
4
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import *
Joel Oksanen's avatar
Joel Oksanen committed
5
6
from torchcrf import CRF
import itertools
7
from random import sample
8
9
10

K = 4  # number of hidden layers in Bert2
HIDDEN_OUTPUT_FEATURES = 768
Joel Oksanen's avatar
Joel Oksanen committed
11
MAX_SEQ_LEN = 128
12
TRAINED_WEIGHTS = 'bert-base-cased'  # cased works better for NER
Joel Oksanen's avatar
Joel Oksanen committed
13
NUM_NE_TAGS = 5  # BIEOS 0-4: [Begin Inside End Outside Single]
14
NUM_RELS = 3  # 0-2: [no relation, e1 featureOf e2, e2 featureOf e1]
Joel Oksanen's avatar
Joel Oksanen committed
15
MLP_HIDDEN_LAYER_NODES = 84
16
MAX_ENTITIES_PER_SENTENCE = 8
17
18


Joel Oksanen's avatar
Joel Oksanen committed
19
20
# Based on Xue et. al. (2019) with some modifications
# Directional tagging scheme from Zheng et. al. (2017)
21
22
23
24
25
26
27
class RelBertNet(nn.Module):

    def __init__(self):
        super(RelBertNet, self).__init__()
        # Load pretrained BERT weights
        config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
        self.bert1 = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
Joel Oksanen's avatar
Joel Oksanen committed
28
        self.bert1.train()
29
30
31
32
33
34

        # Divide BERT encoder layers into two parts
        self.bert2_layers = self.bert1.encoder.layer[-K:]
        self.bert1.encoder.layer = self.bert1.encoder.layer[:-K]
        self.n = config.num_hidden_layers

Joel Oksanen's avatar
Joel Oksanen committed
35
36
37
38
39
40
41
        self.ner_linear = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_NE_TAGS)
        self.crf = CRF(NUM_NE_TAGS, batch_first=True)
        self.crf.train()

        self.mlp1 = nn.Linear(HIDDEN_OUTPUT_FEATURES, MLP_HIDDEN_LAYER_NODES)
        self.mlp2 = nn.Linear(MLP_HIDDEN_LAYER_NODES, NUM_RELS)

42
    def forward(self, encoded_text, ner_tags=None):
Joel Oksanen's avatar
Joel Oksanen committed
43
44
        attn_mask = encoded_text['attention_mask']

45
46
47
48
49
        # BERT1 with MASKall for context
        bert_context_output, _ = self.bert1(**encoded_text)

        # BERT2 with MASKall for NER
        bert_ner_output = bert_context_output
Joel Oksanen's avatar
Joel Oksanen committed
50
        extended_attn_mask = attn_mask[:, None, None, :]
51
        for layer in self.bert2_layers:
Joel Oksanen's avatar
Joel Oksanen committed
52
            bert_ner_output, = layer(bert_ner_output, attention_mask=extended_attn_mask)
53
54

        # CRF for NER
55
56
        bert_ner_output = bert_ner_output.narrow(1, 1, MAX_SEQ_LEN - 2)  # remove CLS and last token
        crf_attn_mask = attn_mask.narrow(1, 2, attn_mask.size()[1] - 2).type(torch.uint8)  # mask out SEP token
Joel Oksanen's avatar
Joel Oksanen committed
57
58
59
        emissions = self.ner_linear(bert_ner_output)
        ner_output = self.crf.decode(emissions, mask=crf_attn_mask)
        # calculate loss if tags provided
60
        ner_loss = None if ner_tags is None else -self.crf(emissions, ner_tags, mask=crf_attn_mask, reduction='mean')
61

Joel Oksanen's avatar
Joel Oksanen committed
62
63
64
65
66
67
68
        # obtain pairs of entities
        entities_by_instance = [RelBertNet.bieos_to_entities(tags) for tags in ner_output]
        combinations_by_instance = [list(itertools.combinations(ent, 2)) for ent in entities_by_instance]
        n_combinations_by_instance = torch.tensor([len(combs) for combs in combinations_by_instance])
        flat_combinations = [comb for combs in combinations_by_instance for comb in combs]

        # if no entity pairs, cannot find relations so return
69
70
        if not any(n > 2 for n in n_combinations_by_instance):
            return ner_output, ner_loss, None, None, None
Joel Oksanen's avatar
Joel Oksanen committed
71
72
73

        # for each pair of named entities recognized, perform BERT2 with MASKrc for RC
        rc_attn_mask = torch.zeros((len(flat_combinations), MAX_SEQ_LEN), dtype=torch.long)
74
        entity_ranges = torch.zeros((len(flat_combinations), 2, 2), dtype=torch.long)
Joel Oksanen's avatar
Joel Oksanen committed
75
76
77
78
        for i, (slice1, slice2) in enumerate(flat_combinations):
            rc_attn_mask[i][0] = 1
            rc_attn_mask[i][slice1] = 1
            rc_attn_mask[i][slice2] = 1
79
            entity_ranges[i] = torch.tensor([[slice1.start, slice1.stop], [slice2.start, slice2.stop]]) - 1
Joel Oksanen's avatar
Joel Oksanen committed
80
81
82
83
84

        bert_rc_output = torch.repeat_interleave(bert_context_output, n_combinations_by_instance, dim=0)
        extended_rc_attn_mask = rc_attn_mask[:, None, None, :]
        for layer in self.bert2_layers:
            bert_rc_output, = layer(bert_rc_output, attention_mask=extended_rc_attn_mask)
85
86

        # MLP for RC
Joel Oksanen's avatar
Joel Oksanen committed
87
88
        rc_cls_output = bert_rc_output.narrow(1, 0, 1).squeeze(1)  # just CLS token
        rc_hidden_layer_output = torch.tanh(self.mlp1(rc_cls_output))  # tanh activation
89
90
        rc_output = self.mlp2(rc_hidden_layer_output)  # softmax activation

91
        # Return NER and RC outputs
92
        return ner_output, ner_loss, rc_output, n_combinations_by_instance, entity_ranges
93

Joel Oksanen's avatar
Joel Oksanen committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    @staticmethod
    def bieos_to_entities(tags):
        entities = []
        b = None
        for idx, tag in enumerate(tags):
            if tag == 0:  # Begin
                b = idx
            if tag == 2 and b is not None:  # End
                entities.append(slice(b+1, idx+2))  # +1 comes from CLS token
                b = None
            if tag == 3:  # Outside
                b = None
            if tag == 4:  # Single
                entities.append(slice(idx+1, idx+2))  # +1 comes from CLS token
                b = None
109
110
111

        # take at max MAX_ENTITIES per instace in order not to overwhelm RC at the start when NER is warming up
        return entities if len(entities) < MAX_ENTITIES_PER_SENTENCE else sample(entities, MAX_ENTITIES_PER_SENTENCE)