Commit 226a21f0 authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Attention outputs for BERT

parent c7ab3edc
...@@ -8,11 +8,12 @@ import time ...@@ -8,11 +8,12 @@ import time
import numpy as np import numpy as np
from sklearn import metrics from sklearn import metrics
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import shap
semeval_2014_train_path = 'data/SemEval-2014/Laptop_Train_v2.xml' semeval_2014_train_path = 'data/SemEval-2014/Laptop_Train_v2.xml'
semeval_2014_test_path = 'data/SemEval-2014/Laptops_Test_Gold.xml' semeval_2014_test_path = 'data/SemEval-2014/Laptops_Test_Gold.xml'
amazon_test_path = 'data/Amazon/amazon_camera_test.xml' amazon_test_path = 'data/Amazon/amazon_camera_test.xml'
trained_model_path = 'semeval_2014.pt' trained_model_path = 'semeval_2014_2.pt'
BATCH_SIZE = 32 BATCH_SIZE = 32
MAX_EPOCHS = 6 MAX_EPOCHS = 6
...@@ -25,9 +26,9 @@ def loss(outputs, labels): ...@@ -25,9 +26,9 @@ def loss(outputs, labels):
class BertAnalyzer: class BertAnalyzer:
def load_saved(self): def load_saved(self, path):
self.net = TDBertNet(len(polarity_indices)) self.net = TDBertNet(len(polarity_indices))
self.net.load_state_dict(torch.load(trained_model_path)) self.net.load_state_dict(torch.load(path))
self.net.eval() self.net.eval()
def train(self, dataset): def train(self, dataset):
...@@ -36,7 +37,7 @@ class BertAnalyzer: ...@@ -36,7 +37,7 @@ class BertAnalyzer:
collate_fn=generate_batch) collate_fn=generate_batch)
self.net = TDBertNet(len(polarity_indices)) self.net = TDBertNet(len(polarity_indices))
optimiser = optim.Adam(net.parameters(), lr=LEARNING_RATE) optimiser = optim.Adam(self.net.parameters(), lr=LEARNING_RATE)
start = time.time() start = time.time()
...@@ -65,7 +66,7 @@ class BertAnalyzer: ...@@ -65,7 +66,7 @@ class BertAnalyzer:
end = time.time() end = time.time()
print('Training took', end - start, 'seconds') print('Training took', end - start, 'seconds')
torch.save(net.state_dict(), trained_model_path) torch.save(self.net.state_dict(), trained_model_path)
def evaluate(self, dataset): def evaluate(self, dataset):
test_data = BertDataset(dataset) test_data = BertDataset(dataset)
...@@ -95,30 +96,30 @@ class BertAnalyzer: ...@@ -95,30 +96,30 @@ class BertAnalyzer:
def analyze_sentence(self, text, char_from, char_to): def analyze_sentence(self, text, char_from, char_to):
instance = Instance(text, char_from, char_to) instance = Instance(text, char_from, char_to)
tokens, tg_from, tg_to = instance.get() tokens, tg_from, tg_to = instance.get()
texts, target_indices = instance.to_tensor() text, target_indices = instance.to_tensor()
with torch.no_grad(): with torch.no_grad():
outputs, attentions = self.net(texts, target_indices) outputs, attentions = self.net(text, target_indices)
target_attentions = torch.mean(attentions, 1)[0][tg_from+1:tg_to+2] # attention_heads = attentions[0]
mean_target_att = torch.mean(target_attentions, 0) # num_heads = len(attention_heads)
# ax = plt.subplot(111)
# plot attention histogram # token_width = 1
att_values = mean_target_att.numpy()[1:-1] # head_width = token_width / num_heads
# for i, head in enumerate(attention_heads):
ax = plt.subplot(111) # # plot attention histogram
width = 0.3 # att_values = torch.mean(head[tg_from+1:tg_to+2], 0)[1:-1].numpy()
bins = [x - width/2 for x in range(1, len(att_values)+1)] #
ax.bar(bins, att_values, width=width) # bins = [x - token_width / 2 + i * head_width for x in range(1, len(att_values) + 1)]
ax.set_xticks(list(range(1, len(att_values)+1))) # ax.bar(bins, att_values, width=head_width)
ax.set_xticklabels(tokens, rotation=45, rotation_mode='anchor', ha='right') # ax.set_xticks(list(range(1, len(att_values) + 1)))
plt.show() # ax.set_xticklabels(tokens, rotation=45, rotation_mode='anchor', ha='right')
# plt.show()
_, pred = torch.max(outputs.data, 1) _, pred = torch.max(outputs.data, 1)
return pred return pred
sentiment_analyzer = BertAnalyzer() sentiment_analyzer = BertAnalyzer()
sentiment_analyzer.load_saved() sentiment_analyzer.load_saved('semeval_2014.pt')
sentiment = sentiment_analyzer.analyze_sentence('I will never buy another computer from HP/Compaq or do business with Circuit City again.', 39, 48) print(sentiment_analyzer.analyze_sentence("Well built laptop with win7.", 11, 17))
print('sentiment:', sentiment) \ No newline at end of file
\ No newline at end of file
...@@ -51,7 +51,7 @@ class BertDataset(Dataset): ...@@ -51,7 +51,7 @@ class BertDataset(Dataset):
if aspect_terms: if aspect_terms:
for term in aspect_terms: for term in aspect_terms:
char_from = int(term.attrib['from']) char_from = int(term.attrib['from'])
char_to = int(term.attrib['to']) - 1 char_to = int(term.attrib['to'])
polarity = term.attrib['polarity'] polarity = term.attrib['polarity']
self.data.append((Instance(text, char_from, char_to), polarity)) self.data.append((Instance(text, char_from, char_to), polarity))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment