Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
g207004202
explainable-multimodal-classification
Commits
bbbeb235
Commit
bbbeb235
authored
May 10, 2021
by
Mao, Bojia
Browse files
onnx example
parent
3da72c68
Changes
1
Hide whitespace changes
Inline
Side-by-side
mmxai/onnx/onnxModel.py
View file @
bbbeb235
...
...
@@ -5,8 +5,11 @@ from torch import nn
import
onnx
from
onnx
import
helper
,
TensorProto
,
checker
from
torchvision
import
transforms
from
torchray.utils
import
imsc
from
PIL
import
Image
from
transformers
import
BertTokenizer
,
AutoTokenizer
,
BertTokenizerFast
,
XLNetTokenizer
from
mmf.models.mmbt
import
MMBT
class
ONNXInterface
:
def
__init__
(
self
,
model_path
,
tokenizer
=
None
):
...
...
@@ -16,22 +19,36 @@ class ONNXInterface:
assert
(
"Model file error"
)
self
.
tokenizer
=
tokenizer
self
.
defaultmodel
=
None
def
visualize
(
self
):
print
(
onnx
.
helper
.
printable_graph
(
self
.
model
.
graph
))
def
onnx_model_forward
(
self
,
image_input
,
text_input
):
'''
It is an model oriented function which will supports several models with different Input Type
Args:
image_input: the image torch.tensor with size (1,3,224,224)
text_input : the text input Str
Returns :
logits computed by model.forward List()
'''
output_name
=
self
.
ort_session
.
get_outputs
()[
0
].
name
input_name1
=
self
.
ort_session
.
get_inputs
()[
0
].
name
input_name2
=
self
.
ort_session
.
get_inputs
()[
1
].
name
inputs
=
[]
count
=
0
while
True
:
try
:
input_name
=
self
.
ort_session
.
get_inputs
()[
count
].
name
inputs
.
append
(
input_name
)
count
+=
1
except
:
if
count
>
4
:
print
(
"The input model is not bert or MMF models, they are not supported please contact the development teams"
)
exit
(
0
)
break
img
=
to_numpy
(
image_input
)
...
...
@@ -48,11 +65,30 @@ class ONNXInterface:
assert
(
"NotImplementedError"
)
print
(
"Please contact the development team to update"
)
input2
=
Tokenizer
(
text_input
,
return_tensors
=
"pt"
)[
"input_ids"
].
squeeze
().
type
(
torch
.
float
)
input2
=
to_numpy
(
input2
)
ort_inputs
=
{
input_name1
:
img
,
input_name2
:
input2
}
if
count
==
3
:
print
(
"Assume it is bert model with only text input"
)
tokens
=
Tokenizer
(
text_input
,
return_tensors
=
"pt"
)
ort_inputs
=
{
k
:
v
.
cpu
().
detach
().
numpy
()
for
k
,
v
in
token
.
items
()}
elif
count
==
2
:
print
(
"Assume image and one text"
)
input1
=
Tokenizer
(
text_input
,
return_tensors
=
"pt"
)[
"input_ids"
].
squeeze
().
type
(
torch
.
float
)
ort_inputs
=
{
inputs
[
0
]:
img
,
inputs
[
1
]:
input1
}
elif
count
==
4
:
print
(
"Assume bert + one image input"
)
input1
=
Tokenizer
(
text_input
,
return_tensors
=
"pt"
)[
"input_ids"
].
cpu
().
detach
().
numpy
()
input2
=
Tokenizer
(
text_input
,
return_tensors
=
"pt"
)[
"token_type_ids"
].
cpu
().
detach
().
numpy
()
input3
=
Tokenizer
(
text_input
,
return_tensors
=
"pt"
)[
"attention_mask"
].
cpu
().
detach
().
numpy
()
ort_inputs
=
{
inputs
[
0
]:
img
,
inputs
[
1
]:
input1
,
inputs
[
2
]:
input2
,
inputs
[
3
]:
input3
}
else
:
print
(
"Assume only image input"
)
ort_inputs
=
{
inputs
[
0
]
:
img
}
ort_outs
=
self
.
ort_session
.
run
([
output_name
],
ort_inputs
)
return
ort_outs
...
...
@@ -72,6 +108,14 @@ class ONNXInterface:
else
:
image_tensor
=
image2tensor
(
image_path
)
logits
=
self
.
onnx_model_forward
(
image_tensor
,
text_input
)
if
list
(
torch
.
tensor
(
logits
).
size
())
!=
[
1
,
2
]:
if
self
.
defaultmodel
==
None
:
self
.
defaultmodel
=
MMBT
.
from_pretrained
(
"mmbt.hateful_memes.images"
)
logits
=
self
.
defaultmodel
.
classify
(
image_path
,
text_input
,
image_tensor
=
torch
.
squeeze
(
image_tensor
,
0
))
print
(
"The output of model is invalid, here use default output instead"
)
scores
=
nn
.
functional
.
softmax
(
torch
.
tensor
(
logits
),
dim
=
1
)
if
image_tensor
!=
None
:
...
...
@@ -97,7 +141,8 @@ def to_numpy(tensor):
if
__name__
==
"__main__"
:
model_path
=
"
model
.onnx"
model_path
=
"
Bert
.onnx"
tokenizers
=
[
"BertTokenizer"
,
"BertTokenizerFast"
,
"AutoTokenizer"
,
"XLNetTokenizer"
]
tokenizer
=
tokenizers
[
0
]
model
=
ONNXInterface
(
model_path
,
tokenizer
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment