Commit 1b41bc63 authored by Mao, Bojia's avatar Mao, Bojia
Browse files

recorrect onnx to device

parent e62f88a2
......@@ -11,7 +11,7 @@ from PIL import Image
from transformers import BertTokenizer, AutoTokenizer, BertTokenizerFast, XLNetTokenizer
from mmf.models.mmbt import MMBT
class ONNXInterface(torch.nn.Module):
class ONNXInterface:
def __init__(self,model_path,tokenizer = None):
self.model = onnx.load(model_path)
self.ort_session = onnxruntime.InferenceSession(model_path)
......@@ -92,7 +92,9 @@ class ONNXInterface(torch.nn.Module):
ort_outs = self.ort_session.run([output_name], ort_inputs)
return ort_outs
def to(self,decive):
if self.defaultmodel != None:
self.defaultmodel.to(device)
def classify(self,image,text_input, image_tensor = None):
'''
Args:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment