diff --git a/ADA/SA/bert_analyzer.py b/ADA/SA/bert_analyzer.py index c397669deab37a51cc14b8f869a421ba567518e2..beb28dfbd75f208c7e47714ed911c0f6db28d116 100644 --- a/ADA/SA/bert_analyzer.py +++ b/ADA/SA/bert_analyzer.py @@ -7,6 +7,7 @@ from bert_dataset import BertDataset, Instance, polarity_indices, generate_batch import time import numpy as np from sklearn import metrics +import matplotlib.pyplot as plt semeval_2014_train_path = 'data/SemEval-2014/Laptop_Train_v2.xml' semeval_2014_test_path = 'data/SemEval-2014/Laptops_Test_Gold.xml' @@ -93,13 +94,25 @@ class BertAnalyzer: def analyze_sentence(self, text, char_from, char_to): instance = Instance(text, char_from, char_to) - _, tg_from, tg_to = instance.get() + tokens, tg_from, tg_to = instance.get() texts, target_indices = instance.to_tensor() with torch.no_grad(): outputs, attentions = self.net(texts, target_indices) target_attentions = torch.mean(attentions, 1)[0][tg_from+1:tg_to+2] + mean_target_att = torch.mean(target_attentions, 0) + + # plot attention histogram + att_values = mean_target_att.numpy()[1:-1] + + ax = plt.subplot(111) + width = 0.3 + bins = [x - width/2 for x in range(1, len(att_values)+1)] + ax.bar(bins, att_values, width=width) + ax.set_xticks(list(range(1, len(att_values)+1))) + ax.set_xticklabels(tokens, rotation=45, rotation_mode='anchor', ha='right') + plt.show() _, pred = torch.max(outputs.data, 1) return pred @@ -107,4 +120,5 @@ class BertAnalyzer: sentiment_analyzer = BertAnalyzer() sentiment_analyzer.load_saved() -sentiment = sentiment_analyzer.analyze_sentence('I hate this laptop', 12, 18) +sentiment = sentiment_analyzer.analyze_sentence('I will never buy another computer from HP/Compaq or do business with Circuit City again.', 39, 48) +print('sentiment:', sentiment) \ No newline at end of file diff --git a/ADA/SA/bert_dataset.py b/ADA/SA/bert_dataset.py index df7faf73d08c796616cdaadce9444daaf556549d..d05beeb6d70cfdd8d2097bf9caba79b2cebcfecc 100644 --- a/ADA/SA/bert_dataset.py +++ b/ADA/SA/bert_dataset.py @@ -16,7 +16,6 @@ def generate_batch(batch): return_tensors='pt') max_tg_len = max(entry['to'] - entry['from'] for entry in batch) - print(max_tg_len) target_indices = torch.tensor([[[min(t, entry['to'])] * HIDDEN_OUTPUT_FEATURES for t in range(entry['from'], entry['from'] + max_tg_len + 1)] for entry in batch]) @@ -27,13 +26,11 @@ def generate_batch(batch): def token_for_char(char_idx, text, tokens): - compressed_idx = len(re.sub(r'\s+', '', text)[:char_idx+1]) - 1 - + compressed_idx = len(re.sub(r'\s+', '', text[:char_idx+1])) - 1 token_idx = -1 while compressed_idx >= 0: token_idx += 1 compressed_idx -= len(tokens[token_idx].replace('##', '')) - return token_idx @@ -80,7 +77,7 @@ class Instance: def get(self): tokens = tokenizer.tokenize(self.text) idx_from = token_for_char(self.char_from, self.text, tokens) - idx_to = token_for_char(self.char_to, self.text, tokens) + idx_to = token_for_char(self.char_to-1, self.text, tokens) return tokens, idx_from, idx_to def to_tensor(self):