draw_keypoints() float support#8276
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/8276
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 4 Unrelated FailuresAs of commit c9fd6ca with merge base c8c3839 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
NicolasHug
left a comment
There was a problem hiding this comment.
Thanks a lot for the PR @GsnMithra !
I made a few comments below but the PR looks great overall, so I took the liberty to address these comments myself.
test/test_utils.py
Outdated
| keypoints_cp = keypoints.clone() | ||
|
|
There was a problem hiding this comment.
This doesn't seem to be used anywhere
| keypoints_cp = keypoints.clone() |
test/test_utils.py
Outdated
| torch.testing.assert_close(out[:, overlap], interpolated_overlap, rtol=0.0, atol=1.0) | ||
|
|
||
|
|
||
| def test_draw_keypoints_dtypes(): |
There was a problem hiding this comment.
Let's move that test down below, so that it is located next to the rest of the draw_keypoints tests. Right now it's in the middle of the draw_segmentation_mask tests which is a bit confusing.
test/test_utils.py
Outdated
|
|
||
|
|
||
| def test_draw_keypoints_dtypes(): | ||
| image_uint8 = torch.full((3, 100, 100), 0, dtype=torch.uint8) |
There was a problem hiding this comment.
This image should not be just zeros, otherwise it will be easy to miss subtle bugs. This should be the same as for the other test i.e.:
| image_uint8 = torch.full((3, 100, 100), 0, dtype=torch.uint8) | |
| torch.randint(0, 256, size=(3, 100, 100), dtype=torch.uint8) |
and in fact you'll see that there's a bug because the test will fail
torchvision/utils.py
Outdated
| ) | ||
|
|
||
| return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) | ||
| return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=original_dtype) |
There was a problem hiding this comment.
Just calling .to here won't scale the float images back down to [0, 1] so we would end up with a flaot image in [0, 255] (that's why the test would fail). It's best to just call to_dtype() for both the uint8 <-> float conversions.
|
Hey @NicolasHug, Thank you for pointing out my mistakes. While it might be second nature for you to find these bugs, I am still in the learning process. I apologize for making you go through my code once again. Thanks again. |
|
No problem at all @GsnMithra thank you for the PR |
Reviewed By: vmoens Differential Revision: D55062794 fbshipit-source-id: 1a9484e4959fef604153857cc7d4a6d7262cbea9 Co-authored-by: Nicolas Hug <contact@nicolas-hug.com> Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
Follow-up PR: #8150
Issue: #8138
Hey there!
I've added functionality to the draw_keypoints() method, allowing it to handle both uint8 and float32 image types.
I welcome any feedback you may have on these changes.
Thank you!