Skip to content
Snippets Groups Projects
Commit c7ab3edc authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Acquire and plot attention values for targets

parent 05c24955
No related branches found
No related tags found
No related merge requests found
...@@ -7,6 +7,7 @@ from bert_dataset import BertDataset, Instance, polarity_indices, generate_batch ...@@ -7,6 +7,7 @@ from bert_dataset import BertDataset, Instance, polarity_indices, generate_batch
import time import time
import numpy as np import numpy as np
from sklearn import metrics from sklearn import metrics
import matplotlib.pyplot as plt
semeval_2014_train_path = 'data/SemEval-2014/Laptop_Train_v2.xml' semeval_2014_train_path = 'data/SemEval-2014/Laptop_Train_v2.xml'
semeval_2014_test_path = 'data/SemEval-2014/Laptops_Test_Gold.xml' semeval_2014_test_path = 'data/SemEval-2014/Laptops_Test_Gold.xml'
...@@ -93,13 +94,25 @@ class BertAnalyzer: ...@@ -93,13 +94,25 @@ class BertAnalyzer:
def analyze_sentence(self, text, char_from, char_to): def analyze_sentence(self, text, char_from, char_to):
instance = Instance(text, char_from, char_to) instance = Instance(text, char_from, char_to)
_, tg_from, tg_to = instance.get() tokens, tg_from, tg_to = instance.get()
texts, target_indices = instance.to_tensor() texts, target_indices = instance.to_tensor()
with torch.no_grad(): with torch.no_grad():
outputs, attentions = self.net(texts, target_indices) outputs, attentions = self.net(texts, target_indices)
target_attentions = torch.mean(attentions, 1)[0][tg_from+1:tg_to+2] target_attentions = torch.mean(attentions, 1)[0][tg_from+1:tg_to+2]
mean_target_att = torch.mean(target_attentions, 0)
# plot attention histogram
att_values = mean_target_att.numpy()[1:-1]
ax = plt.subplot(111)
width = 0.3
bins = [x - width/2 for x in range(1, len(att_values)+1)]
ax.bar(bins, att_values, width=width)
ax.set_xticks(list(range(1, len(att_values)+1)))
ax.set_xticklabels(tokens, rotation=45, rotation_mode='anchor', ha='right')
plt.show()
_, pred = torch.max(outputs.data, 1) _, pred = torch.max(outputs.data, 1)
return pred return pred
...@@ -107,4 +120,5 @@ class BertAnalyzer: ...@@ -107,4 +120,5 @@ class BertAnalyzer:
sentiment_analyzer = BertAnalyzer() sentiment_analyzer = BertAnalyzer()
sentiment_analyzer.load_saved() sentiment_analyzer.load_saved()
sentiment = sentiment_analyzer.analyze_sentence('I hate this laptop', 12, 18) sentiment = sentiment_analyzer.analyze_sentence('I will never buy another computer from HP/Compaq or do business with Circuit City again.', 39, 48)
print('sentiment:', sentiment)
\ No newline at end of file
...@@ -16,7 +16,6 @@ def generate_batch(batch): ...@@ -16,7 +16,6 @@ def generate_batch(batch):
return_tensors='pt') return_tensors='pt')
max_tg_len = max(entry['to'] - entry['from'] for entry in batch) max_tg_len = max(entry['to'] - entry['from'] for entry in batch)
print(max_tg_len)
target_indices = torch.tensor([[[min(t, entry['to'])] * HIDDEN_OUTPUT_FEATURES target_indices = torch.tensor([[[min(t, entry['to'])] * HIDDEN_OUTPUT_FEATURES
for t in range(entry['from'], entry['from'] + max_tg_len + 1)] for t in range(entry['from'], entry['from'] + max_tg_len + 1)]
for entry in batch]) for entry in batch])
...@@ -27,13 +26,11 @@ def generate_batch(batch): ...@@ -27,13 +26,11 @@ def generate_batch(batch):
def token_for_char(char_idx, text, tokens): def token_for_char(char_idx, text, tokens):
compressed_idx = len(re.sub(r'\s+', '', text)[:char_idx+1]) - 1 compressed_idx = len(re.sub(r'\s+', '', text[:char_idx+1])) - 1
token_idx = -1 token_idx = -1
while compressed_idx >= 0: while compressed_idx >= 0:
token_idx += 1 token_idx += 1
compressed_idx -= len(tokens[token_idx].replace('##', '')) compressed_idx -= len(tokens[token_idx].replace('##', ''))
return token_idx return token_idx
...@@ -80,7 +77,7 @@ class Instance: ...@@ -80,7 +77,7 @@ class Instance:
def get(self): def get(self):
tokens = tokenizer.tokenize(self.text) tokens = tokenizer.tokenize(self.text)
idx_from = token_for_char(self.char_from, self.text, tokens) idx_from = token_for_char(self.char_from, self.text, tokens)
idx_to = token_for_char(self.char_to, self.text, tokens) idx_to = token_for_char(self.char_to-1, self.text, tokens)
return tokens, idx_from, idx_to return tokens, idx_from, idx_to
def to_tensor(self): def to_tensor(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment