bert_analyzer.py 4.24 KB
Newer Older
1
2
3
4
5
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tdbertnet import TDBertNet
6
from bert_dataset import BertDataset, Instance, polarity_indices, generate_batch
7
8
9
import time
import numpy as np
from sklearn import metrics
10
import matplotlib.pyplot as plt
11
12
13

semeval_2014_train_path = 'data/SemEval-2014/Laptop_Train_v2.xml'
semeval_2014_test_path = 'data/SemEval-2014/Laptops_Test_Gold.xml'
14
amazon_test_path = 'data/Amazon/amazon_camera_test.xml'
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
trained_model_path = 'semeval_2014.pt'

BATCH_SIZE = 32
MAX_EPOCHS = 6
LEARNING_RATE = 0.00002
loss_criterion = nn.CrossEntropyLoss()


def loss(outputs, labels):
    return loss_criterion(outputs, labels)

class BertAnalyzer:

    def load_saved(self):
        self.net = TDBertNet(len(polarity_indices))
        self.net.load_state_dict(torch.load(trained_model_path))
        self.net.eval()

33
34
    def train(self, dataset):
        train_data = BertDataset(dataset)
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
                                  collate_fn=generate_batch)

        self.net = TDBertNet(len(polarity_indices))
        optimiser = optim.Adam(net.parameters(), lr=LEARNING_RATE)

        start = time.time()

        for epoch in range(MAX_EPOCHS):
            batch_loss = 0.0
            for i, (texts, target_indices, labels) in enumerate(train_loader):
                # zero param gradients
                optimiser.zero_grad()

                # forward pass
50
                outputs, _ = self.net(texts, target_indices)
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

                # backward pass
                l = loss(outputs, labels)
                l.backward()

                # optimise
                optimiser.step()

                # print interim stats every 10 batches
                batch_loss += l.item()
                if i % 10 == 9:
                    print('epoch:', epoch + 1, '-- batch:', i + 1, '-- avg loss:', batch_loss / 10)
                    batch_loss = 0.0

        end = time.time()
        print('Training took', end - start, 'seconds')

        torch.save(net.state_dict(), trained_model_path)

70
71
    def evaluate(self, dataset):
        test_data = BertDataset(dataset)
72
73
74
75
76
77
78
        test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
                                 collate_fn=generate_batch)

        predicted = []
        truths = []
        with torch.no_grad():
            for (texts, target_indices, labels) in test_loader:
79
                outputs, attentions = self.net(texts, target_indices)
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
                _, pred = torch.max(outputs.data, 1)
                predicted += pred.tolist()
                truths += labels.tolist()

        correct = (np.array(predicted) == np.array(truths))
        accuracy = correct.sum() / correct.size
        print('accuracy:', accuracy)

        cm = metrics.confusion_matrix(truths, predicted, labels=range(len(polarity_indices)))
        print('confusion matrix:')
        print(cm)

        f1 = metrics.f1_score(truths, predicted, labels=range(len(polarity_indices)), average='macro')
        print('macro F1:', f1)

95
96
    def analyze_sentence(self, text, char_from, char_to):
        instance = Instance(text, char_from, char_to)
97
        tokens, tg_from, tg_to = instance.get()
98
99
100
101
102
103
        texts, target_indices = instance.to_tensor()

        with torch.no_grad():
            outputs, attentions = self.net(texts, target_indices)

        target_attentions = torch.mean(attentions, 1)[0][tg_from+1:tg_to+2]
104
105
106
107
108
109
110
111
112
113
114
115
        mean_target_att = torch.mean(target_attentions, 0)

        # plot attention histogram
        att_values = mean_target_att.numpy()[1:-1]

        ax = plt.subplot(111)
        width = 0.3
        bins = [x - width/2 for x in range(1, len(att_values)+1)]
        ax.bar(bins, att_values, width=width)
        ax.set_xticks(list(range(1, len(att_values)+1)))
        ax.set_xticklabels(tokens, rotation=45, rotation_mode='anchor', ha='right')
        plt.show()
116
117
118
119

        _, pred = torch.max(outputs.data, 1)
        return pred

120
121
122

sentiment_analyzer = BertAnalyzer()
sentiment_analyzer.load_saved()
123
124
sentiment = sentiment_analyzer.analyze_sentence('I will never buy another computer from HP/Compaq or do business with Circuit City again.', 39, 48)
print('sentiment:', sentiment)