Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
<!-- <a href="https://icml.cc/media/PosterPDFs/ICML%202022/a8acc28734d4fe90ea24353d901ae678.png"> <img src="https://img.shields.io/badge/Poster-grey?logo=airplayvideo&logoColor=white" alt="Poster"></a> -->
</p>

This repo contains the sample code of our proposed ```Unleashing Mask (UM)``` and its variant ```Unleashing Mask Adopt Pruning (UMAP)``` to adjust the given well-trained model for OOD detection in our paper: [Unleashing Mask: Explore the Intrinsic Out-of-Distribution Detection Capability](https://https://github.com/ZFancy/Unleashing-Mask) (ICML 2023).
This repo contains the sample code of our proposed ```Unleashing Mask (UM)``` and its variant ```Unleashing Mask Adopt Pruning (UMAP)``` to adjust the given well-trained model for OOD detection in our paper: [Unleashing Mask: Explore the Intrinsic Out-of-Distribution Detection Capability](https://github.com/ZFancy/Unleashing-Mask) (ICML 2023).
<p align="center"><img src="./figures/framework_overview.jpg" width=100% height=50%></p>
<p align="center"><em>Figure.</em> Framework overview of UM.</p>

Expand Down Expand Up @@ -39,7 +39,7 @@ class CustomLoss(nn.Module):
def forward(self, outputs, labels):
loss = args.beta * self.CrossEntropyLoss(outputs, labels)
if self.criterion != "CrossEntropyLoss":
# args.UM is the estimated loss constriant
# args.UM is the estimated loss constraint
loss = (loss - args.UM).abs() + args.UM
return loss
```
Expand Down Expand Up @@ -73,7 +73,7 @@ criterion = nn.CrossEntropyLoss()
for input, target in data:
...
loss = criterion(model(input), target)
loss = (loss - args.UM).abs() + args.UM # args.UM is the estimated loss constriant
loss = (loss - args.UM).abs() + args.UM # args.UM is the estimated loss constraint
loss.backward()

...
Expand Down Expand Up @@ -130,7 +130,6 @@ python main.py --config configs/smallscale/densenet_cifar10.yaml \
--UM <estimated_loss>\
--energy
```
python main.py --config configs/smallscale/densenet_cifar10.yaml

To experiment with the Outlier Exposure (OE) OOD detection methods. Use flag ```--oe-ood-method``` to control the OE-based methods.

Expand Down