Commit 8baedd6a authored by Mao, Bojia's avatar Mao, Bojia
Browse files

recorrect onnx to device

parent 1b41bc63
......@@ -118,7 +118,7 @@ class ONNXInterface:
if self.defaultmodel == None:
self.defaultmodel = MMBT.from_pretrained("mmbt.hateful_memes.images")
logits = self.defaultmodel.classify(image_path, text_input, image_tensor=torch.squeeze(image_tensor, 0))
logits = self.defaultmodel.classify(image, text_input, image_tensor=torch.squeeze("cuda"), 0))
print("The output of model is invalid, here use default output instead")
scores = nn.functional.softmax(torch.tensor(logits), dim=1)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment