diff --git a/ADA/SA/bert_analyzer.py b/ADA/SA/bert_analyzer.py index bc8fe61d7bf3732eb8c5d1f230098c2bf6a4b4ef..beb28dfbd75f208c7e47714ed911c0f6db28d116 100644 --- a/ADA/SA/bert_analyzer.py +++ b/ADA/SA/bert_analyzer.py @@ -3,13 +3,15 @@ import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from tdbertnet import TDBertNet -from bert_dataset import BertDataset, polarity_indices, generate_batch +from bert_dataset import BertDataset, Instance, polarity_indices, generate_batch import time import numpy as np from sklearn import metrics +import matplotlib.pyplot as plt semeval_2014_train_path = 'data/SemEval-2014/Laptop_Train_v2.xml' semeval_2014_test_path = 'data/SemEval-2014/Laptops_Test_Gold.xml' +amazon_test_path = 'data/Amazon/amazon_camera_test.xml' trained_model_path = 'semeval_2014.pt' BATCH_SIZE = 32 @@ -45,7 +47,7 @@ class BertAnalyzer: optimiser.zero_grad() # forward pass - outputs = self.net(texts, target_indices) + outputs, _ = self.net(texts, target_indices) # backward pass l = loss(outputs, labels) @@ -74,7 +76,7 @@ class BertAnalyzer: truths = [] with torch.no_grad(): for (texts, target_indices, labels) in test_loader: - outputs = self.net(texts, target_indices) + outputs, attentions = self.net(texts, target_indices) _, pred = torch.max(outputs.data, 1) predicted += pred.tolist() truths += labels.tolist() @@ -90,7 +92,33 @@ class BertAnalyzer: f1 = metrics.f1_score(truths, predicted, labels=range(len(polarity_indices)), average='macro') print('macro F1:', f1) + def analyze_sentence(self, text, char_from, char_to): + instance = Instance(text, char_from, char_to) + tokens, tg_from, tg_to = instance.get() + texts, target_indices = instance.to_tensor() + + with torch.no_grad(): + outputs, attentions = self.net(texts, target_indices) + + target_attentions = torch.mean(attentions, 1)[0][tg_from+1:tg_to+2] + mean_target_att = torch.mean(target_attentions, 0) + + # plot attention histogram + att_values = mean_target_att.numpy()[1:-1] + + ax = plt.subplot(111) + width = 0.3 + bins = [x - width/2 for x in range(1, len(att_values)+1)] + ax.bar(bins, att_values, width=width) + ax.set_xticks(list(range(1, len(att_values)+1))) + ax.set_xticklabels(tokens, rotation=45, rotation_mode='anchor', ha='right') + plt.show() + + _, pred = torch.max(outputs.data, 1) + return pred + sentiment_analyzer = BertAnalyzer() sentiment_analyzer.load_saved() -sentiment_analyzer.evaluate(semeval_2014_test_path) \ No newline at end of file +sentiment = sentiment_analyzer.analyze_sentence('I will never buy another computer from HP/Compaq or do business with Circuit City again.', 39, 48) +print('sentiment:', sentiment) \ No newline at end of file diff --git a/ADA/SA/bert_dataset.py b/ADA/SA/bert_dataset.py index a6bde144a68503b717254820423aa18436f5620e..d05beeb6d70cfdd8d2097bf9caba79b2cebcfecc 100644 --- a/ADA/SA/bert_dataset.py +++ b/ADA/SA/bert_dataset.py @@ -26,13 +26,11 @@ def generate_batch(batch): def token_for_char(char_idx, text, tokens): - compressed_idx = len(re.sub(r'\s+', '', text)[:char_idx+1]) - 1 - + compressed_idx = len(re.sub(r'\s+', '', text[:char_idx+1])) - 1 token_idx = -1 while compressed_idx >= 0: token_idx += 1 compressed_idx -= len(tokens[token_idx].replace('##', '')) - return token_idx @@ -55,17 +53,36 @@ class BertDataset(Dataset): char_from = int(term.attrib['from']) char_to = int(term.attrib['to']) - 1 polarity = term.attrib['polarity'] - self.data.append((text, char_from, char_to, polarity)) + self.data.append((Instance(text, char_from, char_to), polarity)) def __len__(self): return len(self.data) def __getitem__(self, idx): - text, char_from, char_to, polarity_str = self.data[idx] + instance, polarity_str = self.data[idx] - tokens = tokenizer.tokenize(text) - idx_from = token_for_char(char_from, text, tokens) - idx_to = token_for_char(char_to, text, tokens) + tokens, idx_from, idx_to = instance.get() polarity = polarity_index(polarity_str) return {'tokens': tokens, 'from': idx_from, 'to': idx_to, 'polarity': polarity} + + +class Instance: + + def __init__(self, text, char_from, char_to): + self.text = text + self.char_from = char_from + self.char_to = char_to + + def get(self): + tokens = tokenizer.tokenize(self.text) + idx_from = token_for_char(self.char_from, self.text, tokens) + idx_to = token_for_char(self.char_to-1, self.text, tokens) + return tokens, idx_from, idx_to + + def to_tensor(self): + tokens, idx_from, idx_to = self.get() + text = tokenizer.encode_plus(tokens, add_special_tokens=True, max_length=MAX_SEQ_LEN, + is_pretokenized=True, return_tensors='pt') + target_indices = torch.tensor([[[t] * HIDDEN_OUTPUT_FEATURES for t in range(idx_from, idx_to + 1)]]) + return text, target_indices diff --git a/ADA/SA/tdbertnet.py b/ADA/SA/tdbertnet.py index 24a69f12868ade1d0ac8661ea11320ff1bdf4465..d3db489907bafec239a3d5a015781a31628007cd 100644 --- a/ADA/SA/tdbertnet.py +++ b/ADA/SA/tdbertnet.py @@ -6,19 +6,23 @@ from transformers import * HIDDEN_OUTPUT_FEATURES = 768 TRAINED_WEIGHTS = 'bert-base-uncased' + class TDBertNet(nn.Module): def __init__(self, num_class): super(TDBertNet, self).__init__() - self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS) + config = BertConfig.from_pretrained(TRAINED_WEIGHTS, output_attentions=True) + self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config) + self.bert_base.config.output_attentions = True self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, num_class) # n of hidden features, n of output labels def forward(self, texts, target_indices): # BERT - bert_output = self.bert_base(**texts)[0] + bert_output, _, attentions = self.bert_base(**texts) # max pooling at target locations target_outputs = torch.gather(bert_output, dim=1, index=target_indices) pooled_output = torch.max(target_outputs, dim=1)[0] # fc layer x = self.fc(pooled_output) - return x + return x, attentions[-1] +