entitybertnet.py 926 Bytes
Newer Older
1 2 3 4 5 6 7
import torch
import torch.nn as nn
from transformers import *

HIDDEN_OUTPUT_FEATURES = 768
TRAINED_WEIGHTS = 'bert-base-uncased'
NUM_CLASSES = 2  # entity, not entity
8
BATCH_SIZE = 32
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23


class EntityBertNet(nn.Module):

    def __init__(self):
        super(EntityBertNet, self).__init__()
        config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
        self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
        self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES)

    def forward(self, input_ids, attn_mask, entity_indices):
        # BERT
        bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask)

        # max pooling at entity locations
24
        entity_pooled_output = bert_output[torch.arange(0, bert_output.shape[0]), entity_indices]
25 26 27 28 29

        # fc layer (softmax activation done in loss function)
        x = self.fc(entity_pooled_output)
        return x