Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
g207004202
explainable-multimodal-classification
Commits
25f75e6a
Commit
25f75e6a
authored
May 11, 2021
by
Mao, Bojia
Browse files
Torchray clear print
parent
8e43d94f
Changes
3
Hide whitespace changes
Inline
Side-by-side
mmxai/interpretability/classification/torchray/extremal_perturbation/multimodal_extremal_perturbation.py
View file @
25f75e6a
...
...
@@ -559,28 +559,4 @@ def text_explanation_presentation(input_text, image_tensor, image_path, model):
print
(
judge
[
output
[
"label"
]],
" confidence: "
,
output
[
"confidence"
])
# if __name__ == "__main__":
#
# from mmf.models.mmbt import MMBT
# from custom_mmbt import MMBTGridHMInterfaceOnlyImage
#
# text = "How I want to say hello to Asian people"
#
# model = MMBT.from_pretrained("mmbt.hateful_memes.images")
# model = model.to(torch.device(
# "cuda:0" if torch.cuda.is_available() else "cpu"))
#
# image_path = "test_img.jpeg"
# image_tensor = image2tensor(image_path)
#
# # if device has some error just comment it
# #image_tensor = image_tensor.to("cuda:0")
#
# _out, out, = multi_extremal_perturbation(model,
# image_tensor,
# image_path,
# text,
# 0,
# reward_func=contrastive_reward,
# debug=True,
# areas=[0.12])
mmxai/onnx/onnxModel.py
View file @
25f75e6a
...
...
@@ -13,15 +13,36 @@ from mmf.models.mmbt import MMBT
class
ONNXInterface
:
def
__init__
(
self
,
model_path
,
tokenizer
=
None
):
'''
Initilize interface by rebuild model from model path and tokenizer
'''
self
.
model
=
onnx
.
load
(
model_path
)
self
.
ort_session
=
onnxruntime
.
InferenceSession
(
model_path
)
if
not
onnx
.
checker
.
check_model
(
self
.
model
):
assert
(
"Model file error"
)
self
.
tokenizer
=
tokenizer
self
.
defaultmodel
=
None
self
.
device
=
"cpu"
if
tokenizer
!=
None
:
if
tokenizer
==
"BertTokenizer"
:
self
.
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
elif
tokenizer
==
"BertTokenizerFast"
:
self
.
tokenizer
=
BertTokenizerFast
.
from_pretrained
(
"bert-base-cased"
)
elif
tokenizer
==
"AutoTokenizer"
:
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"bert-base-cased"
)
elif
tokenizer
==
"XLNetTokenizer"
:
self
.
tokenizer
=
XLNetTokenizer
.
from_pretrained
(
"xlnet-base-cased"
)
else
:
assert
(
"NotImplementedError"
)
print
(
"Please contact the development team to update"
)
def
visualize
(
self
):
'''
visualize model structure
'''
print
(
onnx
.
helper
.
printable_graph
(
self
.
model
.
graph
))
def
onnx_model_forward
(
self
,
image_input
,
text_input
):
...
...
@@ -47,51 +68,32 @@ class ONNXInterface:
exit
(
0
)
break
img
=
to_numpy
(
image_input
)
if
self
.
tokenizer
!=
None
:
if
self
.
tokenizer
==
"BertTokenizer"
:
Tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
elif
self
.
tokenizer
==
"BertTokenizerFast"
:
Tokenizer
=
BertTokenizerFast
.
from_pretrained
(
"bert-base-cased"
)
elif
self
.
tokenizer
==
"AutoTokenizer"
:
Tokenizer
=
AutoTokenizer
.
from_pretrained
(
"bert-base-cased"
)
elif
self
.
tokenizer
==
"XLNetTokenizer"
:
Tokenizer
=
XLNetTokenizer
.
from_pretrained
(
"xlnet-base-cased"
)
else
:
assert
(
"NotImplementedError"
)
print
(
"Please contact the development team to update"
)
Tokenizer
=
self
.
tokenizer
if
count
==
3
:
print
(
"Assume it is bert model with only text input"
)
tokens
=
Tokenizer
(
text_input
,
return_tensors
=
"pt"
)
ort_inputs
=
{
k
:
v
.
cpu
().
detach
().
numpy
()
for
k
,
v
in
tokens
.
items
()}
elif
count
==
2
:
print
(
"Assume image and one text"
)
input1
=
Tokenizer
(
text_input
,
return_tensors
=
"pt"
)[
"input_ids"
].
squeeze
().
type
(
torch
.
float
)
ort_inputs
=
{
inputs
[
0
]:
img
,
inputs
[
1
]:
input1
}
elif
count
==
4
:
print
(
"Assume bert + one image input"
)
input1
=
Tokenizer
(
text_input
,
return_tensors
=
"pt"
)[
"input_ids"
].
cpu
().
detach
().
numpy
()
input2
=
Tokenizer
(
text_input
,
return_tensors
=
"pt"
)[
"token_type_ids"
].
cpu
().
detach
().
numpy
()
input3
=
Tokenizer
(
text_input
,
return_tensors
=
"pt"
)[
"attention_mask"
].
cpu
().
detach
().
numpy
()
ort_inputs
=
{
inputs
[
0
]:
img
,
inputs
[
1
]:
input1
,
inputs
[
2
]:
input2
,
inputs
[
3
]:
input3
}
else
:
print
(
"Assume only image input"
)
ort_inputs
=
{
inputs
[
0
]
:
img
}
ort_outs
=
self
.
ort_session
.
run
([
output_name
],
ort_inputs
)
return
ort_outs
def
to
(
self
,
device
):
self
.
device
=
device
...
...
@@ -122,7 +124,7 @@ class ONNXInterface:
self
.
defaultmodel
=
MMBT
.
from_pretrained
(
"mmbt.hateful_memes.images"
)
self
.
defaultmodel
.
to
(
self
.
device
)
logits
=
self
.
defaultmodel
.
classify
(
image
,
text_input
,
image_tensor
=
torch
.
squeeze
(
image_tensor
.
to
(
self
.
device
),
0
))
print
(
"The output of model is invalid, here use default output instead"
)
scores
=
nn
.
functional
.
softmax
(
torch
.
tensor
(
logits
),
dim
=
1
)
...
...
@@ -145,6 +147,9 @@ def image2tensor(image_path):
def
to_numpy
(
tensor
):
"""
convert torch tensor to numpy array
"""
return
tensor
.
detach
().
cpu
().
numpy
()
if
tensor
.
requires_grad
else
tensor
.
cpu
().
numpy
()
...
...
web_app/interpretability4mmf/torchray_mmf.py
View file @
25f75e6a
...
...
@@ -3,8 +3,7 @@ from PIL import Image
def
torchray_multimodal_explain
(
image_name
,
text
,
model
,
target
,
max_iteration
=
800
):
print
(
image_name
)
print
(
text
)
image_path
=
"static/"
+
image_name
image
=
Image
.
open
(
image_path
)
...
...
@@ -39,7 +38,7 @@ def torchray_multimodal_explain(image_name, text, model, target, max_iteration=8
name_split_list
=
image_name
.
split
(
"."
)
exp_image
=
name_split_list
[
0
]
+
"_torchray_img."
+
name_split_list
[
1
]
PIL_image
.
save
(
"static/"
+
exp_image
)
print
(
txt_summary
)
direction
=
[
"non-hateful"
,
"hateful"
]
img_summary
=
(
"The key area that leads to "
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment