Commit 683e9e5a authored by JunqiJiang's avatar JunqiJiang
Browse files

Merge branch 'master' of...

Merge branch 'master' of https://gitlab.doc.ic.ac.uk/g207004202/explainable-multimodal-classification
parents 569f60f1 a0439675
......@@ -559,28 +559,4 @@ def text_explanation_presentation(input_text, image_tensor, image_path, model):
print(judge[output["label"]], " confidence: ", output["confidence"])
# if __name__ == "__main__":
#
# from mmf.models.mmbt import MMBT
# from custom_mmbt import MMBTGridHMInterfaceOnlyImage
#
# text = "How I want to say hello to Asian people"
#
# model = MMBT.from_pretrained("mmbt.hateful_memes.images")
# model = model.to(torch.device(
# "cuda:0" if torch.cuda.is_available() else "cpu"))
#
# image_path = "test_img.jpeg"
# image_tensor = image2tensor(image_path)
#
# # if device has some error just comment it
# #image_tensor = image_tensor.to("cuda:0")
#
# _out, out, = multi_extremal_perturbation(model,
# image_tensor,
# image_path,
# text,
# 0,
# reward_func=contrastive_reward,
# debug=True,
# areas=[0.12])
......@@ -13,15 +13,36 @@ from mmf.models.mmbt import MMBT
class ONNXInterface:
def __init__(self,model_path,tokenizer = None):
'''
Initilize interface by rebuild model from model path and tokenizer
'''
self.model = onnx.load(model_path)
self.ort_session = onnxruntime.InferenceSession(model_path)
if not onnx.checker.check_model(self.model):
assert("Model file error")
self.tokenizer = tokenizer
self.defaultmodel = None
self.device = "cpu"
if tokenizer != None:
if tokenizer == "BertTokenizer":
self.tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
elif tokenizer == "BertTokenizerFast":
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
elif tokenizer == "AutoTokenizer":
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
elif tokenizer == "XLNetTokenizer":
self.tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
else:
assert("NotImplementedError")
print("Please contact the development team to update")
def visualize(self):
'''
visualize model structure
'''
print(onnx.helper.printable_graph(self.model.graph))
def onnx_model_forward(self, image_input,text_input):
......@@ -47,53 +68,36 @@ class ONNXInterface:
exit(0)
break
img = to_numpy(image_input)
if self.tokenizer != None:
if self.tokenizer == "BertTokenizer":
Tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
elif self.tokenizer == "BertTokenizerFast":
Tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
elif self.tokenizer == "AutoTokenizer":
Tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
elif self.tokenizer == "XLNetTokenizer":
Tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
else:
assert("NotImplementedError")
print("Please contact the development team to update")
Tokenizer = self.tokenizer
if count == 3:
print("Assume it is bert model with only text input")
tokens = Tokenizer(text_input, return_tensors="pt")
ort_inputs = {k: v.cpu().detach().numpy() for k, v in token.items()}
ort_inputs = {k: v.cpu().detach().numpy() for k, v in tokens.items()}
elif count == 2:
print("Assume image and one text")
input1 = Tokenizer(text_input, return_tensors="pt")["input_ids"].squeeze().type(torch.float)
ort_inputs = {inputs[0]: img, inputs[1]: input1}
elif count == 4:
print("Assume bert + one image input")
input1 = Tokenizer(text_input, return_tensors="pt")["input_ids"].cpu().detach().numpy()
input2 = Tokenizer(text_input, return_tensors="pt")["token_type_ids"].cpu().detach().numpy()
input3 = Tokenizer(text_input, return_tensors="pt")["attention_mask"].cpu().detach().numpy()
ort_inputs = {inputs[0]: img, inputs[1]: input1,inputs[2]: input2,inputs[3]: input3}
else:
print("Assume only image input")
ort_inputs = {inputs[0] : img}
ort_outs = self.ort_session.run([output_name], ort_inputs)
return ort_outs
def classify(self,image_path,text_input, image_tensor = None):
def to(self,device):
self.device = device
def classify(self,image,text_input, image_tensor = None):
'''
Args:
image_path: directory of input image
......@@ -103,22 +107,28 @@ class ONNXInterface:
Returns :
label of model prediction and the corresponding confidence
'''
scoreFlag = False
if image_tensor != None:
scoreFlag = True
logits = self.onnx_model_forward(image_tensor,text_input)
else:
image_tensor = image2tensor(image_path)
p = transforms.Compose([transforms.Scale((224,224))])
image,i = imsc(p(image),quiet=True)
image_tensor = torch.reshape(image, (1,3,224,224))
logits = self.onnx_model_forward(image_tensor,text_input)
if list(torch.tensor(logits).size()) != [1, 2]:
if self.defaultmodel == None:
self.defaultmodel = MMBT.from_pretrained("mmbt.hateful_memes.images")
logits = self.defaultmodel.classify(image_path, text_input, image_tensor=torch.squeeze(image_tensor, 0))
print("The output of model is invalid, here use default output instead")
self.defaultmodel.to(self.device)
logits = self.defaultmodel.classify(image, text_input, image_tensor=torch.squeeze(image_tensor.to(self.device), 0))
scores = nn.functional.softmax(torch.tensor(logits), dim=1)
if image_tensor != None:
if scoreFlag == True:
return scores
confidence, label = torch.max(scores, dim=1)
......@@ -137,6 +147,9 @@ def image2tensor(image_path):
def to_numpy(tensor):
"""
convert torch tensor to numpy array
"""
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
......
......@@ -14,7 +14,7 @@ def loadImage(img):
img = Image.open(requests.get(img, stream=True).raw)
else:
img = Image.open(img)
else:
else: # pragma: no cover
sys.exit("ERROR: Unsupported img type. Abort")
return img
......@@ -8,6 +8,12 @@ from mmxai.text_removal.image_loader import loadImage
class SmartTextRemover:
def __init__(self, detector_path):
"""
Function to instantiate SmartTextRemover object.
INPUTS:
detector_path - str: Path to the text detector.
"""
self.__detector = cv.dnn.readNet(detector_path)
@property
......@@ -16,19 +22,26 @@ class SmartTextRemover:
def inpaint(self, image, dilation=0.02, method=cv.INPAINT_TELEA):
"""
function to inpaint the text region of a image
return an image as PIL Image format
Function to inpaint the text in side an image
INPUTS:
image - PIL.Image or numpy.ndarray or str:
The Image object or pixel array or path to the image.
dilation - float: The amount of dilation to apply to the text mask.
method: The inpainting algorithm to used. cv.INPAINT_TELEA or cv.INPAINT_NS.
RETURNS:
PIL.Image: The inpainted image.
"""
image = loadImage(image).convert("RGB")
image_array = np.array(image)
vertices = self.getTextBoxes(image_array, show_boxes=False)
vertices = self.getTextBoxes(image_array, debug_show_boxes=False)
mask = self.generateTextMask(
vertices,
image_array.shape[0],
image_array.shape[1],
enlargement=dilation,
dilation=dilation,
)
impainted_image = cv.inpaint(image_array, mask, 15, method)
......@@ -45,14 +58,25 @@ class SmartTextRemover:
nms_threshold=0.4,
inp_width=320,
inp_height=320,
show_boxes=False,
debug_show_boxes=False,
):
"""
Returns a numpy arrays storeing the vertices of the text boxes
Helper function to detect text inside an image return the text box vertices
INPUTS:
img_PIL - numpy.ndarray: Image pixel array in RGB format.
conf_threshold - float: Confidence threshold.
nms_threshold - float: Non-maximum suppression threshold.
inp_width - float: Preprocess input image by resizing to a specific width.
inp_height - float: Preprocess input image by resizing to a specific height.
debug_show_boxes - bool: If true, will show the textboxes in a cv.window.
RETURNS:
numpy.ndarray - The text box vertices in shape of (:, 4, 2).
"""
# Create a new named window if required
if show_boxes:
if debug_show_boxes: # pragma: no cover
kWinName = "EAST: An Efficient and Accurate Scene Text Detector"
cv.namedWindow(kWinName, cv.WINDOW_NORMAL)
......@@ -106,7 +130,7 @@ class SmartTextRemover:
vertices_all[n, :, :] = vertices
# Add the boxes to current frame if required
if show_boxes:
if debug_show_boxes: # pragma: no cover
for j in range(4):
p1 = (vertices[j][0].astype(int), vertices[j][1].astype(int))
p2 = (
......@@ -116,7 +140,7 @@ class SmartTextRemover:
cv.line(img, p1, p2, (0, 255, 0), 1)
# Display the image frame if required
if show_boxes:
if debug_show_boxes: # pragma: no cover
cv.imshow(kWinName, img)
# waits for user to press any key
......@@ -127,9 +151,19 @@ class SmartTextRemover:
return vertices_all
def generateTextMask(self, vertices, img_height, img_width, enlargement=0.0):
def generateTextMask(self, vertices, img_height, img_width, dilation=0.0):
"""
Function to generate a mask of of the text boxes
Helper function to transform text box vertices to a mask
INPUTS:
vertices - numpy.ndarray:
The text box vertices in shape of (:, 4, 2).
img_height - int: Image height measured in number of pixels
img_width - int: Image width measured in number of pixels
dilation - float: The amount of dilation to apply to the text mask.
RETURNS:
numpy.ndarray: Text mask of shape (img_height, img_width).
"""
# Generate the mesh grid
......@@ -139,8 +173,8 @@ class SmartTextRemover:
x, y = np.meshgrid(x_range, y_range, sparse=True)
# Enlarge all the text boxes for more coverage
del_x = enlargement * img_width
del_y = enlargement * img_height
del_x = dilation * img_width
del_y = dilation * img_height
vertices[:, :2, 0] -= del_x
vertices[:, 2:, 0] += del_x
......@@ -184,8 +218,16 @@ class SmartTextRemover:
def findGradientAndYIntersec(self, p1, p2):
"""
Function to find the gradient and the y-intersection of a straight line
Helper function to find the gradient and the y-intersection of a straight line
using the two-point formula
INPUTS:
p1 - numpy.ndarray: coordinate of point 1. p1 in shape of (2,).
p2 - numpy.ndarray: coordinate of point 2. p2 in shape of (2,).
RETURNS:
float: gradient
float: y intersection
"""
gradient = (p2[1] - p1[1]) / (p2[0] - p1[0])
y_intersec = -gradient * p1[0] + p1[1]
......@@ -193,6 +235,20 @@ class SmartTextRemover:
return gradient, y_intersec
def isOnRight(self, p1, p2, x, y):
"""
Helper function to check if the points are on the right of a straight line
defined by the two-point formula
INPUTS:
p1 - numpy.ndarray: coordinate of point 1. p1 in shape of (2,).
p2 - numpy.ndarray: coordinate of point 2. p2 in shape of (2,).
x - numpy.ndarray: x coordinates of the points to be checked.
y - numpy.ndarray: y coordinates of the points to be checked.
(x.shape must == y.shape)
RETURNS:
np.ndarray: array of bools
"""
gradient, y_intersec = self.findGradientAndYIntersec(p1, p2)
def findXFromY(y):
......@@ -203,6 +259,20 @@ class SmartTextRemover:
return x > x_line
def isAbove(self, p1, p2, x, y):
"""
Helper function to check if the points are above a straight line
defined by the two-point formula
INPUTS:
p1 - numpy.ndarray: coordinate of point 1. p1 in shape of (2,).
p2 - numpy.ndarray: coordinate of point 2. p2 in shape of (2,).
x - numpy.ndarray: x coordinates of the points to be checked.
y - numpy.ndarray: y coordinates of the points to be checked.
(x.shape must == y.shape)
RETURNS:
np.ndarray: array of bools
"""
gradient, y_intersec = self.findGradientAndYIntersec(p1, p2)
y_line = gradient * x + y_intersec
......@@ -210,7 +280,12 @@ class SmartTextRemover:
def convertRgbToBgr(self, img_RGB: np.ndarray):
"""
helper function to convert image array from RGB format to BGR
Helper function to convert image array from RGB format to BGR
INPUT:
img_RGB - numpy.ndarray: image pixel array in RBG format
RETURNS:
numpy.ndarray: image pixel array in BGR format
"""
img_BGR = img_RGB[:, :, ::-1].copy()
......@@ -218,6 +293,11 @@ class SmartTextRemover:
return img_BGR
def decodeBoundingBoxes(self, scores, geometry, scoreThresh):
"""
Helper function to decode bounding boxes.
See https://github.com/opencv/opencv/blob/master/samples/dnn/text_detection.py
for the example usage
"""
detections = []
confidences = []
......@@ -280,7 +360,7 @@ class SmartTextRemover:
return [detections, confidences]
if __name__ == "__main__":
if __name__ == "__main__": # pragma: no cover
remover = SmartTextRemover("mmxai/text_removal/frozen_east_text_detection.pb")
img = remover.inpaint("https://www.iqmetrix.com/hubfs/Meme%2021.jpg")
img.show()
......@@ -58,6 +58,8 @@ numba==0.53.1
numpy==1.20.2
oauthlib==3.1.0
omegaconf==2.0.6
onnx=1.9.0=pypi_0
onnxruntime=1.7.0=pypi_0
opencv-python==4.5.1.48
packaging==20.9
pandas==1.2.3
......@@ -79,7 +81,6 @@ python-dateutil==2.8.1
pytorch-lightning==1.2.7
pytz==2021.1
PyWavelets==1.1.1
pywin32==300
PyYAML==5.3.1
regex==2021.4.4
requests==2.23.0
......
from numpy.__config__ import show
from numpy.lib.arraysetops import isin
from mmxai.text_removal.smart_text_removal import SmartTextRemover
from PIL import Image
......@@ -112,7 +110,7 @@ def testGenerateTextMaskWithMaskEnlargement():
[2.5, 0.5],
[2.6, 2.6]]])
mask = remover.generateTextMask(vertices, img_height, img_width, enlargement=0.2)
mask = remover.generateTextMask(vertices, img_height, img_width, dilation=0.2)
ideal_mask = np.zeros((10, 10), dtype=np.uint8)
ideal_mask[:5, :5] = 255
......
tests/web_app/01245_shap_img.png

161 KB | W: | H:

tests/web_app/01245_shap_img.png

160 KB | W: | H:

tests/web_app/01245_shap_img.png
tests/web_app/01245_shap_img.png
tests/web_app/01245_shap_img.png
tests/web_app/01245_shap_img.png
  • 2-up
  • Swipe
  • Onion skin
tests/web_app/01245_shap_txt.png

11.1 KB | W: | H:

tests/web_app/01245_shap_txt.png

11.9 KB | W: | H:

tests/web_app/01245_shap_txt.png
tests/web_app/01245_shap_txt.png
tests/web_app/01245_shap_txt.png
tests/web_app/01245_shap_txt.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -304,7 +304,7 @@ def test_predict_model_built_error_handling(client):
def test_select_example(client):
url = '/selectExample'
data = {}
data["exampleID"] = "001"
data["exampleID"] = "10398"
response = client.post(url, data=data)
assert response.status_code == 302
......@@ -316,6 +316,6 @@ def test_fetch_example(client):
assert response.status_code == 302
with client.session_transaction() as session:
session["exampleID"] = "001"
session["exampleID"] = "10398"
response = client.post(url, data=data)
assert response.status_code == 302
......@@ -45,21 +45,17 @@ def before_request():
#ip_addr=request.headers['X-Real-Ip']
total_user = user_info.query.filter_by(ip_addr=ip_addr).all()
if len(total_user) > 3:
print("more than 3 !!!")
user_delete = user_info.query.filter_by(ip_addr=ip_addr).first()
db.session.delete(user_delete)
db.session.commit()
user_id = fm.generate_random_str(8)
fm.mkdir('./static/user/' + user_id)
print(user_id + "created")
session['user'] = user_id
file_name = session['user']
expired_time = datetime.now() + timedelta(days=1)
user_insert = user_info(file_name, ip_addr, expired_time)
db.session.add(user_insert)
db.session.commit()
else:
print(session.get('user') + " has existed")
@app.route("/")
......@@ -244,7 +240,6 @@ def inpaint():
inpainted_image_name = img_name_no_extension + "_inpainted" + img_extension
save_path = folder_path + inpainted_image_name
print(save_path)
if not os.path.isfile(save_path):
# Load the inpainter
try:
......@@ -330,9 +325,7 @@ def predict():
f"Your uploaded image and text combination "
f"looks like a <strong>{hateful}</strong> meme, with {cls_confidence * 100: .2f}% confidence. "
)
print(cls_result)
t0 = datetime.now()
try:
if exp_method == "shap":
text_exp, img_exp, txt_msg, img_msg = shap_mmf.shap_multimodal_explain(
......@@ -390,9 +383,6 @@ def predict():
)
return redirect(url_for("hateful_memes"))
elapsed = datetime.now() - t0
print(f"predicting using {exp_method} took {elapsed.seconds: .2f} seconds.")
session["clsResult"] = cls_result
session["imgText"] = img_text
session["textExp"] = text_exp
......@@ -401,12 +391,10 @@ def predict():
exp_text_visl, _ = os.path.splitext(img_exp)
exp_text_visl = exp_text_visl[:-4] + "_txt.png"
print(txt_msg, img_msg)
session["txtMsg"] = txt_msg
session["imgMsg"] = img_msg
try:
print(text_exp)
ut.text_visualisation(text_exp, cls_label, exp_text_visl)
session["textExp"] = exp_text_visl
except:
......@@ -416,11 +404,6 @@ def predict():
)
session["textExp"] = None
print(session["imgText"])
print(session["textExp"])
print(session["imgExp"])
print(session["modelPath"])
flash(
"Done! Hover over the output images to see how to interpret the results",
"success",
......
# -*- coding: UTF-8 -*-
from mmf.models.mmbt import MMBT
from mmf.models.fusions import LateFusion
from mmf.models.vilbert import ViLBERT
......@@ -12,6 +10,7 @@ import json
import re
from mmxai.onnx.onnxModel import ONNXInterface
class InputError(Exception):
def __init__(self, msg):
self.msg = msg
......@@ -31,9 +30,9 @@ def prepare_explanation(
try:
model = setup_model(user_model, model_type, model_path)
except InputError as e:
raise InputError(e.message()) from e # TODO: customise error classes
raise InputError(e.message()) from e
model = model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
img = Image.open(img_name)
try:
......@@ -51,8 +50,8 @@ def prepare_explanation(
return model, label_to_explain, label, conf
# get model output in binary classification format with 2 labels """
def model_output(cls_label, cls_confidence):
""" get model output in binary classification format with 2 labels """
out = np.zeros(2)
out[cls_label] = cls_confidence
out[1 - cls_label] = 1 - cls_confidence
......@@ -139,8 +138,6 @@ def text_visualisation(exp, pred_res, save_path):
else:
plt_title = "not hateful"
print(exp)
# handle different output formats from explainers
vals = []
names = []
......@@ -174,17 +171,6 @@ def text_visualisation(exp, pred_res, save_path):
def read_examples_metadata(path="static/examples/metadata.json"):
""" Function/Class description
Args:
param1:
Returns:
a dictionary
Raises:
KeyError: An example
"""
with open(path) as f:
s = json.load(f)
return s
......
......@@ -2,13 +2,19 @@ import os
import sqlite3
import random
from datetime import datetime
#This library includes functions for uploaded file management and database update
#Functions will be called by APScheduler in app.py
#Confiuration can be found in config.py
"""
This library includes functions for uploaded file management and database update
Functions will be called by APScheduler in app.py
Confiuration can be found in config.py
"""
#This method will generate a random bits(dafult to be 16) long string, which is used as the name of new user's directory
def generate_random_str(random_length=16):
"""
This method will generate a random bits(dafult to be 16) long string, which is used as the name of new user's directory
:param random_length: desired length of string
:return: randomly generated string
"""
random_str = ""
base_str = "ABCDEFGHIGKLMNOPQRSTUVWXYZabcdefghigklmnopqrstuvwxyz0123456789"
length = len(base_str) - 1