import torch import torch.nn as nn from transformers import * HIDDEN_OUTPUT_FEATURES = 768 TRAINED_WEIGHTS = 'bert-base-uncased' NUM_CLASSES = 4 # no relation, fst hasFeature snd, snd hasFeature fst, siblings HIDDEN_ENTITY_FEATURES = 6 # lower -> more general but less informative entity representations class PairBertNet(nn.Module): def __init__(self): super(PairBertNet, self).__init__() # self.entity_fc1 = nn.Linear(HIDDEN_OUTPUT_FEATURES, HIDDEN_ENTITY_FEATURES) # self.entity_fc2 = nn.Linear(HIDDEN_ENTITY_FEATURES, HIDDEN_OUTPUT_FEATURES) config = BertConfig.from_pretrained(TRAINED_WEIGHTS) self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config) self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES * 2, NUM_CLASSES) def forward(self, input_ids, attn_mask, fst_indices, snd_indices): # BERT bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask) # max pooling at entity locations fst_pooled_output = PairBertNet.pooled_output(bert_output, fst_indices) snd_pooled_output = PairBertNet.pooled_output(bert_output, snd_indices) # concat pooled outputs from prod and feat entities combined = torch.cat((fst_pooled_output, snd_pooled_output), dim=1) # fc layer (softmax activation done in loss function) x = self.fc(combined) return x @staticmethod def pooled_output(bert_output, indices): outputs = torch.gather(bert_output, dim=1, index=indices) pooled_output, _ = torch.max(outputs, dim=1) return pooled_output