import torch import torch.nn as nn from transformers import * HIDDEN_OUTPUT_FEATURES = 768 TRAINED_WEIGHTS = 'bert-base-uncased' NUM_CLASSES = 2 # entity, not entity BATCH_SIZE = 32 class EntityBertNet(nn.Module): def __init__(self): super(EntityBertNet, self).__init__() config = BertConfig.from_pretrained(TRAINED_WEIGHTS) self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config) self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES) def forward(self, input_ids, attn_mask, entity_indices): # BERT bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask) # max pooling at entity locations entity_pooled_output = bert_output[torch.arange(0, bert_output.shape[0]), entity_indices] # fc layer (softmax activation done in loss function) x = self.fc(entity_pooled_output) return x