Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
g207004202
explainable-multimodal-classification
Commits
1b41bc63
Commit
1b41bc63
authored
May 11, 2021
by
Mao, Bojia
Browse files
recorrect onnx to device
parent
e62f88a2
Changes
1
Hide whitespace changes
Inline
Side-by-side
mmxai/onnx/onnxModel.py
View file @
1b41bc63
...
...
@@ -11,7 +11,7 @@ from PIL import Image
from
transformers
import
BertTokenizer
,
AutoTokenizer
,
BertTokenizerFast
,
XLNetTokenizer
from
mmf.models.mmbt
import
MMBT
class
ONNXInterface
(
torch
.
nn
.
Module
)
:
class
ONNXInterface
:
def
__init__
(
self
,
model_path
,
tokenizer
=
None
):
self
.
model
=
onnx
.
load
(
model_path
)
self
.
ort_session
=
onnxruntime
.
InferenceSession
(
model_path
)
...
...
@@ -92,7 +92,9 @@ class ONNXInterface(torch.nn.Module):
ort_outs
=
self
.
ort_session
.
run
([
output_name
],
ort_inputs
)
return
ort_outs
def
to
(
self
,
decive
):
if
self
.
defaultmodel
!=
None
:
self
.
defaultmodel
.
to
(
device
)
def
classify
(
self
,
image
,
text_input
,
image_tensor
=
None
):
'''
Args:
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment