Commit 81a02308 authored by cz1716's avatar cz1716
Browse files

add doc strings

parent 28a62442
......@@ -8,6 +8,12 @@ from mmxai.text_removal.image_loader import loadImage
class SmartTextRemover:
def __init__(self, detector_path):
"""
Function to instantiate SmartTextRemover object.
INPUTS:
detector_path - str: Path to the text detector.
"""
self.__detector = cv.dnn.readNet(detector_path)
@property
......@@ -16,19 +22,26 @@ class SmartTextRemover:
def inpaint(self, image, dilation=0.02, method=cv.INPAINT_TELEA):
"""
function to inpaint the text region of a image
return an image as PIL Image format
Function to inpaint the text in side an image
INPUTS:
image - PIL.Image or numpy.ndarray or str:
The Image object or pixel array or path to the image.
dilation - float: The amount of dilation to apply to the text mask.
method: The inpainting algorithm to used. cv.INPAINT_TELEA or cv.INPAINT_NS.
RETURNS:
PIL.Image: The inpainted image.
"""
image = loadImage(image).convert("RGB")
image_array = np.array(image)
vertices = self.getTextBoxes(image_array, show_boxes=False)
vertices = self.getTextBoxes(image_array, debug_show_boxes=False)
mask = self.generateTextMask(
vertices,
image_array.shape[0],
image_array.shape[1],
enlargement=dilation,
dilation=dilation,
)
impainted_image = cv.inpaint(image_array, mask, 15, method)
......@@ -45,14 +58,25 @@ class SmartTextRemover:
nms_threshold=0.4,
inp_width=320,
inp_height=320,
show_boxes=False,
debug_show_boxes=False,
):
"""
Returns a numpy arrays storeing the vertices of the text boxes
Helper function to detect text inside an image return the text box vertices
INPUTS:
img_PIL - numpy.ndarray: Image pixel array in RGB format.
conf_threshold - float: Confidence threshold.
nms_threshold - float: Non-maximum suppression threshold.
inp_width - float: Preprocess input image by resizing to a specific width.
inp_height - float: Preprocess input image by resizing to a specific height.
debug_show_boxes - bool: If true, will show the textboxes in a cv.window.
RETURNS:
numpy.ndarray - The text box vertices in shape of (:, 4, 2).
"""
# Create a new named window if required
if show_boxes: # pragma: no cover
if debug_show_boxes: # pragma: no cover
kWinName = "EAST: An Efficient and Accurate Scene Text Detector"
cv.namedWindow(kWinName, cv.WINDOW_NORMAL)
......@@ -106,7 +130,7 @@ class SmartTextRemover:
vertices_all[n, :, :] = vertices
# Add the boxes to current frame if required
if show_boxes: # pragma: no cover
if debug_show_boxes: # pragma: no cover
for j in range(4):
p1 = (vertices[j][0].astype(int), vertices[j][1].astype(int))
p2 = (
......@@ -116,7 +140,7 @@ class SmartTextRemover:
cv.line(img, p1, p2, (0, 255, 0), 1)
# Display the image frame if required
if show_boxes: # pragma: no cover
if debug_show_boxes: # pragma: no cover
cv.imshow(kWinName, img)
# waits for user to press any key
......@@ -127,9 +151,19 @@ class SmartTextRemover:
return vertices_all
def generateTextMask(self, vertices, img_height, img_width, enlargement=0.0):
def generateTextMask(self, vertices, img_height, img_width, dilation=0.0):
"""
Function to generate a mask of of the text boxes
Helper function to transform text box vertices to a mask
INPUTS:
vertices - numpy.ndarray:
The text box vertices in shape of (:, 4, 2).
img_height - int: Image height measured in number of pixels
img_width - int: Image width measured in number of pixels
dilation - float: The amount of dilation to apply to the text mask.
RETURNS:
numpy.ndarray: Text mask of shape (img_height, img_width).
"""
# Generate the mesh grid
......@@ -139,8 +173,8 @@ class SmartTextRemover:
x, y = np.meshgrid(x_range, y_range, sparse=True)
# Enlarge all the text boxes for more coverage
del_x = enlargement * img_width
del_y = enlargement * img_height
del_x = dilation * img_width
del_y = dilation * img_height
vertices[:, :2, 0] -= del_x
vertices[:, 2:, 0] += del_x
......@@ -184,8 +218,16 @@ class SmartTextRemover:
def findGradientAndYIntersec(self, p1, p2):
"""
Function to find the gradient and the y-intersection of a straight line
Helper function to find the gradient and the y-intersection of a straight line
using the two-point formula
INPUTS:
p1 - numpy.ndarray: coordinate of point 1. p1 in shape of (2,).
p2 - numpy.ndarray: coordinate of point 2. p2 in shape of (2,).
RETURNS:
float: gradient
float: y intersection
"""
gradient = (p2[1] - p1[1]) / (p2[0] - p1[0])
y_intersec = -gradient * p1[0] + p1[1]
......@@ -193,6 +235,20 @@ class SmartTextRemover:
return gradient, y_intersec
def isOnRight(self, p1, p2, x, y):
"""
Helper function to check if the points are on the right of a straight line
defined by the two-point formula
INPUTS:
p1 - numpy.ndarray: coordinate of point 1. p1 in shape of (2,).
p2 - numpy.ndarray: coordinate of point 2. p2 in shape of (2,).
x - numpy.ndarray: x coordinates of the points to be checked.
y - numpy.ndarray: y coordinates of the points to be checked.
(x.shape must == y.shape)
RETURNS:
np.ndarray: array of bools
"""
gradient, y_intersec = self.findGradientAndYIntersec(p1, p2)
def findXFromY(y):
......@@ -203,6 +259,20 @@ class SmartTextRemover:
return x > x_line
def isAbove(self, p1, p2, x, y):
"""
Helper function to check if the points are above a straight line
defined by the two-point formula
INPUTS:
p1 - numpy.ndarray: coordinate of point 1. p1 in shape of (2,).
p2 - numpy.ndarray: coordinate of point 2. p2 in shape of (2,).
x - numpy.ndarray: x coordinates of the points to be checked.
y - numpy.ndarray: y coordinates of the points to be checked.
(x.shape must == y.shape)
RETURNS:
np.ndarray: array of bools
"""
gradient, y_intersec = self.findGradientAndYIntersec(p1, p2)
y_line = gradient * x + y_intersec
......@@ -210,7 +280,12 @@ class SmartTextRemover:
def convertRgbToBgr(self, img_RGB: np.ndarray):
"""
helper function to convert image array from RGB format to BGR
Helper function to convert image array from RGB format to BGR
INPUT:
img_RGB - numpy.ndarray: image pixel array in RBG format
RETURNS:
numpy.ndarray: image pixel array in BGR format
"""
img_BGR = img_RGB[:, :, ::-1].copy()
......@@ -218,6 +293,11 @@ class SmartTextRemover:
return img_BGR
def decodeBoundingBoxes(self, scores, geometry, scoreThresh):
"""
Helper function to decode bounding boxes.
See https://github.com/opencv/opencv/blob/master/samples/dnn/text_detection.py
for the example usage
"""
detections = []
confidences = []
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment