diff --git a/BiRefNetModule/wrapper.py b/BiRefNetModule/wrapper.py index 8acc75e7d..3949de42a 100644 --- a/BiRefNetModule/wrapper.py +++ b/BiRefNetModule/wrapper.py @@ -80,11 +80,19 @@ def __init__(self, device="cpu", usage="General"): local_dir_use_symlinks=False, # Ensures actual files are downloaded, not just symlinks to the cache ) - self.birefnet = AutoModelForImageSegmentation.from_pretrained(model_local_dir, trust_remote_code=False) + # BiRefNet ships a custom architecture (birefnet.py) in its HF repo, so the + # weights cannot be loaded without executing that code. trust_remote_code=True + # is required and is the officially documented way to load ZhengPeng7/BiRefNet. + # The code is fetched to BiRefNetModule/checkpoints/ by snapshot_download above. + self.birefnet = AutoModelForImageSegmentation.from_pretrained(model_local_dir, trust_remote_code=True) self.birefnet.to(device) self.birefnet.eval() - if half_precision: + # fp16 is unstable on Apple's MPS backend (BiRefNet's swin attention can emit + # NaNs); only use half precision on CUDA. Stored on the instance so the model + # weights and the input tensor (see process()) always agree on dtype. + self.use_half = half_precision and str(device).startswith("cuda") + if self.use_half: self.birefnet.half() def cleanup(self): @@ -169,7 +177,7 @@ def get_frames(): self.resolution = resolution_div_by_32 image_preprocessor = ImagePreprocessor(resolution=tuple(self.resolution)) image_proc = image_preprocessor.proc(pil_image).unsqueeze(0).to(self.device) - if half_precision: + if self.use_half: image_proc = image_proc.half() # Inference diff --git a/tests/test_birefnet.py b/tests/test_birefnet.py new file mode 100644 index 000000000..c7cdc9717 --- /dev/null +++ b/tests/test_birefnet.py @@ -0,0 +1,51 @@ +"""Unit tests for BiRefNetModule.wrapper — no network, GPU, or weights required. + +The model load (snapshot_download + AutoModelForImageSegmentation.from_pretrained) +is mocked, so these exercise only the handler's load contract and dtype logic. +""" + +from unittest import mock + +import pytest + +from BiRefNetModule.wrapper import BiRefNetHandler + + +def _make_handler(device): + """Build a BiRefNetHandler with the heavy externals mocked out. + + Returns (handler, model_mock, from_pretrained_mock). + """ + with ( + mock.patch("BiRefNetModule.wrapper.snapshot_download"), + mock.patch("BiRefNetModule.wrapper.AutoModelForImageSegmentation") as mock_cls, + ): + model = mock.MagicMock() + mock_cls.from_pretrained.return_value = model + handler = BiRefNetHandler(device=device, usage="General") + return handler, model, mock_cls.from_pretrained + + +def test_loads_with_trust_remote_code(): + """BiRefNet ships a custom architecture in its HF repo, so the weights cannot + load without executing that code. Regression for issue #230: with + trust_remote_code=False, from_pretrained raises "contains custom code which + must be executed".""" + _, _, from_pretrained = _make_handler("cpu") + from_pretrained.assert_called_once() + _, kwargs = from_pretrained.call_args + assert kwargs.get("trust_remote_code") is True + + +@pytest.mark.parametrize( + "device, expect_half", + [("cpu", False), ("mps", False), ("cuda", True), ("cuda:0", True)], +) +def test_half_precision_only_on_cuda(device, expect_half): + """fp16 is unstable on Apple's MPS backend (BiRefNet's swin attention emits + NaNs), so half precision must be applied only on CUDA. The flag is stored on + the instance so the model weights and the input tensor (in process()) always + agree on dtype.""" + handler, model, _ = _make_handler(device) + assert handler.use_half is expect_half + assert model.half.called is expect_half