Skip to content

Conversation

@chaichontat
Copy link
Contributor

This PR introduces a TensorRT inference runtime for the CellposeSAM model targeting NVIDIA GPUs with native BF16 support (A100, H100, RTX 30/40/50 series). TensorRT is an inference optimizer and runtime that compiles a trained neural network into a highly optimized, hardware-specific engine.

This implementation achieves significant acceleration (1.7x-2.2x; see benchmarks below) by optimizing the computational graph without altering the model's architecture or its learned weights.

All additions are contained within the contrib/ directory, and is a drop-in replacement for the current CellposeModel.

Usage

First, build the optimized engine from the pretrained model.

python cellpose/contrib/cellposetrt/trt_build.py cpsam -o builds/cpsam_b4_smXX_bf16.plan

To use the TensorRT runtime, replace CellposeModel with CellposeModelTRT and use the engine path for pretrained_model.

from cellpose.contrib.cellposetrt import CellposeModelTRT

# New TensorRT model
m_trt = CellposeModelTRT(gpu=True, pretrained_model="builds/cpsam_b4_smXX_bf16.plan")
out_trt = m_trt.eval(img)

Benchmarks

To run the benchmark script to compare the engine against the original PyTorch model.

python cellpose/contrib/cellposetrt/trt_benchmark.py \
  --image /data/registered/reg-0076.tif \
  --pretrained cpsam \
  --engine builds/cpsam_b4_smXX_bf16.plan \
  --batch-size 4 \
  --n-samples 20 \
  --save-masks /tmp/cpsam_masks.tif
  • Mask Parity: IoU parity across 20 test images, with a median of 0.9989 and a minimum of 0.9876.
  • Performance (256×256 tiles, timings after CUDA synchronization):
GPU Batch Size Metric PyTorch (ms) TensorRT (ms) Speedup
RTX 5090 1 Full Pipeline 180.104 96.456 1.87x
Net-Only 14.628 6.154 2.38x
4 Full Pipeline 147.818 104.818 1.41x
Net-Only 46.812 27.680 1.69x
RTX 4090 1 Full Pipeline 188.727 118.944 1.59x
Net-Only 14.688 7.435 1.98x
4 Full Pipeline 168.550 137.602 1.22x
Net-Only 57.410 41.426 1.39x

Note on Batch Size 4: The engines are built with a dynamic batch dimension to handle image sets that are not perfectly divisible by the batch size. This dynamic scheduling incurs some overhead, resulting in a smaller relative speedup compared to a fixed batch size of 1.

Signed-off-by: Chaichontat Sriworarat <[email protected]>
@codecov
Copy link

codecov bot commented Oct 24, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 42.30%. Comparing base (bf958cb) to head (76ba87f).
⚠️ Report is 37 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1356      +/-   ##
==========================================
+ Coverage   42.19%   42.30%   +0.10%     
==========================================
  Files          16       16              
  Lines        3773     3773              
==========================================
+ Hits         1592     1596       +4     
+ Misses       2181     2177       -4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant