Commit 63ca1361 authored by cz1716's avatar cz1716
Browse files

Merge branch 'master' of gitlab.doc.ic.ac.uk:g207004202/explainable-multimodal-classification

parents 81a02308 0b423535
......@@ -559,28 +559,4 @@ def text_explanation_presentation(input_text, image_tensor, image_path, model):
print(judge[output["label"]], " confidence: ", output["confidence"])
# if __name__ == "__main__":
#
# from mmf.models.mmbt import MMBT
# from custom_mmbt import MMBTGridHMInterfaceOnlyImage
#
# text = "How I want to say hello to Asian people"
#
# model = MMBT.from_pretrained("mmbt.hateful_memes.images")
# model = model.to(torch.device(
# "cuda:0" if torch.cuda.is_available() else "cpu"))
#
# image_path = "test_img.jpeg"
# image_tensor = image2tensor(image_path)
#
# # if device has some error just comment it
# #image_tensor = image_tensor.to("cuda:0")
#
# _out, out, = multi_extremal_perturbation(model,
# image_tensor,
# image_path,
# text,
# 0,
# reward_func=contrastive_reward,
# debug=True,
# areas=[0.12])
......@@ -13,15 +13,36 @@ from mmf.models.mmbt import MMBT
class ONNXInterface:
def __init__(self,model_path,tokenizer = None):
'''
Initilize interface by rebuild model from model path and tokenizer
'''
self.model = onnx.load(model_path)
self.ort_session = onnxruntime.InferenceSession(model_path)
if not onnx.checker.check_model(self.model):
assert("Model file error")
self.tokenizer = tokenizer
self.defaultmodel = None
self.device = "cpu"
if tokenizer != None:
if tokenizer == "BertTokenizer":
self.tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
elif tokenizer == "BertTokenizerFast":
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
elif tokenizer == "AutoTokenizer":
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
elif tokenizer == "XLNetTokenizer":
self.tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
else:
assert("NotImplementedError")
print("Please contact the development team to update")
def visualize(self):
'''
visualize model structure
'''
print(onnx.helper.printable_graph(self.model.graph))
def onnx_model_forward(self, image_input,text_input):
......@@ -47,55 +68,35 @@ class ONNXInterface:
exit(0)
break
img = to_numpy(image_input)
if self.tokenizer != None:
if self.tokenizer == "BertTokenizer":
Tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
elif self.tokenizer == "BertTokenizerFast":
Tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
elif self.tokenizer == "AutoTokenizer":
Tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
elif self.tokenizer == "XLNetTokenizer":
Tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
else:
assert("NotImplementedError")
print("Please contact the development team to update")
Tokenizer = self.tokenizer
if count == 3:
print("Assume it is bert model with only text input")
tokens = Tokenizer(text_input, return_tensors="pt")
ort_inputs = {k: v.cpu().detach().numpy() for k, v in tokens.items()}
elif count == 2:
print("Assume image and one text")
input1 = Tokenizer(text_input, return_tensors="pt")["input_ids"].squeeze().type(torch.float)
ort_inputs = {inputs[0]: img, inputs[1]: input1}
elif count == 4:
print("Assume bert + one image input")
input1 = Tokenizer(text_input, return_tensors="pt")["input_ids"].cpu().detach().numpy()
input2 = Tokenizer(text_input, return_tensors="pt")["token_type_ids"].cpu().detach().numpy()
input3 = Tokenizer(text_input, return_tensors="pt")["attention_mask"].cpu().detach().numpy()
ort_inputs = {inputs[0]: img, inputs[1]: input1,inputs[2]: input2,inputs[3]: input3}
else:
print("Assume only image input")
ort_inputs = {inputs[0] : img}
ort_outs = self.ort_session.run([output_name], ort_inputs)
return ort_outs
def to(self,device):
self.device = device
if self.defaultmodel != None:
self.defaultmodel.to(device)
def classify(self,image,text_input, image_tensor = None):
'''
Args:
......@@ -106,8 +107,10 @@ class ONNXInterface:
Returns :
label of model prediction and the corresponding confidence
'''
scoreFlag = False
if image_tensor != None:
scoreFlag = True
logits = self.onnx_model_forward(image_tensor,text_input)
else:
p = transforms.Compose([transforms.Scale((224,224))])
......@@ -119,12 +122,13 @@ class ONNXInterface:
if self.defaultmodel == None:
self.defaultmodel = MMBT.from_pretrained("mmbt.hateful_memes.images")
self.defaultmodel.to(self.device)
logits = self.defaultmodel.classify(image, text_input, image_tensor=torch.squeeze(image_tensor.to(self.device), 0))
print("The output of model is invalid, here use default output instead")
scores = nn.functional.softmax(torch.tensor(logits), dim=1)
if image_tensor != None:
if scoreFlag == True:
return scores
confidence, label = torch.max(scores, dim=1)
......@@ -143,6 +147,9 @@ def image2tensor(image_path):
def to_numpy(tensor):
"""
convert torch tensor to numpy array
"""
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
......
tests/web_app/01245_shap_img.png

161 KB | W: | H:

tests/web_app/01245_shap_img.png

160 KB | W: | H:

tests/web_app/01245_shap_img.png
tests/web_app/01245_shap_img.png
tests/web_app/01245_shap_img.png
tests/web_app/01245_shap_img.png
  • 2-up
  • Swipe
  • Onion skin
tests/web_app/01245_shap_txt.png

11.1 KB | W: | H:

tests/web_app/01245_shap_txt.png

11.9 KB | W: | H:

tests/web_app/01245_shap_txt.png
tests/web_app/01245_shap_txt.png
tests/web_app/01245_shap_txt.png
tests/web_app/01245_shap_txt.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -304,7 +304,7 @@ def test_predict_model_built_error_handling(client):
def test_select_example(client):
url = '/selectExample'
data = {}
data["exampleID"] = "001"
data["exampleID"] = "10398"
response = client.post(url, data=data)
assert response.status_code == 302
......@@ -316,6 +316,6 @@ def test_fetch_example(client):
assert response.status_code == 302
with client.session_transaction() as session:
session["exampleID"] = "001"
session["exampleID"] = "10398"
response = client.post(url, data=data)
assert response.status_code == 302
......@@ -45,21 +45,17 @@ def before_request():
#ip_addr=request.headers['X-Real-Ip']
total_user = user_info.query.filter_by(ip_addr=ip_addr).all()
if len(total_user) > 3:
print("more than 3 !!!")
user_delete = user_info.query.filter_by(ip_addr=ip_addr).first()
db.session.delete(user_delete)
db.session.commit()
user_id = fm.generate_random_str(8)
fm.mkdir('./static/user/' + user_id)
print(user_id + "created")
session['user'] = user_id
file_name = session['user']
expired_time = datetime.now() + timedelta(days=1)
user_insert = user_info(file_name, ip_addr, expired_time)
db.session.add(user_insert)
db.session.commit()
else:
print(session.get('user') + " has existed")
@app.route("/")
......@@ -244,7 +240,6 @@ def inpaint():
inpainted_image_name = img_name_no_extension + "_inpainted" + img_extension
save_path = folder_path + inpainted_image_name
print(save_path)
if not os.path.isfile(save_path):
# Load the inpainter
try:
......@@ -330,9 +325,7 @@ def predict():
f"Your uploaded image and text combination "
f"looks like a <strong>{hateful}</strong> meme, with {cls_confidence * 100: .2f}% confidence. "
)
print(cls_result)
t0 = datetime.now()
try:
if exp_method == "shap":
text_exp, img_exp, txt_msg, img_msg = shap_mmf.shap_multimodal_explain(
......@@ -390,9 +383,6 @@ def predict():
)
return redirect(url_for("hateful_memes"))
elapsed = datetime.now() - t0
print(f"predicting using {exp_method} took {elapsed.seconds: .2f} seconds.")
session["clsResult"] = cls_result
session["imgText"] = img_text
session["textExp"] = text_exp
......@@ -401,12 +391,10 @@ def predict():
exp_text_visl, _ = os.path.splitext(img_exp)
exp_text_visl = exp_text_visl[:-4] + "_txt.png"
print(txt_msg, img_msg)
session["txtMsg"] = txt_msg
session["imgMsg"] = img_msg
try:
print(text_exp)
ut.text_visualisation(text_exp, cls_label, exp_text_visl)
session["textExp"] = exp_text_visl
except:
......@@ -416,11 +404,6 @@ def predict():
)
session["textExp"] = None
print(session["imgText"])
print(session["textExp"])
print(session["imgExp"])
print(session["modelPath"])
flash(
"Done! Hover over the output images to see how to interpret the results",
"success",
......
# -*- coding: UTF-8 -*-
from mmf.models.mmbt import MMBT
from mmf.models.fusions import LateFusion
from mmf.models.vilbert import ViLBERT
......@@ -12,6 +10,7 @@ import json
import re
from mmxai.onnx.onnxModel import ONNXInterface
class InputError(Exception):
def __init__(self, msg):
self.msg = msg
......@@ -31,9 +30,9 @@ def prepare_explanation(
try:
model = setup_model(user_model, model_type, model_path)
except InputError as e:
raise InputError(e.message()) from e # TODO: customise error classes
raise InputError(e.message()) from e
model = model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
img = Image.open(img_name)
try:
......@@ -51,8 +50,8 @@ def prepare_explanation(
return model, label_to_explain, label, conf
# get model output in binary classification format with 2 labels """
def model_output(cls_label, cls_confidence):
""" get model output in binary classification format with 2 labels """
out = np.zeros(2)
out[cls_label] = cls_confidence
out[1 - cls_label] = 1 - cls_confidence
......@@ -139,8 +138,6 @@ def text_visualisation(exp, pred_res, save_path):
else:
plt_title = "not hateful"
print(exp)
# handle different output formats from explainers
vals = []
names = []
......@@ -174,17 +171,6 @@ def text_visualisation(exp, pred_res, save_path):
def read_examples_metadata(path="static/examples/metadata.json"):
""" Function/Class description
Args:
param1:
Returns:
a dictionary
Raises:
KeyError: An example
"""
with open(path) as f:
s = json.load(f)
return s
......
......@@ -2,13 +2,19 @@ import os
import sqlite3
import random
from datetime import datetime
#This library includes functions for uploaded file management and database update
#Functions will be called by APScheduler in app.py
#Confiuration can be found in config.py
"""
This library includes functions for uploaded file management and database update
Functions will be called by APScheduler in app.py
Confiuration can be found in config.py
"""
#This method will generate a random bits(dafult to be 16) long string, which is used as the name of new user's directory
def generate_random_str(random_length=16):
"""
This method will generate a random bits(dafult to be 16) long string, which is used as the name of new user's directory
:param random_length: desired length of string
:return: randomly generated string
"""
random_str = ""
base_str = "ABCDEFGHIGKLMNOPQRSTUVWXYZabcdefghigklmnopqrstuvwxyz0123456789"
length = len(base_str) - 1
......@@ -16,33 +22,43 @@ def generate_random_str(random_length=16):
random_str += base_str[random.randint(0, length)]
return random_str
#This function can create a new directory which path is passed in with parameter "path"
def mkdir(path):
"""
This function can create a new directory
:param path: file path
:return: boolean:hether the directory has been created
"""
path = path.strip()
path = path.rstrip("\\")
isExists = os.path.exists(path)
if not isExists:
os.makedirs(path)
print(path + "Directory created successfully")
return True
else:
print(path + "Directory has already existed!")
return False
#This method will delete all files under "path"
def clean(path):
"""
This method will delete all files under "path"
:param path: file path
:return:
"""
clear_dir(path)
for i in os.listdir(path):
path_file = os.path.join(path, i)
os.rmdir(path_file)
print("clean finished")
#This method will delete all files and directories under "path", then delete the parent directory.
def clear_dir(path):
"""
This method will delete all files and directories under "path", then delete the parent directory.
:param path: file path
:return:
"""
isExists = os.path.exists(path)
if not isExists:
print("no such dir")
return
else:
for i in os.listdir(path):
......@@ -55,12 +71,14 @@ def clear_dir(path):
if os.path.isfile(path_file2):
os.remove(path_file2)
#Periodic method called by APScheduler in app.py.
#It can check all user information in our local database and delete all expired information.
#Additionally, check the "staitic/user" directory to erase expired directories as well.
def check_database(object):
"""
Periodic method called by APScheduler in app.py.
It can check all user information in our local database and delete all expired information.
Additionally, check the "staitic/user" directory to erase expired directories as well.
"""
db = 'user_info.sqlite3'
#db = 'sqlite:///user_info.sqlite3'
con = sqlite3.connect(db)
cur = con.cursor()
select_sql = "select file_name,expired_time,id,ip_addr from user_info"
......@@ -82,8 +100,6 @@ def check_database(object):
for dir in os.listdir("./static/user"):
flag = 0
for row in date_set:
print(row)
print(type(row[2]))
if dir == row[0] and datetime.now() < datetime.strptime(row[1], "%Y-%m-%d %H:%M:%S.%f"):
flag = 1
if flag == 0:
......
......@@ -3,8 +3,7 @@ from PIL import Image
def torchray_multimodal_explain(image_name, text, model, target, max_iteration=800):
print(image_name)
print(text)
image_path = "static/" + image_name
image = Image.open(image_path)
......@@ -39,7 +38,7 @@ def torchray_multimodal_explain(image_name, text, model, target, max_iteration=8
name_split_list = image_name.split(".")
exp_image = name_split_list[0] + "_torchray_img." + name_split_list[1]
PIL_image.save("static/" + exp_image)
print(txt_summary)
direction = ["non-hateful", "hateful"]
img_summary = (
"The key area that leads to "
......
......@@ -30,7 +30,7 @@
"imgTexts": "when Japan thinks you are gonna invade their mainland",
"clsResult": "Your uploaded image and text combination looks like a <strong>HATEFUL</strong> meme, with 94.79% confidence.",
"shap": {
"modelType": "MMBT",
"modelType": "LateFusion",
"imgExp": "examples/40375/40375_shap_img.png",
"txtExp": "examples/40375/40375_shap_txt.png",
"imgMsg": "<p><strong>tl;dr:</strong><br> <span style=\"color: red\">Red</span> (<span style=\"color: blue\">Blue</span>) regions move the model output towards Hateful (Non-hateful).</p><p><strong>Details</strong>:<br>The input image is segmented into 52 regions, and text string is split into 9 features. The shapley values of those 61 features represent their additive contributions towards the model output for the current inclination selected, 0.948, on top of the base value. The base value, 0.0374, is the expected model output without those features. The sum of all shapley values and the base value should equate the selected model output, i.e.</p><p><em>model_output = base_value + total_image_shapley_values + total_text_shapley_values</em>.</p><p>The sum of shapley values for the image features is 0.5587.</p><span style=\"font-size: 0.75rem\">*note that the results may change slightly if the number of evaluations is small due to random sampling in the algorithm</span>",
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment