Skip to content
Snippets Groups Projects
Commit c7ab3edc authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Acquire and plot attention values for targets

parent 05c24955
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@ from bert_dataset import BertDataset, Instance, polarity_indices, generate_batch
import time
import numpy as np
from sklearn import metrics
import matplotlib.pyplot as plt
semeval_2014_train_path = 'data/SemEval-2014/Laptop_Train_v2.xml'
semeval_2014_test_path = 'data/SemEval-2014/Laptops_Test_Gold.xml'
......@@ -93,13 +94,25 @@ class BertAnalyzer:
def analyze_sentence(self, text, char_from, char_to):
instance = Instance(text, char_from, char_to)
_, tg_from, tg_to = instance.get()
tokens, tg_from, tg_to = instance.get()
texts, target_indices = instance.to_tensor()
with torch.no_grad():
outputs, attentions = self.net(texts, target_indices)
target_attentions = torch.mean(attentions, 1)[0][tg_from+1:tg_to+2]
mean_target_att = torch.mean(target_attentions, 0)
# plot attention histogram
att_values = mean_target_att.numpy()[1:-1]
ax = plt.subplot(111)
width = 0.3
bins = [x - width/2 for x in range(1, len(att_values)+1)]
ax.bar(bins, att_values, width=width)
ax.set_xticks(list(range(1, len(att_values)+1)))
ax.set_xticklabels(tokens, rotation=45, rotation_mode='anchor', ha='right')
plt.show()
_, pred = torch.max(outputs.data, 1)
return pred
......@@ -107,4 +120,5 @@ class BertAnalyzer:
sentiment_analyzer = BertAnalyzer()
sentiment_analyzer.load_saved()
sentiment = sentiment_analyzer.analyze_sentence('I hate this laptop', 12, 18)
sentiment = sentiment_analyzer.analyze_sentence('I will never buy another computer from HP/Compaq or do business with Circuit City again.', 39, 48)
print('sentiment:', sentiment)
\ No newline at end of file
......@@ -16,7 +16,6 @@ def generate_batch(batch):
return_tensors='pt')
max_tg_len = max(entry['to'] - entry['from'] for entry in batch)
print(max_tg_len)
target_indices = torch.tensor([[[min(t, entry['to'])] * HIDDEN_OUTPUT_FEATURES
for t in range(entry['from'], entry['from'] + max_tg_len + 1)]
for entry in batch])
......@@ -27,13 +26,11 @@ def generate_batch(batch):
def token_for_char(char_idx, text, tokens):
compressed_idx = len(re.sub(r'\s+', '', text)[:char_idx+1]) - 1
compressed_idx = len(re.sub(r'\s+', '', text[:char_idx+1])) - 1
token_idx = -1
while compressed_idx >= 0:
token_idx += 1
compressed_idx -= len(tokens[token_idx].replace('##', ''))
return token_idx
......@@ -80,7 +77,7 @@ class Instance:
def get(self):
tokens = tokenizer.tokenize(self.text)
idx_from = token_for_char(self.char_from, self.text, tokens)
idx_to = token_for_char(self.char_to, self.text, tokens)
idx_to = token_for_char(self.char_to-1, self.text, tokens)
return tokens, idx_from, idx_to
def to_tensor(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment