Commit e183155c authored by Joel Oksanen's avatar Joel Oksanen

Improved cooperation between entity_annotation and bert_tag_extractor

parent 29ed5986
......@@ -9,7 +9,7 @@ from tagged_rel_dataset import TRAINED_WEIGHTS, MAX_SEQ_LEN, RELATIONS, IGNORE_T
train_data_path = 'data/train.json'
test_data_path = 'data/test.json'
trained_model_path = 'trained_bert_tag_extractor_2.pt'
trained_model_path = 'trained_bert_tag_extractor_camera.pt'
device = torch.device('cuda')
# optimizer
......@@ -60,9 +60,15 @@ class BertTagExtractor:
extractor.train_with_file(file_path, size=size)
return extractor
def train_with_file(self, file_path, size=None):
@staticmethod
def train_and_validate(file_path, valid_frac, size=None):
extractor = BertTagExtractor()
extractor.train_with_file(file_path, size=size, valid_frac=valid_frac)
return extractor
def train_with_file(self, file_path, size=None, valid_frac=None):
# load training data
train_data = TaggedRelDataset.from_file(file_path, size=size)
train_data, valid_data = TaggedRelDataset.from_file(file_path, size=size, valid_frac=valid_frac)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
collate_fn=generate_train_batch)
......@@ -131,9 +137,18 @@ class BertTagExtractor:
torch.save(self.net.state_dict(), trained_model_path)
def evaluate(self, file_path):
if valid_data is not None:
self.evaluate(data=valid_data)
def evaluate(self, file_path=None, data=None):
# load training data
test_data = TaggedRelDataset.from_file(file_path)
if file_path is not None:
test_data = TaggedRelDataset.from_file(file_path)
else:
if data is None:
raise AttributeError('file_path and data cannot both be None')
test_data = data
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
collate_fn=generate_eval_batch)
......@@ -180,8 +195,7 @@ class BertTagExtractor:
# print('macro F1:', f1)
extr = BertTagExtractor.new_trained_with_file(train_data_path)
extr.evaluate(test_data_path)
BertTagExtractor.train_and_validate('data/annotated_camera_reviews.tsv', 0.05, size=200000)
......
......@@ -3,9 +3,11 @@ from torch.utils.data import Dataset
from transformers import BertTokenizer
import pandas as pd
from collections import defaultdict
import numpy as np
from ast import literal_eval
TRAINED_WEIGHTS = 'bert-base-cased' # cased works better for NER
RELATIONS = ['/location/location/contains']
RELATIONS = ['/has_feature']
N_TAGS = 4 * len(RELATIONS) * 2 + 1
MAX_SEQ_LEN = 128
MAX_TOKENS = MAX_SEQ_LEN - 2
......@@ -43,7 +45,12 @@ class TaggedRelDataset(Dataset):
@staticmethod
def from_file(path, valid_frac=None, size=None):
dataset = TaggedRelDataset()
dataset.df = pd.read_json(path, lines=True)
if path.endswith('.json'):
dataset.df = pd.read_json(path, lines=True)
elif path.endswith('.tsv'):
dataset.df = pd.read_csv(path, sep='\t', error_bad_lines=False)
else:
raise AttributeError('Could not recognize file type')
# sample data if a size is specified
if size is not None and size < len(dataset):
......@@ -51,17 +58,18 @@ class TaggedRelDataset(Dataset):
if valid_frac is None:
print('Obtained dataset of size', len(dataset))
return dataset
return dataset, None
else:
validset = TaggedRelDataset()
split_idx = int(len(dataset) * (1 - valid_frac))
dataset.df, validset.df = dataset.df[:split_idx, :], dataset.df[split_idx:, :]
dataset.df, validset.df = np.split(dataset.df, [split_idx], axis=0)
print('Obtained train set of size', len(dataset), 'and validation set of size', len(validset))
return dataset, validset
def instance_from_row(self, row):
text = row['sentText']
tokens = tokenizer.tokenize(text)[:MAX_TOKENS]
tag_map = self.map_for_relation_mentions(row['relationMentions'])
tag_map = self.map_for_relation_mentions(literal_eval(row['relationMentions']))
sorted_entities = sorted(tag_map.keys(), key=len, reverse=True)
......
import pandas as pd
from xml.etree.ElementTree import ElementTree, parse, tostring, Element, SubElement
from gensim.models.phrases import Phrases, Phraser
from nltk import pos_tag
from nltk.tokenize import word_tokenize, sent_tokenize
......@@ -26,16 +25,21 @@ class EntityAnnotator:
self.counter = counter
self.save_path = save_path
self.root = None
self.synset = {}
self.n_annotated = 0
@staticmethod
def new_from_tsv(file_path, name):
df = pd.read_csv(file_path, sep='\t', error_bad_lines=False)
print('tokenizing texts...')
texts = [text.replace('_', ' ')
for _, par in df['reviewText'].items() if not pd.isnull(par)
for _, par in df.sample(frac=1)['reviewText'].items() if not pd.isnull(par)
for text in sent_tokenize(par)]
print('obtaining counter...')
counter = EntityAnnotator.count_nouns(texts)
print('finished initialising annotator')
ann = EntityAnnotator(file_path, counter, name + '.pickle')
ann.save()
return ann
@staticmethod
......@@ -51,15 +55,18 @@ class EntityAnnotator:
f.close()
@staticmethod
def count_nouns(texts):
def count_nouns(raw_texts):
texts = [word_tokenize(text) for text in raw_texts]
print(' obtaining phraser...')
# obtain phraser
bigram = Phrases(texts, threshold=PHRASE_THRESHOLD)
trigram = Phrases(bigram[texts], threshold=PHRASE_THRESHOLD)
phraser = Phraser(trigram)
print(' counting nouns...')
# count nouns
nouns = []
for text in texts:
for idx, text in enumerate(texts):
pos_tags = pos_tag(text)
ngrams = phraser[text]
......@@ -79,6 +86,8 @@ class EntityAnnotator:
if len(token) > 1 and is_noun and is_valid:
nouns.append(token)
word_idx += 1
if idx % 1000 == 0:
print(' {:0.2f} done'.format((idx + 1) / len(texts)))
return Counter(nouns)
......@@ -100,43 +109,82 @@ class EntityAnnotator:
os.system('clear')
print(fg.li_blue + '{} entities annotated'.format(self.n_annotated) + fg.rs)
print(fg.li_green + '{} entities annotated'.format(self.n_annotated) + fg.rs)
print('')
print(fg.li_black + 'root: \'r\'' + fg.rs)
print(fg.li_black + 'subfeat: [number of parent node][ENTER]' + fg.rs)
print(fg.li_black + 'skip: \'s\'' + fg.rs)
print(fg.li_black + 'subfeat: [\'f\'][number of parent node][ENTER]' + fg.rs)
print(fg.li_black + 'synonym: [\'s\'][number of syn node][ENTER]' + fg.rs)
print(fg.li_black + 'nan: \'n\'' + fg.rs)
print(fg.li_black + 'remove: \'x\'' + fg.rs)
print(fg.li_black + 'quit: \'q\'' + fg.rs)
print(fg.li_black + 'abort: \'a\'' + fg.rs)
print('')
if self.root is not None:
print(RenderTree(self.root))
print(fg.li_blue + str(RenderTree(self.root)) + fg.rs)
print('')
print(entity)
print('')
task = readchar.readkey()
if task == 'r':
node = Node(entity)
self.synset[node] = [node.name]
old_root = self.root
self.root = Node(entity)
old_root.parent = self.root
self.root = node
if old_root is not None:
old_root.parent = self.root
self.update_tree_indices()
self.n_annotated += 1
if task.isdigit():
n = int(task)
if task == 'f':
n = None
while True:
subtask = readchar.readkey()
if subtask.isdigit():
n = n * 10 + int(subtask)
if subtask == readchar.key.ENTER:
Node(entity, parent=self.node_with_number(n))
n = n * 10 + int(subtask) if n is not None else int(subtask)
if subtask == readchar.key.ENTER and n is not None:
node = Node(entity, parent=self.node_with_number(n))
self.synset[node] = [node.name]
self.update_tree_indices()
self.n_annotated += 1
break
if subtask == 'a':
break
if task == 's':
n = None
while True:
subtask = readchar.readkey()
if subtask.isdigit():
n = n * 10 + int(subtask) if n is not None else int(subtask)
if subtask == readchar.key.ENTER and n is not None:
self.synset[self.node_with_number(n)].append(entity)
self.n_annotated += 1
break
if subtask == 'a':
break
if task == 'x':
n = None
while True:
subtask = readchar.readkey()
if subtask.isdigit():
n = n * 10 + int(subtask) if n is not None else int(subtask)
if subtask == readchar.key.ENTER and n is not None:
node = self.node_with_number(n)
del self.synset[node]
del node
break
if subtask == 'a':
break
if task == 'n':
self.n_annotated += 1
if task == 'q':
......@@ -145,7 +193,7 @@ class EntityAnnotator:
self.save()
def select_entity(self):
entity = self.counter.most_common()[self.n_annotated]
entity, _ = self.counter.most_common(self.n_annotated+1)[-1]
return entity.replace('_', ' ')
def node_with_number(self, n):
......@@ -157,36 +205,32 @@ class EntityAnnotator:
node.n = i
i += 1
# def get_relation_tuples(self):
# rels = []
# for e1 in LevelOrderIter(self.root):
# if e1.isleaf():
# continue
# for e2 in e1.children:
# rels.append((e1.name, e2.name)) # e1 hasFeature e2
# return rels
def get_annotated_texts(self, save_path):
df = pd.read_csv(self.text_file_path, sep='\t', error_bad_lines=False)
df['relations'] = df['reviewText'].apply(lambda t: self.relations_for_text(t))
df = df[~df['relations'].isnull()]
def save_annotated_texts(self, save_path):
reviews = pd.read_csv(self.text_file_path, sep='\t', error_bad_lines=False)
texts = [text for _, par in reviews.sample(frac=1)['reviewText'].items() if not pd.isnull(par)
for text in sent_tokenize(par)]
labelled_texts = [t for t in map(self.relations_for_text, texts) if t is not None]
df = pd.DataFrame(labelled_texts, columns=['sentText', 'relationMentions'])
df.to_csv(save_path, sep='\t', index=False)
def relations_for_text(self, text):
rels = []
child_entities = []
for e1 in PreOrderIter(self.root):
if not e1.isleaf() and e1.name in text:
if not e1.is_leaf and e1.name in text:
for e2 in e1.children:
if e2.name in text:
# e1 is a parent of an entity in the text
if e1 in child_entities:
# e1 cannot be a parent and a child
return None
rels.append({'em1Text': e1, 'em2Text': e2, 'label': '/has_feature'})
rels.append({'em1Text': e1.name, 'em2Text': e2.name, 'label': '/has_feature'})
child_entities.append(e2)
return rels
return text, rels
ann = EntityAnnotator.new_from_tsv('data/verified_camera_reviews.tsv', 'camera_entity_annotator')
ann.annotate()
ann = EntityAnnotator.load_saved('camera_entity_annotator.pickle')
# ann.annotate()
ann.save_annotated_texts('BERT/data/annotated_camera_reviews.tsv')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment