bert_rel_extractor.py 10.8 KB
Newer Older
1 2 3
import torch
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
4
from torch.nn.functional import softmax
5 6 7 8 9 10
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import time
import numpy as np
from sklearn import metrics
from transformers import get_linear_schedule_with_warmup
11 12 13 14
# from agent.target_extraction.BERT.relation_extractor.pair_rel_dataset import PairRelDataset, generate_batch, generate_production_batch
from agent.target_extraction.BERT.relation_extractor.rel_dataset import PairRelDataset, generate_batch, generate_production_batch, RelInstance
# from agent.target_extraction.BERT.relation_extractor.pairbertnet import NUM_CLASSES, PairBertNet
from agent.target_extraction.BERT.relation_extractor.relbertnet import NUM_CLASSES, RelBertNet
15 16 17 18 19 20 21 22 23

device = torch.device('cuda')

# optimizer
DECAY_RATE = 0.01
LEARNING_RATE = 0.00002
MAX_GRAD_NORM = 1.0

# training
Joel Oksanen's avatar
Joel Oksanen committed
24
N_EPOCHS = 3
25
BATCH_SIZE = 16
26 27
WARM_UP_FRAC = 0.05

28 29 30
# loss
loss_criterion = CrossEntropyLoss()

31 32 33 34

class BertRelExtractor:

    def __init__(self):
35
        self.net = RelBertNet()
36 37 38 39

    @staticmethod
    def load_saved(path):
        extr = BertRelExtractor()
40
        extr.net = RelBertNet()
41 42 43 44 45
        extr.net.load_state_dict(torch.load(path))
        extr.net.eval()
        return extr

    @staticmethod
46
    def new_trained_with_file(file_path, save_path, size=None):
47
        extractor = BertRelExtractor()
48
        extractor.train_with_file(file_path, save_path, size=size)
49 50 51
        return extractor

    @staticmethod
Joel Oksanen's avatar
Joel Oksanen committed
52
    def train_and_validate(file_path, save_file, size=None, valid_frac=None, valid_file_path=None):
53
        extractor = BertRelExtractor()
Joel Oksanen's avatar
Joel Oksanen committed
54
        extractor.train_with_file(file_path, save_file, size=size, valid_frac=valid_frac,
55
                                  valid_file_path=valid_file_path)
56 57
        return extractor

Joel Oksanen's avatar
Joel Oksanen committed
58
    def train_with_file(self, file_path, save_file, size=None, valid_frac=None, valid_file_path=None):
59
        # load training data
60 61 62
        if valid_file_path is None:
            train_data, valid_data = PairRelDataset.from_file(file_path, size=size, valid_frac=valid_frac)
        else:
63 64
            train_size = int(size * (1 - valid_frac)) if size is not None else None
            train_data, _ = PairRelDataset.from_file(file_path, size=train_size)
65
            valid_data, _ = PairRelDataset.from_file(valid_file_path)
66 67 68 69 70 71 72 73 74 75
        train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
                                  collate_fn=generate_batch)

        # initialise BERT
        self.net.cuda()

        # set up optimizer with weight decay
        optimiser = Adam(self.net.parameters(), lr=LEARNING_RATE)

        # set up scheduler for lr
76
        n_training_steps = len(train_loader) * N_EPOCHS
77 78
        scheduler = get_linear_schedule_with_warmup(
            optimiser,
79
            num_warmup_steps=int(WARM_UP_FRAC * n_training_steps),
80 81 82 83 84 85 86 87 88 89 90
            num_training_steps=n_training_steps
        )

        start = time.time()

        for epoch_idx in range(N_EPOCHS):
            self.net.train()
            batch_loss = 0.0

            for batch_idx, batch in enumerate(train_loader):
                # send batch to gpu
91
                input_ids, attn_mask, entity_indices, entity_mask, labels = tuple(i.to(device) for i in batch)
92 93 94 95 96

                # zero param gradients
                optimiser.zero_grad()

                # forward pass
97
                output_scores = self.net(input_ids, attn_mask, entity_indices, entity_mask)
98 99

                # backward pass
100
                loss = loss_criterion(output_scores, labels)
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
                loss.backward()

                # clip gradient norm
                clip_grad_norm_(parameters=self.net.parameters(), max_norm=MAX_GRAD_NORM)

                # optimise
                optimiser.step()

                # update lr
                scheduler.step()

                # print interim stats every 500 batches
                batch_loss += loss.item()
                if batch_idx % 500 == 499:
                    batch_no = batch_idx + 1
                    print('epoch:', epoch_idx + 1, '-- progress: {:.4f}'.format(batch_no / len(train_loader)),
                          '-- batch:', batch_no, '-- avg loss:', batch_loss / 500)
                    batch_loss = 0.0

            print('epoch done')
121
            torch.save(self.net.state_dict(), '{}_epoch_{}.pt'.format(save_file, epoch_idx + 1))
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149

            if valid_data is not None:
                self.evaluate(data=valid_data)

        end = time.time()
        print('Training took', end - start, 'seconds')

    def evaluate(self, file_path=None, data=None, size=None):
        # load eval data
        if file_path is not None:
            test_data, _ = PairRelDataset.from_file(file_path, size=size)
        else:
            if data is None:
                raise AttributeError('file_path and data cannot both be None')
            test_data = data

        test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
                                 collate_fn=generate_batch)

        self.net.cuda()
        self.net.eval()

        outputs = []
        targets = []

        with torch.no_grad():
            for batch in test_loader:
                # send batch to gpu
150
                input_ids, attn_mask, entity_indices, entity_mask, labels = tuple(i.to(device) for i in batch)
151 152

                # forward pass
153
                output_scores = self.net(input_ids, attn_mask, entity_indices, entity_mask)
154 155 156
                _, output_labels = torch.max(output_scores.data, 1)

                outputs += output_labels.tolist()
157
                targets += labels.tolist()
158 159 160 161 162 163 164 165 166 167 168 169 170 171

        assert len(outputs) == len(targets)

        correct = (np.array(outputs) == np.array(targets))
        accuracy = correct.sum() / correct.size
        print('accuracy:', accuracy)

        cm = metrics.confusion_matrix(targets, outputs, labels=range(NUM_CLASSES))
        print('confusion matrix:')
        print(cm)

        f1 = metrics.f1_score(targets, outputs, labels=range(NUM_CLASSES), average='macro')
        print('macro F1:', f1)

172
        precision = metrics.precision_score(targets, outputs, average=None)
173 174
        print('precision:', precision)

175
        recall = metrics.recall_score(targets, outputs, average=None)
176 177
        print('recall:', recall)

178 179 180
    def extract_single_relation(self, text, entities):
        ins = RelInstance.from_sentence(text, entities)
        input_ids, attn_mask, entity_indices, entity_mask, _ = generate_production_batch([ins])
181 182 183 184 185 186

        self.net.cuda()
        self.net.eval()

        with torch.no_grad():
            # send batch to gpu
187 188 189
            input_ids, attn_mask, entity_indices, entity_mask = tuple(i.to(device) for i in [input_ids, attn_mask,
                                                                                             entity_indices,
                                                                                             entity_mask])
190 191

            # forward pass
192
            output_scores = softmax(self.net(input_ids, attn_mask, entity_indices, entity_mask), dim=1)
193 194
            _, output_labels = torch.max(output_scores.data, 1)

195
            ins.print_results_for_labels(output_labels)
196

197
    def extract_relations(self, n_aspects, aspect_index_map, aspect_counts, file_path=None, dataset=None, size=None):
198 199 200 201 202 203 204 205 206 207 208 209 210 211
        # load data
        if file_path is not None:
            data, _ = PairRelDataset.from_file(file_path, size=size)
        else:
            if dataset is None:
                raise AttributeError('file_path and data cannot both be None')
            data = dataset

        loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
                            collate_fn=generate_production_batch)

        self.net.cuda()
        self.net.eval()

212 213
        prob_matrix = np.zeros((n_aspects, n_aspects))
        count_matrix = np.zeros((n_aspects, n_aspects))
214 215

        with torch.no_grad():
216
            for input_ids, attn_mask, prod_indices, feat_indices, instances in loader:
217
                # send batch to gpu
218 219 220
                input_ids, attn_mask, prod_indices, feat_indices = tuple(i.to(device) for i in [input_ids, attn_mask,
                                                                                                prod_indices,
                                                                                                feat_indices])
221 222

                # forward pass
223
                output_scores = softmax(self.net(input_ids, attn_mask, prod_indices, feat_indices), dim=1)
224
                rel_scores = output_scores.narrow(1, 1, 2)
225

226 227 228 229 230 231 232
                for ins, scores in zip(instances, rel_scores.tolist()):
                    forward_score, backward_score = scores
                    fst_idx, snd_idx = aspect_index_map[ins.fst], aspect_index_map[ins.snd]
                    prob_matrix[snd_idx][fst_idx] += forward_score
                    prob_matrix[fst_idx][snd_idx] += backward_score
                    count_matrix[snd_idx][fst_idx] += 1
                    count_matrix[fst_idx][snd_idx] += 1
233

234
        return prob_matrix, count_matrix
235

236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
    def extract_relations2(self, n_aspects, dataset):
        loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
                            collate_fn=generate_production_batch)

        self.net.cuda()
        self.net.eval()

        prob_matrix = np.zeros((n_aspects, n_aspects))
        count_matrix = np.zeros((n_aspects, n_aspects))

        with torch.no_grad():
            for input_ids, attn_mask, entity_indices, combination_indices, instances in loader:
                # send batch to gpu
                input_ids, attn_mask, entity_indices, combination_indices = tuple(i.to(device) for i in
                                                                                  [input_ids, attn_mask,
                                                                                   entity_indices, combination_indices])

                # forward pass
                output_scores = softmax(self.net(input_ids, attn_mask, entity_indices, combination_indices), dim=1)
                rel_scores = output_scores.narrow(1, 1, 2).tolist()

                entity_pairs = [ep for instance in instances for ep in instance.entity_pairs]
                for ep, scores in zip(entity_pairs, rel_scores):
                    forward_score, backward_score = scores
                    prob_matrix[ep.snd.idx][ep.fst.idx] += forward_score
                    prob_matrix[ep.fst.idx][ep.snd.idx] += backward_score
                    count_matrix[ep.snd.idx][ep.fst.idx] += 1
                    count_matrix[ep.fst.idx][ep.snd.idx] += 1

        return prob_matrix, count_matrix

267

268 269 270
# extr: BertRelExtractor = BertRelExtractor.load_saved('multi_extractor_5_products_epoch_1.pt')
# extr.extract_single_relation('The mixer comes with a stainless steel bowl.',
#                              ['mixer', 'stainless steel', 'bowl'])