Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
What's new
7
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Open sidebar
Joel Oksanen
individual_project
Commits
e183155c
Commit
e183155c
authored
May 18, 2020
by
Joel Oksanen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Improved cooperation between entity_annotation and bert_tag_extractor
parent
29ed5986
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
112 additions
and
46 deletions
+112
-46
ADA/server/agent/target_extraction/BERT/bert_tag_extractor.py
...server/agent/target_extraction/BERT/bert_tag_extractor.py
+21
-7
ADA/server/agent/target_extraction/BERT/tagged_rel_dataset.py
...server/agent/target_extraction/BERT/tagged_rel_dataset.py
+13
-5
ADA/server/agent/target_extraction/entity_annotation.py
ADA/server/agent/target_extraction/entity_annotation.py
+78
-34
No files found.
ADA/server/agent/target_extraction/BERT/bert_tag_extractor.py
View file @
e183155c
...
...
@@ -9,7 +9,7 @@ from tagged_rel_dataset import TRAINED_WEIGHTS, MAX_SEQ_LEN, RELATIONS, IGNORE_T
train_data_path
=
'data/train.json'
test_data_path
=
'data/test.json'
trained_model_path
=
'trained_bert_tag_extractor_
2
.pt'
trained_model_path
=
'trained_bert_tag_extractor_
camera
.pt'
device
=
torch
.
device
(
'cuda'
)
# optimizer
...
...
@@ -60,9 +60,15 @@ class BertTagExtractor:
extractor
.
train_with_file
(
file_path
,
size
=
size
)
return
extractor
def
train_with_file
(
self
,
file_path
,
size
=
None
):
@
staticmethod
def
train_and_validate
(
file_path
,
valid_frac
,
size
=
None
):
extractor
=
BertTagExtractor
()
extractor
.
train_with_file
(
file_path
,
size
=
size
,
valid_frac
=
valid_frac
)
return
extractor
def
train_with_file
(
self
,
file_path
,
size
=
None
,
valid_frac
=
None
):
# load training data
train_data
=
TaggedRelDataset
.
from_file
(
file_path
,
size
=
size
)
train_data
,
valid_data
=
TaggedRelDataset
.
from_file
(
file_path
,
size
=
size
,
valid_frac
=
valid_frac
)
train_loader
=
DataLoader
(
train_data
,
batch_size
=
BATCH_SIZE
,
shuffle
=
True
,
num_workers
=
4
,
collate_fn
=
generate_train_batch
)
...
...
@@ -131,9 +137,18 @@ class BertTagExtractor:
torch
.
save
(
self
.
net
.
state_dict
(),
trained_model_path
)
def
evaluate
(
self
,
file_path
):
if
valid_data
is
not
None
:
self
.
evaluate
(
data
=
valid_data
)
def
evaluate
(
self
,
file_path
=
None
,
data
=
None
):
# load training data
test_data
=
TaggedRelDataset
.
from_file
(
file_path
)
if
file_path
is
not
None
:
test_data
=
TaggedRelDataset
.
from_file
(
file_path
)
else
:
if
data
is
None
:
raise
AttributeError
(
'file_path and data cannot both be None'
)
test_data
=
data
test_loader
=
DataLoader
(
test_data
,
batch_size
=
BATCH_SIZE
,
shuffle
=
False
,
num_workers
=
4
,
collate_fn
=
generate_eval_batch
)
...
...
@@ -180,8 +195,7 @@ class BertTagExtractor:
# print('macro F1:', f1)
extr
=
BertTagExtractor
.
new_trained_with_file
(
train_data_path
)
extr
.
evaluate
(
test_data_path
)
BertTagExtractor
.
train_and_validate
(
'data/annotated_camera_reviews.tsv'
,
0.05
,
size
=
200000
)
...
...
ADA/server/agent/target_extraction/BERT/tagged_rel_dataset.py
View file @
e183155c
...
...
@@ -3,9 +3,11 @@ from torch.utils.data import Dataset
from
transformers
import
BertTokenizer
import
pandas
as
pd
from
collections
import
defaultdict
import
numpy
as
np
from
ast
import
literal_eval
TRAINED_WEIGHTS
=
'bert-base-cased'
# cased works better for NER
RELATIONS
=
[
'/
location/location/contains
'
]
RELATIONS
=
[
'/
has_feature
'
]
N_TAGS
=
4
*
len
(
RELATIONS
)
*
2
+
1
MAX_SEQ_LEN
=
128
MAX_TOKENS
=
MAX_SEQ_LEN
-
2
...
...
@@ -43,7 +45,12 @@ class TaggedRelDataset(Dataset):
@
staticmethod
def
from_file
(
path
,
valid_frac
=
None
,
size
=
None
):
dataset
=
TaggedRelDataset
()
dataset
.
df
=
pd
.
read_json
(
path
,
lines
=
True
)
if
path
.
endswith
(
'.json'
):
dataset
.
df
=
pd
.
read_json
(
path
,
lines
=
True
)
elif
path
.
endswith
(
'.tsv'
):
dataset
.
df
=
pd
.
read_csv
(
path
,
sep
=
'
\t
'
,
error_bad_lines
=
False
)
else
:
raise
AttributeError
(
'Could not recognize file type'
)
# sample data if a size is specified
if
size
is
not
None
and
size
<
len
(
dataset
):
...
...
@@ -51,17 +58,18 @@ class TaggedRelDataset(Dataset):
if
valid_frac
is
None
:
print
(
'Obtained dataset of size'
,
len
(
dataset
))
return
dataset
return
dataset
,
None
else
:
validset
=
TaggedRelDataset
()
split_idx
=
int
(
len
(
dataset
)
*
(
1
-
valid_frac
))
dataset
.
df
,
validset
.
df
=
dataset
.
df
[:
split_idx
,
:],
dataset
.
df
[
split_idx
:,
:]
dataset
.
df
,
validset
.
df
=
np
.
split
(
dataset
.
df
,
[
split_idx
],
axis
=
0
)
print
(
'Obtained train set of size'
,
len
(
dataset
),
'and validation set of size'
,
len
(
validset
))
return
dataset
,
validset
def
instance_from_row
(
self
,
row
):
text
=
row
[
'sentText'
]
tokens
=
tokenizer
.
tokenize
(
text
)[:
MAX_TOKENS
]
tag_map
=
self
.
map_for_relation_mentions
(
row
[
'relationMentions'
])
tag_map
=
self
.
map_for_relation_mentions
(
literal_eval
(
row
[
'relationMentions'
])
)
sorted_entities
=
sorted
(
tag_map
.
keys
(),
key
=
len
,
reverse
=
True
)
...
...
ADA/server/agent/target_extraction/entity_annotation.py
View file @
e183155c
import
pandas
as
pd
from
xml.etree.ElementTree
import
ElementTree
,
parse
,
tostring
,
Element
,
SubElement
from
gensim.models.phrases
import
Phrases
,
Phraser
from
nltk
import
pos_tag
from
nltk.tokenize
import
word_tokenize
,
sent_tokenize
...
...
@@ -26,16 +25,21 @@ class EntityAnnotator:
self
.
counter
=
counter
self
.
save_path
=
save_path
self
.
root
=
None
self
.
synset
=
{}
self
.
n_annotated
=
0
@
staticmethod
def
new_from_tsv
(
file_path
,
name
):
df
=
pd
.
read_csv
(
file_path
,
sep
=
'
\t
'
,
error_bad_lines
=
False
)
print
(
'tokenizing texts...'
)
texts
=
[
text
.
replace
(
'_'
,
' '
)
for
_
,
par
in
df
[
'reviewText'
].
items
()
if
not
pd
.
isnull
(
par
)
for
_
,
par
in
df
.
sample
(
frac
=
1
)
[
'reviewText'
].
items
()
if
not
pd
.
isnull
(
par
)
for
text
in
sent_tokenize
(
par
)]
print
(
'obtaining counter...'
)
counter
=
EntityAnnotator
.
count_nouns
(
texts
)
print
(
'finished initialising annotator'
)
ann
=
EntityAnnotator
(
file_path
,
counter
,
name
+
'.pickle'
)
ann
.
save
()
return
ann
@
staticmethod
...
...
@@ -51,15 +55,18 @@ class EntityAnnotator:
f
.
close
()
@
staticmethod
def
count_nouns
(
texts
):
def
count_nouns
(
raw_texts
):
texts
=
[
word_tokenize
(
text
)
for
text
in
raw_texts
]
print
(
' obtaining phraser...'
)
# obtain phraser
bigram
=
Phrases
(
texts
,
threshold
=
PHRASE_THRESHOLD
)
trigram
=
Phrases
(
bigram
[
texts
],
threshold
=
PHRASE_THRESHOLD
)
phraser
=
Phraser
(
trigram
)
print
(
' counting nouns...'
)
# count nouns
nouns
=
[]
for
text
in
texts
:
for
idx
,
text
in
enumerate
(
texts
)
:
pos_tags
=
pos_tag
(
text
)
ngrams
=
phraser
[
text
]
...
...
@@ -79,6 +86,8 @@ class EntityAnnotator:
if
len
(
token
)
>
1
and
is_noun
and
is_valid
:
nouns
.
append
(
token
)
word_idx
+=
1
if
idx
%
1000
==
0
:
print
(
' {:0.2f} done'
.
format
((
idx
+
1
)
/
len
(
texts
)))
return
Counter
(
nouns
)
...
...
@@ -100,43 +109,82 @@ class EntityAnnotator:
os
.
system
(
'clear'
)
print
(
fg
.
li_
blue
+
'{} entities annotated'
.
format
(
self
.
n_annotated
)
+
fg
.
rs
)
print
(
fg
.
li_
green
+
'{} entities annotated'
.
format
(
self
.
n_annotated
)
+
fg
.
rs
)
print
(
''
)
print
(
fg
.
li_black
+
'root:
\'
r
\'
'
+
fg
.
rs
)
print
(
fg
.
li_black
+
'subfeat: [number of parent node][ENTER]'
+
fg
.
rs
)
print
(
fg
.
li_black
+
'skip:
\'
s
\'
'
+
fg
.
rs
)
print
(
fg
.
li_black
+
'subfeat: [
\'
f
\'
][number of parent node][ENTER]'
+
fg
.
rs
)
print
(
fg
.
li_black
+
'synonym: [
\'
s
\'
][number of syn node][ENTER]'
+
fg
.
rs
)
print
(
fg
.
li_black
+
'nan:
\'
n
\'
'
+
fg
.
rs
)
print
(
fg
.
li_black
+
'remove:
\'
x
\'
'
+
fg
.
rs
)
print
(
fg
.
li_black
+
'quit:
\'
q
\'
'
+
fg
.
rs
)
print
(
fg
.
li_black
+
'abort:
\'
a
\'
'
+
fg
.
rs
)
print
(
''
)
if
self
.
root
is
not
None
:
print
(
RenderTree
(
self
.
root
))
print
(
fg
.
li_blue
+
str
(
RenderTree
(
self
.
root
))
+
fg
.
rs
)
print
(
''
)
print
(
entity
)
print
(
''
)
task
=
readchar
.
readkey
()
if
task
==
'r'
:
node
=
Node
(
entity
)
self
.
synset
[
node
]
=
[
node
.
name
]
old_root
=
self
.
root
self
.
root
=
Node
(
entity
)
old_root
.
parent
=
self
.
root
self
.
root
=
node
if
old_root
is
not
None
:
old_root
.
parent
=
self
.
root
self
.
update_tree_indices
()
self
.
n_annotated
+=
1
if
task
.
isdigit
()
:
n
=
int
(
task
)
if
task
==
'f'
:
n
=
None
while
True
:
subtask
=
readchar
.
readkey
()
if
subtask
.
isdigit
():
n
=
n
*
10
+
int
(
subtask
)
if
subtask
==
readchar
.
key
.
ENTER
:
Node
(
entity
,
parent
=
self
.
node_with_number
(
n
))
n
=
n
*
10
+
int
(
subtask
)
if
n
is
not
None
else
int
(
subtask
)
if
subtask
==
readchar
.
key
.
ENTER
and
n
is
not
None
:
node
=
Node
(
entity
,
parent
=
self
.
node_with_number
(
n
))
self
.
synset
[
node
]
=
[
node
.
name
]
self
.
update_tree_indices
()
self
.
n_annotated
+=
1
break
if
subtask
==
'a'
:
break
if
task
==
's'
:
n
=
None
while
True
:
subtask
=
readchar
.
readkey
()
if
subtask
.
isdigit
():
n
=
n
*
10
+
int
(
subtask
)
if
n
is
not
None
else
int
(
subtask
)
if
subtask
==
readchar
.
key
.
ENTER
and
n
is
not
None
:
self
.
synset
[
self
.
node_with_number
(
n
)].
append
(
entity
)
self
.
n_annotated
+=
1
break
if
subtask
==
'a'
:
break
if
task
==
'x'
:
n
=
None
while
True
:
subtask
=
readchar
.
readkey
()
if
subtask
.
isdigit
():
n
=
n
*
10
+
int
(
subtask
)
if
n
is
not
None
else
int
(
subtask
)
if
subtask
==
readchar
.
key
.
ENTER
and
n
is
not
None
:
node
=
self
.
node_with_number
(
n
)
del
self
.
synset
[
node
]
del
node
break
if
subtask
==
'a'
:
break
if
task
==
'n'
:
self
.
n_annotated
+=
1
if
task
==
'q'
:
...
...
@@ -145,7 +193,7 @@ class EntityAnnotator:
self
.
save
()
def
select_entity
(
self
):
entity
=
self
.
counter
.
most_common
(
)[
self
.
n_annotated
]
entity
,
_
=
self
.
counter
.
most_common
(
self
.
n_annotated
+
1
)[
-
1
]
return
entity
.
replace
(
'_'
,
' '
)
def
node_with_number
(
self
,
n
):
...
...
@@ -157,36 +205,32 @@ class EntityAnnotator:
node
.
n
=
i
i
+=
1
# def get_relation_tuples(self):
# rels = []
# for e1 in LevelOrderIter(self.root):
# if e1.isleaf():
# continue
# for e2 in e1.children:
# rels.append((e1.name, e2.name)) # e1 hasFeature e2
# return rels
def
get_annotated_texts
(
self
,
save_path
):
df
=
pd
.
read_csv
(
self
.
text_file_path
,
sep
=
'
\t
'
,
error_bad_lines
=
False
)
df
[
'relations'
]
=
df
[
'reviewText'
].
apply
(
lambda
t
:
self
.
relations_for_text
(
t
))
df
=
df
[
~
df
[
'relations'
].
isnull
()]
def
save_annotated_texts
(
self
,
save_path
):
reviews
=
pd
.
read_csv
(
self
.
text_file_path
,
sep
=
'
\t
'
,
error_bad_lines
=
False
)
texts
=
[
text
for
_
,
par
in
reviews
.
sample
(
frac
=
1
)[
'reviewText'
].
items
()
if
not
pd
.
isnull
(
par
)
for
text
in
sent_tokenize
(
par
)]
labelled_texts
=
[
t
for
t
in
map
(
self
.
relations_for_text
,
texts
)
if
t
is
not
None
]
df
=
pd
.
DataFrame
(
labelled_texts
,
columns
=
[
'sentText'
,
'relationMentions'
])
df
.
to_csv
(
save_path
,
sep
=
'
\t
'
,
index
=
False
)
def
relations_for_text
(
self
,
text
):
rels
=
[]
child_entities
=
[]
for
e1
in
PreOrderIter
(
self
.
root
):
if
not
e1
.
isleaf
()
and
e1
.
name
in
text
:
if
not
e1
.
is
_
leaf
and
e1
.
name
in
text
:
for
e2
in
e1
.
children
:
if
e2
.
name
in
text
:
# e1 is a parent of an entity in the text
if
e1
in
child_entities
:
# e1 cannot be a parent and a child
return
None
rels
.
append
({
'em1Text'
:
e1
,
'em2Text'
:
e2
,
'label'
:
'/has_feature'
})
rels
.
append
({
'em1Text'
:
e1
.
name
,
'em2Text'
:
e2
.
name
,
'label'
:
'/has_feature'
})
child_entities
.
append
(
e2
)
return
rels
return
text
,
rels
ann
=
EntityAnnotator
.
new_from_tsv
(
'data/verified_camera_reviews.tsv'
,
'camera_entity_annotator'
)
ann
.
annotate
()
ann
=
EntityAnnotator
.
load_saved
(
'camera_entity_annotator.pickle'
)
# ann.annotate()
ann
.
save_annotated_texts
(
'BERT/data/annotated_camera_reviews.tsv'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment