import torch from torch.utils.data import Dataset from transformers import BertTokenizer import pandas as pd import numpy as np from ast import literal_eval import os.path from agent.target_extraction.BERT.relation_extractor.pairbertnet import TRAINED_WEIGHTS, HIDDEN_OUTPUT_FEATURES MAX_SEQ_LEN = 128 MASK_TOKEN = '[MASK]' tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS) class EntityDataset(Dataset): def __init__(self, df, training=True, size=None): self.df = df self.training = training # sample data if a size is specified if size is not None and size < len(self): self.df = self.df.sample(size, replace=False) @staticmethod def for_extraction(df): dataset = EntityDataset(df, training=False) print('Obtained dataset of size', len(dataset)) return dataset @staticmethod def from_file(file_name, valid_frac=None, size=None): f = open(os.path.dirname(__file__) + '/../data/' + file_name) if file_name.endswith('.json'): dataset = EntityDataset(pd.read_json(f, lines=True), size=size) elif file_name.endswith('.tsv'): dataset = EntityDataset(pd.read_csv(f, sep='\t', error_bad_lines=False), size=size) else: raise AttributeError('Could not recognize file type') if valid_frac is None: print('Obtained dataset of size', len(dataset)) return dataset, None else: split_idx = int(len(dataset) * (1 - valid_frac)) dataset.df, valid_df = np.split(dataset.df, [split_idx], axis=0) validset = EntityDataset(valid_df) print('Obtained train set of size', len(dataset), 'and validation set of size', len(validset)) return dataset, validset def instance_from_row(self, row): if self.training: return EntityInstance(literal_eval(row['tokens']), row['entity_idx'], label=row['label']) else: return EntityInstance(row['tokens'], row['entity_idx'], entity=row['entity']) def __len__(self): return len(self.df.index) def __getitem__(self, idx): return self.instance_from_row(self.df.iloc[idx]) class EntityInstance: def __init__(self, tokens, entity_idx, label=None, entity=None): self.tokens = tokens self.entity_idx = entity_idx self.label = label self.entity = entity def generate_batch(instances: [EntityInstance]): encoded = tokenizer.batch_encode_plus([instance.tokens for instance in instances], add_special_tokens=True, max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True, return_tensors='pt') input_ids = encoded['input_ids'] attn_mask = encoded['attention_mask'] entity_indices = torch.tensor([instance.entity_idx for instance in instances]) labels = torch.tensor([instance.label for instance in instances]) return input_ids, attn_mask, entity_indices, labels def generate_production_batch(instances: [EntityInstance]): encoded = tokenizer.batch_encode_plus([instance.tokens for instance in instances], add_special_tokens=True, max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True, return_tensors='pt') input_ids = encoded['input_ids'] attn_mask = encoded['attention_mask'] entity_indices = torch.tensor([instance.entity_idx for instance in instances]) return input_ids, attn_mask, entity_indices, instances # def indices_for_entity_ranges(ranges): # max_e_len = max(end - start for start, end in ranges) # indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES # for t in range(start, start + max_e_len + 1)] # for start, end in ranges]) # return indices