bert_dataset.py 3.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from torch.utils.data import Dataset
import xml.etree.ElementTree as ET
from transformers import *
from tdbertnet import TRAINED_WEIGHTS, HIDDEN_OUTPUT_FEATURES
import re

MAX_SEQ_LEN = 128
polarity_indices = {'positive': 0, 'negative': 1, 'neutral': 2, 'conflict': 3}
tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)


def generate_batch(batch):
    texts = tokenizer.batch_encode_plus([entry['tokens'] for entry in batch], add_special_tokens=True,
                                        max_length=MAX_SEQ_LEN, pad_to_max_length=True, is_pretokenized=True,
                                        return_tensors='pt')

    max_tg_len = max(entry['to'] - entry['from'] for entry in batch)
    target_indices = torch.tensor([[[min(t, entry['to'])] * HIDDEN_OUTPUT_FEATURES
                                    for t in range(entry['from'], entry['from'] + max_tg_len + 1)]
                                   for entry in batch])

    polarity_labels = torch.tensor([entry['polarity'] for entry in batch])

    return texts, target_indices, polarity_labels


def token_for_char(char_idx, text, tokens):
29
    compressed_idx = len(re.sub(r'\s+', '', text[:char_idx+1])) - 1
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    token_idx = -1
    while compressed_idx >= 0:
        token_idx += 1
        compressed_idx -= len(tokens[token_idx].replace('##', ''))
    return token_idx


def polarity_index(polarity):
    return polarity_indices[polarity]


class BertDataset(Dataset):

    def __init__(self, xml_file):
        tree = ET.parse(xml_file)

        self.data = []

        for sentence in tree.getroot():
            text = sentence.find('text').text
            aspect_terms = sentence.find('aspectTerms')
            if aspect_terms:
                for term in aspect_terms:
                    char_from = int(term.attrib['from'])
                    char_to = int(term.attrib['to']) - 1
                    polarity = term.attrib['polarity']
56
                    self.data.append((Instance(text, char_from, char_to), polarity))
57
58
59
60
61

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
62
        instance, polarity_str = self.data[idx]
63

64
        tokens, idx_from, idx_to = instance.get()
65
66
67
        polarity = polarity_index(polarity_str)

        return {'tokens': tokens, 'from': idx_from, 'to': idx_to, 'polarity': polarity}
68
69
70
71
72
73
74
75
76
77
78
79


class Instance:

    def __init__(self, text, char_from, char_to):
        self.text = text
        self.char_from = char_from
        self.char_to = char_to

    def get(self):
        tokens = tokenizer.tokenize(self.text)
        idx_from = token_for_char(self.char_from, self.text, tokens)
80
        idx_to = token_for_char(self.char_to-1, self.text, tokens)
81
82
83
84
85
86
87
88
        return tokens, idx_from, idx_to

    def to_tensor(self):
        tokens, idx_from, idx_to = self.get()
        text = tokenizer.encode_plus(tokens, add_special_tokens=True, max_length=MAX_SEQ_LEN,
                                     is_pretokenized=True, return_tensors='pt')
        target_indices = torch.tensor([[[t] * HIDDEN_OUTPUT_FEATURES for t in range(idx_from, idx_to + 1)]])
        return text, target_indices