-
Notifications
You must be signed in to change notification settings - Fork 1.4k
4609: Add AUC-Margin Loss for AUROC optimization #8719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
shubham-61969
wants to merge
5
commits into
Project-MONAI:dev
Choose a base branch
from
shubham-61969:4609-aucm-loss
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
3ee0c07
Add AUC-Margin loss for AUROC optimization (#4609)
shubham-61969 f1d38f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] c550c29
Correct masked mean computation in AUCMLoss and update docstrings
shubham-61969 2a56f54
Validate binary targets, clarify reduction, and fix AUCM typing
shubham-61969 448c5df
Validate imratio and input shape, added test cases for it and fix non…
shubham-61969 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from torch.nn.modules.loss import _Loss | ||
|
|
||
| from monai.utils import LossReduction | ||
|
|
||
|
|
||
| class AUCMLoss(_Loss): | ||
| """ | ||
| AUC-Margin loss with squared-hinge surrogate loss for optimizing AUROC. | ||
|
|
||
| The loss optimizes the Area Under the ROC Curve (AUROC) by using margin-based constraints | ||
| on positive and negative predictions. It supports two versions: 'v1' includes class prior | ||
| information, while 'v2' removes this dependency for better generalization. | ||
|
|
||
| Reference: | ||
| Yuan, Zhuoning, Yan, Yan, Sonka, Milan, and Yang, Tianbao. | ||
| "Large-scale robust deep auc maximization: A new surrogate loss and empirical studies on medical image classification." | ||
| Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021. | ||
| https://arxiv.org/abs/2012.03173 | ||
|
|
||
| Implementation based on: https://github.com/Optimization-AI/LibAUC/blob/1.4.0/libauc/losses/auc.py | ||
|
|
||
| Example: | ||
| >>> import torch | ||
| >>> from monai.losses import AUCMLoss | ||
| >>> loss_fn = AUCMLoss() | ||
| >>> input = torch.randn(32, 1, requires_grad=True) | ||
| >>> target = torch.randint(0, 2, (32, 1)).float() | ||
| >>> loss = loss_fn(input, target) | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| margin: float = 1.0, | ||
| imratio: float | None = None, | ||
| version: str = "v1", | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| ) -> None: | ||
| """ | ||
| Args: | ||
| margin: margin for squared-hinge surrogate loss (default: ``1.0``). | ||
| imratio: the ratio of the number of positive samples to the number of total samples in the training dataset. | ||
| If this value is not given, it will be automatically calculated with mini-batch samples. | ||
| This value is ignored when ``version`` is set to ``'v2'``. | ||
| version: whether to include prior class information in the objective function (default: ``'v1'``). | ||
| 'v1' includes class prior, 'v2' removes this dependency. | ||
| reduction: {``"none"``, ``"mean"``, ``"sum"``} | ||
| Specifies the reduction to apply to the output. Defaults to ``"mean"``. | ||
| Note: This loss is computed at the batch level and always returns a scalar. | ||
| The reduction parameter is accepted for API consistency but has no effect. | ||
|
|
||
| Raises: | ||
| ValueError: When ``version`` is not one of ["v1", "v2"]. | ||
| ValueError: When ``imratio`` is not in [0, 1]. | ||
|
|
||
| Example: | ||
| >>> import torch | ||
| >>> from monai.losses import AUCMLoss | ||
| >>> loss_fn = AUCMLoss(version='v2') | ||
| >>> input = torch.randn(32, 1, requires_grad=True) | ||
| >>> target = torch.randint(0, 2, (32, 1)).float() | ||
| >>> loss = loss_fn(input, target) | ||
| """ | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| if version not in ["v1", "v2"]: | ||
| raise ValueError(f"version should be 'v1' or 'v2', got {version}") | ||
| if imratio is not None and not (0.0 <= imratio <= 1.0): | ||
| raise ValueError(f"imratio must be in [0, 1], got {imratio}") | ||
| self.margin = margin | ||
| self.imratio = imratio | ||
| self.version = version | ||
| self.a = nn.Parameter(torch.tensor(0.0)) | ||
| self.b = nn.Parameter(torch.tensor(0.0)) | ||
| self.alpha = nn.Parameter(torch.tensor(0.0)) | ||
|
|
||
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Args: | ||
| input: the shape should be B1HW[D], where the channel dimension is 1 for binary classification. | ||
| target: the shape should be B1HW[D], with values 0 or 1. | ||
|
|
||
| Returns: | ||
| torch.Tensor: scalar AUCM loss. | ||
|
|
||
| Raises: | ||
| ValueError: When input or target have incorrect shapes. | ||
| ValueError: When input or target have fewer than 2 dimensions. | ||
| ValueError: When target contains non-binary values. | ||
| """ | ||
| if input.ndim < 2 or target.ndim < 2: | ||
| raise ValueError("Input and target must have at least 2 dimensions (B, C, ...)") | ||
| if input.shape[1] != 1: | ||
| raise ValueError(f"Input should have 1 channel for binary classification, got {input.shape[1]}") | ||
| if target.shape[1] != 1: | ||
| raise ValueError(f"Target should have 1 channel, got {target.shape[1]}") | ||
| if input.shape != target.shape: | ||
| raise ValueError(f"Input and target shapes do not match: {input.shape} vs {target.shape}") | ||
|
|
||
| input = input.flatten() | ||
| target = target.flatten() | ||
|
|
||
| if not torch.all((target == 0) | (target == 1)): | ||
| raise ValueError("Target must contain only binary values (0 or 1)") | ||
|
|
||
| pos_mask = (target == 1).float() | ||
| neg_mask = (target == 0).float() | ||
|
|
||
| if self.version == "v1": | ||
| p = float(self.imratio) if self.imratio is not None else float(pos_mask.mean().item()) | ||
| loss = ( | ||
| (1 - p) * self._safe_mean((input - self.a) ** 2, pos_mask) | ||
| + p * self._safe_mean((input - self.b) ** 2, neg_mask) | ||
| + 2 | ||
| * self.alpha | ||
| * ( | ||
| p * (1 - p) * self.margin | ||
| + self._safe_mean(p * input * neg_mask - (1 - p) * input * pos_mask, pos_mask + neg_mask) | ||
| ) | ||
| - p * (1 - p) * self.alpha**2 | ||
| ) | ||
| else: | ||
| loss = ( | ||
| self._safe_mean((input - self.a) ** 2, pos_mask) | ||
| + self._safe_mean((input - self.b) ** 2, neg_mask) | ||
| + 2 * self.alpha * (self.margin + self._safe_mean(input, neg_mask) - self._safe_mean(input, pos_mask)) | ||
| - self.alpha**2 | ||
| ) | ||
|
|
||
| return loss | ||
|
|
||
| def _safe_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | ||
| """Compute mean safely over masked elements.""" | ||
| denom = mask.sum() | ||
| if denom == 0: | ||
| return torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype) | ||
| return (tensor * mask).sum() / denom | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
|
|
||
| from monai.losses import AUCMLoss | ||
| from tests.test_utils import test_script_save | ||
|
|
||
|
|
||
| class TestAUCMLoss(unittest.TestCase): | ||
| """Test cases for AUCMLoss.""" | ||
|
|
||
| def test_v1(self): | ||
| """Test AUCMLoss with version 'v1'.""" | ||
| loss_fn = AUCMLoss(version="v1") | ||
| input = torch.randn(32, 1, requires_grad=True) | ||
| target = torch.randint(0, 2, (32, 1)).float() | ||
| loss = loss_fn(input, target) | ||
| self.assertIsInstance(loss, torch.Tensor) | ||
| self.assertEqual(loss.ndim, 0) | ||
|
|
||
| def test_v2(self): | ||
| """Test AUCMLoss with version 'v2'.""" | ||
| loss_fn = AUCMLoss(version="v2") | ||
| input = torch.randn(32, 1, requires_grad=True) | ||
| target = torch.randint(0, 2, (32, 1)).float() | ||
| loss = loss_fn(input, target) | ||
| self.assertIsInstance(loss, torch.Tensor) | ||
| self.assertEqual(loss.ndim, 0) | ||
|
|
||
| def test_invalid_version(self): | ||
| """Test that invalid version raises ValueError.""" | ||
| with self.assertRaises(ValueError): | ||
| AUCMLoss(version="invalid") | ||
|
|
||
| def test_invalid_imratio(self): | ||
| """Test that invalid imratio raises ValueError.""" | ||
| with self.assertRaises(ValueError): | ||
| AUCMLoss(imratio=1.5) | ||
| with self.assertRaises(ValueError): | ||
| AUCMLoss(imratio=-0.1) | ||
|
|
||
| def test_invalid_input_shape(self): | ||
| """Test that invalid input shape raises ValueError.""" | ||
| loss_fn = AUCMLoss() | ||
| input = torch.randn(32, 2) # Wrong channel | ||
| target = torch.randint(0, 2, (32, 1)).float() | ||
| with self.assertRaises(ValueError): | ||
| loss_fn(input, target) | ||
|
|
||
| def test_invalid_target_shape(self): | ||
| """Test that invalid target shape raises ValueError.""" | ||
| loss_fn = AUCMLoss() | ||
| input = torch.randn(32, 1) | ||
| target = torch.randint(0, 2, (32, 2)).float() # Wrong channel | ||
| with self.assertRaises(ValueError): | ||
| loss_fn(input, target) | ||
|
|
||
| def test_insufficient_dimensions(self): | ||
| """Test that tensors with insufficient dimensions raise ValueError.""" | ||
| loss_fn = AUCMLoss() | ||
| input = torch.randn(32) # 1D tensor | ||
| target = torch.randint(0, 2, (32, 1)).float() | ||
| with self.assertRaises(ValueError): | ||
| loss_fn(input, target) | ||
|
|
||
| def test_shape_mismatch(self): | ||
| """Test that mismatched shapes raise ValueError.""" | ||
| loss_fn = AUCMLoss() | ||
| input = torch.randn(32, 1) | ||
| target = torch.randint(0, 2, (16, 1)).float() | ||
| with self.assertRaises(ValueError): | ||
| loss_fn(input, target) | ||
|
|
||
| def test_non_binary_target(self): | ||
| """Test that non-binary target values raise ValueError.""" | ||
| loss_fn = AUCMLoss() | ||
| input = torch.randn(32, 1) | ||
| target = torch.tensor([[0.5], [1.0], [2.0], [0.0]] * 8) # 32x1, still non-binary | ||
| with self.assertRaises(ValueError): | ||
| loss_fn(input, target) | ||
|
|
||
| def test_backward(self): | ||
| """Test that gradients can be computed.""" | ||
| loss_fn = AUCMLoss() | ||
| input = torch.randn(32, 1, requires_grad=True) | ||
| target = torch.randint(0, 2, (32, 1)).float() | ||
| loss = loss_fn(input, target) | ||
| loss.backward() | ||
| self.assertIsNotNone(input.grad) | ||
|
|
||
| def test_script_save(self): | ||
| """Test that the loss can be saved as TorchScript.""" | ||
| loss_fn = AUCMLoss() | ||
| test_script_save(loss_fn, torch.randn(32, 1), torch.randint(0, 2, (32, 1)).float()) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.