Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,23 @@ def test_draw_boxes():
assert_equal(img, img_cp)


@pytest.mark.parametrize("fill", [True, False])
def test_draw_boxes_dtypes(fill):
img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8)
out_uint8 = utils.draw_bounding_boxes(img_uint8, boxes, fill=fill)

assert img_uint8 is not out_uint8
assert out_uint8.dtype == torch.uint8

img_float = to_dtype(img_uint8, torch.float, scale=True)
out_float = utils.draw_bounding_boxes(img_float, boxes, fill=fill)

assert img_float is not out_float
assert out_float.is_floating_point()

torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)


@pytest.mark.parametrize("colors", [None, ["red", "blue", "#FF00FF", (1, 34, 122)], "red", "#FF00FF", (1, 34, 122)])
def test_draw_boxes_colors(colors):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
Expand Down Expand Up @@ -152,7 +169,6 @@ def test_draw_boxes_grayscale():

def test_draw_invalid_boxes():
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
Expand All @@ -162,8 +178,6 @@ def test_draw_invalid_boxes():

with pytest.raises(TypeError, match="Tensor expected"):
utils.draw_bounding_boxes(img_tp, boxes)
with pytest.raises(ValueError, match="Tensor uint8 expected"):
utils.draw_bounding_boxes(img_wrong1, boxes)
with pytest.raises(ValueError, match="Pass individual images, not batches"):
utils.draw_bounding_boxes(img_wrong2, boxes)
with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"):
Expand Down
23 changes: 15 additions & 8 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,12 @@ def draw_bounding_boxes(
) -> torch.Tensor:

"""
Draws bounding boxes on given image.
The values of the input image should be uint8 between 0 and 255.
Draws bounding boxes on given RGB image.
The image values should be uint8 in [0, 255] or float in [0, 1].
If fill is True, Resulting Tensor should be saved as PNG image.

Args:
image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
image (Tensor): Tensor of shape (C, H, W) and dtype uint8 or float.
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
`0 <= ymin < ymax < H`.
Expand All @@ -188,13 +188,14 @@ def draw_bounding_boxes(
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
"""
import torchvision.transforms.v2.functional as F # noqa

if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_bounding_boxes)
if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
elif not (image.dtype == torch.uint8 or image.is_floating_point()):
raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size(0) not in {1, 3}:
Expand Down Expand Up @@ -230,8 +231,11 @@ def draw_bounding_boxes(
if image.size(0) == 1:
image = torch.tile(image, (3, 1, 1))

ndarr = image.permute(1, 2, 0).cpu().numpy()
img_to_draw = Image.fromarray(ndarr)
original_dtype = image.dtype
if original_dtype.is_floating_point:
image = F.to_dtype(image, dtype=torch.uint8, scale=True)

img_to_draw = F.to_pil_image(image)
img_boxes = boxes.to(torch.int64).tolist()

if fill:
Expand All @@ -250,7 +254,10 @@ def draw_bounding_boxes(
margin = width + 1
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)

return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
out = F.pil_to_tensor(img_to_draw)
if original_dtype.is_floating_point:
out = F.to_dtype(out, dtype=original_dtype, scale=True)
return out


@torch.no_grad()
Expand Down