Commit af560174 authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Server can now use bert for SA, seems to be working better than Bayes SA

parent 56fb62b9
...@@ -8,10 +8,10 @@ import time ...@@ -8,10 +8,10 @@ import time
import numpy as np import numpy as np
from sklearn import metrics from sklearn import metrics
semeval_2014_train_path = 'agent/SA/data/SemEval-2014/Laptop_Train_v2.xml' semeval_2014_train_path = 'data/SemEval-2014/Laptop_Train_v2.xml'
semeval_2014_test_path = 'agent/SA/data/SemEval-2014/Laptops_Test_Gold.xml' semeval_2014_test_path = 'data/SemEval-2014/Laptops_Test_Gold.xml'
amazon_test_path = 'agent/SA/data/Amazon/annotated_amazon_laptop_reviews.xml' amazon_test_path = 'agent/SA/data/Amazon/annotated_amazon_laptop_reviews.xml'
trained_model_path = 'agent/SA/semeval_2014_2.pt' trained_model_path = 'semeval_2014_2.pt'
BATCH_SIZE = 32 BATCH_SIZE = 32
MAX_EPOCHS = 6 MAX_EPOCHS = 6
...@@ -22,12 +22,13 @@ loss_criterion = nn.CrossEntropyLoss() ...@@ -22,12 +22,13 @@ loss_criterion = nn.CrossEntropyLoss()
def loss(outputs, labels): def loss(outputs, labels):
return loss_criterion(outputs, labels) return loss_criterion(outputs, labels)
class BertAnalyzer: class BertAnalyzer:
@staticmethod @staticmethod
def default(): def default():
sa = BertAnalyzer() sa = BertAnalyzer()
sa.load_saved('agent/SA/semeval_2014.pt') sa.load_saved(trained_model_path)
return sa return sa
def load_saved(self, path): def load_saved(self, path):
...@@ -35,8 +36,8 @@ class BertAnalyzer: ...@@ -35,8 +36,8 @@ class BertAnalyzer:
self.net.load_state_dict(torch.load(path)) self.net.load_state_dict(torch.load(path))
self.net.eval() self.net.eval()
def train(self, dataset): def train(self, data_file):
train_data = BertDataset(dataset) train_data = BertDataset.from_file(data_file)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
collate_fn=generate_batch) collate_fn=generate_batch)
...@@ -72,8 +73,8 @@ class BertAnalyzer: ...@@ -72,8 +73,8 @@ class BertAnalyzer:
torch.save(self.net.state_dict(), trained_model_path) torch.save(self.net.state_dict(), trained_model_path)
def evaluate(self, dataset): def evaluate(self, data_file):
test_data = BertDataset(dataset) test_data = BertDataset.from_file(data_file)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
collate_fn=generate_batch) collate_fn=generate_batch)
...@@ -97,6 +98,19 @@ class BertAnalyzer: ...@@ -97,6 +98,19 @@ class BertAnalyzer:
f1 = metrics.f1_score(truths, predicted, labels=range(len(polarity_indices)), average='macro') f1 = metrics.f1_score(truths, predicted, labels=range(len(polarity_indices)), average='macro')
print('macro F1:', f1) print('macro F1:', f1)
def get_batch_sentiment_polarity(self, data):
dataset = BertDataset.from_data(data)
loader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=8, collate_fn=generate_batch)
predicted = []
with torch.no_grad():
for texts, target_indices, _ in loader:
outputs, attentions = self.net(texts, target_indices)
batch_val, batch_pred = torch.max(outputs.data, 1)
predicted += [BertAnalyzer.get_polarity(val, pred) for val, pred in zip(batch_val, batch_pred)]
return predicted
def get_sentiment_polarity(self, text, char_from, char_to): def get_sentiment_polarity(self, text, char_from, char_to):
instance = Instance(text, char_from, char_to) instance = Instance(text, char_from, char_to)
tokens, tg_from, tg_to = instance.get() tokens, tg_from, tg_to = instance.get()
...@@ -121,6 +135,10 @@ class BertAnalyzer: ...@@ -121,6 +135,10 @@ class BertAnalyzer:
# plt.show() # plt.show()
val, pred = torch.max(outputs.data, 1) val, pred = torch.max(outputs.data, 1)
return BertAnalyzer.get_polarity(val, pred)
@staticmethod
def get_polarity(val, pred):
if pred == 0: if pred == 0:
# positive # positive
return val return val
...@@ -129,5 +147,4 @@ class BertAnalyzer: ...@@ -129,5 +147,4 @@ class BertAnalyzer:
return -val return -val
else: else:
# neutral or conflicted # neutral or conflicted
return 0 return 0
\ No newline at end of file
...@@ -40,11 +40,14 @@ def polarity_index(polarity): ...@@ -40,11 +40,14 @@ def polarity_index(polarity):
class BertDataset(Dataset): class BertDataset(Dataset):
def __init__(self, xml_file): def __init__(self):
tree = ET.parse(xml_file)
self.data = [] self.data = []
@staticmethod
def from_file(file):
dataset = BertDataset()
tree = ET.parse(file)
dataset.data = []
for sentence in tree.getroot(): for sentence in tree.getroot():
text = sentence.find('text').text text = sentence.find('text').text
aspect_terms = sentence.find('aspectTerms') aspect_terms = sentence.find('aspectTerms')
...@@ -53,7 +56,14 @@ class BertDataset(Dataset): ...@@ -53,7 +56,14 @@ class BertDataset(Dataset):
char_from = int(term.attrib['from']) char_from = int(term.attrib['from'])
char_to = int(term.attrib['to']) char_to = int(term.attrib['to'])
polarity = term.attrib['polarity'] polarity = term.attrib['polarity']
self.data.append((Instance(text, char_from, char_to), polarity)) dataset.data.append((Instance(text, char_from, char_to), polarity))
return dataset
@staticmethod
def from_data(data):
dataset = BertDataset()
dataset.data = [(Instance(text, char_from, char_to), 'neutral') for text, char_from, char_to in data]
return dataset
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
......
...@@ -22,7 +22,7 @@ class TDBertNet(nn.Module): ...@@ -22,7 +22,7 @@ class TDBertNet(nn.Module):
# max pooling at target locations # max pooling at target locations
target_outputs = torch.gather(bert_output, dim=1, index=target_indices) target_outputs = torch.gather(bert_output, dim=1, index=target_indices)
pooled_output = torch.max(target_outputs, dim=1)[0] pooled_output = torch.max(target_outputs, dim=1)[0]
# fc layer # fc layer with softmax activation
x = self.fc(pooled_output) x = F.softmax(self.fc(pooled_output), 1)
return x, attentions[-1] return x, attentions[-1]
from nltk.tokenize import sent_tokenize
import re import re
from agent.review_tokenizer import ReviewTokenizer from agent.review_tokenizer import ReviewTokenizer
from anytree import PostOrderIter from anytree import PostOrderIter
...@@ -6,9 +5,11 @@ import pickle ...@@ -6,9 +5,11 @@ import pickle
from agent.argument import * from agent.argument import *
from functools import reduce from functools import reduce
from agent.SA.bert_analyzer import BertAnalyzer from agent.SA.bert_analyzer import BertAnalyzer
from agent.review import Review
class Agent: class Agent:
sentiment_threshold = 0.95
review_tokenizer = ReviewTokenizer() review_tokenizer = ReviewTokenizer()
bert_analyzer = BertAnalyzer.default() bert_analyzer = BertAnalyzer.default()
...@@ -18,15 +19,6 @@ class Agent: ...@@ -18,15 +19,6 @@ class Agent:
self.classifier = pickle.load(f) self.classifier = pickle.load(f)
f.close() f.close()
# extract phrases
def extract_phrases(self, review_body):
sentences = sent_tokenize(review_body)
phrases = []
for sentence in sentences:
phrases += re.split(' but | although | though | otherwise | however | unless | whereas | despite |<br />',
sentence)
return phrases
# analyze sentiment # analyze sentiment
def get_bayes_sentiment(self, phrase): def get_bayes_sentiment(self, phrase):
# get classification # get classification
...@@ -36,60 +28,42 @@ class Agent: ...@@ -36,60 +28,42 @@ class Agent:
strength = (prob_classification.prob(classification) - 0.5) * 2 strength = (prob_classification.prob(classification) - 0.5) * 2
return strength if classification == '+' else -strength return strength if classification == '+' else -strength
def get_bert_sentiment(self, text, char_from, char_to): def get_bert_sentiments(self, data):
return self.bert_analyzer.get_sentiment_polarity(text, char_from, char_to) return list(self.bert_analyzer.get_batch_sentiment_polarity(data))
# remove all ancestors of node in list l def extract_votes(self, reviews):
def remove_ancestors(self, node, l): labelled_phrases = [(phrase.text, arg.start, arg.end) for review in reviews for phrase in review.phrases for arg
if node.parent != None: in phrase.args]
try:
l.remove(node.parent) sentiments = self.get_bert_sentiments(labelled_phrases)
except ValueError:
pass for review in reviews:
self.remove_ancestors(node.parent, l) for phrase in review.phrases:
bayes_sentiment = self.get_bayes_sentiment(phrase.text)
# get argument(s) that match phrase for arg in phrase.args:
def get_arguments(self, phrase): sentiment = sentiments.pop(0)
argument_matches = [] print(phrase.text)
arguments = [node for node in PostOrderIter(camera)] print('arg:', arg.start, '-', arg.end)
while len(arguments) > 0: print('bert:', sentiment)
f = arguments.pop(0) print('bayes:', bayes_sentiment)
for word in glossary[f]: arg.set_sentiment(sentiment)
matches = [(f, m.start(), m.end()) for m in re.finditer(word, phrase)]
if matches: @staticmethod
argument_matches += matches def get_aggregates(reviews):
self.remove_ancestors(f, arguments) ra = []
break vote_sum = {arg: 0 for arg in arguments}
return argument_matches vote_phrases = {arg: [] for arg in arguments}
for review in reviews:
def extract_votes(self, phrases): for phrase in review.phrases:
votes = {} for arg, sentiment in phrase.get_votes().items():
vote_phrases = {} vote_phrases[arg].append({'phrase': phrase.text, 'sentiment': sentiment})
for phrase in phrases: for arg, sentiment in review.get_votes().items():
for argument, start, end in self.get_arguments(phrase): ra.append({'review_id': review.id, 'argument': arg, 'vote': sentiment})
sentiment = self.get_bayes_sentiment(phrase) # self.get_bert_sentiment(phrase, start, end) vote_sum[arg] += sentiment
if abs(sentiment) > self.sentiment_threshold: return ra, vote_sum, vote_phrases
if (argument not in votes) or (abs(votes[argument]) < abs(sentiment)):
votes[argument] = sentiment # what if there's two phrases with same argument? @staticmethod
vote_phrases[argument] = {'phrase': phrase, 'sentiment': sentiment} def get_qbaf(ra, review_count):
# normalize votes to 1 (+) or -1 (-)
for argument in votes:
votes[argument] = 1 if votes[argument] > 0 else -1
return votes, vote_phrases
# augment votes (Definition 4.3) obtained for a single critic
def augment_votes(self, votes):
arguments = [node for node in PostOrderIter(camera)]
for argument in arguments:
if argument not in votes:
polar_sum = 0
for subfeat in argument.children:
if subfeat in votes:
polar_sum += votes[subfeat]
if polar_sum != 0:
votes[argument] = 1 if polar_sum > 0 else -1
def get_qbaf(self, ra, review_count):
# sums of all positive and negative votes for arguments # sums of all positive and negative votes for arguments
argument_sums = {} argument_sums = {}
for argument in arguments: for argument in arguments:
...@@ -147,31 +121,17 @@ class Agent: ...@@ -147,31 +121,17 @@ class Agent:
supporter_strengths) supporter_strengths)
return strengths return strengths
def analyze_reviews(self, reviews): def analyze_reviews(self, csv):
# get ra reviews = [Review(row) for _, row in csv.iterrows()]
self.ra = [] # extract augmented votes
self.vote_sum = {argument: 0 for argument in arguments} self.extract_votes(reviews)
self.vote_phrases = {argument: [] for argument in arguments} voting_reviews = list(filter(lambda r: r.is_voting(), reviews))
voting_reviews = 0 if len(voting_reviews) / len(reviews) < 0.33:
review_count = 0
for _, review in reviews.iterrows():
review_id = review['review_id']
review_count += 1
phrases = self.extract_phrases(review['review_body'])
votes, vote_phrases = self.extract_votes(phrases)
self.augment_votes(votes)
voting_reviews += 1 if len(votes) > 0 else 0
# add final vote tuples to ra with simplified polarity in {+ (true), - (false)}
for argument in votes:
self.ra.append({'review_id': review_id, 'argument': argument, 'vote': votes[argument]})
self.vote_sum[argument] += votes[argument]
for argument in vote_phrases:
self.vote_phrases[argument].append(vote_phrases[argument])
# only consider items that obtained votes from at least 33% of reviewers
if voting_reviews / review_count < 0.33:
print('warning: only a small fraction of reviews generated votes') print('warning: only a small fraction of reviews generated votes')
# get aggregates
ra, self.vote_sum, self.vote_phrases = Agent.get_aggregates(reviews)
# get qbaf from ra # get qbaf from ra
self.qbaf = self.get_qbaf(self.ra, review_count) self.qbaf = self.get_qbaf(ra, len(reviews))
# apply gradual semantics # apply gradual semantics
self.strengths = self.get_strengths(self.qbaf) self.strengths = self.get_strengths(self.qbaf)
# print results # print results
......
...@@ -27,8 +27,12 @@ class Communicator: ...@@ -27,8 +27,12 @@ class Communicator:
def __init__(self, dl): def __init__(self, dl):
self.dl = dl self.dl = dl
self.product_id = None
def set_product(self, product_id): def has_loaded_product(self, product_id):
return self.product_id == product_id
def load_product(self, product_id):
self.product_id = product_id self.product_id = product_id
self.arguments = {arguments[i] : Argument(i, arguments[i].name) for i in range(len(arguments))} self.arguments = {arguments[i] : Argument(i, arguments[i].name) for i in range(len(arguments))}
self.argument_nodes = arguments self.argument_nodes = arguments
......
import re
from nltk.tokenize import sent_tokenize
from agent.SA.bert_dataset import MAX_SEQ_LEN
from anytree import PostOrderIter
from agent.argument import *
class Review:
SENTIMENT_THRESHOLD = 0.95
PHRASE_MAX_WORDS = MAX_SEQ_LEN * 0.3
def __init__(self, data):
self.id = data['review_id']
self.body = data['review_body']
self.phrases = Review.extract_phrases(self.body)
self.votes = {}
# extract phrases
@staticmethod
def extract_phrases(review_body):
sentences = sent_tokenize(review_body)
texts = []
for sentence in sentences:
texts += re.split(' but | although | though | otherwise | however | unless | whereas | despite |<br />',
sentence)
texts = filter(lambda t: len(t.split()) < Review.PHRASE_MAX_WORDS, texts)
phrases = [Phrase(text) for text in texts]
return phrases
def get_votes(self):
for arg, sentiment in [(arg, sentiment) for phrase in self.phrases for arg, sentiment in phrase.votes.items()]:
if arg not in self.votes or abs(sentiment) > abs(self.votes[arg]):
self.votes[arg] = sentiment
# normalize
for arg in self.votes:
self.votes[arg] = 1 if self.votes[arg] > 0 else -1
self.augment_votes()
return self.votes
# augment votes (Definition 4.3) obtained for a single critic
def augment_votes(self):
arguments = [node for node in PostOrderIter(camera)]
for argument in arguments:
if argument not in self.votes:
polar_sum = 0
for subfeat in argument.children:
if subfeat in self.votes:
polar_sum += self.votes[subfeat]
if polar_sum != 0:
self.votes[argument] = 1 if polar_sum > 0 else -1
def is_voting(self):
return len(self.votes) > 0
class Phrase:
def __init__(self, text):
self.text = text
self.args = self.get_args(text)
self.votes = {}
# get argument(s) that match phrase
def get_args(self, phrase):
argument_matches = []
arguments = [node for node in PostOrderIter(camera)]
while len(arguments) > 0:
f = arguments.pop(0)
for word in glossary[f]:
matches = [Arg(f, m.start(), m.end()) for m in re.finditer(word, phrase)]
if matches:
argument_matches += matches
self.remove_ancestors(f, arguments)
break
return argument_matches
# remove all ancestors of node in list l
def remove_ancestors(self, node, l):
if node.parent != None:
try:
l.remove(node.parent)
except ValueError:
pass
self.remove_ancestors(node.parent, l)
def add_arg(self, arg):
self.args.append(arg)
def num_args(self):
return len(self.args)
def get_votes(self):
for arg in self.args:
if (abs(arg.sentiment) > Review.SENTIMENT_THRESHOLD and
(arg.node not in self.votes or abs(arg.sentiment) > abs(self.votes[arg.node]))):
self.votes[arg.node] = arg.sentiment
return self.votes
class Arg:
def __init__(self, node, start, end):
self.node = node
self.start = start
self.end = end
self.sentiment = None
def set_sentiment(self, sentiment):
self.sentiment = sentiment
...@@ -20,7 +20,10 @@ def product(request): ...@@ -20,7 +20,10 @@ def product(request):
star_rating = dl.get_avg_star_rating(id) star_rating = dl.get_avg_star_rating(id)
image_url = 'https://ws-na.amazon-adsystem.com/widgets/q?_encoding=UTF8&MarketPlace=US&ASIN=' + id + '&ServiceVersion=20070822&ID=AsinImage&WS=1&Format=SL250' image_url = 'https://ws-na.amazon-adsystem.com/widgets/q?_encoding=UTF8&MarketPlace=US&ASIN=' + id + '&ServiceVersion=20070822&ID=AsinImage&WS=1&Format=SL250'
communicator.set_product(id) if not communicator.has_loaded_product(id):
communicator.load_product(id)
return HttpResponse("OK")
init_message = communicator.get_init_message() init_message = communicator.get_init_message()
class Empty: class Empty:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment