Skip to content
Snippets Groups Projects
Commit 79c7877c authored by Park, Se's avatar Park, Se
Browse files

Delete model.py

parent 90c0fdfe
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertConfig
class QualityEstimation(nn.Module):
def __init__(self, hidden_dim):
super(QualityEstimation, self).__init__()
self.hidden_dim = hidden_dim
# Instantiating BERT model object
config = BertConfig()
self.bert = BertModel(config).from_pretrained('bert-base-multilingual-cased')
self.dropout = nn.Dropout(0.25)
# LSTM and classification layers
self.lstm = nn.LSTM(input_size=768, hidden_size=self.hidden_dim,
num_layers=1, batch_first=True,
dropout=0, bidirectional=False)
self.fc1 = nn.Linear(self.hidden_dim, 1)
nn.init.kaiming_normal_(self.fc1.weight)
# self.fc2 = nn.Linear(self.hidden_dim, 1)
# nn.init.kaiming_normal_(self.fc2.weight)
def forward(self, token_ids, segment_ids=None, attention_mask=None):
encoded_layers, _ = self.bert(input_ids=token_ids, token_type_ids=segment_ids, attention_mask=attention_mask)
encoded_layers = self.dropout(encoded_layers)
output, _ = self.lstm(encoded_layers)
# output = torch.tanh(self.fc1(output[:,-1,:]))
qe_scores = self.fc1(output[:,-1,:])
# qe_scores = torch.tanh(qe_scores)
return qe_scores
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment