Commit ad94642b authored by Mao, Bojia's avatar Mao, Bojia
Browse files

recorrect onnx to device

parent a445c718
......@@ -94,8 +94,7 @@ class ONNXInterface:
return ort_outs
def to(self,device):
self.device = device
if self.defaultmodel != None:
self.defaultmodel.to(device)
def classify(self,image,text_input, image_tensor = None):
'''
Args:
......@@ -119,6 +118,7 @@ class ONNXInterface:
if self.defaultmodel == None:
self.defaultmodel = MMBT.from_pretrained("mmbt.hateful_memes.images")
self.defaultmodel.to(self.device)
logits = self.defaultmodel.classify(image, text_input, image_tensor=torch.squeeze(image_tensor.to(self.device), 0))
print("The output of model is invalid, here use default output instead")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment