From b544ae18415eb112fbb39346ca8cbb1f7e266571 Mon Sep 17 00:00:00 2001 From: Eric Eaglstun Date: Sat, 6 Jun 2026 15:43:58 -0600 Subject: [PATCH] fix(birefnet): load with trust_remote_code and disable fp16 on MPS MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BiRefNet ships a custom architecture (birefnet.py) in its HuggingFace repo, so AutoModelForImageSegmentation.from_pretrained cannot load the weights without executing that code. With trust_remote_code=False it raises "contains custom code which must be executed", making BiRefNet alpha-hint generation completely non-functional. Load with trust_remote_code=True — the officially documented way to load ZhengPeng7/BiRefNet; the code is fetched locally by snapshot_download just above. Also gate half precision to CUDA only. fp16 is unstable on Apple's MPS backend (BiRefNet's swin attention emits NaNs), and the model weights and the input tensor in process() were keyed off the same module-level flag, so they could silently disagree on dtype. The decision is now stored on the instance (self.use_half) and used in both places. Adds tests/test_birefnet.py (model load mocked — no network/GPU/weights): asserts trust_remote_code=True and that half() is applied only on CUDA. Both fail on the prior code and pass with the fix. Closes #230 Co-Authored-By: Claude Opus 4.8 (1M context) --- BiRefNetModule/wrapper.py | 14 ++++++++--- tests/test_birefnet.py | 51 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 3 deletions(-) create mode 100644 tests/test_birefnet.py diff --git a/BiRefNetModule/wrapper.py b/BiRefNetModule/wrapper.py index 8acc75e7d..3949de42a 100644 --- a/BiRefNetModule/wrapper.py +++ b/BiRefNetModule/wrapper.py @@ -80,11 +80,19 @@ def __init__(self, device="cpu", usage="General"): local_dir_use_symlinks=False, # Ensures actual files are downloaded, not just symlinks to the cache ) - self.birefnet = AutoModelForImageSegmentation.from_pretrained(model_local_dir, trust_remote_code=False) + # BiRefNet ships a custom architecture (birefnet.py) in its HF repo, so the + # weights cannot be loaded without executing that code. trust_remote_code=True + # is required and is the officially documented way to load ZhengPeng7/BiRefNet. + # The code is fetched to BiRefNetModule/checkpoints/ by snapshot_download above. + self.birefnet = AutoModelForImageSegmentation.from_pretrained(model_local_dir, trust_remote_code=True) self.birefnet.to(device) self.birefnet.eval() - if half_precision: + # fp16 is unstable on Apple's MPS backend (BiRefNet's swin attention can emit + # NaNs); only use half precision on CUDA. Stored on the instance so the model + # weights and the input tensor (see process()) always agree on dtype. + self.use_half = half_precision and str(device).startswith("cuda") + if self.use_half: self.birefnet.half() def cleanup(self): @@ -169,7 +177,7 @@ def get_frames(): self.resolution = resolution_div_by_32 image_preprocessor = ImagePreprocessor(resolution=tuple(self.resolution)) image_proc = image_preprocessor.proc(pil_image).unsqueeze(0).to(self.device) - if half_precision: + if self.use_half: image_proc = image_proc.half() # Inference diff --git a/tests/test_birefnet.py b/tests/test_birefnet.py new file mode 100644 index 000000000..c7cdc9717 --- /dev/null +++ b/tests/test_birefnet.py @@ -0,0 +1,51 @@ +"""Unit tests for BiRefNetModule.wrapper — no network, GPU, or weights required. + +The model load (snapshot_download + AutoModelForImageSegmentation.from_pretrained) +is mocked, so these exercise only the handler's load contract and dtype logic. +""" + +from unittest import mock + +import pytest + +from BiRefNetModule.wrapper import BiRefNetHandler + + +def _make_handler(device): + """Build a BiRefNetHandler with the heavy externals mocked out. + + Returns (handler, model_mock, from_pretrained_mock). + """ + with ( + mock.patch("BiRefNetModule.wrapper.snapshot_download"), + mock.patch("BiRefNetModule.wrapper.AutoModelForImageSegmentation") as mock_cls, + ): + model = mock.MagicMock() + mock_cls.from_pretrained.return_value = model + handler = BiRefNetHandler(device=device, usage="General") + return handler, model, mock_cls.from_pretrained + + +def test_loads_with_trust_remote_code(): + """BiRefNet ships a custom architecture in its HF repo, so the weights cannot + load without executing that code. Regression for issue #230: with + trust_remote_code=False, from_pretrained raises "contains custom code which + must be executed".""" + _, _, from_pretrained = _make_handler("cpu") + from_pretrained.assert_called_once() + _, kwargs = from_pretrained.call_args + assert kwargs.get("trust_remote_code") is True + + +@pytest.mark.parametrize( + "device, expect_half", + [("cpu", False), ("mps", False), ("cuda", True), ("cuda:0", True)], +) +def test_half_precision_only_on_cuda(device, expect_half): + """fp16 is unstable on Apple's MPS backend (BiRefNet's swin attention emits + NaNs), so half precision must be applied only on CUDA. The flag is stored on + the instance so the model weights and the input tensor (in process()) always + agree on dtype.""" + handler, model, _ = _make_handler(device) + assert handler.use_half is expect_half + assert model.half.called is expect_half