Commit 28a62442 authored by cz1716's avatar cz1716
Browse files

Merge branch 'master' of gitlab.doc.ic.ac.uk:g207004202/explainable-multimodal-classification

parents f38a1a05 a445c718
......@@ -20,7 +20,7 @@ class ONNXInterface:
self.tokenizer = tokenizer
self.defaultmodel = None
self.device = "cpu"
def visualize(self):
print(onnx.helper.printable_graph(self.model.graph))
......@@ -92,7 +92,8 @@ class ONNXInterface:
ort_outs = self.ort_session.run([output_name], ort_inputs)
return ort_outs
def to(self,decive):
def to(self,device):
self.device = device
if self.defaultmodel != None:
self.defaultmodel.to(device)
def classify(self,image,text_input, image_tensor = None):
......@@ -118,7 +119,7 @@ class ONNXInterface:
if self.defaultmodel == None:
self.defaultmodel = MMBT.from_pretrained("mmbt.hateful_memes.images")
logits = self.defaultmodel.classify(image_path, text_input, image_tensor=torch.squeeze(image_tensor, 0))
logits = self.defaultmodel.classify(image, text_input, image_tensor=torch.squeeze(image_tensor.to(self.device), 0))
print("The output of model is invalid, here use default output instead")
scores = nn.functional.softmax(torch.tensor(logits), dim=1)
......
......@@ -58,6 +58,8 @@ numba==0.53.1
numpy==1.20.2
oauthlib==3.1.0
omegaconf==2.0.6
onnx=1.9.0=pypi_0
onnxruntime=1.7.0=pypi_0
opencv-python==4.5.1.48
packaging==20.9
pandas==1.2.3
......@@ -79,7 +81,6 @@ python-dateutil==2.8.1
pytorch-lightning==1.2.7
pytz==2021.1
PyWavelets==1.1.1
pywin32==300
PyYAML==5.3.1
regex==2021.4.4
requests==2.23.0
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment