Commit b397963d authored by Mao, Bojia's avatar Mao, Bojia
Browse files

recorrect onnx

parent 942bec66
......@@ -11,7 +11,7 @@ from PIL import Image
from transformers import BertTokenizer, AutoTokenizer, BertTokenizerFast, XLNetTokenizer
from mmf.models.mmbt import MMBT
class ONNXInterface:
class ONNXInterface(torch.nn.Module):
def __init__(self,model_path,tokenizer = None):
self.model = onnx.load(model_path)
self.ort_session = onnxruntime.InferenceSession(model_path)
......@@ -93,7 +93,7 @@ class ONNXInterface:
ort_outs = self.ort_session.run([output_name], ort_inputs)
return ort_outs
def classify(self,image_path,text_input, image_tensor = None):
def classify(self,image,text_input, image_tensor = None):
'''
Args:
image_path: directory of input image
......@@ -103,10 +103,13 @@ class ONNXInterface:
Returns :
label of model prediction and the corresponding confidence
'''
if image_tensor != None:
logits = self.onnx_model_forward(image_tensor,text_input)
else:
image_tensor = image2tensor(image_path)
p = transforms.Compose([transforms.Scale((224,224))])
image,i = imsc(p(image),quiet=True)
image_tensor = torch.reshape(image, (1,3,224,224))
logits = self.onnx_model_forward(image_tensor,text_input)
if list(torch.tensor(logits).size()) != [1, 2]:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment