Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Joel Oksanen
individual_project
Commits
1a7440c3
Commit
1a7440c3
authored
May 23, 2020
by
Joel Oksanen
Browse files
Implemented one way pair extractor
parent
4f691588
Changes
4
Hide whitespace changes
Inline
Side-by-side
ADA/server/agent/target_extraction/BERT/bert_rel_extractor.py
View file @
1a7440c3
...
...
@@ -11,7 +11,7 @@ from transformers import get_linear_schedule_with_warmup
from
agent.target_extraction.BERT.pair_rel_dataset
import
PairRelDataset
,
generate_batch
,
generate_production_batch
from
agent.target_extraction.BERT.pairbertnet
import
NUM_CLASSES
,
PairBertNet
trained_model_path
=
'trained_bert_rel_extractor_camera_
and_
backpack_
with_nan
.pt'
trained_model_path
=
'trained_bert_rel_extractor_camera_backpack_
laptop_pair
.pt'
device
=
torch
.
device
(
'cuda'
)
loss_criterion
=
CrossEntropyLoss
()
...
...
@@ -22,14 +22,14 @@ MAX_GRAD_NORM = 1.0
# training
N_EPOCHS
=
3
BATCH_SIZE
=
32
BATCH_SIZE
=
16
WARM_UP_FRAC
=
0.05
class
BertRelExtractor
:
def
__init__
(
self
):
self
.
net
=
None
self
.
net
=
PairBertNet
()
@
staticmethod
def
load_saved
(
path
):
...
...
@@ -58,7 +58,6 @@ class BertRelExtractor:
collate_fn
=
generate_batch
)
# initialise BERT
self
.
net
=
PairBertNet
()
self
.
net
.
cuda
()
# set up optimizer with weight decay
...
...
@@ -80,13 +79,13 @@ class BertRelExtractor:
for
batch_idx
,
batch
in
enumerate
(
train_loader
):
# send batch to gpu
input_ids
,
attn_mask
,
target_labels
=
tuple
(
i
.
to
(
device
)
for
i
in
batch
)
input_ids
,
attn_mask
,
prod_indices
,
feat_indices
,
target_labels
=
tuple
(
i
.
to
(
device
)
for
i
in
batch
)
# zero param gradients
optimiser
.
zero_grad
()
# forward pass
output_scores
=
self
.
net
(
input_ids
,
attn_mask
)
output_scores
=
self
.
net
(
input_ids
,
attn_mask
,
prod_indices
,
feat_indices
)
# backward pass
loss
=
loss_criterion
(
output_scores
,
target_labels
)
...
...
@@ -140,10 +139,10 @@ class BertRelExtractor:
with
torch
.
no_grad
():
for
batch
in
test_loader
:
# send batch to gpu
input_ids
,
attn_mask
,
target_labels
=
tuple
(
i
.
to
(
device
)
for
i
in
batch
)
input_ids
,
attn_mask
,
prod_indices
,
feat_indices
,
target_labels
=
tuple
(
i
.
to
(
device
)
for
i
in
batch
)
# forward pass
output_scores
=
self
.
net
(
input_ids
,
attn_mask
)
output_scores
=
self
.
net
(
input_ids
,
attn_mask
,
prod_indices
,
feat_indices
)
_
,
output_labels
=
torch
.
max
(
output_scores
.
data
,
1
)
outputs
+=
output_labels
.
tolist
()
...
...
@@ -170,17 +169,18 @@ class BertRelExtractor:
def
extract_single_relation
(
self
,
text
,
e1
,
e2
):
ins
=
PairRelDataset
.
get_instance
(
text
,
e1
,
e2
)
input_ids
,
attn_mask
,
instances
=
generate_production_batch
([
ins
])
input_ids
,
attn_mask
,
prod_indices
,
feat_indices
,
instances
=
generate_production_batch
([
ins
])
self
.
net
.
cuda
()
self
.
net
.
eval
()
with
torch
.
no_grad
():
# send batch to gpu
input_ids
,
attn_mask
=
tuple
(
i
.
to
(
device
)
for
i
in
[
input_ids
,
attn_mask
])
input_ids
,
attn_mask
,
prod_indices
,
feat_indices
=
tuple
(
i
.
to
(
device
)
for
i
in
[
input_ids
,
attn_mask
,
prod_indices
,
feat_indices
])
# forward pass
output_scores
=
softmax
(
self
.
net
(
input_ids
,
attn_mask
),
dim
=
1
)
output_scores
=
softmax
(
self
.
net
(
input_ids
,
attn_mask
,
prod_indices
,
feat_indices
),
dim
=
1
)
_
,
output_labels
=
torch
.
max
(
output_scores
.
data
,
1
)
print
(
instances
[
0
].
get_relation_for_label
(
output_labels
[
0
]))
...
...
@@ -203,12 +203,15 @@ class BertRelExtractor:
outputs
=
[]
with
torch
.
no_grad
():
for
input_ids
,
attn_mask
,
instances
in
loader
:
for
input_ids
,
attn_mask
,
prod_indices
,
feat_indices
,
instances
in
loader
:
# send batch to gpu
input_ids
,
attn_mask
=
tuple
(
i
.
to
(
device
)
for
i
in
[
input_ids
,
attn_mask
])
input_ids
,
attn_mask
,
prod_indices
,
feat_indices
=
tuple
(
i
.
to
(
device
)
for
i
in
[
input_ids
,
attn_mask
,
prod_indices
,
feat_indices
])
# forward pass
output_scores
=
softmax
(
self
.
net
(
input_ids
,
attn_mask
),
dim
=
1
)
output_scores
=
softmax
(
self
.
net
(
input_ids
,
attn_mask
,
prod_indices
,
feat_indices
),
dim
=
1
)
_
,
output_labels
=
torch
.
max
(
output_scores
.
data
,
1
)
outputs
+=
map
(
lambda
x
:
x
[
0
].
get_relation_for_label
(
x
[
1
]),
zip
(
instances
,
output_labels
.
tolist
()))
...
...
@@ -222,5 +225,5 @@ class BertRelExtractor:
return
outputs
extr
:
BertRelExtractor
=
BertRelExtractor
.
load_saved
(
'trained_bert_rel_extractor_camera_
and_
backpack_
with_nan
.pt'
)
extr
.
e
valuate
(
'data/annotated_
camera
_review_pairs.tsv'
,
size
=
1000
0
)
extr
:
BertRelExtractor
=
BertRelExtractor
.
load_saved
(
'trained_bert_rel_extractor_camera_backpack_
laptop_pair
.pt'
)
extr
.
e
xtract_relations
(
file_path
=
'data/annotated_
acoustic_guitar
_review_pairs.tsv'
,
size
=
5
0
)
ADA/server/agent/target_extraction/BERT/pair_rel_dataset.py
View file @
1a7440c3
...
...
@@ -9,10 +9,8 @@ from agent.target_extraction.BERT.pairbertnet import TRAINED_WEIGHTS, HIDDEN_OUT
MAX_SEQ_LEN
=
128
RELATIONS
=
[
'/has_feature'
,
'/no_relation'
]
RELATION_LABEL_MAP
=
{
None
:
None
,
'/has_feature'
:
1
,
'/no_relation'
:
0
}
PROD_TOKEN
=
'[MASK]'
FEAT_TOKEN
=
'[MASK]'
MASK_TOKEN
=
'[MASK]'
tokenizer
=
BertTokenizer
.
from_pretrained
(
TRAINED_WEIGHTS
)
tokenizer
.
add_tokens
([
PROD_TOKEN
,
FEAT_TOKEN
])
def
generate_batch
(
batch
):
...
...
@@ -23,7 +21,10 @@ def generate_batch(batch):
attn_mask
=
encoded
[
'attention_mask'
]
labels
=
torch
.
tensor
([
instance
.
label
for
instance
in
batch
])
return
input_ids
,
attn_mask
,
labels
both_ranges
=
[(
instance
.
prod_range
,
instance
.
feat_range
)
for
instance
in
batch
]
prod_indices
,
feat_indices
=
map
(
indices_for_entity_ranges
,
zip
(
*
both_ranges
))
return
input_ids
,
attn_mask
,
prod_indices
,
feat_indices
,
labels
def
generate_production_batch
(
batch
):
...
...
@@ -33,7 +34,18 @@ def generate_production_batch(batch):
input_ids
=
encoded
[
'input_ids'
]
attn_mask
=
encoded
[
'attention_mask'
]
return
input_ids
,
attn_mask
,
batch
both_ranges
=
[(
instance
.
prod_range
,
instance
.
feat_range
)
for
instance
in
batch
]
prod_indices
,
feat_indices
=
map
(
indices_for_entity_ranges
,
zip
(
*
both_ranges
))
return
input_ids
,
attn_mask
,
prod_indices
,
feat_indices
,
batch
def
indices_for_entity_ranges
(
ranges
):
max_e_len
=
max
(
end
-
start
for
start
,
end
in
ranges
)
indices
=
torch
.
tensor
([[[
min
(
t
,
end
)]
*
HIDDEN_OUTPUT_FEATURES
for
t
in
range
(
start
,
start
+
max_e_len
+
1
)]
for
start
,
end
in
ranges
])
return
indices
class
PairRelDataset
(
Dataset
):
...
...
@@ -89,7 +101,7 @@ class PairRelDataset(Dataset):
i
=
0
found_entities
=
[]
ranges
=
[]
ranges
=
{}
while
i
<
len
(
tokens
):
match
=
False
for
entity
in
[
prod
,
feat
]:
...
...
@@ -97,10 +109,10 @@ class PairRelDataset(Dataset):
if
match_length
is
not
None
:
if
entity
in
found_entities
:
return
None
# raise AttributeError('Entity {} appears twice in text {}'.format(entity, text))
tokens
[
i
:
i
+
match_length
]
=
[
PROD_TOKEN
if
entity
==
prod
else
FEAT
_TOKEN
]
*
match_length
tokens
[
i
:
i
+
match_length
]
=
[
MASK
_TOKEN
]
*
match_length
match
=
True
found_entities
.
append
(
entity
)
ranges
.
append
((
i
,
i
+
match_length
-
1
))
ranges
[
entity
]
=
(
i
+
1
,
i
+
match_length
)
# + 1 taking into account the [CLS] token
i
+=
match_length
break
if
not
match
:
...
...
@@ -109,7 +121,7 @@ class PairRelDataset(Dataset):
if
len
(
found_entities
)
!=
2
:
return
None
# raise AttributeError('Could not find entities {} and {} in {}. Found entities {}'.format(e1, e2, text, found_entities))
return
PairRelInstance
(
tokens
,
prod
,
feat
,
tuple
(
ranges
)
,
RELATION_LABEL_MAP
[
relation
],
text
)
return
PairRelInstance
(
tokens
,
prod
,
feat
,
ranges
[
prod
],
ranges
[
feat
]
,
RELATION_LABEL_MAP
[
relation
],
text
)
@
staticmethod
def
token_entity_match
(
first_token_idx
,
entity
,
tokens
):
...
...
@@ -146,11 +158,12 @@ class PairRelDataset(Dataset):
class
PairRelInstance
:
def
__init__
(
self
,
tokens
,
prod
,
feat
,
entity
_range
s
,
label
,
text
):
def
__init__
(
self
,
tokens
,
prod
,
feat
,
prod_range
,
feat
_range
,
label
,
text
):
self
.
tokens
=
tokens
self
.
prod
=
prod
self
.
feat
=
feat
self
.
entity_ranges
=
entity_ranges
self
.
prod_range
=
prod_range
self
.
feat_range
=
feat_range
self
.
label
=
label
self
.
text
=
text
...
...
ADA/server/agent/target_extraction/BERT/pairbertnet.py
View file @
1a7440c3
...
...
@@ -13,11 +13,21 @@ class PairBertNet(nn.Module):
super
(
PairBertNet
,
self
).
__init__
()
config
=
BertConfig
.
from_pretrained
(
TRAINED_WEIGHTS
)
self
.
bert_base
=
BertModel
.
from_pretrained
(
TRAINED_WEIGHTS
,
config
=
config
)
self
.
fc
=
nn
.
Linear
(
HIDDEN_OUTPUT_FEATURES
,
NUM_CLASSES
)
self
.
fc
=
nn
.
Linear
(
HIDDEN_OUTPUT_FEATURES
*
2
,
NUM_CLASSES
)
def
forward
(
self
,
input_ids
,
attn_mask
):
def
forward
(
self
,
input_ids
,
attn_mask
,
prod_indices
,
feat_indices
):
# BERT
_
,
pooler_output
=
self
.
bert_base
(
input_ids
=
input_ids
,
attention_mask
=
attn_mask
)
bert_output
,
_
=
self
.
bert_base
(
input_ids
=
input_ids
,
attention_mask
=
attn_mask
)
# max pooling at entity locations
prod_outputs
=
torch
.
gather
(
bert_output
,
dim
=
1
,
index
=
prod_indices
)
feat_outputs
=
torch
.
gather
(
bert_output
,
dim
=
1
,
index
=
feat_indices
)
prod_pooled_output
,
_
=
torch
.
max
(
prod_outputs
,
dim
=
1
)
feat_pooled_output
,
_
=
torch
.
max
(
feat_outputs
,
dim
=
1
)
# concat pooled outputs from prod and feat entities
combined
=
torch
.
cat
((
prod_pooled_output
,
feat_pooled_output
),
dim
=
1
)
# fc layer (softmax activation done in loss function)
x
=
self
.
fc
(
pooler_output
)
x
=
self
.
fc
(
combined
)
return
x
ADA/server/agent/target_extraction/entity_annotation.py
View file @
1a7440c3
...
...
@@ -235,11 +235,12 @@ class EntityAnnotator:
return
{
'em1Text'
:
m
[
0
],
'em2Text'
:
m
[
1
],
'label'
:
'/no_relation'
}
def
pair_relations_for_text
(
self
,
text
,
nan_entities
):
tokens
=
self
.
phraser
[
word_tokenize
(
text
)]
single_tokens
=
word_tokenize
(
text
)
all_tokens
=
set
().
union
(
*
[
single_tokens
,
self
.
phraser
[
single_tokens
]])
entity_mentions
=
[]
for
n
in
PreOrderIter
(
self
.
root
):
cont
,
mention
=
self
.
mention_in_text
(
text
,
tokens
,
node
=
n
)
cont
,
mention
=
self
.
mention_in_text
(
all_
tokens
,
node
=
n
)
if
not
cont
:
# many mentions of same entity
return
None
...
...
@@ -255,7 +256,7 @@ class EntityAnnotator:
if
len
(
entity_mentions
)
==
1
:
nan_mention
=
None
for
term
in
nan_entities
:
cont
,
mention
=
self
.
mention_in_text
(
text
,
tokens
,
term
=
term
)
cont
,
mention
=
self
.
mention_in_text
(
all_
tokens
,
term
=
term
)
if
not
cont
:
# many mentions of term
return
None
...
...
@@ -272,7 +273,7 @@ class EntityAnnotator:
# returns True, (synonym of node / term / None) if there is exactly one or zero such occurrence,
# otherwise False, None
def
mention_in_text
(
self
,
text
,
tokens
,
node
=
None
,
term
=
None
):
def
mention_in_text
(
self
,
tokens
,
node
=
None
,
term
=
None
):
mention
=
None
for
syn
in
({
syn
.
lower
()
for
syn
in
self
.
synset
[
node
]}
if
node
is
not
None
else
{
term
}):
n_matches
=
sum
(
1
for
token
in
tokens
if
syn
.
lower
()
==
token
.
lower
().
replace
(
'_'
,
' '
))
...
...
@@ -320,4 +321,4 @@ class EntityAnnotator:
ann
:
EntityAnnotator
=
EntityAnnotator
.
load_saved
(
'acoustic_guitar_annotator.pickle'
)
ann
.
save_annotated_pairs
(
'BERT/data/annotated_acoustic_guitar_review_pairs.tsv'
)
\ No newline at end of file
ann
.
save_annotated_pairs
(
'BERT/data/annotated_acoustic_guitar_review_pairs.tsv'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment