1
+ import tempfile
2
+ from typing import Dict , List , Union
3
+ import numpy as np
4
+ from dds_cloudapi_sdk import (
5
+ DetectionTask ,
6
+ Client ,
7
+ Config ,
8
+ TextPrompt ,
9
+ DetectionModel ,
10
+ DetectionTarget ,
11
+ )
12
+ from PIL import Image
13
+ import concurrent .futures
14
+
15
+ class GroundingDINOAPIWrapper :
16
+ """API wrapper for Grounding DINO 1.5
17
+
18
+ Args:
19
+ token (str): The token for Grounding DINO 1.5 API. We are now opening free API access to Grounding DINO 1.5. For
20
+ educators, students, and researchers, we offer an API with extensive usage times to
21
+ support your educational and research endeavors. You can get free API token at here:
22
+ https://deepdataspace.com/request_api
23
+
24
+ """
25
+
26
+ def __init__ (self , token : str ):
27
+ self .client = Client (Config (token = token ))
28
+
29
+ def inference (self , prompt : Dict , return_mask :bool = False ):
30
+ """Main inference function of Grounding DINO 1.5. We take batch as input and
31
+ each image is a dict. N. We do not support batch inference for now.
32
+
33
+ Args:
34
+ prompts (dict): Annotations with the following keys:
35
+ - "image" (str): Path to image. E.g. "test1.jpg",
36
+ - "prompt" (str): Text prompt sepearted by '.' E.g. 'cate1 . cate2 . cate3'
37
+ return_mask (bool): Whether to return mask. Defaults to False.
38
+
39
+ Returns:
40
+ (Dict): Detection results in dict format with keys::
41
+ - "scores": (List[float]): A list of scores for each object in the batch
42
+ - "labels": (List[int]): A list of labels for each object in the batch
43
+ - "boxes": (List[List[int]]): A list of boxes for each object in the batch,
44
+ in format [xmin, ymin, xmax, ymax]
45
+ - "masks": (List[np.ndarray]): A list of segmentations for each object in the batch
46
+ """
47
+ # construct input prompts
48
+ image = self .get_image_url (prompt ["image" ]),
49
+ task = DetectionTask (
50
+ image_url = image [0 ],
51
+ prompts = [TextPrompt (text = prompt ['prompt' ])],
52
+ targets = [DetectionTarget .Mask , DetectionTarget .BBox ] if return_mask else [DetectionTarget .BBox ],
53
+ model = DetectionModel .GDino1_5_Pro ,
54
+ )
55
+ self .client .run_task (task )
56
+ result = task .result
57
+ return self .postprocess (result , task , return_mask )
58
+
59
+
60
+ def postprocess (self , result , task , return_mask ):
61
+ """Postprocess the result from the API call
62
+
63
+ Args:
64
+ result (TaskResult): Task result with the following keys:
65
+ - objects (List[DetectionObject]): Each DetectionObject has the following keys:
66
+ - bbox (List[float]): Box in xyxy format
67
+ - category (str): Detection category
68
+ - score (float): Detection score
69
+ - mask (DetectionObjectMask): Use mask.counts to parse RLE mask
70
+ task (DetectionTask): The task object
71
+ return_mask (bool): Whether to return mask
72
+
73
+ Returns:
74
+ (Dict): Return dict in format:
75
+ {
76
+ "scores": (List[float]): A list of scores for each object
77
+ "categorys": (List[str]): A list of categorys for each object
78
+ "boxes": (List[List[int]]): A list of boxes for each object
79
+ "masks": (List[PIL.Image]): A list of masks in the format of PIL.Image
80
+ }
81
+ """
82
+ def process_object_with_mask (object ):
83
+ box = object .bbox
84
+ score = object .score
85
+ category = object .category
86
+ mask = task .rle2rgba (object .mask )
87
+ return box , score , category , mask
88
+
89
+ def process_object_without_mask (object ):
90
+ box = object .bbox
91
+ score = object .score
92
+ category = object .category
93
+ mask = None
94
+ return box , score , category , mask
95
+
96
+ boxes , scores , categorys , masks = [], [], [], []
97
+ with concurrent .futures .ThreadPoolExecutor () as executor :
98
+ if return_mask :
99
+ process_object = process_object_with_mask
100
+ else :
101
+ process_object = process_object_without_mask
102
+ futures = [executor .submit (process_object , obj ) for obj in result .objects ]
103
+ for future in concurrent .futures .as_completed (futures ):
104
+ box , score , category , mask = future .result ()
105
+ boxes .append (box )
106
+ scores .append (score )
107
+ categorys .append (category )
108
+ if mask is not None :
109
+ masks .append (mask )
110
+
111
+ return dict (boxes = boxes , categorys = categorys , scores = scores , masks = masks )
112
+
113
+ def get_image_url (self , image : Union [str , np .ndarray ]):
114
+ """Upload Image to server and return the url
115
+
116
+ Args:
117
+ image (Union[str, np.ndarray]): The image to upload. Can be a file path or np.ndarray.
118
+ If it is a np.ndarray, it will be saved to a temporary file.
119
+
120
+ Returns:
121
+ str: The url of the image
122
+ """
123
+ if isinstance (image , str ):
124
+ url = self .client .upload_file (image )
125
+ else :
126
+ with tempfile .NamedTemporaryFile (delete = True , suffix = ".png" ) as tmp_file :
127
+ # image is in numpy format, convert to PIL Image
128
+ image = Image .fromarray (image )
129
+ image .save (tmp_file , format = "PNG" )
130
+ tmp_file_path = tmp_file .name
131
+ url = self .client .upload_file (tmp_file_path )
132
+ return url
0 commit comments