import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertConfig
class QualityEstimation(nn.Module):
def __init__(self, hidden_dim):
super(QualityEstimation, self).__init__()
self.hidden_dim = hidden_dim
# Instantiating BERT model object
config = BertConfig()
self.bert = BertModel(config).from_pretrained('bert-base-multilingual-cased')
self.dropout = nn.Dropout(0.25)
# LSTM and classification layers
self.lstm = nn.LSTM(input_size=768, hidden_size=self.hidden_dim,
num_layers=1, batch_first=True,
dropout=0, bidirectional=False)
self.fc1 = nn.Linear(self.hidden_dim, 1)
# self.fc2 = nn.Linear(self.hidden_dim, 1)
# nn.init.kaiming_normal_(self.fc2.weight)
def forward(self, token_ids, segment_ids=None, attention_mask=None):
encoded_layers, _ = self.bert(input_ids=token_ids, token_type_ids=segment_ids, attention_mask=attention_mask)
encoded_layers = self.dropout(encoded_layers)
output, _ = self.lstm(encoded_layers)
# output = torch.tanh(self.fc1(output[:,-1,:]))
qe_scores = self.fc1(output[:,-1,:])
# qe_scores = torch.tanh(qe_scores)
return qe_scores
