Commit 56fb62b9 authored by Joel Oksanen's avatar Joel Oksanen
Browse files

Major refactors in order to make server the main source folder

parent 57b5fe71
.idea *.idea
*.pt *.pt
__pycache__/ __pycache__/
amazon_data/ server/agent/amazon_data/
.DS_Store .DS_Store
This diff is collapsed.
...@@ -2,18 +2,16 @@ import torch ...@@ -2,18 +2,16 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tdbertnet import TDBertNet from agent.SA.tdbertnet import TDBertNet
from bert_dataset import BertDataset, Instance, polarity_indices, generate_batch from agent.SA.bert_dataset import BertDataset, Instance, polarity_indices, generate_batch
import time import time
import numpy as np import numpy as np
from sklearn import metrics from sklearn import metrics
import matplotlib.pyplot as plt
import shap
semeval_2014_train_path = 'data/SemEval-2014/Laptop_Train_v2.xml' semeval_2014_train_path = 'agent/SA/data/SemEval-2014/Laptop_Train_v2.xml'
semeval_2014_test_path = 'data/SemEval-2014/Laptops_Test_Gold.xml' semeval_2014_test_path = 'agent/SA/data/SemEval-2014/Laptops_Test_Gold.xml'
amazon_test_path = 'data/Amazon/annotated_amazon_laptop_reviews.xml' amazon_test_path = 'agent/SA/data/Amazon/annotated_amazon_laptop_reviews.xml'
trained_model_path = 'semeval_2014_2.pt' trained_model_path = 'agent/SA/semeval_2014_2.pt'
BATCH_SIZE = 32 BATCH_SIZE = 32
MAX_EPOCHS = 6 MAX_EPOCHS = 6
...@@ -29,7 +27,7 @@ class BertAnalyzer: ...@@ -29,7 +27,7 @@ class BertAnalyzer:
@staticmethod @staticmethod
def default(): def default():
sa = BertAnalyzer() sa = BertAnalyzer()
sa.load_saved('semeval_2014.pt') sa.load_saved('agent/SA/semeval_2014.pt')
return sa return sa
def load_saved(self, path): def load_saved(self, path):
...@@ -133,7 +131,3 @@ class BertAnalyzer: ...@@ -133,7 +131,3 @@ class BertAnalyzer:
# neutral or conflicted # neutral or conflicted
return 0 return 0
sentiment_analyzer = BertAnalyzer.default()
sentiment_analyzer.evaluate(semeval_2014_test_path)
sentiment_analyzer.evaluate(amazon_test_path)
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from transformers import * from transformers import *
from tdbertnet import TRAINED_WEIGHTS, HIDDEN_OUTPUT_FEATURES from agent.SA.tdbertnet import TRAINED_WEIGHTS, HIDDEN_OUTPUT_FEATURES
import re import re
MAX_SEQ_LEN = 128 MAX_SEQ_LEN = 128
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment