Commit c7ab3edc authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Acquire and plot attention values for targets

parent 05c24955
......@@ -7,6 +7,7 @@ from bert_dataset import BertDataset, Instance, polarity_indices, generate_batch
import time
import numpy as np
from sklearn import metrics
import matplotlib.pyplot as plt
semeval_2014_train_path = 'data/SemEval-2014/Laptop_Train_v2.xml'
semeval_2014_test_path = 'data/SemEval-2014/Laptops_Test_Gold.xml'
......@@ -93,13 +94,25 @@ class BertAnalyzer:
def analyze_sentence(self, text, char_from, char_to):
instance = Instance(text, char_from, char_to)
_, tg_from, tg_to = instance.get()
tokens, tg_from, tg_to = instance.get()
texts, target_indices = instance.to_tensor()
with torch.no_grad():
outputs, attentions = self.net(texts, target_indices)
target_attentions = torch.mean(attentions, 1)[0][tg_from+1:tg_to+2]
mean_target_att = torch.mean(target_attentions, 0)
# plot attention histogram
att_values = mean_target_att.numpy()[1:-1]
ax = plt.subplot(111)
width = 0.3
bins = [x - width/2 for x in range(1, len(att_values)+1)]
ax.bar(bins, att_values, width=width)
ax.set_xticks(list(range(1, len(att_values)+1)))
ax.set_xticklabels(tokens, rotation=45, rotation_mode='anchor', ha='right')
plt.show()
_, pred = torch.max(outputs.data, 1)
return pred
......@@ -107,4 +120,5 @@ class BertAnalyzer:
sentiment_analyzer = BertAnalyzer()
sentiment_analyzer.load_saved()
sentiment = sentiment_analyzer.analyze_sentence('I hate this laptop', 12, 18)
sentiment = sentiment_analyzer.analyze_sentence('I will never buy another computer from HP/Compaq or do business with Circuit City again.', 39, 48)
print('sentiment:', sentiment)
\ No newline at end of file
......@@ -16,7 +16,6 @@ def generate_batch(batch):
return_tensors='pt')
max_tg_len = max(entry['to'] - entry['from'] for entry in batch)
print(max_tg_len)
target_indices = torch.tensor([[[min(t, entry['to'])] * HIDDEN_OUTPUT_FEATURES
for t in range(entry['from'], entry['from'] + max_tg_len + 1)]
for entry in batch])
......@@ -27,13 +26,11 @@ def generate_batch(batch):
def token_for_char(char_idx, text, tokens):
compressed_idx = len(re.sub(r'\s+', '', text)[:char_idx+1]) - 1
compressed_idx = len(re.sub(r'\s+', '', text[:char_idx+1])) - 1
token_idx = -1
while compressed_idx >= 0:
token_idx += 1
compressed_idx -= len(tokens[token_idx].replace('##', ''))
return token_idx
......@@ -80,7 +77,7 @@ class Instance:
def get(self):
tokens = tokenizer.tokenize(self.text)
idx_from = token_for_char(self.char_from, self.text, tokens)
idx_to = token_for_char(self.char_to, self.text, tokens)
idx_to = token_for_char(self.char_to-1, self.text, tokens)
return tokens, idx_from, idx_to
def to_tensor(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment