Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Joel Oksanen
individual_project
Commits
226a21f0
Commit
226a21f0
authored
Apr 15, 2020
by
Joel Oksanen
Browse files
Attention outputs for BERT
parent
c7ab3edc
Changes
2
Hide whitespace changes
Inline
Side-by-side
ADA/SA/bert_analyzer.py
View file @
226a21f0
...
...
@@ -8,11 +8,12 @@ import time
import
numpy
as
np
from
sklearn
import
metrics
import
matplotlib.pyplot
as
plt
import
shap
semeval_2014_train_path
=
'data/SemEval-2014/Laptop_Train_v2.xml'
semeval_2014_test_path
=
'data/SemEval-2014/Laptops_Test_Gold.xml'
amazon_test_path
=
'data/Amazon/amazon_camera_test.xml'
trained_model_path
=
'semeval_2014.pt'
trained_model_path
=
'semeval_2014
_2
.pt'
BATCH_SIZE
=
32
MAX_EPOCHS
=
6
...
...
@@ -25,9 +26,9 @@ def loss(outputs, labels):
class
BertAnalyzer
:
def
load_saved
(
self
):
def
load_saved
(
self
,
path
):
self
.
net
=
TDBertNet
(
len
(
polarity_indices
))
self
.
net
.
load_state_dict
(
torch
.
load
(
trained_model_
path
))
self
.
net
.
load_state_dict
(
torch
.
load
(
path
))
self
.
net
.
eval
()
def
train
(
self
,
dataset
):
...
...
@@ -36,7 +37,7 @@ class BertAnalyzer:
collate_fn
=
generate_batch
)
self
.
net
=
TDBertNet
(
len
(
polarity_indices
))
optimiser
=
optim
.
Adam
(
net
.
parameters
(),
lr
=
LEARNING_RATE
)
optimiser
=
optim
.
Adam
(
self
.
net
.
parameters
(),
lr
=
LEARNING_RATE
)
start
=
time
.
time
()
...
...
@@ -65,7 +66,7 @@ class BertAnalyzer:
end
=
time
.
time
()
print
(
'Training took'
,
end
-
start
,
'seconds'
)
torch
.
save
(
net
.
state_dict
(),
trained_model_path
)
torch
.
save
(
self
.
net
.
state_dict
(),
trained_model_path
)
def
evaluate
(
self
,
dataset
):
test_data
=
BertDataset
(
dataset
)
...
...
@@ -95,30 +96,30 @@ class BertAnalyzer:
def
analyze_sentence
(
self
,
text
,
char_from
,
char_to
):
instance
=
Instance
(
text
,
char_from
,
char_to
)
tokens
,
tg_from
,
tg_to
=
instance
.
get
()
text
s
,
target_indices
=
instance
.
to_tensor
()
text
,
target_indices
=
instance
.
to_tensor
()
with
torch
.
no_grad
():
outputs
,
attentions
=
self
.
net
(
texts
,
target_indices
)
target_attentions
=
torch
.
mean
(
attentions
,
1
)[
0
][
tg_from
+
1
:
tg_to
+
2
]
mean_target_att
=
torch
.
mean
(
target_attentions
,
0
)
# plot attention histogram
att_values
=
mean_target_att
.
numpy
()[
1
:
-
1
]
ax
=
plt
.
subplot
(
111
)
width
=
0.3
bins
=
[
x
-
width
/
2
for
x
in
range
(
1
,
len
(
att_values
)
+
1
)]
ax
.
bar
(
bins
,
att_values
,
width
=
width
)
ax
.
set_xticks
(
list
(
range
(
1
,
len
(
att_values
)
+
1
)))
ax
.
set_xticklabels
(
tokens
,
rotation
=
45
,
rotation_mode
=
'anchor'
,
ha
=
'right'
)
plt
.
show
()
outputs
,
attentions
=
self
.
net
(
text
,
target_indices
)
# attention_heads = attentions[0]
# num_heads = len(attention_heads)
# ax = plt.subplot(111)
# token_width = 1
# head_width = token_width / num_heads
# for i, head in enumerate(attention_heads):
# # plot attention histogram
# att_values = torch.mean(head[tg_from+1:tg_to+2], 0)[1:-1].numpy()
#
# bins = [x - token_width / 2 + i * head_width for x in range(1, len(att_values) + 1)]
# ax.bar(bins, att_values, width=head_width)
# ax.set_xticks(list(range(1, len(att_values) + 1)))
# ax.set_xticklabels(tokens, rotation=45, rotation_mode='anchor', ha='right')
# plt.show()
_
,
pred
=
torch
.
max
(
outputs
.
data
,
1
)
return
pred
sentiment_analyzer
=
BertAnalyzer
()
sentiment_analyzer
.
load_saved
()
sentiment
=
sentiment_analyzer
.
analyze_sentence
(
'I will never buy another computer from HP/Compaq or do business with Circuit City again.'
,
39
,
48
)
print
(
'sentiment:'
,
sentiment
)
\ No newline at end of file
sentiment_analyzer
.
load_saved
(
'semeval_2014.pt'
)
print
(
sentiment_analyzer
.
analyze_sentence
(
"Well built laptop with win7."
,
11
,
17
))
\ No newline at end of file
ADA/SA/bert_dataset.py
View file @
226a21f0
...
...
@@ -51,7 +51,7 @@ class BertDataset(Dataset):
if
aspect_terms
:
for
term
in
aspect_terms
:
char_from
=
int
(
term
.
attrib
[
'from'
])
char_to
=
int
(
term
.
attrib
[
'to'
])
-
1
char_to
=
int
(
term
.
attrib
[
'to'
])
polarity
=
term
.
attrib
[
'polarity'
]
self
.
data
.
append
((
Instance
(
text
,
char_from
,
char_to
),
polarity
))
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment