Commit 226a21f0 authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Attention outputs for BERT

parent c7ab3edc
......@@ -8,11 +8,12 @@ import time
import numpy as np
from sklearn import metrics
import matplotlib.pyplot as plt
import shap
semeval_2014_train_path = 'data/SemEval-2014/Laptop_Train_v2.xml'
semeval_2014_test_path = 'data/SemEval-2014/Laptops_Test_Gold.xml'
amazon_test_path = 'data/Amazon/amazon_camera_test.xml'
trained_model_path = 'semeval_2014.pt'
trained_model_path = 'semeval_2014_2.pt'
BATCH_SIZE = 32
MAX_EPOCHS = 6
......@@ -25,9 +26,9 @@ def loss(outputs, labels):
class BertAnalyzer:
def load_saved(self):
def load_saved(self, path):
self.net = TDBertNet(len(polarity_indices))
self.net.load_state_dict(torch.load(trained_model_path))
self.net.load_state_dict(torch.load(path))
self.net.eval()
def train(self, dataset):
......@@ -36,7 +37,7 @@ class BertAnalyzer:
collate_fn=generate_batch)
self.net = TDBertNet(len(polarity_indices))
optimiser = optim.Adam(net.parameters(), lr=LEARNING_RATE)
optimiser = optim.Adam(self.net.parameters(), lr=LEARNING_RATE)
start = time.time()
......@@ -65,7 +66,7 @@ class BertAnalyzer:
end = time.time()
print('Training took', end - start, 'seconds')
torch.save(net.state_dict(), trained_model_path)
torch.save(self.net.state_dict(), trained_model_path)
def evaluate(self, dataset):
test_data = BertDataset(dataset)
......@@ -95,30 +96,30 @@ class BertAnalyzer:
def analyze_sentence(self, text, char_from, char_to):
instance = Instance(text, char_from, char_to)
tokens, tg_from, tg_to = instance.get()
texts, target_indices = instance.to_tensor()
text, target_indices = instance.to_tensor()
with torch.no_grad():
outputs, attentions = self.net(texts, target_indices)
target_attentions = torch.mean(attentions, 1)[0][tg_from+1:tg_to+2]
mean_target_att = torch.mean(target_attentions, 0)
# plot attention histogram
att_values = mean_target_att.numpy()[1:-1]
ax = plt.subplot(111)
width = 0.3
bins = [x - width/2 for x in range(1, len(att_values)+1)]
ax.bar(bins, att_values, width=width)
ax.set_xticks(list(range(1, len(att_values)+1)))
ax.set_xticklabels(tokens, rotation=45, rotation_mode='anchor', ha='right')
plt.show()
outputs, attentions = self.net(text, target_indices)
# attention_heads = attentions[0]
# num_heads = len(attention_heads)
# ax = plt.subplot(111)
# token_width = 1
# head_width = token_width / num_heads
# for i, head in enumerate(attention_heads):
# # plot attention histogram
# att_values = torch.mean(head[tg_from+1:tg_to+2], 0)[1:-1].numpy()
#
# bins = [x - token_width / 2 + i * head_width for x in range(1, len(att_values) + 1)]
# ax.bar(bins, att_values, width=head_width)
# ax.set_xticks(list(range(1, len(att_values) + 1)))
# ax.set_xticklabels(tokens, rotation=45, rotation_mode='anchor', ha='right')
# plt.show()
_, pred = torch.max(outputs.data, 1)
return pred
sentiment_analyzer = BertAnalyzer()
sentiment_analyzer.load_saved()
sentiment = sentiment_analyzer.analyze_sentence('I will never buy another computer from HP/Compaq or do business with Circuit City again.', 39, 48)
print('sentiment:', sentiment)
\ No newline at end of file
sentiment_analyzer.load_saved('semeval_2014.pt')
print(sentiment_analyzer.analyze_sentence("Well built laptop with win7.", 11, 17))
\ No newline at end of file
......@@ -51,7 +51,7 @@ class BertDataset(Dataset):
if aspect_terms:
for term in aspect_terms:
char_from = int(term.attrib['from'])
char_to = int(term.attrib['to']) - 1
char_to = int(term.attrib['to'])
polarity = term.attrib['polarity']
self.data.append((Instance(text, char_from, char_to), polarity))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment