diff --git a/ADA-X/.gitignore b/ADA-X/.gitignore index 0ee535a72acae92bfa868d40a1c2de903a4f7701..926d5d678d3249bc63fabfe78b77393d4526bbf3 100644 --- a/ADA-X/.gitignore +++ b/ADA-X/.gitignore @@ -5,6 +5,7 @@ server/agent/amazon_data/ server/agent/SA/data/ server/agent/target_extraction/data/ server/agent/target_extraction/BERT/data/ +server/agent/target_extraction/eval/qa/ .DS_Store *.pickle *.wv \ No newline at end of file diff --git a/ADA-X/server/agent/target_extraction/BERT/entity_extractor/bert_entity_extractor.py b/ADA-X/server/agent/target_extraction/BERT/entity_extractor/bert_entity_extractor.py index 376bfd7154ce90d0aa2993db6b6facd3004432b6..bf2bfaeb563ca468202708503a0c67284c81afc6 100644 --- a/ADA-X/server/agent/target_extraction/BERT/entity_extractor/bert_entity_extractor.py +++ b/ADA-X/server/agent/target_extraction/BERT/entity_extractor/bert_entity_extractor.py @@ -10,7 +10,7 @@ from sklearn import metrics import statistics from transformers import get_linear_schedule_with_warmup from agent.target_extraction.BERT.entity_extractor.entity_dataset import EntityDataset, generate_batch, generate_production_batch -from agent.target_extraction.BERT.entity_extractor.entitybertnet import NUM_CLASSES, EntityBertNet +from agent.target_extraction.BERT.entity_extractor.entitybertnet import NUM_CLASSES, EntityBertNet, BATCH_SIZE device = torch.device('cuda') @@ -21,7 +21,6 @@ MAX_GRAD_NORM = 1.0 # training N_EPOCHS = 3 -BATCH_SIZE = 32 WARM_UP_FRAC = 0.05 # loss @@ -61,8 +60,7 @@ class BertEntityExtractor: else: train_size = int(size * (1 - valid_frac)) if size is not None else None train_data, _ = EntityDataset.from_file(file_path, size=train_size) - valid_size = int(size * valid_frac) if size is not None else int(len(train_data) * valid_frac) - valid_data, _ = EntityDataset.from_file(valid_file_path, size=valid_size) + valid_data, _ = EntityDataset.from_file(valid_file_path) train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, collate_fn=generate_batch) @@ -119,11 +117,11 @@ class BertEntityExtractor: print('epoch done') + torch.save(self.net.state_dict(), '{}_epoch_{}.pt'.format(save_file, epoch_idx + 1)) + if valid_data is not None: self.evaluate(data=valid_data) - torch.save(self.net.state_dict(), '{}.pt'.format(save_file)) - end = time.time() print('Training took', end - start, 'seconds') @@ -207,3 +205,7 @@ class BertEntityExtractor: probs[ins.entity].append(score) return {t: statistics.mean(t_probs) if len(t_probs) > 0 else None for t, t_probs in probs.items()} + + +BertEntityExtractor.train_and_validate('all_reviews_features.tsv', 'feature_extractor', + valid_file_path='annotated_watch_review_features.tsv') diff --git a/ADA-X/server/agent/target_extraction/BERT/entity_extractor/entity_dataset.py b/ADA-X/server/agent/target_extraction/BERT/entity_extractor/entity_dataset.py index 318dc209c8ef15cec22ef3b2eab9d841250d50a5..6e1152ceb40927230379f0aaed9b4f9c7f644f58 100644 --- a/ADA-X/server/agent/target_extraction/BERT/entity_extractor/entity_dataset.py +++ b/ADA-X/server/agent/target_extraction/BERT/entity_extractor/entity_dataset.py @@ -8,58 +8,22 @@ import os.path from agent.target_extraction.BERT.relation_extractor.pairbertnet import TRAINED_WEIGHTS, HIDDEN_OUTPUT_FEATURES MAX_SEQ_LEN = 128 -LABELS = ['ASPECT', 'NAN'] -LABEL_MAP = {'ASPECT': 1, 'NAN': 0, None: None} MASK_TOKEN = '[MASK]' tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS) -def generate_batch(batch): - encoded = tokenizer.batch_encode_plus([instance.tokens for instance in batch], add_special_tokens=True, - max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True, - return_tensors='pt') - input_ids = encoded['input_ids'] - attn_mask = encoded['attention_mask'] - labels = torch.tensor([instance.label for instance in batch]) - - entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch]) - - return input_ids, attn_mask, entity_indices, labels - - -def generate_production_batch(batch): - encoded = tokenizer.batch_encode_plus([instance.tokens for instance in batch], add_special_tokens=True, - max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True, - return_tensors='pt') - input_ids = encoded['input_ids'] - attn_mask = encoded['attention_mask'] - - entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch]) - - return input_ids, attn_mask, entity_indices, batch - - -def indices_for_entity_ranges(ranges): - max_e_len = max(end - start for start, end in ranges) - indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES - for t in range(start, start + max_e_len + 1)] - for start, end in ranges]) - return indices - - class EntityDataset(Dataset): - def __init__(self, df, size=None): - # filter inapplicable rows - self.df = df[df.apply(lambda x: EntityDataset.instance_from_row(x) is not None, axis=1)] - + def __init__(self, df, training=True, size=None): + self.df = df + self.training = training # sample data if a size is specified if size is not None and size < len(self): self.df = self.df.sample(size, replace=False) @staticmethod - def from_df(df, size=None): - dataset = EntityDataset(df, size=size) + def for_extraction(df): + dataset = EntityDataset(df, training=False) print('Obtained dataset of size', len(dataset)) return dataset @@ -83,80 +47,60 @@ class EntityDataset(Dataset): print('Obtained train set of size', len(dataset), 'and validation set of size', len(validset)) return dataset, validset - @staticmethod - def instance_from_row(row): - unpacked_arr = literal_eval(row['entityMentions']) if type(row['entityMentions']) is str else row['entityMentions'] - rms = [rm for rm in unpacked_arr if 'label' not in rm or rm['label'] in LABELS] - if len(rms) == 1: - entity, label = rms[0]['text'], (rms[0]['label'] if 'label' in rms[0] else None) - else: - return None # raise AttributeError('Instances must have exactly one relation') - - text = row['sentText'] - return EntityDataset.get_instance(text, entity, label=label) - - @staticmethod - def get_instance(text, entity, label=None): - tokens = tokenizer.tokenize(text) - - i = 0 - found_entity = False - entity_range = None - while i < len(tokens): - match_length = EntityDataset.token_entity_match(i, entity.lower(), tokens) - if match_length is not None: - if found_entity: - return None # raise AttributeError('Entity {} appears twice in text {}'.format(entity, text)) - found_entity = True - tokens[i:i + match_length] = [MASK_TOKEN] * match_length - entity_range = (i + 1, i + match_length) # + 1 taking into account the [CLS] token - i += match_length - else: - i += 1 - - if found_entity: - return PairRelInstance(tokens, entity, entity_range, LABEL_MAP[label], text) + def instance_from_row(self, row): + if self.training: + return EntityInstance(literal_eval(row['tokens']), + row['entity_idx'], + label=row['label']) else: - return None - - @staticmethod - def token_entity_match(first_token_idx, entity, tokens): - token_idx = first_token_idx - remaining_entity = entity - while remaining_entity: - if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity: - # start of new word - remaining_entity = remaining_entity.lstrip() - if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]: - remaining_entity = remaining_entity[len(tokens[token_idx]):] - token_idx += 1 - else: - break - else: - # continuing same word - if (token_idx < len(tokens) and tokens[token_idx].startswith('##') - and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]): - remaining_entity = remaining_entity[len(tokens[token_idx][2:]):] - token_idx += 1 - else: - break - if remaining_entity: - return None - else: - return token_idx - first_token_idx + return EntityInstance(row['tokens'], + row['entity_idx'], + entity=row['entity']) def __len__(self): return len(self.df.index) def __getitem__(self, idx): - return EntityDataset.instance_from_row(self.df.iloc[idx]) + return self.instance_from_row(self.df.iloc[idx]) -class PairRelInstance: +class EntityInstance: - def __init__(self, tokens, entity, entity_range, label, text): + def __init__(self, tokens, entity_idx, label=None, entity=None): self.tokens = tokens - self.entity = entity - self.entity_range = entity_range + self.entity_idx = entity_idx self.label = label - self.text = text + self.entity = entity + + +def generate_batch(instances: [EntityInstance]): + encoded = tokenizer.batch_encode_plus([instance.tokens for instance in instances], add_special_tokens=True, + max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True, + return_tensors='pt') + input_ids = encoded['input_ids'] + attn_mask = encoded['attention_mask'] + + entity_indices = torch.tensor([instance.entity_idx for instance in instances]) + labels = torch.tensor([instance.label for instance in instances]) + + return input_ids, attn_mask, entity_indices, labels + + +def generate_production_batch(instances: [EntityInstance]): + encoded = tokenizer.batch_encode_plus([instance.tokens for instance in instances], add_special_tokens=True, + max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True, + return_tensors='pt') + input_ids = encoded['input_ids'] + attn_mask = encoded['attention_mask'] + + entity_indices = torch.tensor([instance.entity_idx for instance in instances]) + + return input_ids, attn_mask, entity_indices, instances + + +# def indices_for_entity_ranges(ranges): +# max_e_len = max(end - start for start, end in ranges) +# indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES +# for t in range(start, start + max_e_len + 1)] +# for start, end in ranges]) +# return indices diff --git a/ADA-X/server/agent/target_extraction/BERT/entity_extractor/entitybertnet.py b/ADA-X/server/agent/target_extraction/BERT/entity_extractor/entitybertnet.py index 1bd83acc2688859892f70e42aac38beb65ae2f30..7d1aa162f42738c26cf1f7d5e380a858a1ef2184 100644 --- a/ADA-X/server/agent/target_extraction/BERT/entity_extractor/entitybertnet.py +++ b/ADA-X/server/agent/target_extraction/BERT/entity_extractor/entitybertnet.py @@ -5,6 +5,7 @@ from transformers import * HIDDEN_OUTPUT_FEATURES = 768 TRAINED_WEIGHTS = 'bert-base-uncased' NUM_CLASSES = 2 # entity, not entity +BATCH_SIZE = 32 class EntityBertNet(nn.Module): @@ -20,14 +21,9 @@ class EntityBertNet(nn.Module): bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask) # max pooling at entity locations - entity_pooled_output = EntityBertNet.pooled_output(bert_output, entity_indices) + entity_pooled_output = bert_output[torch.arange(0, bert_output.shape[0]), entity_indices] # fc layer (softmax activation done in loss function) x = self.fc(entity_pooled_output) return x - @staticmethod - def pooled_output(bert_output, indices): - outputs = torch.gather(bert_output, dim=1, index=indices) - pooled_output, _ = torch.max(outputs, dim=1) - return pooled_output diff --git a/ADA-X/server/agent/target_extraction/BERT/relation_extractor/bert_rel_extractor.py b/ADA-X/server/agent/target_extraction/BERT/relation_extractor/bert_rel_extractor.py index 1dadd5a4e03e4860f6821ec6fb9acd9d882dc0ac..5dd68f4afa24e8fa35db3da9bf97b4cbdc6953b6 100644 --- a/ADA-X/server/agent/target_extraction/BERT/relation_extractor/bert_rel_extractor.py +++ b/ADA-X/server/agent/target_extraction/BERT/relation_extractor/bert_rel_extractor.py @@ -8,8 +8,10 @@ import time import numpy as np from sklearn import metrics from transformers import get_linear_schedule_with_warmup -from agent.target_extraction.BERT.relation_extractor.pair_rel_dataset import PairRelDataset, generate_batch, generate_production_batch -from agent.target_extraction.BERT.relation_extractor.pairbertnet import NUM_CLASSES, PairBertNet +# from agent.target_extraction.BERT.relation_extractor.pair_rel_dataset import PairRelDataset, generate_batch, generate_production_batch +from agent.target_extraction.BERT.relation_extractor.rel_dataset import PairRelDataset, generate_batch, generate_production_batch, RelInstance +# from agent.target_extraction.BERT.relation_extractor.pairbertnet import NUM_CLASSES, PairBertNet +from agent.target_extraction.BERT.relation_extractor.relbertnet import NUM_CLASSES, RelBertNet device = torch.device('cuda') @@ -30,12 +32,12 @@ loss_criterion = CrossEntropyLoss() class BertRelExtractor: def __init__(self): - self.net = PairBertNet() + self.net = RelBertNet() @staticmethod def load_saved(path): extr = BertRelExtractor() - extr.net = PairBertNet() + extr.net = RelBertNet() extr.net.load_state_dict(torch.load(path)) extr.net.eval() return extr @@ -60,8 +62,7 @@ class BertRelExtractor: else: train_size = int(size * (1 - valid_frac)) if size is not None else None train_data, _ = PairRelDataset.from_file(file_path, size=train_size) - valid_size = int(size * valid_frac) if size is not None else int(len(train_data) * valid_frac) - valid_data, _ = PairRelDataset.from_file(valid_file_path, size=valid_size) + valid_data, _ = PairRelDataset.from_file(valid_file_path) train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, collate_fn=generate_batch) @@ -87,16 +88,16 @@ class BertRelExtractor: for batch_idx, batch in enumerate(train_loader): # send batch to gpu - input_ids, attn_mask, masked_indices, fst_indices, snd_indices, target_labels = tuple(i.to(device) for i in batch) + input_ids, attn_mask, entity_indices, entity_mask, labels = tuple(i.to(device) for i in batch) # zero param gradients optimiser.zero_grad() # forward pass - output_scores = self.net(input_ids, attn_mask, masked_indices, fst_indices, snd_indices) + output_scores = self.net(input_ids, attn_mask, entity_indices, entity_mask) # backward pass - loss = loss_criterion(output_scores, target_labels) + loss = loss_criterion(output_scores, labels) loss.backward() # clip gradient norm @@ -117,12 +118,11 @@ class BertRelExtractor: batch_loss = 0.0 print('epoch done') + torch.save(self.net.state_dict(), '{}_epoch_{}.pt'.format(save_file, epoch_idx + 1)) if valid_data is not None: self.evaluate(data=valid_data) - torch.save(self.net.state_dict(), '{}.pt'.format(save_file)) - end = time.time() print('Training took', end - start, 'seconds') @@ -147,15 +147,14 @@ class BertRelExtractor: with torch.no_grad(): for batch in test_loader: # send batch to gpu - input_ids, attn_mask, masked_indices, fst_indices, snd_indices, target_labels = tuple(i.to(device) - for i in batch) + input_ids, attn_mask, entity_indices, entity_mask, labels = tuple(i.to(device) for i in batch) # forward pass - output_scores = self.net(input_ids, attn_mask, masked_indices, fst_indices, snd_indices) + output_scores = self.net(input_ids, attn_mask, entity_indices, entity_mask) _, output_labels = torch.max(output_scores.data, 1) outputs += output_labels.tolist() - targets += target_labels.tolist() + targets += labels.tolist() assert len(outputs) == len(targets) @@ -176,25 +175,24 @@ class BertRelExtractor: recall = metrics.recall_score(targets, outputs, average=None) print('recall:', recall) - def extract_single_relation(self, text, e1, e2): - ins = PairRelDataset.get_instance(text, e1, e2) - input_ids, attn_mask, masked_indices, prod_indices, feat_indices, instances = generate_production_batch([ins]) + def extract_single_relation(self, text, entities): + ins = RelInstance.from_sentence(text, entities) + input_ids, attn_mask, entity_indices, entity_mask, _ = generate_production_batch([ins]) self.net.cuda() self.net.eval() with torch.no_grad(): # send batch to gpu - input_ids, attn_mask, masked_indices, prod_indices, feat_indices = tuple(i.to(device) for i in - [input_ids, attn_mask, - masked_indices, prod_indices, - feat_indices]) + input_ids, attn_mask, entity_indices, entity_mask = tuple(i.to(device) for i in [input_ids, attn_mask, + entity_indices, + entity_mask]) # forward pass - output_scores = softmax(self.net(input_ids, attn_mask, masked_indices, prod_indices, feat_indices), dim=1) + output_scores = softmax(self.net(input_ids, attn_mask, entity_indices, entity_mask), dim=1) _, output_labels = torch.max(output_scores.data, 1) - print(instances[0].get_relation_for_label(output_labels[0])) + ins.print_results_for_labels(output_labels) def extract_relations(self, n_aspects, aspect_index_map, aspect_counts, file_path=None, dataset=None, size=None): # load data @@ -215,15 +213,14 @@ class BertRelExtractor: count_matrix = np.zeros((n_aspects, n_aspects)) with torch.no_grad(): - for input_ids, attn_mask, masked_indices, prod_indices, feat_indices, instances in loader: + for input_ids, attn_mask, prod_indices, feat_indices, instances in loader: # send batch to gpu - input_ids, attn_mask, masked_indices, prod_indices, feat_indices = tuple(i.to(device) for i in - [input_ids, attn_mask, - masked_indices, prod_indices, - feat_indices]) + input_ids, attn_mask, prod_indices, feat_indices = tuple(i.to(device) for i in [input_ids, attn_mask, + prod_indices, + feat_indices]) # forward pass - output_scores = softmax(self.net(input_ids, attn_mask, masked_indices, prod_indices, feat_indices), dim=1) + output_scores = softmax(self.net(input_ids, attn_mask, prod_indices, feat_indices), dim=1) rel_scores = output_scores.narrow(1, 1, 2) for ins, scores in zip(instances, rel_scores.tolist()): @@ -236,4 +233,38 @@ class BertRelExtractor: return prob_matrix, count_matrix + def extract_relations2(self, n_aspects, dataset): + loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, + collate_fn=generate_production_batch) + + self.net.cuda() + self.net.eval() + + prob_matrix = np.zeros((n_aspects, n_aspects)) + count_matrix = np.zeros((n_aspects, n_aspects)) + + with torch.no_grad(): + for input_ids, attn_mask, entity_indices, combination_indices, instances in loader: + # send batch to gpu + input_ids, attn_mask, entity_indices, combination_indices = tuple(i.to(device) for i in + [input_ids, attn_mask, + entity_indices, combination_indices]) + + # forward pass + output_scores = softmax(self.net(input_ids, attn_mask, entity_indices, combination_indices), dim=1) + rel_scores = output_scores.narrow(1, 1, 2).tolist() + + entity_pairs = [ep for instance in instances for ep in instance.entity_pairs] + for ep, scores in zip(entity_pairs, rel_scores): + forward_score, backward_score = scores + prob_matrix[ep.snd.idx][ep.fst.idx] += forward_score + prob_matrix[ep.fst.idx][ep.snd.idx] += backward_score + count_matrix[ep.snd.idx][ep.fst.idx] += 1 + count_matrix[ep.fst.idx][ep.snd.idx] += 1 + + return prob_matrix, count_matrix + +# extr: BertRelExtractor = BertRelExtractor.load_saved('multi_extractor_5_products_epoch_1.pt') +# extr.extract_single_relation('The mixer comes with a stainless steel bowl.', +# ['mixer', 'stainless steel', 'bowl']) diff --git a/ADA-X/server/agent/target_extraction/BERT/relation_extractor/pairbertnet.py b/ADA-X/server/agent/target_extraction/BERT/relation_extractor/pairbertnet.py index c0a3ca909e0ddc5a81bf97685a902928a058b0b9..3f8c0f8f815b2cba813a4def64902e0a29206270 100644 --- a/ADA-X/server/agent/target_extraction/BERT/relation_extractor/pairbertnet.py +++ b/ADA-X/server/agent/target_extraction/BERT/relation_extractor/pairbertnet.py @@ -4,7 +4,7 @@ from transformers import * HIDDEN_OUTPUT_FEATURES = 768 TRAINED_WEIGHTS = 'bert-base-uncased' -NUM_CLASSES = 3 # no relation, fst hasFeature snd, snd hasFeature fst +NUM_CLASSES = 4 # no relation, fst hasFeature snd, snd hasFeature fst, siblings HIDDEN_ENTITY_FEATURES = 6 # lower -> more general but less informative entity representations @@ -18,18 +18,7 @@ class PairBertNet(nn.Module): self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config) self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES * 2, NUM_CLASSES) - def forward(self, input_ids, attn_mask, masked_indices, fst_indices, snd_indices): - # embeddings = self.bert_base.get_input_embeddings() - # input_embeddings = embeddings(input_ids) - # - # # get partially masked input_embeddings for entity terms - # unmasked_entity_embeddings = input_embeddings[masked_indices[:, 0], masked_indices[:, 1]] - # hidden_entity_repr = torch.tanh(self.entity_fc1(unmasked_entity_embeddings)) - # masked_entity_embeddings = torch.repeat_interleave(hidden_entity_repr, 128, dim=1) # 768 / 12 = 64 - # - # # replace input_embeddings with partially masked ones for entities - # input_embeddings[masked_indices[:, 0], masked_indices[:, 1]] = masked_entity_embeddings - + def forward(self, input_ids, attn_mask, fst_indices, snd_indices): # BERT bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask) diff --git a/ADA-X/server/agent/target_extraction/BERT/relation_extractor/rel_dataset.py b/ADA-X/server/agent/target_extraction/BERT/relation_extractor/rel_dataset.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..5f411a59d4778f8e24af071c4c07dd0a8c49b7f7 100644 --- a/ADA-X/server/agent/target_extraction/BERT/relation_extractor/rel_dataset.py +++ b/ADA-X/server/agent/target_extraction/BERT/relation_extractor/rel_dataset.py @@ -0,0 +1,178 @@ +import torch +from torch.utils.data import Dataset +from transformers import BertTokenizer +import pandas as pd +import numpy as np +from ast import literal_eval +from agent.target_extraction.BERT.relation_extractor.relbertnet import TRAINED_WEIGHTS, MAX_SEQ_LEN, MAX_ENTITIES +import os + +MASK_TOKEN = '[MASK]' +tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS) + + +def generate_batch(batch): + encoded = tokenizer.batch_encode_plus([instance.tokens for instance in batch], add_special_tokens=True, + max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True, + return_tensors='pt') + input_ids = encoded['input_ids'] + attn_mask = encoded['attention_mask'] + + entity_indices = torch.tensor(list(map(indices_for_instance, batch))) + entity_mask = torch.tensor([[n < instance.get_count() for n in range(MAX_ENTITIES)] for instance in batch]) + labels = torch.tensor([e.label for instance in batch for e in instance.entities]) + + return input_ids, attn_mask, entity_indices, entity_mask, labels + + +def generate_production_batch(batch): + encoded = tokenizer.batch_encode_plus([instance.tokens for instance in batch], add_special_tokens=True, + max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True, + return_tensors='pt') + input_ids = encoded['input_ids'] + attn_mask = encoded['attention_mask'] + + entity_indices = torch.tensor(list(map(indices_for_instance, batch))) + entity_mask = torch.tensor([[n < instance.get_count() for n in range(MAX_ENTITIES)] for instance in batch]) + + return input_ids, attn_mask, entity_indices, entity_mask, batch + + +def indices_for_instance(instance): + indices = [[instance.entities[n].rng[0] if i < instance.entities[n].rng[0] else min(instance.entities[n].rng[1], i) + for i in range(MAX_SEQ_LEN)] + if n < len(instance.entities) else [0] * MAX_SEQ_LEN + for n in range(MAX_ENTITIES)] + return indices + + +class PairRelDataset(Dataset): + + def __init__(self, df, training=True, size=None): + self.df = df + self.training = training + # sample data if a size is specified + if size is not None and size < len(self): + self.df = self.df.sample(size, replace=False) + + @staticmethod + def for_extraction(df): + dataset = PairRelDataset(df, training=False) + print('Obtained dataset of size', len(dataset)) + return dataset + + @staticmethod + def from_file(file_name, valid_frac=None, size=None): + f = open(os.path.dirname(__file__) + '/../data/' + file_name) + dataset = PairRelDataset(pd.read_csv(f, sep='\t', error_bad_lines=False), size=size) + + if valid_frac is None: + print('Obtained dataset of size', len(dataset)) + return dataset, None + else: + split_idx = int(len(dataset) * (1 - valid_frac)) + dataset.df, valid_df = np.split(dataset.df, [split_idx], axis=0) + validset = PairRelDataset(valid_df) + print('Obtained train set of size', len(dataset), 'and validation set of size', len(validset)) + return dataset, validset + + def instance_from_row(self, row): + if self.training: + return RelInstance(literal_eval(row['tokens']), + literal_eval(row['entity_ranges']), + true_labels=literal_eval(row['labels'])) + else: + return RelInstance(row['tokens'], + row['entity_ranges'], + entity_labels=row['entity_labels']) + + def __len__(self): + return len(self.df.index) + + def __getitem__(self, idx): + return self.instance_from_row(self.df.iloc[idx]) + + +class RelInstance: + + def __init__(self, tokens, entity_ranges, true_labels=None, entity_labels=None, entity_texts=None): + self.tokens = tokens + self.entities = [Entity(rng, + label=(true_labels[n] if true_labels else None), + idx=(entity_labels[n] if entity_labels else None), + text=(entity_texts[n] if entity_texts else None)) + for n, rng in enumerate(entity_ranges)] + print(self.tokens) + print(entity_ranges) + + def get_count(self): + return len(self.entities) + + def print_results_for_labels(self, labels): + assert len(labels) == len(self.entities) + label_map = ['not an aspect', 'aspect', 'sub-feature'] + for e, l in zip(self.entities, labels): + print('{}: {}'.format(e.text, label_map[l])) + + @staticmethod + def from_sentence(text, entities): + def token_entity_match(first_token_idx, entity, tokens): + token_idx = first_token_idx + remaining_entity = entity + while remaining_entity: + if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity: + # start of new word + remaining_entity = remaining_entity.lstrip() + if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]: + remaining_entity = remaining_entity[len(tokens[token_idx]):] + token_idx += 1 + else: + break + else: + # continuing same word + if (token_idx < len(tokens) and tokens[token_idx].startswith('##') + and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]): + remaining_entity = remaining_entity[len(tokens[token_idx][2:]):] + token_idx += 1 + else: + break + if remaining_entity or (token_idx < len(tokens) and tokens[token_idx].startswith('##')): + return None + else: + return token_idx - first_token_idx + + tokens = tokenizer.tokenize(text) + + i = 0 + entity_ranges = [] + while i < len(tokens): + match = False + # check for aspects + for e in entities: + match_length = token_entity_match(i, e.lower(), tokens) + if match_length is not None: + entity_ranges.append((e, (i + 1, i + match_length))) # + 1 taking into account the [CLS] token + match = True + i += match_length + break + if not match: + i += 1 + + if len(entity_ranges) == 0 or len(entity_ranges) > 3: + return None + + # mask entity mentions + for _, (start, end) in entity_ranges: + tokens[(start - 1):end] = ['[MASK]'] * (end - (start - 1)) + + texts, ranges = zip(*entity_ranges) + return RelInstance(tokens, ranges, entity_texts=texts) + + +class Entity: + + def __init__(self, rng, label=None, idx=None, text=None): + self.rng = rng + self.label = label + self.idx = idx + self.text = text diff --git a/ADA-X/server/agent/target_extraction/BERT/relation_extractor/relbertnet.py b/ADA-X/server/agent/target_extraction/BERT/relation_extractor/relbertnet.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..5e7a4edfa299966f59da55b765d5db668be85140 100644 --- a/ADA-X/server/agent/target_extraction/BERT/relation_extractor/relbertnet.py +++ b/ADA-X/server/agent/target_extraction/BERT/relation_extractor/relbertnet.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from transformers import * + +TRAINED_WEIGHTS = 'bert-base-uncased' +HIDDEN_OUTPUT_FEATURES = 768 +MAX_SEQ_LEN = 128 +NUM_CLASSES = 3 # no relation, fst hasFeature snd, snd hasFeature fst +MAX_ENTITIES = 3 + + +class RelBertNet(nn.Module): + + def __init__(self): + super(RelBertNet, self).__init__() + config = BertConfig.from_pretrained(TRAINED_WEIGHTS) + self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config) + self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES) + + def forward(self, input_ids, attn_mask, entity_indices, entity_mask): + # BERT + bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask) + + # obtain entity combinations + combinations = RelBertNet.entity_combinations(bert_output, entity_indices, entity_mask) + + # fc layer (softmax activation done in loss function) + x = self.fc(combinations) + return x + + @staticmethod + def entity_combinations(bert_output, entity_indices, entity_mask): + # pool outputs + bert_output_exp = bert_output.unsqueeze(1).repeat(1, MAX_ENTITIES, 1, 1) + indices_exp = entity_indices.unsqueeze(3).repeat(1, 1, 1, HIDDEN_OUTPUT_FEATURES) + outputs = torch.gather(bert_output_exp, dim=2, index=indices_exp) + pooled_outputs, _ = torch.max(outputs, dim=2) + # pooled_outputs = torch.flatten(pooled_outputs, start_dim=0, end_dim=1) + return pooled_outputs[entity_mask] + + +# b_output = torch.randn((2, MAX_SEQ_LEN, HIDDEN_OUTPUT_FEATURES)) +# e_indices = torch.tensor([[[0, 1, 1, 1, 1], +# [2, 2, 2, 2, 2], +# [3, 3, 3, 3, 3]], +# +# [[0, 1, 1, 1, 1], +# [2, 2, 2, 2, 2], +# [4, 4, 4, 4, 4]]]) +# +# entity_mask = torch.tensor([[True, False, False], [True, True, True]]) +# +# print(RelBertNet.entity_combinations(b_output, e_indices, entity_mask)) diff --git a/ADA-X/server/agent/target_extraction/entity_annotation.py b/ADA-X/server/agent/target_extraction/entity_annotation.py index 960e4e600c6228f82f2ed9bb7492e18736f07047..5b10b185d46b8456a742fa3b179ddb958b7cb421 100644 --- a/ADA-X/server/agent/target_extraction/entity_annotation.py +++ b/ADA-X/server/agent/target_extraction/entity_annotation.py @@ -12,11 +12,17 @@ import readchar import random from sty import fg, bg from anytree import Node, RenderTree, LevelOrderIter, PreOrderIter +from itertools import combinations, repeat +from pathos.multiprocessing import ProcessingPool as Pool +from transformers import BertTokenizer +from agent.target_extraction.BERT.relation_extractor.relbertnet import TRAINED_WEIGHTS, MAX_ENTITIES PHRASE_THRESHOLD = 4 ROW_CHARACTER_COUNT = 100 stop_words = stopwords.words('english') ann_bgs = [bg.blue, bg.red] # child, parent +pool = Pool(4) +tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS) class EntityAnnotator: @@ -212,24 +218,55 @@ class EntityAnnotator: texts = [text for _, par in reviews['reviewText'].items() if not pd.isnull(par) for text in sent_tokenize(par)] - pair_texts = [t for t in map(lambda t: self.pair_relations_for_text(t), texts) + pair_texts = [t for t in map(lambda t: self.pair_relation_for_text(t), texts) if t is not None] df = pd.DataFrame(pair_texts, columns=['sentText', 'relationMentions']) df.to_csv(save_path, sep='\t', index=False) - def save_annotated_entities(self, save_path): + def save_annotated_pairs2(self, save_path, n): + reviews = pd.read_csv(self.text_file_path, sep='\t', error_bad_lines=False) + texts = [line for _, par in reviews['reviewText'].items() if not pd.isnull(par) + for sent in sent_tokenize(par) for line in sent.splitlines()] + + instances = [] + idx = 0 + while len(instances) < n and idx <= len(texts): + texts_sub = texts[idx:idx+20000] + idx += 20000 + instances += filter(lambda i: i is not None, pool.map(relation_instances_for_text, + repeat(tokenizer, len(texts_sub)), + repeat(self.root, len(texts_sub)), + repeat(self.synset, len(texts_sub)), + texts_sub)) + print(len(instances)) + + instances = instances[:n] + df = pd.DataFrame(instances, columns=['tokens', 'entity_ranges', 'labels']) + df.to_csv(save_path, sep='\t', index=False) + + def save_annotated_entities(self, save_path, n): reviews = pd.read_csv(self.text_file_path, sep='\t', error_bad_lines=False) texts = [text for _, par in reviews['reviewText'].items() if not pd.isnull(par) for text in sent_tokenize(par)] - all_entities = {(e, True) for e in self.get_annotated_entities()}.union( - {(e, False) for e in self.get_nan_entities()}) - - entity_texts = [t for t in map(lambda t: self.entity_mentions_in_text(t, all_entities), texts) - if t is not None] - - df = pd.DataFrame(entity_texts, columns=['sentText', 'entityMentions']) + product_entities = {s.lower() for s in self.synset[self.root]} + other_entities = {(e.lower(), True) for e in self.get_annotated_features()}.union( + {(e.lower(), False) for e in self.get_nan_entities()}) + + instances = [] + idx = 0 + while len(instances) < n and idx <= len(texts): + texts_sub = texts[idx:idx + 20000] + idx += 20000 + instances += filter(lambda i: i is not None, pool.map(entity_instances_for_text, + repeat(tokenizer, len(texts_sub)), + repeat(product_entities, len(texts_sub)), + repeat(other_entities, len(texts_sub)), + texts_sub)) + print(len(instances)) + + df = pd.DataFrame(instances, columns=['tokens', 'entity_idx', 'label']) df.to_csv(save_path, sep='\t', index=False) @staticmethod @@ -247,7 +284,7 @@ class EntityAnnotator: random.shuffle(m) return {'em1Text': m[0], 'em2Text': m[1], 'label': '/no_relation'} - def pair_relations_for_text(self, text, nan_entities=None): + def pair_relation_for_text(self, text, nan_entities=None): single_tokens = word_tokenize(text) tagged_single = pos_tag(single_tokens) tagged_all = set().union(*[tagged_single, pos_tag(self.phraser[single_tokens])]) @@ -327,6 +364,9 @@ class EntityAnnotator: def get_annotated_entities(self): return {syn.lower() for n in PreOrderIter(self.root) for syn in self.synset[n]} + def get_annotated_features(self): + return {syn.lower() for n in self.root.descendants for syn in self.synset[n]} + def get_nan_entities(self): annotated = self.get_annotated_entities() return {t.replace('_', ' ').lower() for t, _ in self.counter.most_common(self.n_annotated) @@ -358,6 +398,220 @@ class EntityAnnotator: return text, rels -ea: EntityAnnotator = EntityAnnotator.load_saved('annotators/watch_annotator.pickle') -ea.save_annotated_pairs('BERT/data/annotated_watch_review_pairs.tsv') -ea.save_annotated_entities('BERT/data/annotated_watch_review_entities.tsv') +def relation_instances_for_text(tokenizer, root, synset, text): + def joined_tokens(tokens, entity_ranges): + joined = [] + j_token = tokens[0] + start = 0 + for idx, t in enumerate(tokens): + if idx == 0: + continue + if t.startswith('##'): + # continuing same word + j_token += t[2:] + elif any(idx in r and idx-1 in r for r in entity_ranges): + # continuing same multi-word entity + j_token = j_token + " " + t + else: + # new word + joined.append((j_token, start, idx)) + j_token = t + start = idx + if j_token: + joined.append((j_token, start, len(tokens))) + return joined + + def noun_entity_mentions(tokens, entity_mentions): + entity_ranges = [range(em[2][0]-1, em[2][1]) for em in entity_mentions] + joined = joined_tokens(tokens, entity_ranges) + tags = [tag for _, tag in pos_tag([t for t, _, _ in joined])] + noun_ranges = [range(start, end) for idx, (_, start, end) in enumerate(joined) if tags[idx].startswith('NN')] + return [em for idx, em in enumerate(entity_mentions) if entity_ranges[idx] in noun_ranges] + + def token_entity_match(first_token_idx, entity, tokens): + token_idx = first_token_idx + remaining_entity = entity + while remaining_entity: + if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity: + # start of new word + remaining_entity = remaining_entity.lstrip() + if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]: + remaining_entity = remaining_entity[len(tokens[token_idx]):] + token_idx += 1 + else: + break + else: + # continuing same word + if (token_idx < len(tokens) and tokens[token_idx].startswith('##') + and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]): + remaining_entity = remaining_entity[len(tokens[token_idx][2:]):] + token_idx += 1 + else: + break + if remaining_entity or (token_idx < len(tokens) and tokens[token_idx].startswith('##')): + return None + else: + return token_idx - first_token_idx + + def get_rel_label(fst_m, snd_m): + fst_n, _, _ = fst_m + snd_n, _, _ = snd_m + if snd_n in fst_n.descendants: + return 1 + elif fst_n in snd_n.descendants: + return 2 + else: + return 0 + + tokens = tokenizer.tokenize(text) + + i = 0 + entity_mentions = [] + while i < len(tokens): + match = False + for n in PreOrderIter(root): + for syn in synset[n]: + match_length = token_entity_match(i, syn.lower(), tokens) + if match_length is not None: + if any(em[0] == n for em in entity_mentions): + # sentence cannot mention same aspect twice + return None + entity_mentions.append((n, syn, (i + 1, i + match_length))) # + 1 taking into account the [CLS] token + match = True + i += match_length + break + if match: + break + if not match: + i += 1 + + if len(entity_mentions) < 2: + return None + + # filter out non-nouns + entity_mentions = noun_entity_mentions(tokens, entity_mentions) + + if len(entity_mentions) < 2 or len(entity_mentions) > MAX_ENTITIES: + return None + + # mask entity mentions + for _, _, (start, end) in entity_mentions: + tokens[(start-1):end] = ['[MASK]'] * (end-(start-1)) + + entity_mentions = sorted(entity_mentions, key=lambda em: em[2]) + entity_ranges = [em[2] for em in entity_mentions] + labels = {(i, j): get_rel_label(entity_mentions[i], entity_mentions[j]) + for i, j in combinations(range(len(entity_mentions)), 2)} + return tokens, entity_ranges, labels + + +def entity_instances_for_text(tokenizer, product_entities, other_entities, text): + def joined_tokens(tokens, entity_ranges): + joined = [] + j_token = tokens[0] + start = 0 + for idx, t in enumerate(tokens): + if idx == 0: + continue + if t.startswith('##'): + # continuing same word + j_token += t[2:] + elif any(idx in r and idx-1 in r for r in entity_ranges): + # continuing same multi-word entity + j_token = j_token + " " + t + else: + # new word + joined.append((j_token, start, idx)) + j_token = t + start = idx + if j_token: + joined.append((j_token, start, len(tokens))) + return joined + + def noun_entity_mentions(tokens, entity_mentions): + entity_ranges = [range(em[0][0]-1, em[0][1]) for em in entity_mentions] + joined = joined_tokens(tokens, entity_ranges) + tags = [tag for _, tag in pos_tag([t for t, _, _ in joined])] + noun_ranges = [range(start, end) for idx, (_, start, end) in enumerate(joined) if tags[idx].startswith('NN')] + return [em for idx, em in enumerate(entity_mentions) if entity_ranges[idx] in noun_ranges] + + def token_entity_match(first_token_idx, entity, tokens): + token_idx = first_token_idx + remaining_entity = entity + while remaining_entity: + if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity: + # start of new word + remaining_entity = remaining_entity.lstrip() + if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]: + remaining_entity = remaining_entity[len(tokens[token_idx]):] + token_idx += 1 + else: + break + else: + # continuing same word + if (token_idx < len(tokens) and tokens[token_idx].startswith('##') + and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]): + remaining_entity = remaining_entity[len(tokens[token_idx][2:]):] + token_idx += 1 + else: + break + if remaining_entity or (token_idx < len(tokens) and tokens[token_idx].startswith('##')): + return None + else: + return token_idx - first_token_idx + + def mask_tokens(tokens, mask_ranges): + for (start, _), m in mask_ranges: + tokens[start-1] = m + return [t for idx, t in enumerate(tokens) + if not any(idx in range(start, end) for (start, end), _ in mask_ranges)] + + + tokens = tokenizer.tokenize(text) + + entity_mentions = [] + product_mentions = [] + for i in range(len(tokens)): + for entity, is_aspect in other_entities: + match_length = token_entity_match(i, entity, tokens) + if match_length is not None: + entity_mentions.append(((i + 1, i + match_length), is_aspect)) # + 1 taking into account the [CLS] token + for entity in product_entities: + match_length = token_entity_match(i, entity, tokens) + if match_length is not None: + product_mentions.append(((i + 1, i + match_length), True)) # + 1 taking into account the [CLS] token + + if len(entity_mentions) != 1: + return None + + # filter out non-nouns + entity_mentions = noun_entity_mentions(tokens, entity_mentions) + # filter intersecting product mentions + product_mentions = list(filter(lambda pm: not any(pm2 != pm and pm2[0][0] <= pm[0][0] and pm2[0][1] >= pm[0][1] for pm2 in product_mentions), product_mentions)) + + if len(entity_mentions) != 1: + return None + + (e_start, e_end), is_aspect = entity_mentions[0] + # mask entity mentions + tokens = mask_tokens(tokens, [((e_start, e_end), '[MASK]')] + [((start, end), 'product') for (start, end), _ in product_mentions]) + + return tokens, e_start, 1 if is_aspect else 0 + + +# ann: EntityAnnotator = EntityAnnotator.load_saved('annotators/acoustic_guitar_annotator.pickle') +# ann.save_annotated_entities('BERT/data/annotated_acoustic_guitar_review_features.tsv', 37000) +ann: EntityAnnotator = EntityAnnotator.load_saved('annotators/backpack_entity_annotator.pickle') +ann.save_annotated_entities('BERT/data/annotated_backpack_review_features.tsv', 37000) +ann: EntityAnnotator = EntityAnnotator.load_saved('annotators/cardigan_entity_annotator.pickle') +ann.save_annotated_entities('BERT/data/annotated_cardigan_review_features.tsv', 37000) +ann: EntityAnnotator = EntityAnnotator.load_saved('annotators/laptop_entity_annotator.pickle') +ann.save_annotated_entities('BERT/data/annotated_laptop_review_features.tsv', 37000) +ann: EntityAnnotator = EntityAnnotator.load_saved('annotators/camera_entity_annotator.pickle') +ann.save_annotated_entities('BERT/data/annotated_camera_review_features.tsv', 37000) +ann: EntityAnnotator = EntityAnnotator.load_saved('annotators/watch_annotator.pickle') +ann.save_annotated_entities('BERT/data/annotated_watch_review_features.tsv', 37000) + + +# ann: EntityAnnotator = EntityAnnotator.load_saved('annotators/acoustic_guitar_annotator.pickle') +# print(ann.synset[ann.root]) diff --git a/ADA-X/server/agent/target_extraction/target_extractor.py b/ADA-X/server/agent/target_extraction/target_extractor.py index f6400461a434a18f7991de502ce0c1a900f524df..81fa1d254da6d2f2018f538569149c2f5859dd9e 100644 --- a/ADA-X/server/agent/target_extraction/target_extractor.py +++ b/ADA-X/server/agent/target_extraction/target_extractor.py @@ -14,8 +14,11 @@ import pickle from agent.target_extraction.product import Product from agent.target_extraction.BERT.entity_extractor.entity_dataset import EntityDataset from agent.target_extraction.BERT.entity_extractor.bert_entity_extractor import BertEntityExtractor -from agent.target_extraction.BERT.relation_extractor.pair_rel_dataset import PairRelDataset +# from agent.target_extraction.BERT.relation_extractor.pair_rel_dataset import PairRelDataset +from agent.target_extraction.BERT.relation_extractor.rel_dataset import PairRelDataset from agent.target_extraction.BERT.relation_extractor.bert_rel_extractor import BertRelExtractor +from agent.target_extraction.BERT.relation_extractor.relbertnet import TRAINED_WEIGHTS, MAX_ENTITIES +from transformers import BertTokenizer from pathos.multiprocessing import ProcessingPool as Pool import itertools from time import time @@ -25,9 +28,10 @@ np.set_printoptions(precision=4, threshold=np.inf, suppress=True) stop_words = stopwords.words('english') wnl = WordNetLemmatizer() sentiment_lexicon = pd.read_csv('data/NRC-Sentiment-Lexicon-Wordlevel-v0.92.tsv', sep='\t', index_col=0) -entity_extractor_path = 'BERT/entity_extractor/entity_extractor_five_products_2_epoch_3.pt' -rel_extractor_path = 'BERT/relation_extractor/rel_extractor_five_products_2_epoch_3.pt' +entity_extractor_path = 'BERT/entity_extractor/entity_extractor_multi_epoch_1.pt' # entity_extractor_five_products_2_epoch_3.pt' +rel_extractor_path = 'BERT/relation_extractor/multi_extractor_5_products_epoch_3.pt' # rel_extractor_five_products_2_epoch_3.pt' pool = Pool(4) +tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS) def ngrams(text, phraser): @@ -113,6 +117,91 @@ def entity_mentions_in_text(text, phrase, ngrams, entities): return None +def entity_mentions_in_text_with_tokenizer(text, tokenizer, entities): + def joined_tokens(tokens, entity_ranges): + joined = [] + j_token = tokens[0] + start = 0 + for idx, t in enumerate(tokens): + if idx == 0: + continue + if t.startswith('##'): + # continuing same word + j_token += t[2:] + elif any(idx in r and idx-1 in r for r in entity_ranges): + # continuing same multi-word entity + j_token = j_token + " " + t + else: + # new word + joined.append((j_token, start, idx)) + j_token = t + start = idx + if j_token: + joined.append((j_token, start, len(tokens))) + return joined + + def noun_entity_mentions(tokens, entity_mentions): + entity_ranges = [range(em[0][0]-1, em[0][1]) for em in entity_mentions] + joined = joined_tokens(tokens, entity_ranges) + tags = [tag for _, tag in pos_tag([t for t, _, _ in joined])] + noun_ranges = [range(start, end) for idx, (_, start, end) in enumerate(joined) if tags[idx].startswith('NN')] + return [em for idx, em in enumerate(entity_mentions) if entity_ranges[idx] in noun_ranges] + + def token_entity_match(first_token_idx, entity, tokens): + token_idx = first_token_idx + remaining_entity = entity + while remaining_entity: + if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity: + # start of new word + remaining_entity = remaining_entity.lstrip() + if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]: + remaining_entity = remaining_entity[len(tokens[token_idx]):] + token_idx += 1 + else: + break + else: + # continuing same word + if (token_idx < len(tokens) and tokens[token_idx].startswith('##') + and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]): + remaining_entity = remaining_entity[len(tokens[token_idx][2:]):] + token_idx += 1 + else: + break + if remaining_entity or (token_idx < len(tokens) and tokens[token_idx].startswith('##')): + return None + else: + return token_idx - first_token_idx + + tokens = tokenizer.tokenize(text) + + entity_mentions = [] + for i in range(len(tokens)): + for entity in entities: + match_length = token_entity_match(i, entity.lower(), tokens) + if match_length is not None: + entity_mentions.append(((i + 1, i + match_length), entity.lower())) # + 1 taking into account the [CLS] token + + if len(entity_mentions) == 0: + return None + + # filter out non-nouns + entity_mentions = noun_entity_mentions(tokens, entity_mentions) + + if len(entity_mentions) == 0: + return None + + entity_mentions = sorted(entity_mentions, key=lambda em: em[0]) + + # mask entity mentions + masked_tokens = [] + for (start, end), _ in entity_mentions: + masked = tokens.copy() + masked[(start - 1):end] = ['[MASK]'] * (end - (start - 1)) + masked_tokens.append(masked) + + return [(masked_tokens[idx], e, er) for idx, (er, e) in enumerate(entity_mentions)] + + def pair_relations_for_text(text, ngrams, aspects, syn_dict): def overlapping_terms(ts, t): @@ -136,6 +225,103 @@ def pair_relations_for_text(text, ngrams, aspects, syn_dict): return (text, [{'em1Text': found_aspects[0], 'em2Text': found_aspects[1]}]) if len(found_aspects) == 2 else None +def pair_relations_for_text_with_tokenizer(text, tokenizer, aspects, syn_dict): + def joined_tokens(tokens, entity_ranges): + joined = [] + j_token = tokens[0] + start = 0 + for idx, t in enumerate(tokens): + if idx == 0: + continue + if t.startswith('##'): + # continuing same word + j_token += t[2:] + elif any(idx in r and idx-1 in r for r in entity_ranges): + # continuing same multi-word entity + j_token = j_token + " " + t + else: + # new word + joined.append((j_token, start, idx)) + j_token = t + start = idx + if j_token: + joined.append((j_token, start, len(tokens))) + return joined + + def noun_entity_mentions(tokens, entity_mentions): + entity_ranges = [range(em[1][0]-1, em[1][1]) for em in entity_mentions] + joined = joined_tokens(tokens, entity_ranges) + tags = [tag for _, tag in pos_tag([t for t, _, _ in joined])] + noun_ranges = [range(start, end) for idx, (_, start, end) in enumerate(joined) if tags[idx].startswith('NN')] + return [em for idx, em in enumerate(entity_mentions) if entity_ranges[idx] in noun_ranges] + + def token_entity_match(first_token_idx, entity, tokens): + token_idx = first_token_idx + remaining_entity = entity + while remaining_entity: + if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity: + # start of new word + remaining_entity = remaining_entity.lstrip() + if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]: + remaining_entity = remaining_entity[len(tokens[token_idx]):] + token_idx += 1 + else: + break + else: + # continuing same word + if (token_idx < len(tokens) and tokens[token_idx].startswith('##') + and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]): + remaining_entity = remaining_entity[len(tokens[token_idx][2:]):] + token_idx += 1 + else: + break + if remaining_entity: + return None + else: + return token_idx - first_token_idx + + tokens = tokenizer.tokenize(text) + + i = 0 + entity_mentions = [] + while i < len(tokens): + match = False + for a_idx, a in enumerate(aspects): + for syn in syn_dict[a]: + match_length = token_entity_match(i, syn.lower(), tokens) + if match_length is not None: + if any(em[0] == a_idx for em in entity_mentions): + # sentence cannot mention same aspect twice + return None + entity_mentions.append((a_idx, (i + 1, i + match_length))) # + 1 taking into account the [CLS] token + match = True + i += match_length + break + if match: + break + if not match: + i += 1 + + if len(entity_mentions) < 2: + return None + + # filter out non-nouns + entity_mentions = noun_entity_mentions(tokens, entity_mentions) + + if len(entity_mentions) < 2 or len(entity_mentions) > MAX_ENTITIES: + return None + + # mask entity mentions + for _, (start, end) in entity_mentions: + tokens[(start - 1):end] = ['[MASK]'] * (end - (start - 1)) + + entity_mentions = sorted(entity_mentions, key=lambda em: em[1]) + entity_ranges = [em[1] for em in entity_mentions] + entity_labels = [em[0] for em in entity_mentions] + + return tokens, entity_ranges, entity_labels + + class TargetExtractor: N_ASPECTS = 100 @@ -148,7 +334,7 @@ class TargetExtractor: MAX_DEPTH = 2 # word2vec - MIN_TERM_COUNT = 0 + MIN_TERM_COUNT = 100 SYNONYM_SIMILARITY = 0.21 PRODUCT_ABSORPTION_MULT = 3 # see product_absorption() WV_SIZE = 300 @@ -166,8 +352,9 @@ class TargetExtractor: print('tokenizing phrases...') # tokenize and normalize phrases - texts = TargetExtractor.obtain_texts(file_path, text_column, n=500000) + texts = TargetExtractor.obtain_texts(file_path, text_column, n=50000) self.sentences = list(itertools.chain.from_iterable(pool.map(sent_tokenize, texts))) + self.sentences = list(itertools.chain.from_iterable(pool.map(str.splitlines, self.sentences))) self.sentences = pool.map(lambda s: s.replace('_', ' ').lower(), self.sentences) self.phrases = pool.map(word_tokenize, self.sentences) @@ -188,7 +375,9 @@ class TargetExtractor: print('mining aspects...') # mine aspects - self.aspects, self.counts = self.get_aspects(self.counter) + self.aspects, self.counts, min_count = self.get_aspects(self.counter) + + print(min_count) t_feature = time() print('Feature extraction took {} seconds'.format(t_feature - t_noun)) @@ -206,12 +395,15 @@ class TargetExtractor: self.aspects = [aspect for aspect in self.aspects if aspect in self.syn_dict.keys()] self.counts = {aspect: sum(self.counts[syn] for syn in self.syn_dict[aspect]) for aspect in self.aspects} self.aspects = sorted(self.aspects, key=self.counts.get, reverse=True) + print(self.aspects) + print(self.counts) + print(self.syn_dict) t_syn = time() print('Synonym extraction took {} seconds'.format(t_syn - t_feature)) print('extracting relatedness matrix...') - self.relatedness_matrix = self.get_bert_relations() + self.relatedness_matrix = self.get_bert_relations2() print('extracting aspect tree...') self.tree = self.get_product_tree() @@ -264,6 +456,7 @@ class TargetExtractor: m = prob_matrix / self.get_aspect_counts() # scale rows by aspect counts non_features = self.product_absorption(m) for idx in non_features: + print(self.aspects[idx]) # absorb probabilities prob_matrix[0] += prob_matrix[idx] prob_matrix[:, 0] += prob_matrix[:, idx] @@ -280,6 +473,43 @@ class TargetExtractor: return self.relatedness_matrix + def get_bert_relations2(self): + print(' select phrases for relation extraction...') + instances = filter(lambda i: i is not None, pool.map(pair_relations_for_text_with_tokenizer, + self.sentences, + itertools.repeat(tokenizer, len(self.sentences)), + itertools.repeat(self.aspects, len(self.sentences)), + itertools.repeat(self.syn_dict, len(self.sentences)))) + df = pd.DataFrame(instances, columns=['tokens', 'entity_ranges', 'entity_labels']) + + print(' extracting relations with BERT...') + dataset = PairRelDataset.for_extraction(df) + bert_extractor = BertRelExtractor.load_saved(rel_extractor_path) + prob_matrix, count_matrix = bert_extractor.extract_relations2(len(self.aspects), dataset) + + # absorb non-features to product + m = prob_matrix / self.get_aspect_counts() # scale rows by aspect counts + non_features = self.product_absorption(m) + for idx in non_features: + print(self.aspects[idx]) + # absorb probabilities + prob_matrix[idx][0] = 0 + prob_matrix[0][idx] = 0 + prob_matrix[0] += prob_matrix[idx] + prob_matrix[:, 0] += prob_matrix[:, idx] + # absorb synonyms and counts + self.syn_dict[self.aspects[0]].update(self.syn_dict[self.aspects[idx]]) + self.counts[self.aspects[0]] += self.counts[self.aspects[idx]] + del self.syn_dict[self.aspects[idx]] + del self.counts[self.aspects[idx]] + prob_matrix = np.delete(np.delete(prob_matrix, non_features, axis=0), non_features, axis=1) + self.aspects = [a for idx, a in enumerate(self.aspects) if idx not in non_features] + + # recalculate relatedness matrix + relatedness_matrix = prob_matrix / self.get_aspect_counts() + + return relatedness_matrix + def get_aspect_counts(self): return np.array([self.counts[aspect] for aspect in self.aspects]) @@ -300,37 +530,40 @@ class TargetExtractor: terms = [term for term, count in term_counts] print(' preparing entity texts for BERT...') - entity_texts = [t for t in pool.map(entity_mentions_in_text, self.sentences, self.phrases, self.ngram_phrases, - itertools.repeat(terms, len(self.sentences))) - if t is not None] - df = pd.DataFrame(entity_texts, columns=['sentText', 'entityMentions']) + instances = [instance for instances in + filter(lambda i: i is not None, pool.map(entity_mentions_in_text_with_tokenizer, + self.sentences[:100000], + itertools.repeat(tokenizer, len(self.sentences[:100000])), + itertools.repeat(terms, len(self.sentences[:100000])))) + for instance in instances] + df = pd.DataFrame(instances, columns=['tokens', 'entity', 'entity_range']) print(' extracting entities with BERT...') - dataset = EntityDataset.from_df(df) + dataset = EntityDataset.for_extraction(df) entity_extractor = BertEntityExtractor.load_saved(entity_extractor_path) probs = entity_extractor.extract_entity_probabilities(terms, dataset=dataset) - aspects = [term for term in terms if probs[term] is not None and probs[term] >= TargetExtractor.ENTITY_PROB_THRESHOLD] + l = list(sorted(probs.items(), key=lambda p: p[1] if p[1] else -1, reverse=True)) + for e, s in l: + print('{}: {}'.format(e, s)) - # bring product to front of list - if self.product in aspects: - aspects.remove(self.product) - aspects.insert(0, self.product) + aspects = [term for term in terms if probs[term] is not None and probs[term] >= TargetExtractor.ENTITY_PROB_THRESHOLD] + counts = {term: count for term, count in term_counts if term in aspects} - return aspects, {term: count for term, count in term_counts if term in aspects} + return aspects, counts, counts[aspects[-1]] def get_word2vec_model(self, size, window, min_count): model = Word2Vec(self.ngram_phrases, size=size, window=window, min_count=min_count).wv return model def save(self): - f = open('extractors/{}_extractor_f.pickle'.format(self.product), 'wb') + f = open('extractors/{}_extractor_f2.pickle'.format(self.product), 'wb') pickle.dump(self, f) f.close() @staticmethod def load_saved(product): - f = open('extractors/{}_extractor_f.pickle'.format(product), 'rb') + f = open('extractors/{}_extractor_f2.pickle'.format(product), 'rb') extractor = pickle.load(f) f.close() return extractor @@ -348,32 +581,69 @@ class TargetExtractor: root = Node(self.aspects[0]) root.idx = 0 - deps = {idx: self.aspect_dependence_with_strength(idx) for idx in range(1, len(self.aspects))} + dependencies = {idx: self.aspect_dependence_with_strength(idx) for idx in range(1, len(self.aspects))} - for no_dep_idx in {idx for idx, dep in deps.items() if dep is None}: + for no_dep_idx in {idx for idx, dep in dependencies.items() if dep is None}: node = Node(self.aspects[no_dep_idx], parent=root) node.idx = no_dep_idx - del deps[no_dep_idx] + del dependencies[no_dep_idx] - sorted_deps = sorted(deps.items(), key=lambda x: x[1][1], reverse=True) + n_dependants = {idx: sum(1 for _, (dep_idx, _) in dependencies.items() if dep_idx == idx) + for idx in dependencies.keys()} + dep_groups = [filter(lambda idx: n_dependants[idx] == n, dependencies.keys()) + for n in reversed(range(max(n_dependants.values()) + 1))] - for idx, (dep_idx, _) in sorted_deps: - # print(self.aspects[idx], self.aspects[dep_idx]) - if any(n for n in root.descendants if n.idx == idx): - continue + for group in dep_groups: + sorted_deps = sorted(group, key=lambda idx: dependencies[idx][1], reverse=True) - dep_n = next((n for n in root.descendants if n.idx == dep_idx), None) - if dep_n: - if dep_n.depth < 2: - n = Node(self.aspects[idx], parent=dep_n) + for idx in sorted_deps: + dep_idx = dependencies[idx][0] + + if any(n for n in root.descendants if n.idx == idx): + continue + + dep_n = next((n for n in root.descendants if n.idx == dep_idx), None) + if dep_n: + if dep_n.depth < 2: + n = Node(self.aspects[idx], parent=dep_n) + else: + n = Node(self.aspects[idx], parent=dep_n.parent) else: - n = Node(self.aspects[idx], parent=dep_n.parent) + dep_n = Node(self.aspects[dep_idx], parent=root) + dep_n.idx = dep_idx + n = Node(self.aspects[idx], parent=dep_n) + n.idx = idx + + return root + + def get_product_tree_no_limit(self): + root = Node(self.aspects[0]) + root.idx = 0 + + dependencies = {idx: self.aspect_dependence_with_strength(idx) for idx in range(1, len(self.aspects))} + + for no_dep_idx in {idx for idx, dep in dependencies.items() if dep is None}: + node = Node(self.aspects[no_dep_idx], parent=root) + node.idx = no_dep_idx + del dependencies[no_dep_idx] + + sorted_deps = sorted(dependencies.items(), key=lambda dep: dep[1][1], reverse=True) + unassigned = [] + + for idx, (dep_idx, _) in sorted_deps: + n = next((n for n in unassigned if n.idx == idx), None) + if n is None: + n = Node(self.aspects[idx]) + n.idx = idx else: - print(self.aspects[idx], self.aspects[dep_idx]) - dep_n = Node(self.aspects[dep_idx], parent=root) + unassigned.remove(n) + + dep_n = next((n for n in root.descendants if n.idx == dep_idx), None) + if dep_n is None: + dep_n = Node(self.aspects[dep_idx]) dep_n.idx = dep_idx - n = Node(self.aspects[idx], parent=dep_n) - n.idx = idx + unassigned.append(dep_n) + n.parent = dep_n return root @@ -434,9 +704,6 @@ class TargetExtractor: c2 = np.delete(m[:, idx2], [idx1, idx2]).reshape(1, -1) return cosine_similarity((c1 - r1), (c2 - r2))[0][0] - def get_inverse_co_occurrence(self, m, idx1, idx2): - return (self.counts[self.aspects[idx1]] * self.counts[self.aspects[idx2]]) / (m[idx1][idx2] * m[idx2][idx1]) - class Synset: @@ -480,3 +747,7 @@ class Synset: if w in group: return group return None + + +extr: TargetExtractor = TargetExtractor.load_saved('necklace') +extr.get_aspects(extr.counter)