Commit bbbeb235 authored by Mao, Bojia's avatar Mao, Bojia
Browse files

onnx example

parent 3da72c68
......@@ -5,8 +5,11 @@ from torch import nn
import onnx
from onnx import helper, TensorProto, checker
from torchvision import transforms
from torchray.utils import imsc
from PIL import Image
from transformers import BertTokenizer, AutoTokenizer, BertTokenizerFast, XLNetTokenizer
from mmf.models.mmbt import MMBT
class ONNXInterface:
def __init__(self,model_path,tokenizer = None):
......@@ -16,22 +19,36 @@ class ONNXInterface:
assert("Model file error")
self.tokenizer = tokenizer
self.defaultmodel = None
def visualize(self):
print(onnx.helper.printable_graph(self.model.graph))
def onnx_model_forward(self, image_input,text_input):
'''
It is an model oriented function which will supports several models with different Input Type
Args:
image_input: the image torch.tensor with size (1,3,224,224)
text_input : the text input Str
Returns :
logits computed by model.forward List()
'''
output_name = self.ort_session.get_outputs()[0].name
input_name1 = self.ort_session.get_inputs()[0].name
input_name2 = self.ort_session.get_inputs()[1].name
inputs = []
count = 0
while True:
try:
input_name = self.ort_session.get_inputs()[count].name
inputs.append(input_name)
count += 1
except:
if count > 4:
print("The input model is not bert or MMF models, they are not supported please contact the development teams")
exit(0)
break
img = to_numpy(image_input)
......@@ -48,11 +65,30 @@ class ONNXInterface:
assert("NotImplementedError")
print("Please contact the development team to update")
input2 = Tokenizer(text_input, return_tensors="pt")["input_ids"].squeeze().type(torch.float)
input2 = to_numpy(input2)
ort_inputs = {input_name1: img, input_name2: input2}
if count == 3:
print("Assume it is bert model with only text input")
tokens = Tokenizer(text_input, return_tensors="pt")
ort_inputs = {k: v.cpu().detach().numpy() for k, v in token.items()}
elif count == 2:
print("Assume image and one text")
input1 = Tokenizer(text_input, return_tensors="pt")["input_ids"].squeeze().type(torch.float)
ort_inputs = {inputs[0]: img, inputs[1]: input1}
elif count == 4:
print("Assume bert + one image input")
input1 = Tokenizer(text_input, return_tensors="pt")["input_ids"].cpu().detach().numpy()
input2 = Tokenizer(text_input, return_tensors="pt")["token_type_ids"].cpu().detach().numpy()
input3 = Tokenizer(text_input, return_tensors="pt")["attention_mask"].cpu().detach().numpy()
ort_inputs = {inputs[0]: img, inputs[1]: input1,inputs[2]: input2,inputs[3]: input3}
else:
print("Assume only image input")
ort_inputs = {inputs[0] : img}
ort_outs = self.ort_session.run([output_name], ort_inputs)
return ort_outs
......@@ -72,6 +108,14 @@ class ONNXInterface:
else:
image_tensor = image2tensor(image_path)
logits = self.onnx_model_forward(image_tensor,text_input)
if list(torch.tensor(logits).size()) != [1, 2]:
if self.defaultmodel == None:
self.defaultmodel = MMBT.from_pretrained("mmbt.hateful_memes.images")
logits = self.defaultmodel.classify(image_path, text_input, image_tensor=torch.squeeze(image_tensor, 0))
print("The output of model is invalid, here use default output instead")
scores = nn.functional.softmax(torch.tensor(logits), dim=1)
if image_tensor != None:
......@@ -97,7 +141,8 @@ def to_numpy(tensor):
if __name__ == "__main__":
model_path = "model.onnx"
model_path = "Bert.onnx"
tokenizers = ["BertTokenizer","BertTokenizerFast","AutoTokenizer","XLNetTokenizer"]
tokenizer = tokenizers[0]
model = ONNXInterface(model_path,tokenizer)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment