Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions BiRefNetModule/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,19 @@ def __init__(self, device="cpu", usage="General"):
local_dir_use_symlinks=False, # Ensures actual files are downloaded, not just symlinks to the cache
)

self.birefnet = AutoModelForImageSegmentation.from_pretrained(model_local_dir, trust_remote_code=False)
# BiRefNet ships a custom architecture (birefnet.py) in its HF repo, so the
# weights cannot be loaded without executing that code. trust_remote_code=True
# is required and is the officially documented way to load ZhengPeng7/BiRefNet.
# The code is fetched to BiRefNetModule/checkpoints/<repo> by snapshot_download above.
self.birefnet = AutoModelForImageSegmentation.from_pretrained(model_local_dir, trust_remote_code=True)

self.birefnet.to(device)
self.birefnet.eval()
if half_precision:
# fp16 is unstable on Apple's MPS backend (BiRefNet's swin attention can emit
# NaNs); only use half precision on CUDA. Stored on the instance so the model
# weights and the input tensor (see process()) always agree on dtype.
self.use_half = half_precision and str(device).startswith("cuda")
if self.use_half:
self.birefnet.half()

def cleanup(self):
Expand Down Expand Up @@ -169,7 +177,7 @@ def get_frames():
self.resolution = resolution_div_by_32
image_preprocessor = ImagePreprocessor(resolution=tuple(self.resolution))
image_proc = image_preprocessor.proc(pil_image).unsqueeze(0).to(self.device)
if half_precision:
if self.use_half:
image_proc = image_proc.half()

# Inference
Expand Down
51 changes: 51 additions & 0 deletions tests/test_birefnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Unit tests for BiRefNetModule.wrapper — no network, GPU, or weights required.

The model load (snapshot_download + AutoModelForImageSegmentation.from_pretrained)
is mocked, so these exercise only the handler's load contract and dtype logic.
"""

from unittest import mock

import pytest

from BiRefNetModule.wrapper import BiRefNetHandler


def _make_handler(device):
"""Build a BiRefNetHandler with the heavy externals mocked out.

Returns (handler, model_mock, from_pretrained_mock).
"""
with (
mock.patch("BiRefNetModule.wrapper.snapshot_download"),
mock.patch("BiRefNetModule.wrapper.AutoModelForImageSegmentation") as mock_cls,
):
model = mock.MagicMock()
mock_cls.from_pretrained.return_value = model
handler = BiRefNetHandler(device=device, usage="General")
return handler, model, mock_cls.from_pretrained


def test_loads_with_trust_remote_code():
"""BiRefNet ships a custom architecture in its HF repo, so the weights cannot
load without executing that code. Regression for issue #230: with
trust_remote_code=False, from_pretrained raises "contains custom code which
must be executed"."""
_, _, from_pretrained = _make_handler("cpu")
from_pretrained.assert_called_once()
_, kwargs = from_pretrained.call_args
assert kwargs.get("trust_remote_code") is True


@pytest.mark.parametrize(
"device, expect_half",
[("cpu", False), ("mps", False), ("cuda", True), ("cuda:0", True)],
)
def test_half_precision_only_on_cuda(device, expect_half):
"""fp16 is unstable on Apple's MPS backend (BiRefNet's swin attention emits
NaNs), so half precision must be applied only on CUDA. The flag is stored on
the instance so the model weights and the input tensor (in process()) always
agree on dtype."""
handler, model, _ = _make_handler(device)
assert handler.use_half is expect_half
assert model.half.called is expect_half