tag_list = ["bush", "shrub", "tree", "plant"]
tags = ". ".join(tag_list)
def get_grounding_output(
model: torch.nn.Module,
image: torch.Tensor,
caption: str,
box_threshold: float,
text_threshold: float,
device: str = "cpu",
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
"""
Process an image and caption through a model to generate grounded outputs,
including filtered bounding boxes and corresponding text phrases.
Parameters:
- model (torch.nn.Module): The model to process the input data.
- image (torch.Tensor): The image tensor.
- caption (str): The caption string related to the image.
- box_threshold (float): The threshold value to filter the bounding boxes based on confidence scores.
- text_threshold (float): The threshold value to filter the text based on logits.
- device (str, optional): The device type, 'cpu' or 'cuda', where the computation will take place. Defaults to 'cpu'.
Returns:
- tuple:
- filtered_boxes (torch.Tensor): The filtered bounding boxes.
- scores (torch.Tensor): The confidence scores of the phrases.
- pred_phrases (list of str): The predicted phrases associated with the bounding boxes.
"""
# Prepare caption
caption = caption.lower().strip()
if not caption.endswith("."):
caption += "."
# Move model and image to the specified device
model = model.to(device)
image = image.to(device)
# Generate predictions
try:
with torch.no_grad():
outputs = model(
image.unsqueeze(0), captions=[caption]
) # Ensure image is 4D
logits = outputs["pred_logits"].sigmoid()[0] # (num_queries, num_classes)
boxes = outputs["pred_boxes"][0] # (num_queries, 4)
# Filter outputs based on thresholds
max_logits = logits.max(dim=1)[0]
filt_mask = max_logits > box_threshold
logits_filt = logits[filt_mask]
boxes_filt = boxes[filt_mask]
# Prepare phrases and scores
tokenizer = model.tokenizer
tokenized = tokenizer(caption)
pred_phrases, scores = [], []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = groundingdino.util.utils.get_phrases_from_posmap(
logit > text_threshold, tokenized, tokenizer
)
pred_phrases.append(f"{pred_phrase} ({logit.max().item():.4f})")
scores.append(logit.max().item())
return boxes_filt, torch.tensor(scores), pred_phrases
except Exception as e:
raise Exception(f"An error occurred during model prediction: {e}")
# Find bounding boxes with grounding dino
boxes_filt, scores, pred_phrases = get_grounding_output(
dino_model,
image,
tags,
0.35,
0.25,
device=cfg.DEVICE,
)
boxes_filt =boxes_filt.cpu()
# Resize boxes
size = image_pil.size
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
# use NMS to handle overlapped boxes
nms_idx = (
torchvision.ops.nms(boxes_filt, scores, 0.5).numpy().tolist()
)
if cfg.DO_IOU_MERGE:
boxes_filt_clean = boxes_filt[nms_idx]
pred_phrases_clean = [pred_phrases[idx] for idx in nms_idx]
print(f"NMS: before {boxes_filt.shape[0]} boxes, after {boxes_filt_clean.shape[0]} boxes")
else:
boxes_filt_clean = boxes_filt
pred_phrases_clean = pred_phrases
def show_box(box: Iterable[float], ax: matplotlib.axes.Axes, label: str) -> None:
x0, y0 = box[0], box[1]
w, h = box[2] - x0, box[3] - y0
rect = plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor="none", lw=2)
ax.add_patch(rect)
ax.text(
x0,
y0,
label,
verticalalignment="top",
color="white",
fontsize=8,
bbox={"facecolor": "black", "alpha": 0.5},
)
return None
fig, axs = plt.subplots(1, 2, figsize=(10, 5), dpi=100, squeeze=False)
ax = axs[0, 0]
ax.imshow(image_np)
ax.set_title("Origineel", wrap=True)
ax.axis("off")
ax = axs[0, 1]
ax.imshow(image_np)
for box, label in zip(boxes_filt_clean, pred_phrases_clean):
show_box(box.numpy(), ax, label)
ax.set_title(f"GroundingDino tags: {tag_list}", wrap=True)
ax.axis("off")
fig.tight_layout();