Commit 174b15e4 authored by  Joel  Oksanen's avatar Joel Oksanen
Browse files
parents 6986b0d8 c7ab3edc
...@@ -3,13 +3,15 @@ import torch.nn as nn ...@@ -3,13 +3,15 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tdbertnet import TDBertNet from tdbertnet import TDBertNet
from bert_dataset import BertDataset, polarity_indices, generate_batch from bert_dataset import BertDataset, Instance, polarity_indices, generate_batch
import time import time
import numpy as np import numpy as np
from sklearn import metrics from sklearn import metrics
import matplotlib.pyplot as plt
semeval_2014_train_path = 'data/SemEval-2014/Laptop_Train_v2.xml' semeval_2014_train_path = 'data/SemEval-2014/Laptop_Train_v2.xml'
semeval_2014_test_path = 'data/SemEval-2014/Laptops_Test_Gold.xml' semeval_2014_test_path = 'data/SemEval-2014/Laptops_Test_Gold.xml'
amazon_test_path = 'data/Amazon/amazon_camera_test.xml'
trained_model_path = 'semeval_2014.pt' trained_model_path = 'semeval_2014.pt'
BATCH_SIZE = 32 BATCH_SIZE = 32
...@@ -45,7 +47,7 @@ class BertAnalyzer: ...@@ -45,7 +47,7 @@ class BertAnalyzer:
optimiser.zero_grad() optimiser.zero_grad()
# forward pass # forward pass
outputs = self.net(texts, target_indices) outputs, _ = self.net(texts, target_indices)
# backward pass # backward pass
l = loss(outputs, labels) l = loss(outputs, labels)
...@@ -74,7 +76,7 @@ class BertAnalyzer: ...@@ -74,7 +76,7 @@ class BertAnalyzer:
truths = [] truths = []
with torch.no_grad(): with torch.no_grad():
for (texts, target_indices, labels) in test_loader: for (texts, target_indices, labels) in test_loader:
outputs = self.net(texts, target_indices) outputs, attentions = self.net(texts, target_indices)
_, pred = torch.max(outputs.data, 1) _, pred = torch.max(outputs.data, 1)
predicted += pred.tolist() predicted += pred.tolist()
truths += labels.tolist() truths += labels.tolist()
...@@ -90,7 +92,33 @@ class BertAnalyzer: ...@@ -90,7 +92,33 @@ class BertAnalyzer:
f1 = metrics.f1_score(truths, predicted, labels=range(len(polarity_indices)), average='macro') f1 = metrics.f1_score(truths, predicted, labels=range(len(polarity_indices)), average='macro')
print('macro F1:', f1) print('macro F1:', f1)
def analyze_sentence(self, text, char_from, char_to):
instance = Instance(text, char_from, char_to)
tokens, tg_from, tg_to = instance.get()
texts, target_indices = instance.to_tensor()
with torch.no_grad():
outputs, attentions = self.net(texts, target_indices)
target_attentions = torch.mean(attentions, 1)[0][tg_from+1:tg_to+2]
mean_target_att = torch.mean(target_attentions, 0)
# plot attention histogram
att_values = mean_target_att.numpy()[1:-1]
ax = plt.subplot(111)
width = 0.3
bins = [x - width/2 for x in range(1, len(att_values)+1)]
ax.bar(bins, att_values, width=width)
ax.set_xticks(list(range(1, len(att_values)+1)))
ax.set_xticklabels(tokens, rotation=45, rotation_mode='anchor', ha='right')
plt.show()
_, pred = torch.max(outputs.data, 1)
return pred
sentiment_analyzer = BertAnalyzer() sentiment_analyzer = BertAnalyzer()
sentiment_analyzer.load_saved() sentiment_analyzer.load_saved()
sentiment_analyzer.evaluate(semeval_2014_test_path) sentiment = sentiment_analyzer.analyze_sentence('I will never buy another computer from HP/Compaq or do business with Circuit City again.', 39, 48)
\ No newline at end of file print('sentiment:', sentiment)
\ No newline at end of file
...@@ -26,13 +26,11 @@ def generate_batch(batch): ...@@ -26,13 +26,11 @@ def generate_batch(batch):
def token_for_char(char_idx, text, tokens): def token_for_char(char_idx, text, tokens):
compressed_idx = len(re.sub(r'\s+', '', text)[:char_idx+1]) - 1 compressed_idx = len(re.sub(r'\s+', '', text[:char_idx+1])) - 1
token_idx = -1 token_idx = -1
while compressed_idx >= 0: while compressed_idx >= 0:
token_idx += 1 token_idx += 1
compressed_idx -= len(tokens[token_idx].replace('##', '')) compressed_idx -= len(tokens[token_idx].replace('##', ''))
return token_idx return token_idx
...@@ -55,17 +53,36 @@ class BertDataset(Dataset): ...@@ -55,17 +53,36 @@ class BertDataset(Dataset):
char_from = int(term.attrib['from']) char_from = int(term.attrib['from'])
char_to = int(term.attrib['to']) - 1 char_to = int(term.attrib['to']) - 1
polarity = term.attrib['polarity'] polarity = term.attrib['polarity']
self.data.append((text, char_from, char_to, polarity)) self.data.append((Instance(text, char_from, char_to), polarity))
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
def __getitem__(self, idx): def __getitem__(self, idx):
text, char_from, char_to, polarity_str = self.data[idx] instance, polarity_str = self.data[idx]
tokens = tokenizer.tokenize(text) tokens, idx_from, idx_to = instance.get()
idx_from = token_for_char(char_from, text, tokens)
idx_to = token_for_char(char_to, text, tokens)
polarity = polarity_index(polarity_str) polarity = polarity_index(polarity_str)
return {'tokens': tokens, 'from': idx_from, 'to': idx_to, 'polarity': polarity} return {'tokens': tokens, 'from': idx_from, 'to': idx_to, 'polarity': polarity}
class Instance:
def __init__(self, text, char_from, char_to):
self.text = text
self.char_from = char_from
self.char_to = char_to
def get(self):
tokens = tokenizer.tokenize(self.text)
idx_from = token_for_char(self.char_from, self.text, tokens)
idx_to = token_for_char(self.char_to-1, self.text, tokens)
return tokens, idx_from, idx_to
def to_tensor(self):
tokens, idx_from, idx_to = self.get()
text = tokenizer.encode_plus(tokens, add_special_tokens=True, max_length=MAX_SEQ_LEN,
is_pretokenized=True, return_tensors='pt')
target_indices = torch.tensor([[[t] * HIDDEN_OUTPUT_FEATURES for t in range(idx_from, idx_to + 1)]])
return text, target_indices
...@@ -6,19 +6,23 @@ from transformers import * ...@@ -6,19 +6,23 @@ from transformers import *
HIDDEN_OUTPUT_FEATURES = 768 HIDDEN_OUTPUT_FEATURES = 768
TRAINED_WEIGHTS = 'bert-base-uncased' TRAINED_WEIGHTS = 'bert-base-uncased'
class TDBertNet(nn.Module): class TDBertNet(nn.Module):
def __init__(self, num_class): def __init__(self, num_class):
super(TDBertNet, self).__init__() super(TDBertNet, self).__init__()
self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS) config = BertConfig.from_pretrained(TRAINED_WEIGHTS, output_attentions=True)
self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
self.bert_base.config.output_attentions = True
self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, num_class) # n of hidden features, n of output labels self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, num_class) # n of hidden features, n of output labels
def forward(self, texts, target_indices): def forward(self, texts, target_indices):
# BERT # BERT
bert_output = self.bert_base(**texts)[0] bert_output, _, attentions = self.bert_base(**texts)
# max pooling at target locations # max pooling at target locations
target_outputs = torch.gather(bert_output, dim=1, index=target_indices) target_outputs = torch.gather(bert_output, dim=1, index=target_indices)
pooled_output = torch.max(target_outputs, dim=1)[0] pooled_output = torch.max(target_outputs, dim=1)[0]
# fc layer # fc layer
x = self.fc(pooled_output) x = self.fc(pooled_output)
return x return x, attentions[-1]
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment