entity_dataset.py 4.01 KB
Newer Older
1 2 3 4 5 6
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
import pandas as pd
import numpy as np
from ast import literal_eval
7 8
import os.path
from agent.target_extraction.BERT.relation_extractor.pairbertnet import TRAINED_WEIGHTS, HIDDEN_OUTPUT_FEATURES
9 10

MAX_SEQ_LEN = 128
11
MASK_TOKEN = '[MASK]'
12 13 14
tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)


15
class EntityDataset(Dataset):
16

17 18 19
    def __init__(self, df, training=True, size=None):
        self.df = df
        self.training = training
20 21 22 23 24
        # sample data if a size is specified
        if size is not None and size < len(self):
            self.df = self.df.sample(size, replace=False)

    @staticmethod
25 26
    def for_extraction(df):
        dataset = EntityDataset(df, training=False)
27 28 29 30
        print('Obtained dataset of size', len(dataset))
        return dataset

    @staticmethod
31 32 33 34 35 36
    def from_file(file_name, valid_frac=None, size=None):
        f = open(os.path.dirname(__file__) + '/../data/' + file_name)
        if file_name.endswith('.json'):
            dataset = EntityDataset(pd.read_json(f, lines=True), size=size)
        elif file_name.endswith('.tsv'):
            dataset = EntityDataset(pd.read_csv(f, sep='\t', error_bad_lines=False), size=size)
37 38 39 40 41 42 43 44 45
        else:
            raise AttributeError('Could not recognize file type')

        if valid_frac is None:
            print('Obtained dataset of size', len(dataset))
            return dataset, None
        else:
            split_idx = int(len(dataset) * (1 - valid_frac))
            dataset.df, valid_df = np.split(dataset.df, [split_idx], axis=0)
46
            validset = EntityDataset(valid_df)
47 48 49
            print('Obtained train set of size', len(dataset), 'and validation set of size', len(validset))
            return dataset, validset

50 51 52 53 54
    def instance_from_row(self, row):
        if self.training:
            return EntityInstance(literal_eval(row['tokens']),
                                  row['entity_idx'],
                                  label=row['label'])
55
        else:
56 57 58
            return EntityInstance(row['tokens'],
                                  row['entity_idx'],
                                  entity=row['entity'])
59 60 61 62 63

    def __len__(self):
        return len(self.df.index)

    def __getitem__(self, idx):
64
        return self.instance_from_row(self.df.iloc[idx])
65 66


67
class EntityInstance:
68

69
    def __init__(self, tokens, entity_idx, label=None, entity=None):
70
        self.tokens = tokens
71
        self.entity_idx = entity_idx
72
        self.label = label
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
        self.entity = entity


def generate_batch(instances: [EntityInstance]):
    encoded = tokenizer.batch_encode_plus([instance.tokens for instance in instances], add_special_tokens=True,
                                          max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
                                          return_tensors='pt')
    input_ids = encoded['input_ids']
    attn_mask = encoded['attention_mask']

    entity_indices = torch.tensor([instance.entity_idx for instance in instances])
    labels = torch.tensor([instance.label for instance in instances])

    return input_ids, attn_mask, entity_indices, labels


def generate_production_batch(instances: [EntityInstance]):
    encoded = tokenizer.batch_encode_plus([instance.tokens for instance in instances], add_special_tokens=True,
                                          max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
                                          return_tensors='pt')
    input_ids = encoded['input_ids']
    attn_mask = encoded['attention_mask']

    entity_indices = torch.tensor([instance.entity_idx for instance in instances])

    return input_ids, attn_mask, entity_indices, instances


# def indices_for_entity_ranges(ranges):
#     max_e_len = max(end - start for start, end in ranges)
#     indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES
#                              for t in range(start, start + max_e_len + 1)]
#                             for start, end in ranges])
#     return indices