bert_analyzer.py 3.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tdbertnet import TDBertNet
from bert_dataset import BertDataset, polarity_indices, generate_batch
import time
import numpy as np
from sklearn import metrics

semeval_2014_train_path = 'data/SemEval-2014/Laptop_Train_v2.xml'
semeval_2014_test_path = 'data/SemEval-2014/Laptops_Test_Gold.xml'
13
amazon_test_path = 'data/Amazon/amazon_camera_test.xml'
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
trained_model_path = 'semeval_2014.pt'

BATCH_SIZE = 32
MAX_EPOCHS = 6
LEARNING_RATE = 0.00002
loss_criterion = nn.CrossEntropyLoss()


def loss(outputs, labels):
    return loss_criterion(outputs, labels)

class BertAnalyzer:

    def load_saved(self):
        self.net = TDBertNet(len(polarity_indices))
        self.net.load_state_dict(torch.load(trained_model_path))
        self.net.eval()

32
33
    def train(self, dataset):
        train_data = BertDataset(dataset)
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
                                  collate_fn=generate_batch)

        self.net = TDBertNet(len(polarity_indices))
        optimiser = optim.Adam(net.parameters(), lr=LEARNING_RATE)

        start = time.time()

        for epoch in range(MAX_EPOCHS):
            batch_loss = 0.0
            for i, (texts, target_indices, labels) in enumerate(train_loader):
                # zero param gradients
                optimiser.zero_grad()

                # forward pass
                outputs = self.net(texts, target_indices)

                # backward pass
                l = loss(outputs, labels)
                l.backward()

                # optimise
                optimiser.step()

                # print interim stats every 10 batches
                batch_loss += l.item()
                if i % 10 == 9:
                    print('epoch:', epoch + 1, '-- batch:', i + 1, '-- avg loss:', batch_loss / 10)
                    batch_loss = 0.0

        end = time.time()
        print('Training took', end - start, 'seconds')

        torch.save(net.state_dict(), trained_model_path)

69
70
    def evaluate(self, dataset):
        test_data = BertDataset(dataset)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
                                 collate_fn=generate_batch)

        predicted = []
        truths = []
        with torch.no_grad():
            for (texts, target_indices, labels) in test_loader:
                outputs = self.net(texts, target_indices)
                _, pred = torch.max(outputs.data, 1)
                predicted += pred.tolist()
                truths += labels.tolist()

        correct = (np.array(predicted) == np.array(truths))
        accuracy = correct.sum() / correct.size
        print('accuracy:', accuracy)

        cm = metrics.confusion_matrix(truths, predicted, labels=range(len(polarity_indices)))
        print('confusion matrix:')
        print(cm)

        f1 = metrics.f1_score(truths, predicted, labels=range(len(polarity_indices)), average='macro')
        print('macro F1:', f1)


sentiment_analyzer = BertAnalyzer()
sentiment_analyzer.load_saved()
97
sentiment_analyzer.evaluate(amazon_test_path)