rel_dataset.py 7.06 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
import pandas as pd
import numpy as np
from ast import literal_eval
from agent.target_extraction.BERT.relation_extractor.relbertnet import TRAINED_WEIGHTS, MAX_SEQ_LEN, MAX_ENTITIES
import os

MASK_TOKEN = '[MASK]'
tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)


def generate_batch(batch):
    encoded = tokenizer.batch_encode_plus([instance.tokens for instance in batch], add_special_tokens=True,
                                          max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
                                          return_tensors='pt')
    input_ids = encoded['input_ids']
    attn_mask = encoded['attention_mask']

    entity_indices = torch.tensor(list(map(indices_for_instance, batch)))
    entity_mask = torch.tensor([[n < instance.get_count() for n in range(MAX_ENTITIES)] for instance in batch])
    labels = torch.tensor([e.label for instance in batch for e in instance.entities])

    return input_ids, attn_mask, entity_indices, entity_mask, labels


def generate_production_batch(batch):
    encoded = tokenizer.batch_encode_plus([instance.tokens for instance in batch], add_special_tokens=True,
                                          max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
                                          return_tensors='pt')
    input_ids = encoded['input_ids']
    attn_mask = encoded['attention_mask']

    entity_indices = torch.tensor(list(map(indices_for_instance, batch)))
    entity_mask = torch.tensor([[n < instance.get_count() for n in range(MAX_ENTITIES)] for instance in batch])

    return input_ids, attn_mask, entity_indices, entity_mask, batch


def indices_for_instance(instance):
    indices = [[instance.entities[n].rng[0] if i < instance.entities[n].rng[0] else min(instance.entities[n].rng[1], i)
                for i in range(MAX_SEQ_LEN)]
               if n < len(instance.entities) else [0] * MAX_SEQ_LEN
               for n in range(MAX_ENTITIES)]
    return indices


class PairRelDataset(Dataset):

    def __init__(self, df, training=True, size=None):
        self.df = df
        self.training = training
        # sample data if a size is specified
        if size is not None and size < len(self):
            self.df = self.df.sample(size, replace=False)

    @staticmethod
    def for_extraction(df):
        dataset = PairRelDataset(df, training=False)
        print('Obtained dataset of size', len(dataset))
        return dataset

    @staticmethod
    def from_file(file_name, valid_frac=None, size=None):
        f = open(os.path.dirname(__file__) + '/../data/' + file_name)
        dataset = PairRelDataset(pd.read_csv(f, sep='\t', error_bad_lines=False), size=size)

        if valid_frac is None:
            print('Obtained dataset of size', len(dataset))
            return dataset, None
        else:
            split_idx = int(len(dataset) * (1 - valid_frac))
            dataset.df, valid_df = np.split(dataset.df, [split_idx], axis=0)
            validset = PairRelDataset(valid_df)
            print('Obtained train set of size', len(dataset), 'and validation set of size', len(validset))
            return dataset, validset

    def instance_from_row(self, row):
        if self.training:
            return RelInstance(literal_eval(row['tokens']),
                               literal_eval(row['entity_ranges']),
                               true_labels=literal_eval(row['labels']))
        else:
            return RelInstance(row['tokens'],
                               row['entity_ranges'],
                               entity_labels=row['entity_labels'])

    def __len__(self):
        return len(self.df.index)

    def __getitem__(self, idx):
        return self.instance_from_row(self.df.iloc[idx])


class RelInstance:

    def __init__(self, tokens, entity_ranges, true_labels=None, entity_labels=None, entity_texts=None):
        self.tokens = tokens
        self.entities = [Entity(rng,
                                label=(true_labels[n] if true_labels else None),
                                idx=(entity_labels[n] if entity_labels else None),
                                text=(entity_texts[n] if entity_texts else None))
                         for n, rng in enumerate(entity_ranges)]
        print(self.tokens)
        print(entity_ranges)

    def get_count(self):
        return len(self.entities)

    def print_results_for_labels(self, labels):
        assert len(labels) == len(self.entities)
        label_map = ['not an aspect', 'aspect', 'sub-feature']
        for e, l in zip(self.entities, labels):
            print('{}: {}'.format(e.text, label_map[l]))

    @staticmethod
    def from_sentence(text, entities):
        def token_entity_match(first_token_idx, entity, tokens):
            token_idx = first_token_idx
            remaining_entity = entity
            while remaining_entity:
                if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity:
                    # start of new word
                    remaining_entity = remaining_entity.lstrip()
                    if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]:
                        remaining_entity = remaining_entity[len(tokens[token_idx]):]
                        token_idx += 1
                    else:
                        break
                else:
                    # continuing same word
                    if (token_idx < len(tokens) and tokens[token_idx].startswith('##')
                            and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]):
                        remaining_entity = remaining_entity[len(tokens[token_idx][2:]):]
                        token_idx += 1
                    else:
                        break
            if remaining_entity or (token_idx < len(tokens) and tokens[token_idx].startswith('##')):
                return None
            else:
                return token_idx - first_token_idx

        tokens = tokenizer.tokenize(text)

        i = 0
        entity_ranges = []
        while i < len(tokens):
            match = False
            # check for aspects
            for e in entities:
                match_length = token_entity_match(i, e.lower(), tokens)
                if match_length is not None:
                    entity_ranges.append((e, (i + 1, i + match_length)))  # + 1 taking into account the [CLS] token
                    match = True
                    i += match_length
                    break
            if not match:
                i += 1

        if len(entity_ranges) == 0 or len(entity_ranges) > 3:
            return None

        # mask entity mentions
        for _, (start, end) in entity_ranges:
            tokens[(start - 1):end] = ['[MASK]'] * (end - (start - 1))

        texts, ranges = zip(*entity_ranges)
        return RelInstance(tokens, ranges, entity_texts=texts)


class Entity:

    def __init__(self, rng, label=None, idx=None, text=None):
        self.rng = rng
        self.label = label
        self.idx = idx
        self.text = text