Commit c7ab3edc authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Acquire and plot attention values for targets

parent 05c24955
...@@ -7,6 +7,7 @@ from bert_dataset import BertDataset, Instance, polarity_indices, generate_batch ...@@ -7,6 +7,7 @@ from bert_dataset import BertDataset, Instance, polarity_indices, generate_batch
import time import time
import numpy as np import numpy as np
from sklearn import metrics from sklearn import metrics
import matplotlib.pyplot as plt
semeval_2014_train_path = 'data/SemEval-2014/Laptop_Train_v2.xml' semeval_2014_train_path = 'data/SemEval-2014/Laptop_Train_v2.xml'
semeval_2014_test_path = 'data/SemEval-2014/Laptops_Test_Gold.xml' semeval_2014_test_path = 'data/SemEval-2014/Laptops_Test_Gold.xml'
...@@ -93,13 +94,25 @@ class BertAnalyzer: ...@@ -93,13 +94,25 @@ class BertAnalyzer:
def analyze_sentence(self, text, char_from, char_to): def analyze_sentence(self, text, char_from, char_to):
instance = Instance(text, char_from, char_to) instance = Instance(text, char_from, char_to)
_, tg_from, tg_to = instance.get() tokens, tg_from, tg_to = instance.get()
texts, target_indices = instance.to_tensor() texts, target_indices = instance.to_tensor()
with torch.no_grad(): with torch.no_grad():
outputs, attentions = self.net(texts, target_indices) outputs, attentions = self.net(texts, target_indices)
target_attentions = torch.mean(attentions, 1)[0][tg_from+1:tg_to+2] target_attentions = torch.mean(attentions, 1)[0][tg_from+1:tg_to+2]
mean_target_att = torch.mean(target_attentions, 0)
# plot attention histogram
att_values = mean_target_att.numpy()[1:-1]
ax = plt.subplot(111)
width = 0.3
bins = [x - width/2 for x in range(1, len(att_values)+1)]
ax.bar(bins, att_values, width=width)
ax.set_xticks(list(range(1, len(att_values)+1)))
ax.set_xticklabels(tokens, rotation=45, rotation_mode='anchor', ha='right')
plt.show()
_, pred = torch.max(outputs.data, 1) _, pred = torch.max(outputs.data, 1)
return pred return pred
...@@ -107,4 +120,5 @@ class BertAnalyzer: ...@@ -107,4 +120,5 @@ class BertAnalyzer:
sentiment_analyzer = BertAnalyzer() sentiment_analyzer = BertAnalyzer()
sentiment_analyzer.load_saved() sentiment_analyzer.load_saved()
sentiment = sentiment_analyzer.analyze_sentence('I hate this laptop', 12, 18) sentiment = sentiment_analyzer.analyze_sentence('I will never buy another computer from HP/Compaq or do business with Circuit City again.', 39, 48)
print('sentiment:', sentiment)
\ No newline at end of file
...@@ -16,7 +16,6 @@ def generate_batch(batch): ...@@ -16,7 +16,6 @@ def generate_batch(batch):
return_tensors='pt') return_tensors='pt')
max_tg_len = max(entry['to'] - entry['from'] for entry in batch) max_tg_len = max(entry['to'] - entry['from'] for entry in batch)
print(max_tg_len)
target_indices = torch.tensor([[[min(t, entry['to'])] * HIDDEN_OUTPUT_FEATURES target_indices = torch.tensor([[[min(t, entry['to'])] * HIDDEN_OUTPUT_FEATURES
for t in range(entry['from'], entry['from'] + max_tg_len + 1)] for t in range(entry['from'], entry['from'] + max_tg_len + 1)]
for entry in batch]) for entry in batch])
...@@ -27,13 +26,11 @@ def generate_batch(batch): ...@@ -27,13 +26,11 @@ def generate_batch(batch):
def token_for_char(char_idx, text, tokens): def token_for_char(char_idx, text, tokens):
compressed_idx = len(re.sub(r'\s+', '', text)[:char_idx+1]) - 1 compressed_idx = len(re.sub(r'\s+', '', text[:char_idx+1])) - 1
token_idx = -1 token_idx = -1
while compressed_idx >= 0: while compressed_idx >= 0:
token_idx += 1 token_idx += 1
compressed_idx -= len(tokens[token_idx].replace('##', '')) compressed_idx -= len(tokens[token_idx].replace('##', ''))
return token_idx return token_idx
...@@ -80,7 +77,7 @@ class Instance: ...@@ -80,7 +77,7 @@ class Instance:
def get(self): def get(self):
tokens = tokenizer.tokenize(self.text) tokens = tokenizer.tokenize(self.text)
idx_from = token_for_char(self.char_from, self.text, tokens) idx_from = token_for_char(self.char_from, self.text, tokens)
idx_to = token_for_char(self.char_to, self.text, tokens) idx_to = token_for_char(self.char_to-1, self.text, tokens)
return tokens, idx_from, idx_to return tokens, idx_from, idx_to
def to_tensor(self): def to_tensor(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment