Commit 25f75e6a authored by Mao, Bojia's avatar Mao, Bojia
Browse files

Torchray clear print

parent 8e43d94f
......@@ -559,28 +559,4 @@ def text_explanation_presentation(input_text, image_tensor, image_path, model):
print(judge[output["label"]], " confidence: ", output["confidence"])
# if __name__ == "__main__":
#
# from mmf.models.mmbt import MMBT
# from custom_mmbt import MMBTGridHMInterfaceOnlyImage
#
# text = "How I want to say hello to Asian people"
#
# model = MMBT.from_pretrained("mmbt.hateful_memes.images")
# model = model.to(torch.device(
# "cuda:0" if torch.cuda.is_available() else "cpu"))
#
# image_path = "test_img.jpeg"
# image_tensor = image2tensor(image_path)
#
# # if device has some error just comment it
# #image_tensor = image_tensor.to("cuda:0")
#
# _out, out, = multi_extremal_perturbation(model,
# image_tensor,
# image_path,
# text,
# 0,
# reward_func=contrastive_reward,
# debug=True,
# areas=[0.12])
......@@ -13,15 +13,36 @@ from mmf.models.mmbt import MMBT
class ONNXInterface:
def __init__(self,model_path,tokenizer = None):
'''
Initilize interface by rebuild model from model path and tokenizer
'''
self.model = onnx.load(model_path)
self.ort_session = onnxruntime.InferenceSession(model_path)
if not onnx.checker.check_model(self.model):
assert("Model file error")
self.tokenizer = tokenizer
self.defaultmodel = None
self.device = "cpu"
if tokenizer != None:
if tokenizer == "BertTokenizer":
self.tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
elif tokenizer == "BertTokenizerFast":
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
elif tokenizer == "AutoTokenizer":
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
elif tokenizer == "XLNetTokenizer":
self.tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
else:
assert("NotImplementedError")
print("Please contact the development team to update")
def visualize(self):
'''
visualize model structure
'''
print(onnx.helper.printable_graph(self.model.graph))
def onnx_model_forward(self, image_input,text_input):
......@@ -47,51 +68,32 @@ class ONNXInterface:
exit(0)
break
img = to_numpy(image_input)
if self.tokenizer != None:
if self.tokenizer == "BertTokenizer":
Tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
elif self.tokenizer == "BertTokenizerFast":
Tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
elif self.tokenizer == "AutoTokenizer":
Tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
elif self.tokenizer == "XLNetTokenizer":
Tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
else:
assert("NotImplementedError")
print("Please contact the development team to update")
Tokenizer = self.tokenizer
if count == 3:
print("Assume it is bert model with only text input")
tokens = Tokenizer(text_input, return_tensors="pt")
ort_inputs = {k: v.cpu().detach().numpy() for k, v in tokens.items()}
elif count == 2:
print("Assume image and one text")
input1 = Tokenizer(text_input, return_tensors="pt")["input_ids"].squeeze().type(torch.float)
ort_inputs = {inputs[0]: img, inputs[1]: input1}
elif count == 4:
print("Assume bert + one image input")
input1 = Tokenizer(text_input, return_tensors="pt")["input_ids"].cpu().detach().numpy()
input2 = Tokenizer(text_input, return_tensors="pt")["token_type_ids"].cpu().detach().numpy()
input3 = Tokenizer(text_input, return_tensors="pt")["attention_mask"].cpu().detach().numpy()
ort_inputs = {inputs[0]: img, inputs[1]: input1,inputs[2]: input2,inputs[3]: input3}
else:
print("Assume only image input")
ort_inputs = {inputs[0] : img}
ort_outs = self.ort_session.run([output_name], ort_inputs)
return ort_outs
def to(self,device):
self.device = device
......@@ -122,7 +124,7 @@ class ONNXInterface:
self.defaultmodel = MMBT.from_pretrained("mmbt.hateful_memes.images")
self.defaultmodel.to(self.device)
logits = self.defaultmodel.classify(image, text_input, image_tensor=torch.squeeze(image_tensor.to(self.device), 0))
print("The output of model is invalid, here use default output instead")
scores = nn.functional.softmax(torch.tensor(logits), dim=1)
......@@ -145,6 +147,9 @@ def image2tensor(image_path):
def to_numpy(tensor):
"""
convert torch tensor to numpy array
"""
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
......
......@@ -3,8 +3,7 @@ from PIL import Image
def torchray_multimodal_explain(image_name, text, model, target, max_iteration=800):
print(image_name)
print(text)
image_path = "static/" + image_name
image = Image.open(image_path)
......@@ -39,7 +38,7 @@ def torchray_multimodal_explain(image_name, text, model, target, max_iteration=8
name_split_list = image_name.split(".")
exp_image = name_split_list[0] + "_torchray_img." + name_split_list[1]
PIL_image.save("static/" + exp_image)
print(txt_summary)
direction = ["non-hateful", "hateful"]
img_summary = (
"The key area that leads to "
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment