bert_entity_extractor.py 7.55 KB
Newer Older
1 2 3 4 5 6 7 8 9
import torch
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from torch.nn.functional import softmax
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import time
import numpy as np
from sklearn import metrics
10
import statistics
11 12
from transformers import get_linear_schedule_with_warmup
from agent.target_extraction.BERT.entity_extractor.entity_dataset import EntityDataset, generate_batch, generate_production_batch
13
from agent.target_extraction.BERT.entity_extractor.entitybertnet import NUM_CLASSES, EntityBertNet, BATCH_SIZE
14 15 16 17 18 19 20 21 22

device = torch.device('cuda')

# optimizer
DECAY_RATE = 0.01
LEARNING_RATE = 0.00002
MAX_GRAD_NORM = 1.0

# training
Joel Oksanen's avatar
Joel Oksanen committed
23
N_EPOCHS = 3
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
WARM_UP_FRAC = 0.05

# loss
loss_criterion = CrossEntropyLoss()


class BertEntityExtractor:

    def __init__(self):
        self.net = EntityBertNet()

    @staticmethod
    def load_saved(path):
        extr = BertEntityExtractor()
        extr.net = EntityBertNet()
        extr.net.load_state_dict(torch.load(path))
        extr.net.eval()
        return extr

    @staticmethod
    def new_trained_with_file(file_path, save_path, size=None):
        extractor = BertEntityExtractor()
        extractor.train_with_file(file_path, save_path, size=size)
        return extractor

    @staticmethod
Joel Oksanen's avatar
Joel Oksanen committed
50
    def train_and_validate(file_path, save_file, size=None, valid_frac=None, valid_file_path=None):
51
        extractor = BertEntityExtractor()
Joel Oksanen's avatar
Joel Oksanen committed
52
        extractor.train_with_file(file_path, save_file, size=size, valid_frac=valid_frac,
53 54 55
                                  valid_file_path=valid_file_path)
        return extractor

Joel Oksanen's avatar
Joel Oksanen committed
56
    def train_with_file(self, file_path, save_file, size=None, valid_frac=None, valid_file_path=None):
57 58 59 60 61 62
        # load training data
        if valid_file_path is None:
            train_data, valid_data = EntityDataset.from_file(file_path, size=size, valid_frac=valid_frac)
        else:
            train_size = int(size * (1 - valid_frac)) if size is not None else None
            train_data, _ = EntityDataset.from_file(file_path, size=train_size)
63
            valid_data, _ = EntityDataset.from_file(valid_file_path)
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
        train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
                                  collate_fn=generate_batch)

        # initialise BERT
        self.net.cuda()

        # set up optimizer with weight decay
        optimiser = Adam(self.net.parameters(), lr=LEARNING_RATE)

        # set up scheduler for lr
        n_training_steps = len(train_loader) * N_EPOCHS
        scheduler = get_linear_schedule_with_warmup(
            optimiser,
            num_warmup_steps=int(WARM_UP_FRAC * n_training_steps),
            num_training_steps=n_training_steps
        )

        start = time.time()

        for epoch_idx in range(N_EPOCHS):
            self.net.train()
            batch_loss = 0.0

            for batch_idx, batch in enumerate(train_loader):
                # send batch to gpu
                input_ids, attn_mask, entity_indices, target_labels = tuple(i.to(device) for i in batch)

                # zero param gradients
                optimiser.zero_grad()

                # forward pass
                output_scores = self.net(input_ids, attn_mask, entity_indices)

                # backward pass
                loss = loss_criterion(output_scores, target_labels)
                loss.backward()

                # clip gradient norm
                clip_grad_norm_(parameters=self.net.parameters(), max_norm=MAX_GRAD_NORM)

                # optimise
                optimiser.step()

                # update lr
                scheduler.step()

                # print interim stats every 250 batches
                batch_loss += loss.item()
                if batch_idx % 250 == 249:
                    batch_no = batch_idx + 1
                    print('epoch:', epoch_idx + 1, '-- progress: {:.4f}'.format(batch_no / len(train_loader)),
                          '-- batch:', batch_no, '-- avg loss:', batch_loss / 250)
                    batch_loss = 0.0

            print('epoch done')

120 121
            torch.save(self.net.state_dict(), '{}_epoch_{}.pt'.format(save_file, epoch_idx + 1))

122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
            if valid_data is not None:
                self.evaluate(data=valid_data)

        end = time.time()
        print('Training took', end - start, 'seconds')

    def evaluate(self, file_path=None, data=None, size=None):
        # load eval data
        if file_path is not None:
            test_data, _ = EntityDataset.from_file(file_path, size=size)
        else:
            if data is None:
                raise AttributeError('file_path and data cannot both be None')
            test_data = data

        test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
                                 collate_fn=generate_batch)

        self.net.cuda()
        self.net.eval()

        outputs = []
        targets = []

        with torch.no_grad():
            for batch in test_loader:
                # send batch to gpu
                input_ids, attn_mask, entity_indices, target_labels = tuple(i.to(device) for i in batch)

                # forward pass
                output_scores = self.net(input_ids, attn_mask, entity_indices)
                _, output_labels = torch.max(output_scores.data, 1)

                outputs += output_labels.tolist()
                targets += target_labels.tolist()

        assert len(outputs) == len(targets)

        correct = (np.array(outputs) == np.array(targets))
        accuracy = correct.sum() / correct.size
        print('accuracy:', accuracy)

        cm = metrics.confusion_matrix(targets, outputs, labels=range(NUM_CLASSES))
        print('confusion matrix:')
        print(cm)

        f1 = metrics.f1_score(targets, outputs, labels=range(NUM_CLASSES), average='macro')
        print('macro F1:', f1)

        precision = metrics.precision_score(targets, outputs, average=None)
        print('precision:', precision)

        recall = metrics.recall_score(targets, outputs, average=None)
        print('recall:', recall)

177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
    def extract_entity_probabilities(self, terms, file_path=None, dataset=None, size=None):
        # load data
        if file_path is not None:
            data, _ = EntityDataset.from_file(file_path, size=size)
        else:
            if dataset is None:
                raise AttributeError('file_path and data cannot both be None')
            data = dataset

        loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
                            collate_fn=generate_production_batch)

        self.net.cuda()
        self.net.eval()

        probs = {term: [] for term in terms}

        with torch.no_grad():
            for input_ids, attn_mask, entity_indices, instances in loader:
                # send batch to gpu
                input_ids, attn_mask, entity_indices = tuple(i.to(device) for i in [input_ids, attn_mask,
                                                                                    entity_indices])

                # forward pass
                output_scores = softmax(self.net(input_ids, attn_mask, entity_indices), dim=1)
                entity_scores = output_scores.narrow(1, 1, 1).flatten()

                for ins, score in zip(instances, entity_scores.tolist()):
                    probs[ins.entity].append(score)
206

207
        return {t: statistics.mean(t_probs) if len(t_probs) > 0 else None for t, t_probs in probs.items()}
208 209 210 211


BertEntityExtractor.train_and_validate('all_reviews_features.tsv', 'feature_extractor',
                                       valid_file_path='annotated_watch_review_features.tsv')