Commit d735c192 authored by Mao, Bojia's avatar Mao, Bojia
Browse files

recorrect onnx to device

parent 9c8d48af
......@@ -20,7 +20,7 @@ class ONNXInterface:
self.tokenizer = tokenizer
self.defaultmodel = None
self.device = "cpu"
def visualize(self):
print(onnx.helper.printable_graph(self.model.graph))
......@@ -93,6 +93,7 @@ class ONNXInterface:
ort_outs = self.ort_session.run([output_name], ort_inputs)
return ort_outs
def to(self,decive):
self.device = device
if self.defaultmodel != None:
self.defaultmodel.to(device)
def classify(self,image,text_input, image_tensor = None):
......@@ -118,7 +119,7 @@ class ONNXInterface:
if self.defaultmodel == None:
self.defaultmodel = MMBT.from_pretrained("mmbt.hateful_memes.images")
logits = self.defaultmodel.classify(image, text_input, image_tensor=torch.squeeze(image_tensor.to("cuda"), 0))
logits = self.defaultmodel.classify(image, text_input, image_tensor=torch.squeeze(image_tensor.to(self.device), 0))
print("The output of model is invalid, here use default output instead")
scores = nn.functional.softmax(torch.tensor(logits), dim=1)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment