Skip to content

Commit c0f8c21

Browse files
committed
refactor code and add gradio
1 parent 11ed6fc commit c0f8c21

16 files changed

+477
-53
lines changed

LICENSE

Whitespace-only changes.

asset/demo.jpg

80.7 KB
Loading

asset/demo2.jpeg

59.6 KB
Loading

asset/demo3.jpeg

96.6 KB
Loading

asset/demo4.jpeg

417 KB
Loading

asset/demo5.jpeg

599 KB
Loading

asset/demo_output.jpg

85.6 KB
Loading

demo/demo.py

+28-53
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,29 @@
1+
import argparse
12
import os
2-
3-
from dds_cloudapi_sdk import Config
4-
from dds_cloudapi_sdk import Client
5-
from dds_cloudapi_sdk import DetectionTask
6-
from dds_cloudapi_sdk import TextPrompt
7-
from dds_cloudapi_sdk import DetectionModel
8-
from dds_cloudapi_sdk import DetectionTarget
9-
10-
# Step 1: initialize the config
11-
token = os.getenv("DDS_API")
12-
config = Config(token)
13-
14-
# Step 2: initialize the client
15-
client = Client(config)
16-
17-
# Step 3: run the task by DetectionTask class
18-
image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
19-
# if you are processing local image file, upload them to DDS server to get the image url
20-
# image_url = client.upload_file("/path/to/your/prompt/image.png")
21-
22-
task = DetectionTask(
23-
image_url=image_url,
24-
prompts=[TextPrompt(text="iron man")],
25-
targets=[DetectionTarget.Mask, DetectionTarget.BBox], # detect both bbox and mask
26-
model=DetectionModel.GDino1_5_Pro, # detect with GroundingDino-1.5-Pro model.
27-
# Available models: []
28-
)
29-
30-
client.run_task(task)
31-
result = task.result
32-
33-
print(result.mask_url)
34-
35-
objects = result.objects # the list of detected objects
36-
for idx, obj in enumerate(objects):
37-
print(obj.score) # 0.42
38-
39-
print(obj.category) # "iron man"
40-
41-
print(obj.bbox) # [635.0, 458.0, 704.0, 508.0]
42-
43-
print(
44-
obj.mask.counts
45-
) # RLE compressed to string, ]o`f08fa14M3L2O2M2O1O1O1O1N2O1N2O1N2N3M2O3L3M3N2M2N3N1N2O...
46-
47-
# convert the RLE format to RGBA image
48-
mask_image = task.rle2rgba(obj.mask)
49-
print(mask_image.size) # (1600, 1170)
50-
51-
# save the image to file
52-
mask_image.save(f"mask_{idx}.png")
53-
54-
break
3+
from gdino import GroundingDINOAPIWrapper, visualize
4+
from PIL import Image
5+
import numpy as np
6+
7+
8+
def get_args():
9+
parser = argparse.ArgumentParser(description="Interactive Inference")
10+
parser.add_argument(
11+
"--token",
12+
type=str,
13+
help="The token for T-Rex2 API. We are now opening free API access to T-Rex2",
14+
)
15+
parser.add_argument(
16+
"--box_threshold", type=float, default=0.3, help="The threshold for box score"
17+
)
18+
return parser.parse_args()
19+
20+
if __name__ == "__main__":
21+
args = get_args()
22+
gdino = GroundingDINOAPIWrapper(args.token)
23+
prompts = dict(image='asset/demo.jpg', prompt='person.pigeon.tree')
24+
results = gdino.inference(prompts)
25+
# now visualize the results
26+
image_pil = Image.open(prompts['image'])
27+
image_pil = visualize(image_pil, results)
28+
# dump the image to the disk
29+
image_pil.save('asset/demo_output.jpg')

gdino/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .model_wrapper import GroundingDINOAPIWrapper
2+
from .visualize import visualize
3+
4+
__all__ = ["GroundingDINOAPIWrapper", "visualize"]
3.09 KB
Binary file not shown.

gdino/model_wrapper.py

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import tempfile
2+
from typing import Dict, List, Union
3+
import numpy as np
4+
from dds_cloudapi_sdk import (
5+
DetectionTask,
6+
Client,
7+
Config,
8+
TextPrompt,
9+
DetectionModel,
10+
DetectionTarget,
11+
)
12+
from PIL import Image
13+
import concurrent.futures
14+
15+
class GroundingDINOAPIWrapper:
16+
"""API wrapper for Grounding DINO 1.5
17+
18+
Args:
19+
token (str): The token for Grounding DINO 1.5 API. We are now opening free API access to Grounding DINO 1.5. For
20+
educators, students, and researchers, we offer an API with extensive usage times to
21+
support your educational and research endeavors. You can get free API token at here:
22+
https://deepdataspace.com/request_api
23+
24+
"""
25+
26+
def __init__(self, token: str):
27+
self.client = Client(Config(token=token))
28+
29+
def inference(self, prompt: Dict, return_mask:bool=False):
30+
"""Main inference function of Grounding DINO 1.5. We take batch as input and
31+
each image is a dict. N. We do not support batch inference for now.
32+
33+
Args:
34+
prompts (dict): Annotations with the following keys:
35+
- "image" (str): Path to image. E.g. "test1.jpg",
36+
- "prompt" (str): Text prompt sepearted by '.' E.g. 'cate1 . cate2 . cate3'
37+
return_mask (bool): Whether to return mask. Defaults to False.
38+
39+
Returns:
40+
(Dict): Detection results in dict format with keys::
41+
- "scores": (List[float]): A list of scores for each object in the batch
42+
- "labels": (List[int]): A list of labels for each object in the batch
43+
- "boxes": (List[List[int]]): A list of boxes for each object in the batch,
44+
in format [xmin, ymin, xmax, ymax]
45+
- "masks": (List[np.ndarray]): A list of segmentations for each object in the batch
46+
"""
47+
# construct input prompts
48+
image=self.get_image_url(prompt["image"]),
49+
task=DetectionTask(
50+
image_url=image[0],
51+
prompts=[TextPrompt(text=prompt['prompt'])],
52+
targets=[DetectionTarget.Mask, DetectionTarget.BBox] if return_mask else [DetectionTarget.BBox],
53+
model=DetectionModel.GDino1_5_Pro,
54+
)
55+
self.client.run_task(task)
56+
result = task.result
57+
return self.postprocess(result, task, return_mask)
58+
59+
60+
def postprocess(self, result, task, return_mask):
61+
"""Postprocess the result from the API call
62+
63+
Args:
64+
result (TaskResult): Task result with the following keys:
65+
- objects (List[DetectionObject]): Each DetectionObject has the following keys:
66+
- bbox (List[float]): Box in xyxy format
67+
- category (str): Detection category
68+
- score (float): Detection score
69+
- mask (DetectionObjectMask): Use mask.counts to parse RLE mask
70+
task (DetectionTask): The task object
71+
return_mask (bool): Whether to return mask
72+
73+
Returns:
74+
(Dict): Return dict in format:
75+
{
76+
"scores": (List[float]): A list of scores for each object
77+
"categorys": (List[str]): A list of categorys for each object
78+
"boxes": (List[List[int]]): A list of boxes for each object
79+
"masks": (List[PIL.Image]): A list of masks in the format of PIL.Image
80+
}
81+
"""
82+
def process_object_with_mask(object):
83+
box = object.bbox
84+
score = object.score
85+
category = object.category
86+
mask = task.rle2rgba(object.mask)
87+
return box, score, category, mask
88+
89+
def process_object_without_mask(object):
90+
box = object.bbox
91+
score = object.score
92+
category = object.category
93+
mask = None
94+
return box, score, category, mask
95+
96+
boxes, scores, categorys, masks = [], [], [], []
97+
with concurrent.futures.ThreadPoolExecutor() as executor:
98+
if return_mask:
99+
process_object = process_object_with_mask
100+
else:
101+
process_object = process_object_without_mask
102+
futures = [executor.submit(process_object, obj) for obj in result.objects]
103+
for future in concurrent.futures.as_completed(futures):
104+
box, score, category, mask = future.result()
105+
boxes.append(box)
106+
scores.append(score)
107+
categorys.append(category)
108+
if mask is not None:
109+
masks.append(mask)
110+
111+
return dict(boxes=boxes, categorys=categorys, scores=scores, masks=masks)
112+
113+
def get_image_url(self, image: Union[str, np.ndarray]):
114+
"""Upload Image to server and return the url
115+
116+
Args:
117+
image (Union[str, np.ndarray]): The image to upload. Can be a file path or np.ndarray.
118+
If it is a np.ndarray, it will be saved to a temporary file.
119+
120+
Returns:
121+
str: The url of the image
122+
"""
123+
if isinstance(image, str):
124+
url = self.client.upload_file(image)
125+
else:
126+
with tempfile.NamedTemporaryFile(delete=True, suffix=".png") as tmp_file:
127+
# image is in numpy format, convert to PIL Image
128+
image = Image.fromarray(image)
129+
image.save(tmp_file, format="PNG")
130+
tmp_file_path = tmp_file.name
131+
url = self.client.upload_file(tmp_file_path)
132+
return url

gdino/version.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = 'v1.5'

gdino/visualize.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from typing import Dict
2+
3+
import numpy as np
4+
from PIL import Image, ImageDraw, ImageFont, ImageOps
5+
import random
6+
7+
8+
def draw_mask(mask, draw, random_color=True):
9+
"""Draws a mask with a specified color on an image.
10+
11+
Args:
12+
mask (np.array): Binary mask as a NumPy array.
13+
draw (ImageDraw.Draw): ImageDraw object to draw on the image.
14+
random_color (bool): Whether to use a random color for the mask.
15+
"""
16+
if random_color:
17+
color = (
18+
random.randint(0, 255),
19+
random.randint(0, 255),
20+
random.randint(0, 255),
21+
153,
22+
)
23+
else:
24+
color = (30, 144, 255, 153)
25+
26+
nonzero_coords = np.transpose(np.nonzero(mask))
27+
28+
for coord in nonzero_coords:
29+
draw.point(coord[::-1], fill=color)
30+
31+
def visualize(image_pil: Image,
32+
result: Dict,
33+
draw_width: float = 6.0,
34+
return_mask=True,
35+
draw_score=True) -> Image:
36+
"""Plot bounding boxes and labels on an image.
37+
38+
Args:
39+
image_pil (PIL.Image): The input image as a PIL Image object.
40+
result (Dict[str, Union[torch.Tensor, List[torch.Tensor]]]): The target dictionary containing
41+
the bounding boxes and labels. The keys are:
42+
- boxes (List[int]): A list of bounding boxes in shape (N, 4), [x1, y1, x2, y2] format.
43+
- scores (List[float]): A list of scores for each bounding box. shape (N)
44+
- categorys (List[str]): A list of categorys for each object
45+
- masks (List[PIL.Image]): A list of masks in the format of PIL.Image
46+
draw_score (bool): Draw score on the image. Defaults to False.
47+
48+
Returns:
49+
PIL.Image: The input image with plotted bounding boxes, labels, and masks.
50+
"""
51+
# Get the bounding boxes and labels from the target dictionary
52+
boxes = result["boxes"]
53+
scores = result["scores"]
54+
categorys = result["categorys"]
55+
masks = result.get("masks", [])
56+
57+
# Find all unique categories and build a cate2color dictionary
58+
cate2color = {}
59+
unique_categorys = set(categorys)
60+
for cate in unique_categorys:
61+
cate2color[cate] = tuple(np.random.randint(0, 255, size=3).tolist())
62+
63+
# Create a PIL ImageDraw object to draw on the input image
64+
if isinstance(image_pil, np.ndarray):
65+
image_pil = Image.fromarray(image_pil)
66+
draw = ImageDraw.Draw(image_pil)
67+
68+
# Create a new binary mask image with the same size as the input image
69+
mask = Image.new("L", image_pil.size, 0)
70+
# Create a PIL ImageDraw object to draw on the mask image
71+
mask_draw = ImageDraw.Draw(mask)
72+
73+
# Draw boxes, labels, and masks for each box and label in the target dictionary
74+
for box, score, category in zip(boxes, scores, categorys):
75+
# Extract the box coordinates
76+
x0, y0, x1, y1 = box
77+
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
78+
color = cate2color[category]
79+
80+
# Draw the box outline on the input image
81+
draw.rectangle([x0, y0, x1, y1], outline=color, width=int(draw_width))
82+
83+
# Draw the label and score on the input image
84+
if draw_score:
85+
text = f"{category} {score:.2f}"
86+
else:
87+
text = f"{category}"
88+
89+
font = ImageFont.load_default()
90+
if hasattr(font, "getbbox"):
91+
bbox = draw.textbbox((x0, y0), text, font)
92+
else:
93+
w, h = draw.textsize(text, font)
94+
bbox = (x0, y0, w + x0, y0 + h)
95+
draw.rectangle(bbox, fill=color)
96+
draw.text((x0, y0), text, fill="white")
97+
98+
# Draw the mask on the input image if masks are provided
99+
if len(masks) > 0 and return_mask:
100+
size = image_pil.size
101+
mask_image = Image.new("RGBA", size, color=(0, 0, 0, 0))
102+
mask_draw = ImageDraw.Draw(mask_image)
103+
for mask in masks:
104+
mask = np.array(mask)[:, :, -1]
105+
draw_mask(mask, mask_draw)
106+
107+
image_pil = Image.alpha_composite(image_pil.convert("RGBA"), mask_image).convert("RGB")
108+
return image_pil

0 commit comments

Comments
 (0)