Open
Conversation
# Motivation Fix the failures (failed on XPU) introduced by pytorch#169081 due to the `CUDA` hardcode. # Additional Context Fix pytorch#170006 Pull Request resolved: pytorch#174755 Approved by: https://github.com/Lucaskabela
Update the torch-xpu-ops commit to [intel/torch-xpu-ops@de4f69](intel/torch-xpu-ops@de4f698), includes: - Clean up SYCL cmake - Support all Pytorch strides for FlashAttention fwd/bwd kernel - Apply sycl::reqd_sub_group_size attribute only to device code - Support complex dtype logaddexp for XPU Pull Request resolved: pytorch#174657 Approved by: https://github.com/EikanWang
…ch#174310)" This reverts commit 3eb20b2. Reverted pytorch#174310 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](pytorch#174310 (comment)))
…ACppScheduling. (pytorch#160688) Following the design of pytorch#160175, This PR renamed `CUDACppScheduling` to `CUTLASSScheduling` so that they can be reused for XPU. Pull Request resolved: pytorch#160688 Approved by: https://github.com/mlazos
…`fresh_cache()` to avoid race condition (pytorch#174834) `mock.patch.dict(os.environ, ...)` calls `os.environ.copy()` internally, which iterates all env var keys and fetches values in separate steps. This is not atomic and can race with background threads (e.g. Triton async compilation) modifying the environment, causing KeyError. Replace with explicit `os.environ.get()` for saving and manual restore in finally blocks, which uses atomic C-level lookups. Pull Request resolved: pytorch#174834 Approved by: https://github.com/oulgen
…3669) Missing tests! Will add after landing Nicolas' PR. Ahh I still need to clean some stuff up: gotta remove my clang-tidy removal oop. And I'm now reconsidering whether I want float4 in the supported dtypes lol. cc @pearu if you have thoughts on this second point. Pull Request resolved: pytorch#173669 Approved by: https://github.com/mikaylagawarecki
Pull Request resolved: pytorch#174733 Approved by: https://github.com/zou3519 ghstack dependencies: pytorch#173669
…174734) This fixes the bug where SymInts, MemoryFormat, ScalarType, and Layout were all entering the IntType path instead of going on their respective branches. We do explicitly point SymInt to Int for now, because we do not support having a StableIValue from SymInt yet, but nothing has decreased in support due to this change. If you don't believe me, check out the test case I added in the stack just one above that tests that with or without this change, registering certain ops with SymInt schemas work and don't cause recompiles. Pull Request resolved: pytorch#174734 Approved by: https://github.com/albanD ghstack dependencies: pytorch#173669, pytorch#174733
We already had a shim for it, so we're merely adding a C++ convenience wrapper and tests. This is why we can stick it in libtorch_agn_2_9. I had wanted to just write a test that would take in a Tensor and return a Layout and then check that the layout matches in python, but I learned that our python binding code (pytorch/torch/csrc/jit/python/pybind_utils.cpp) doesn't properly convert Layout IValues to torch.layout (because layout, dtype and memory format are all treated like ints). So the test case I ended up writing takes in a Layout and runs the check from C++. Further testing I've done: I wanted to confirm that my_layout would still run with no problems on 2.9 even if it needs to be built with 2.11 for the layout API. So I: - built a wheel for libtorch_agn_2_9 with my PR's torch - went to a different conda env - installed 2.9.1 torch - installed the wheel i just built - ran the test `python test/cpp_extensions/test_libtorch_agnostic.py -v -k my_layout` and it passed! Pull Request resolved: pytorch#174735 Approved by: https://github.com/andrewor14 ghstack dependencies: pytorch#173669, pytorch#174733, pytorch#174734
pytorch#174338 (comment) causes an issue with vLLM benchmark for some models, so I need to revert it. Pull Request resolved: pytorch#174838 Approved by: https://github.com/atalman
Upgrades to our latest type checker, which includes performance improvements. Will follow up with a PR that removes unused suppressions to keep things tidy. Test: `lintrunner -a` Pull Request resolved: pytorch#174426 Approved by: https://github.com/Skylion007
how allow_unbacked_sharding gets passed <img width="935" height="477" alt="Screenshot 2026-02-10 at 22 03 01" src="https://github.com/user-attachments/assets/6c90c8e7-045d-4680-b3b7-5e75f0160ef7" /> Pull Request resolved: pytorch#172385 Approved by: https://github.com/wconstab
This pull request refactors how ONNX operator signatures and related schema types are handled in the exporter codebase. The main change is to consistently use types from the new `onnx_ir` module (`ir.schemas`) instead of the legacy `_schemas` module. **Dependency updates** * Updated the pinned version of `onnx-ir` in `.ci/docker/requirements-ci.txt` from `0.1.15` to `0.1.16` to ensure compatibility with the new IR features. Pull Request resolved: pytorch#174740 Approved by: https://github.com/malfet
…4790) These guards do not matter practically, even if something changes, you will end up something else recompiling - like the function defaults etc. These guards end up being very expensive because for a large model, they can end up accumulating tons of OBJECT ALIASING guards. Pull Request resolved: pytorch#174790 Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42
Pull Request resolved: pytorch#174449 Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42, https://github.com/mlazos
…ytorch#174713) Previously, running in inference mode would require the user to call torch.compiler.mark_step_begin() after each call of the torch.compile function, which is not the intended behavior and did not match the behavior of cudagraph trees in the inductor backend. Pull Request resolved: pytorch#174713 Approved by: https://github.com/eellison, https://github.com/BoyuanFeng
The main call emission already handles unwrapping unspecialized args passed as 0d tensors through `prepare_triton_kernel_call`. Do the same for the autotuning path. This was previously hidden because Triton would automatically unwrap the tensor in some cases. Fixes pytorch#174420. Pull Request resolved: pytorch#174583 Approved by: https://github.com/eellison, https://github.com/jansel, https://github.com/mlazos
…c8 kernel (pytorch#174362) This will allow `sm_103` devices call vec8 kernels. Verification script: ```Python import torch from torch.profiler import profile, ProfilerActivity device = torch.device("cuda") for dtype in (torch.bfloat16, torch.float16,): x = torch.randn(1024, device=device, dtype=dtype) with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: y = torch.relu(x) stats = prof.key_averages() for entry in stats: if "at::native::vectorized_elementwise_kernel" in entry.key: print(entry.key) ``` Before: ``` void at::native::vectorized_elementwise_kernel<4, at::native::(anonymous namespace)::launch_clamp_scalar(at::TensorIteratorBase&, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::{lambda()#1}::operator()() const::{lambda()#9}::operator()() const::{lambda(c10::BFloat16)#1}, std::array<char*, 2ul> >(int, at::native::(anonymous namespace)::launch_clamp_scalar(at::TensorIteratorBase&, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::{lambda()#1}::operator()() const::{lambda()#9}::operator()() const::{lambda(c10::BFloat16)#1}, std::array<char*, 2ul>) void at::native::vectorized_elementwise_kernel<4, at::native::(anonymous namespace)::launch_clamp_scalar(at::TensorIteratorBase&, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::{lambda()#1}::operator()() const::{lambda()#8}::operator()() const::{lambda(c10::Half)#1}, std::array<char*, 2ul> >(int, at::native::(anonymous namespace)::launch_clamp_scalar(at::TensorIteratorBase&, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::{lambda()#1}::operator()() const::{lambda()#8}::operator()() const::{lambda(c10::Half)#1}, std::array<char*, 2ul>) ``` After: ``` void at::native::vectorized_elementwise_kernel<8, at::native::(anonymous namespace)::launch_clamp_scalar(at::TensorIteratorBase&, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::{lambda()#1}::operator()() const::{lambda()#9}::operator()() const::{lambda(c10::BFloat16)#1}, std::array<char*, 2ul> >(int, at::native::(anonymous namespace)::launch_clamp_scalar(at::TensorIteratorBase&, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::{lambda()#1}::operator()() const::{lambda()#9}::operator()() const::{lambda(c10::BFloat16)#1}, std::array<char*, 2ul>) void at::native::vectorized_elementwise_kernel<8, at::native::(anonymous namespace)::launch_clamp_scalar(at::TensorIteratorBase&, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::{lambda()#1}::operator()() const::{lambda()#8}::operator()() const::{lambda(c10::Half)#1}, std::array<char*, 2ul> >(int, at::native::(anonymous namespace)::launch_clamp_scalar(at::TensorIteratorBase&, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::{lambda()#1}::operator()() const::{lambda()#8}::operator()() const::{lambda(c10::Half)#1}, std::array<char*, 2ul>) ``` Pull Request resolved: pytorch#174362 Approved by: https://github.com/ngimel
…74831) Add explicit `use_strided_shard_as_shard_order` flag to DTensorSpec to control whether _StridedShard placements should be interpreted as encoding shard order. Previously this was inferred ad-hoc at each usage site by checking for _StridedShard presence in placements. The flag is auto-detected in __post_init__ and can be explicitly set to False (e.g. in view propagation) to treat _StridedShard as a regular Shard. Also extracts `_update_shard_order_and_placements` as a shared helper for updating placements and shard order after each transform step, used by both `stringify_transform_infos` and the upcoming `fill_transform_infos_unsharded_shape`. Pull Request resolved: pytorch#174831 Approved by: https://github.com/weifengpy
…ytorch#173358) Tests error out with `RuntimeError: _convert_weight_to_int4pack_cuda is only supported on AMD gpu arch greater than or equal to CDNA2` so add the corresponding check to skip on unsupported AMD architectures. Tested on Radeon Pro V710 Pull Request resolved: pytorch#173358 Approved by: https://github.com/isuruf
…on (pytorch#173349) These test cases were implemented with PT2E API thus removed during PT2E migration to torchao. They are re-enabled without PT2E API. Updated files: 1. test/inductor/test_aot_inductor.py 2. test/inductor/test_cpu_cpp_wrapper.py 3. test/inductor/test_cpu_repro.py 4. test/inductor/test_cpu_select_algorithm.py 5. test/inductor/test_mkldnn_pattern_matcher.py Pull Request resolved: pytorch#173349 Approved by: https://github.com/jerryzh168
This PR is part of pytorch#160175. It extracts the CUDA-independent functionality from `CUDACodeCache` into `CUTLASSCodeCache`, which `CUDACodeCache` then inherits and extends with CUDA-specific logic. This design allows `CUTLASSCodeCache` to be reused by XPU as well. In addition, CUDA compilation logic has been moved into torch/_inductor/codegen/cuda/compile_utils.py, making codecache.py cleaner. Pull Request resolved: pytorch#160706 Approved by: https://github.com/EikanWang, https://github.com/mlazos ghstack dependencies: pytorch#160688
…h#170578) Pull Request resolved: pytorch#170578 Approved by: https://github.com/eellison
Pull Request resolved: pytorch#174564 Approved by: https://github.com/malfet
…er in c10d (pytorch#174202) Integrates **TorchComms** backend wrapper at the c10d level. The primary goal is to enable DeviceMesh to use TorchComms as an alternative to the traditional NCCL/Gloo Process Group backends. It enables an opt-in features via a single env variable TORCH_DISTRIBUTED_USE_TORCHCOMMS=1 that allows process group and device mesh to use torchcomms backend. _BackendWrapper: https://github.com/meta-pytorch/torchcomms/blob/main/comms/torchcomms/_comms.pyi This is a rehash of pytorch#170132 from @fduwjj Pull Request resolved: pytorch#174202 Approved by: https://github.com/d4l3k Co-authored-by: fduwjj <fduwjj@gmail.com>
…172385)" This reverts commit a74aec1. Reverted pytorch#172385 on behalf of https://github.com/jeffdaily due to broke ROCm, signal was missed but still allowed to land ([comment](pytorch#172385 (comment)))
) # Motivation In the SYCL 2020 specification, [`info::device::local_mem_size`](https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#api:info-device-local-mem-size) is described as "`the size of local memory arena in bytes.`", which is somewhat ambiguous and leaves room for interpretation. From the relevant SYCL extensions - specifically [work_group_static](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_work_group_static.asciidoc#total-allocation-check) and [work_group_scratch_memory](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_work_group_scratch_memory.asciidoc#error-conditions), it becomes clear that `local_mem_size` refers to `the device's local memory capacity that can be allocated per work-group`. In other words, it represents the maximum amount of local memory available to each work-group. On [Level Zero backend](https://oneapi-src.github.io/level-zero-spec/level-zero/latest/core/api.html#_CPPv4N30ze_device_compute_properties_t20maxSharedLocalMemoryE), it returns `maxSharedLocalMemory` which means the maximum shared local memory per group. This interpretation aligns with CUDA's `SharedMemPerBlock`, which likewise describes the per-block shared memory limit. # Additional Context Have confirmed with SYCL team in internal ticket: CMPLRLLVM-72698 Fix intel/torch-xpu-ops#2723 Pull Request resolved: pytorch#172314 Approved by: https://github.com/EikanWang, https://github.com/albanD
Fixes pytorch#165875 ### Summary The `test_pw_kernel_benchmark` test was flaky, failing intermittently with: ``` Traceback (most recent call last): File "/var/lib/jenkins/pytorch/test/inductor/test_kernel_benchmark.py", line 145, in test_pw_kernel_benchmark self.verify_compiled_kernels() File "/var/lib/jenkins/pytorch/test/inductor/test_kernel_benchmark.py", line 78, in verify_compiled_kernels ).run(bench_out) RuntimeError: Expected to not find "GB/s" but found it None UNK c4fmwupa7r 0.007ms 0.000 GB 0.01GB/s 23 regs 0 spills 0 shared mem @ XBLOCK: 8, num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None 0.007ms 0.000 GB 0.01GB/s 23 regs 0 spills 0 shared mem @ XBLOCK: 8, waves_per_eu: 2, num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None ~~~~ <--- HERE 0.007ms 0.000 GB 0.01GB/s 23 regs 0 spills 0 shared mem @ XBLOCK: 8, waves_per_eu: 1, num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None From CHECK-NOT: GB/s To execute this test, run the following from the base repo dir: PYTORCH_TEST_WITH_ROCM=1 python test/inductor/test_kernel_benchmark.py TestKernelBenchmark.test_pw_kernel_benchmark This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0 ``` The test required exactly one "GB/s" occurrence. However, when using the `-kc` flag (which benchmarks all autotuning configurations), multiple configs are benchmarked and each prints "GB/s", causing the test to fail when more than one config is tested. ### Fix: Changed the assertion to check for **at least** `GB_count` occurrences instead of exactly, by removing the `exactly=1`. This aligns with the intended behavior of the `-kc` flag and makes the test robust to variable numbers of autotuning configs. ### Testing All of the below test passed on MI300. - `test_pw_kernel_benchmark` - passes consistently - `test_matmul_triton_kernel_benchmark` - still passes - `test_mm_triton_kernel_benchmark` - still passes Pull Request resolved: pytorch#174765 Approved by: https://github.com/eellison
…orch#174729) Summary: Change a mistaken OrderedSet.union() to OrderedSet.update(). This caused cudagraph cache miskeying when a partition takes dynamically-shaped inputs from non-cudagraphable ops. Differential Revision: D92879629 Pull Request resolved: pytorch#174729 Approved by: https://github.com/eellison, https://github.com/BoyuanFeng, https://github.com/Skylion007
Differential Revision: D92907658 Pass `enable_tf32` to the autotune process pool Pull Request resolved: pytorch#174742 Approved by: https://github.com/masnesral
Pull Request resolved: pytorch#172372 Approved by: https://github.com/anijain2305 ghstack dependencies: pytorch#172152
Pull Request resolved: pytorch#172395 Approved by: https://github.com/anijain2305 ghstack dependencies: pytorch#172152, pytorch#172372
[BE] Tesor -> Tensor Pull Request resolved: pytorch#175061 Approved by: https://github.com/jcaip
Pull Request resolved: pytorch#175184 Approved by: https://github.com/liangel-02
Pull Request resolved: pytorch#175185 Approved by: https://github.com/albanD ghstack dependencies: pytorch#175184
Support multi-output ops like split, unbind, topk, sort. Tested for these ops and things look reasonable (not an exhaustive test of all multi-output ops): - unbind: 0 true positives because its strategy unshards the unbind dimension, so all non-trivial rules involve Replicate inputs → skipped. This is correct behavior (the validator only tests non-fully-replicated combos). - topk: 14 true positives, 0 false positives - sort: 102 true positives, 0 false positives - split_with_sizes: 24 true positives, 0 false positives - chunk: 18 true positives, 0 false positives No unexpected issues with any of the multi-output operators. The implementation handles all of them correctly — single-output and multi-output ops with varying tuple sizes (unbind's dynamic N outputs, topk/sort's 2-element tuples, split's variable chunks). Pull Request resolved: pytorch#174995 Approved by: https://github.com/pianpwk, https://github.com/zpcore ghstack dependencies: pytorch#174799, pytorch#174800
Fixes pytorch#159945 . ## Motivation: This PR provides an alternate approach to fix the issue mentioned above. We had initially propsed a solution in pytorch#160063 , which is not yet merged. If this PR gets merged, the previous one will be deleted. ## Changes: This pull request makes a targeted improvement to the backend registration logic in `torch/distributed/distributed_c10d.py`. The change allows remapping devices from a "fake" backend to a real backend, while preventing the "fake" backend from claiming devices already mapped. Backend registration logic improvement: * Updated `register_backend` to permit devices mapped to the "fake" backend to be remapped to a real backend, while preventing the "fake" backend from claiming devices already assigned to another backend. (`torch/distributed/distributed_c10d.py`) @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta, @EikanWang , @kwen2501, @ankurneog @albanD @jerome-habana : please review. Pull Request resolved: pytorch#174764 Approved by: https://github.com/d4l3k, https://github.com/jeromean
Pull Request resolved: pytorch#174846 Approved by: https://github.com/laithsakka
Pull Request resolved: pytorch#174793 Approved by: https://github.com/laithsakka
pytorch#173950 To prepare moving CUDA 13 wheels to stable wheels, need to add CUDA 13 periodic cuda tests. Pull Request resolved: pytorch#174850 Approved by: https://github.com/atalman Co-authored-by: Andrey Talman <atalman@fb.com>
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned vision hash. Pull Request resolved: pytorch#175214 Approved by: https://github.com/pytorchbot
…ytorch#174200) The test was failing because it expected cuDNN Attention backend selection with cuDNN >= 9.9.0, but the C++ implementation requires cuDNN > 9.15.0 (per issue pytorch#169849). Additionally, the test only checked for device capabilities (9,0) and (10,0), while the C++ code also supports minor versions 3: (9,3) and (10,3). This change updates the test to match the C++ check_prefer_cudnn_attention() function exactly, ensuring the test correctly validates the backend selection logic on systems with cuDNN > 9.15.0 and supported device capabilities. Fixes pytorch#174199 CC: @morrison-turnansky @groenenboomj @stmcgovern @jewelkm89 @cleonard530 Pull Request resolved: pytorch#174200 Approved by: https://github.com/eqy
…on bug (pytorch#175168) For issue: pytorch#158212 Perf run: https://github.com/pytorch/pytorch/actions/runs/22110761079 At the microbench level this change appears within noise for flex: <img width="507" height="424" alt="image" src="https://github.com/user-attachments/assets/ba731e23-f49f-4151-a3dc-c47547672dfd" /> Pull Request resolved: pytorch#175168 Approved by: https://github.com/mlazos
Pull Request resolved: pytorch#171482 Approved by: https://github.com/zou3519
…rch#174938) FIXES pytorch#174187 The `remove_no_ops` pass (e.g. `add(x, 0) → x`) is unsound when `x` is subsequently mutated in-place (e.g. via `set_`). The original op acts as an implicit copy; removing it causes downstream users to observe the post-mutation value instead. Guard the optimization by checking the op schema's `alias_info.is_write` to detect if the replacement value is the target of any in-place mutation. Authored with Claude. Pull Request resolved: pytorch#174938 Approved by: https://github.com/eellison ghstack dependencies: pytorch#174924
… — 4-43x speedup for VLM pos embed resizing (pytorch#174578) <!-- Title:[CUDA] Parallelize upsample_bicubic2d across batch/channel dimensions — 4.3x-43x speedup for VLM pos embed resizing --> Adds a parallel kernel variant for `upsample_bicubic2d` that parallelizes across batch and channel dimensions, targeting VLM position embedding resizing workloads where spatial grids are small but channel counts are high. **Kimi K2.5 real-world shapes** (batch=1, channels=1152, input 64x64, float32, B200): Image sizes are converted to output grids via the Kimi K2.5 vision encoder's `navit_resize_image` logic (patch_size=14, merge_kernel_size=2). | image size | grid (output) | output elems | before (us) | after (us) | speedup | |------------|---------------|--------------|-------------|------------|---------| | 224x224 | 16x16 | 256 | 907 | 27 | **34x** | | 384x384 | 28x28 | 784 | 1,425 | 33 | **43x** | | 512x512 | 38x38 | 1,444 | 1,267 | 48 | **27x** | | 640x480 | 36x46 | 1,656 | 1,257 | 50 | **25x** | | 768x768 | 56x56 | 3,136 | 1,188 | 81 | **15x** | | 1024x768 | 56x74 | 4,144 | 1,147 | 99 | **12x** | | 1024x1024 | 74x74 | 5,476 | 1,146 | 121 | **9.5x** | | 1280x1024 | 74x92 | 6,808 | 1,230 | 144 | **8.6x** | | 1080x1080 | 78x78 | 6,084 | 1,225 | 132 | **9.3x** | | 1920x1080 | 78x138 | 10,764 | 1,171 | 217 | **5.4x** | | 2048x2048 | 130x130 | 16,900 | 1,442 | 334 | **4.3x** | Ran on B200. The parallel kernel is faster across **all 48 tested Kimi K2.5 grid shapes** (4.3x–48.5x), including shapes above the current 4096 threshold. <details> <summary><b>Full Kimi K2.5 sweep (48 unique grid shapes)</b></summary> ``` image_size grid (out) out_elems threshold orig_us para_us speedup winner ------------------------------------------------------------------------------------------------------------ 28x28 2x2 4 below 1129.8 27.7 40.7x parallel 56x56 4x4 16 below 1207.3 27.5 43.9x parallel 100x100 8x8 64 below 965.9 25.7 37.6x parallel 224x224 16x16 256 below 906.6 26.9 33.7x parallel 256x256 20x20 400 below 1141.8 29.2 39.2x parallel 320x320 24x24 576 below 1021.5 30.8 33.2x parallel 384x384 28x28 784 below 1425.0 33.3 42.7x parallel 512x512 38x38 1444 below 1267.4 47.6 26.6x parallel 768x768 56x56 3136 below 1188.4 80.6 14.7x parallel 1024x1024 74x74 5476 ABOVE 1145.5 120.6 9.5x parallel 1152x1152 84x84 7056 ABOVE 1285.0 148.0 8.7x parallel 1080x1080 78x78 6084 ABOVE 1225.2 131.6 9.3x parallel 1080x1350 98x78 7644 ABOVE 1242.5 158.9 7.8x parallel 1080x1920 138x78 10764 ABOVE 1162.3 217.6 5.3x parallel 360x800 58x26 1508 below 1835.5 49.7 37.0x parallel 390x844 62x28 1736 below 1666.8 51.6 32.3x parallel 414x896 64x30 1920 below 1716.1 55.1 31.2x parallel 1366x768 56x98 5488 ABOVE 1172.5 120.6 9.7x parallel 1536x864 62x110 6820 ABOVE 1228.4 143.6 8.6x parallel 1920x1080 78x138 10764 ABOVE 1170.8 217.4 5.4x parallel 2560x1440 98x172 16856 ABOVE 1438.8 333.3 4.3x parallel 640x480 36x46 1656 below 1257.1 50.3 25.0x parallel 800x600 44x58 2552 below 1203.2 66.3 18.2x parallel 1024x768 56x74 4144 ABOVE 1146.9 99.4 11.5x parallel 1280x960 70x92 6440 ABOVE 1243.6 140.0 8.9x parallel 1600x1200 86x116 9976 ABOVE 1175.3 200.9 5.8x parallel 2048x1536 110x148 16280 ABOVE 1426.5 321.7 4.4x parallel 4000x3000 112x150 16800 ABOVE 1430.1 334.0 4.3x parallel 1280x720 52x92 4784 ABOVE 1235.3 107.1 11.5x parallel 1280x1024 74x92 6808 ABOVE 1230.4 143.9 8.6x parallel 3000x2000 106x158 16748 ABOVE 1442.8 333.2 4.3x parallel 240x320 24x18 432 below 1018.2 29.5 34.5x parallel 480x640 46x36 1656 below 1255.0 49.7 25.3x parallel 600x800 58x44 2552 below 1221.6 66.5 18.4x parallel 768x1024 74x56 4144 ABOVE 1184.1 99.0 12.0x parallel 720x1280 92x52 4784 ABOVE 1274.6 106.3 12.0x parallel 960x1280 92x70 6440 ABOVE 1220.0 140.2 8.7x parallel 1024x1280 92x74 6808 ABOVE 1165.3 144.0 8.1x parallel 1440x2560 172x98 16856 ABOVE 1486.1 333.8 4.5x parallel 3000x4000 150x112 16800 ABOVE 1449.3 333.7 4.3x parallel 4000x6000 158x106 16748 ABOVE 1416.4 331.6 4.3x parallel 1536x1536 110x110 12100 ABOVE 1222.1 241.6 5.1x parallel 2048x2048 130x130 16900 ABOVE 1442.2 334.0 4.3x parallel 7168x7168 128x128 16384 ABOVE 1451.5 320.8 4.5x parallel 100x2000 144x8 1152 below 2487.3 51.3 48.5x parallel 2000x100 8x144 1152 below 1278.8 45.1 28.4x parallel 300x1800 130x22 2860 below 1626.4 72.9 22.3x parallel 1800x300 22x130 2860 below 1307.9 71.4 18.3x parallel ``` </details> **Generic benchmark sweep (various batch/channel/spatial sizes):** Also swept batches and different sizes: | batch | channels | input | output | before (us) | after (us) | speedup | |-------|----------|-------|--------|-------------|------------|---------| | 1 | 1152 | 16x16 | 6x6 | 794 | 25 | **32x** | | 1 | 1152 | 32x32 | 14x14 | 819 | 25 | **32x** | | 64 | 1152 | 16x16 | 6x6 | 68,917 | 376 | **183x** | | 64 | 1152 | 32x32 | 14x14 | 100,807 | 454 | **222x** | | 64 | 1152 | 64x64 | 32x32 | 99,188 | 1,255 | **79x** | | 256 | 1152 | 16x16 | 6x6 | 309,810 | 625 | **495x** | | 256 | 1152 | 32x32 | 14x14 | 402,735 | 1,174 | **343x** | | 256 | 768 | 16x16 | 6x6 | 205,985 | 492 | **419x** | For output sizes > 18432 elements with low channel counts, the parallel kernel can be slower than the original. So in those cases we run the original kernel. <details> <summary><b>Full benchmark sweep (all configurations)</b></summary> ``` channels spatial batch original_us parallel_us speedup winner -------------------------------------------------------------------------------- 1 16x16->6x6 1 20.1 19.9 1.0x parallel 3 16x16->6x6 1 20.0 19.9 1.0x parallel 8 16x16->6x6 1 23.0 19.8 1.2x parallel 16 16x16->6x6 1 28.0 19.7 1.4x parallel 32 16x16->6x6 1 39.2 24.8 1.6x parallel 64 16x16->6x6 1 63.6 24.7 2.6x parallel 128 16x16->6x6 1 106.6 24.9 4.3x parallel 256 16x16->6x6 1 189.5 19.7 9.6x parallel 512 16x16->6x6 1 360.9 20.8 17.4x parallel 768 16x16->6x6 1 536.8 22.8 23.5x parallel 1024 16x16->6x6 1 707.7 26.6 26.6x parallel 1152 16x16->6x6 1 794.5 25.3 31.4x parallel 1 16x16->6x6 4 19.1 19.1 1.0x original 3 16x16->6x6 4 25.2 19.4 1.3x parallel 8 16x16->6x6 4 39.4 19.0 2.1x parallel 16 16x16->6x6 4 60.1 19.0 3.2x parallel 32 16x16->6x6 4 103.1 19.0 5.4x parallel 64 16x16->6x6 4 189.2 19.1 9.9x parallel 128 16x16->6x6 4 360.7 20.5 17.6x parallel 256 16x16->6x6 4 708.2 23.9 29.7x parallel 512 16x16->6x6 4 1393.3 31.0 44.9x parallel 768 16x16->6x6 4 2078.4 37.5 55.5x parallel 1024 16x16->6x6 4 2763.8 43.9 63.0x parallel 1152 16x16->6x6 4 3108.0 47.6 65.4x parallel 1 16x16->6x6 64 64.2 19.2 3.3x parallel 3 16x16->6x6 64 150.2 19.1 7.9x parallel 8 16x16->6x6 64 364.9 20.3 18.0x parallel 16 16x16->6x6 64 711.0 24.4 29.1x parallel 32 16x16->6x6 64 1398.8 30.9 45.3x parallel 64 16x16->6x6 64 2768.5 43.7 63.3x parallel 128 16x16->6x6 64 5513.0 70.0 78.7x parallel 256 16x16->6x6 64 11201.7 121.2 92.4x parallel 512 16x16->6x6 64 22646.0 224.6 100.8x parallel 768 16x16->6x6 64 34187.6 327.3 104.5x parallel 1024 16x16->6x6 64 58161.6 352.5 165.0x parallel 1152 16x16->6x6 64 68917.3 376.3 183.2x parallel 1 16x16->6x6 256 205.5 19.3 10.6x parallel 3 16x16->6x6 256 553.3 22.7 24.4x parallel 8 16x16->6x6 256 1411.0 30.9 45.6x parallel 16 16x16->6x6 256 2780.1 43.6 63.7x parallel 32 16x16->6x6 256 5525.9 70.0 79.0x parallel 64 16x16->6x6 256 11215.0 121.5 92.3x parallel 128 16x16->6x6 256 22662.1 224.0 101.2x parallel 256 16x16->6x6 256 58168.7 352.7 164.9x parallel 512 16x16->6x6 256 136647.2 401.1 340.7x parallel 768 16x16->6x6 256 205984.7 491.7 418.9x parallel 1024 16x16->6x6 256 275243.9 565.1 487.1x parallel 1152 16x16->6x6 256 309809.8 625.3 495.5x parallel 1 32x32->14x14 1 19.0 19.2 1.0x original 3 32x32->14x14 1 19.5 19.2 1.0x parallel 8 32x32->14x14 1 22.9 19.2 1.2x parallel 16 32x32->14x14 1 28.8 19.2 1.5x parallel 32 32x32->14x14 1 39.5 19.0 2.1x parallel 64 32x32->14x14 1 62.2 19.3 3.2x parallel 128 32x32->14x14 1 106.7 18.9 5.6x parallel 256 32x32->14x14 1 195.3 19.4 10.1x parallel 512 32x32->14x14 1 373.3 21.1 17.7x parallel 768 32x32->14x14 1 553.5 23.1 24.0x parallel 1024 32x32->14x14 1 731.3 25.2 29.0x parallel 1152 32x32->14x14 1 819.0 25.3 32.3x parallel 1 32x32->14x14 4 19.4 19.1 1.0x parallel 3 32x32->14x14 4 25.3 19.2 1.3x parallel 8 32x32->14x14 4 39.4 19.1 2.1x parallel 16 32x32->14x14 4 62.0 19.3 3.2x parallel 32 32x32->14x14 4 106.8 19.3 5.5x parallel 64 32x32->14x14 4 195.3 19.1 10.2x parallel 128 32x32->14x14 4 373.3 20.9 17.8x parallel 256 32x32->14x14 4 730.4 25.2 29.0x parallel 512 32x32->14x14 4 1439.0 32.2 44.7x parallel 768 32x32->14x14 4 2170.2 39.9 54.3x parallel 1024 32x32->14x14 4 2902.9 47.6 60.9x parallel 1152 32x32->14x14 4 3269.2 51.3 63.7x parallel 1 32x32->14x14 64 65.9 19.3 3.4x parallel 3 32x32->14x14 64 154.6 19.3 8.0x parallel 8 32x32->14x14 64 378.6 20.9 18.1x parallel 16 32x32->14x14 64 735.4 25.0 29.4x parallel 32 32x32->14x14 64 1445.8 32.1 45.0x parallel 64 32x32->14x14 64 2907.4 50.1 58.0x parallel 128 32x32->14x14 64 5847.9 76.2 76.8x parallel 256 32x32->14x14 64 17563.2 122.0 144.0x parallel 512 32x32->14x14 64 44512.8 221.9 200.6x parallel 768 32x32->14x14 64 66981.9 324.5 206.4x parallel 1024 32x32->14x14 64 89614.5 426.4 210.2x parallel 1152 32x32->14x14 64 100806.7 453.9 222.1x parallel 1 32x32->14x14 256 213.9 19.3 11.1x parallel 3 32x32->14x14 256 572.0 23.0 24.9x parallel 8 32x32->14x14 256 1460.5 32.2 45.3x parallel 16 32x32->14x14 256 2922.9 47.6 61.4x parallel 32 32x32->14x14 256 5860.9 76.5 76.6x parallel 64 32x32->14x14 256 17576.8 121.8 144.4x parallel 128 32x32->14x14 256 44539.4 222.7 200.0x parallel 256 32x32->14x14 256 89627.0 425.7 210.5x parallel 512 32x32->14x14 256 179157.8 606.2 295.6x parallel 768 32x32->14x14 256 268537.4 835.0 321.6x parallel 1024 32x32->14x14 256 357948.8 1058.8 338.1x parallel 1152 32x32->14x14 256 402730.1 1173.6 343.2x parallel 1 64x64->32x32 1 19.2 19.2 1.0x parallel 3 64x64->32x32 1 21.3 19.2 1.1x parallel 8 64x64->32x32 1 27.0 19.1 1.4x parallel 16 64x64->32x32 1 35.0 19.1 1.8x parallel 32 64x64->32x32 1 51.5 24.3 2.1x parallel 64 64x64->32x32 1 87.2 24.4 3.6x parallel 128 64x64->32x32 1 151.3 24.4 6.2x parallel 256 64x64->32x32 1 280.5 21.9 12.8x parallel 512 64x64->32x32 1 544.0 27.4 19.9x parallel 768 64x64->32x32 1 804.3 34.4 23.4x parallel 1024 64x64->32x32 1 1067.3 36.7 29.1x parallel 1152 64x64->32x32 1 1192.5 37.2 32.1x parallel 1 64x64->32x32 4 22.7 19.1 1.2x parallel 3 64x64->32x32 4 31.3 19.0 1.6x parallel 8 64x64->32x32 4 51.6 19.1 2.7x parallel 16 64x64->32x32 4 84.2 19.0 4.4x parallel 32 64x64->32x32 4 150.7 23.9 6.3x parallel 64 64x64->32x32 4 279.4 22.1 12.7x parallel 128 64x64->32x32 4 542.0 27.0 20.1x parallel 256 64x64->32x32 4 1061.8 35.6 29.8x parallel 512 64x64->32x32 4 2099.5 51.3 41.0x parallel 768 64x64->32x32 4 3185.1 68.1 46.8x parallel 1024 64x64->32x32 4 4953.5 86.1 57.5x parallel 1152 64x64->32x32 4 5606.0 96.9 57.9x parallel 1 64x64->32x32 64 88.4 19.2 4.6x parallel 3 64x64->32x32 64 218.2 21.3 10.2x parallel 8 64x64->32x32 64 547.2 27.1 20.2x parallel 16 64x64->32x32 64 1066.7 34.0 31.4x parallel 32 64x64->32x32 64 2105.6 51.3 41.1x parallel 64 64x64->32x32 64 4960.5 86.1 57.6x parallel 128 64x64->32x32 64 11029.8 161.9 68.1x parallel 256 64x64->32x32 64 22072.0 307.1 71.9x parallel 512 64x64->32x32 64 44101.9 595.6 74.1x parallel 768 64x64->32x32 64 66135.2 882.4 74.9x parallel 1024 64x64->32x32 64 88148.9 1167.4 75.5x parallel 1152 64x64->32x32 64 99176.1 1255.3 79.0x parallel 1 64x64->32x32 256 303.8 22.2 13.7x parallel 3 64x64->32x32 256 823.2 31.0 26.5x parallel 8 64x64->32x32 256 2117.6 51.3 41.3x parallel 16 64x64->32x32 256 4974.2 86.7 57.4x parallel 32 64x64->32x32 256 11044.8 162.0 68.2x parallel 64 64x64->32x32 256 22079.8 308.0 71.7x parallel 128 64x64->32x32 256 44112.5 595.4 74.1x parallel 256 64x64->32x32 256 88191.2 1167.6 75.5x parallel 512 64x64->32x32 256 176321.2 1868.5 94.4x parallel 768 64x64->32x32 256 264438.5 2522.7 104.8x parallel 1 128x128->64x64 1 19.5 19.4 1.0x parallel 3 128x128->64x64 1 20.7 19.1 1.1x parallel 8 128x128->64x64 1 27.1 19.1 1.4x parallel 16 128x128->64x64 1 35.3 19.0 1.9x parallel 32 128x128->64x64 1 51.8 19.1 2.7x parallel 64 128x128->64x64 1 84.9 21.5 4.0x parallel 128 128x128->64x64 1 148.3 27.1 5.5x parallel 256 128x128->64x64 1 274.1 33.8 8.1x parallel 512 128x128->64x64 1 529.7 51.4 10.3x parallel 768 128x128->64x64 1 783.6 68.4 11.4x parallel 1024 128x128->64x64 1 1077.8 86.2 12.5x parallel 1152 128x128->64x64 1 1266.4 96.3 13.2x parallel 1 128x128->64x64 4 22.9 19.3 1.2x parallel 3 128x128->64x64 4 31.1 19.1 1.6x parallel 8 128x128->64x64 4 51.9 19.2 2.7x parallel 16 128x128->64x64 4 84.9 22.3 3.8x parallel 32 128x128->64x64 4 148.4 26.8 5.5x parallel 64 128x128->64x64 4 275.3 34.3 8.0x parallel 128 128x128->64x64 4 530.1 51.7 10.3x parallel 256 128x128->64x64 4 1079.4 86.0 12.6x parallel 512 128x128->64x64 4 2787.4 160.0 17.4x parallel 768 128x128->64x64 4 4207.6 231.8 18.2x parallel 1024 128x128->64x64 4 5602.3 303.5 18.5x parallel 1152 128x128->64x64 4 6297.4 337.5 18.7x parallel 1 128x128->64x64 64 88.7 21.7 4.1x parallel 3 128x128->64x64 64 217.1 31.1 7.0x parallel 8 128x128->64x64 64 535.2 51.6 10.4x parallel 16 128x128->64x64 64 1083.1 86.0 12.6x parallel 32 128x128->64x64 64 2792.4 160.0 17.5x parallel 64 128x128->64x64 64 5606.8 302.8 18.5x parallel 128 128x128->64x64 64 11181.1 589.2 19.0x parallel 256 128x128->64x64 64 22325.1 1153.3 19.4x parallel 512 128x128->64x64 64 44612.4 2280.3 19.6x parallel 768 128x128->64x64 64 66901.1 3405.6 19.6x parallel 1 128x128->64x64 256 297.2 34.1 8.7x parallel 3 128x128->64x64 256 805.4 68.2 11.8x parallel 8 128x128->64x64 256 2806.5 160.1 17.5x parallel 16 128x128->64x64 256 5621.0 303.4 18.5x parallel 32 128x128->64x64 256 11196.1 589.0 19.0x parallel 64 128x128->64x64 256 22341.1 1153.6 19.4x parallel 128 128x128->64x64 256 44638.6 2280.2 19.6x parallel 1 256x256->128x128 1 19.9 19.3 1.0x parallel 3 256x256->128x128 1 21.0 18.9 1.1x parallel 8 256x256->128x128 1 27.1 19.0 1.4x parallel 16 256x256->128x128 1 35.1 21.3 1.6x parallel 32 256x256->128x128 1 51.6 27.0 1.9x parallel 64 256x256->128x128 1 84.2 33.7 2.5x parallel 128 256x256->128x128 1 149.5 51.1 2.9x parallel 256 256x256->128x128 1 314.4 85.8 3.7x parallel 512 256x256->128x128 1 769.3 160.7 4.8x parallel 768 256x256->128x128 1 1154.3 230.8 5.0x parallel 1024 256x256->128x128 1 1530.6 300.9 5.1x parallel 1152 256x256->128x128 1 1720.4 336.6 5.1x parallel 1 256x256->128x128 4 22.9 19.1 1.2x parallel 3 256x256->128x128 4 31.2 21.4 1.5x parallel 8 256x256->128x128 4 51.5 26.9 1.9x parallel 16 256x256->128x128 4 84.3 33.3 2.5x parallel 32 256x256->128x128 4 148.9 51.5 2.9x parallel 64 256x256->128x128 4 315.2 85.6 3.7x parallel 128 256x256->128x128 4 770.4 158.3 4.9x parallel 256 256x256->128x128 4 1532.1 301.0 5.1x parallel 512 256x256->128x128 4 3037.4 586.3 5.2x parallel 768 256x256->128x128 4 4541.2 868.0 5.2x parallel 1024 256x256->128x128 4 6043.8 1147.9 5.3x parallel 1152 256x256->128x128 4 6799.2 1290.2 5.3x parallel 1 256x256->128x128 64 89.0 33.9 2.6x parallel 3 256x256->128x128 64 219.7 68.1 3.2x parallel 8 256x256->128x128 64 774.5 158.3 4.9x parallel 16 256x256->128x128 64 1536.0 301.3 5.1x parallel 32 256x256->128x128 64 3043.0 586.8 5.2x parallel 64 256x256->128x128 64 6049.8 1149.7 5.3x parallel 128 256x256->128x128 64 12068.2 2274.0 5.3x parallel 1 256x256->128x128 256 335.5 86.0 3.9x parallel 3 256x256->128x128 256 1171.7 230.0 5.1x parallel 8 256x256->128x128 256 3054.2 586.1 5.2x parallel 16 256x256->128x128 256 6063.3 1149.6 5.3x parallel 32 256x256->128x128 256 12080.5 2274.0 5.3x parallel 1 256x256->512x512 1 21.6 21.3 1.0x parallel 3 256x256->512x512 1 27.0 31.0 0.9x original 8 256x256->512x512 1 38.8 50.0 0.8x original 16 256x256->512x512 1 56.4 84.7 0.7x original 32 256x256->512x512 1 92.4 150.2 0.6x original 64 256x256->512x512 1 176.6 284.4 0.6x original 128 256x256->512x512 1 360.2 557.5 0.6x original 256 256x256->512x512 1 711.9 1092.5 0.7x original 512 256x256->512x512 1 1412.8 2161.4 0.7x original 768 256x256->512x512 1 2108.1 3230.0 0.7x original 1024 256x256->512x512 1 2800.4 4300.0 0.7x original 1152 256x256->512x512 1 3147.3 4833.7 0.7x original 1 256x256->512x512 4 29.7 33.8 0.9x original 3 256x256->512x512 4 47.5 68.0 0.7x original 8 256x256->512x512 4 92.6 149.9 0.6x original 16 256x256->512x512 4 177.9 284.3 0.6x original 32 256x256->512x512 4 360.9 557.9 0.6x original 64 256x256->512x512 4 712.1 1092.6 0.7x original 128 256x256->512x512 4 1414.1 2161.6 0.7x original 256 256x256->512x512 4 2800.7 4300.2 0.7x original 512 256x256->512x512 4 5592.8 8574.3 0.7x original 768 256x256->512x512 4 8524.5 12854.2 0.7x original 1 256x256->512x512 64 192.2 286.3 0.7x original 3 256x256->512x512 64 550.9 825.9 0.7x original 8 256x256->512x512 64 1424.6 2160.8 0.7x original 16 256x256->512x512 64 2813.3 4299.4 0.7x original 32 256x256->512x512 64 5600.7 8574.1 0.7x original 1 256x256->512x512 256 765.1 1092.3 0.7x original 3 256x256->512x512 256 2158.0 3230.7 0.7x original 8 256x256->512x512 256 5645.3 8573.9 0.7x original 1 512x512->256x256 1 19.5 19.3 1.0x parallel 3 512x512->256x256 1 21.5 21.3 1.0x parallel 8 512x512->256x256 1 27.0 27.0 1.0x original 16 512x512->256x256 1 35.6 34.7 1.0x parallel 32 512x512->256x256 1 54.9 52.7 1.0x parallel 64 512x512->256x256 1 93.4 86.0 1.1x parallel 128 512x512->256x256 1 226.6 158.0 1.4x parallel 256 512x512->256x256 1 440.7 298.6 1.5x parallel 512 512x512->256x256 1 867.7 580.2 1.5x parallel 768 512x512->256x256 1 1327.9 859.4 1.5x parallel 1024 512x512->256x256 1 1863.8 1137.8 1.6x parallel 1152 512x512->256x256 1 2131.2 1277.4 1.7x parallel 1 512x512->256x256 4 23.5 22.6 1.0x parallel 3 512x512->256x256 4 32.8 30.9 1.1x parallel 8 512x512->256x256 4 54.0 50.4 1.1x parallel 16 512x512->256x256 4 92.7 85.1 1.1x parallel 32 512x512->256x256 4 225.2 157.9 1.4x parallel 64 512x512->256x256 4 442.0 299.9 1.5x parallel 128 512x512->256x256 4 867.5 580.9 1.5x parallel 256 512x512->256x256 4 1838.3 1139.5 1.6x parallel 512 512x512->256x256 4 3996.4 2251.6 1.8x parallel 768 512x512->256x256 4 6149.2 3364.3 1.8x parallel 1 512x512->256x256 64 98.4 85.9 1.1x parallel 3 512x512->256x256 64 338.1 227.8 1.5x parallel 8 512x512->256x256 64 872.8 580.4 1.5x parallel 16 512x512->256x256 64 1868.6 1139.4 1.6x parallel 32 512x512->256x256 64 3976.1 2251.9 1.8x parallel 1 512x512->256x256 256 460.6 299.3 1.5x parallel 3 512x512->256x256 256 1347.0 860.4 1.6x parallel 8 512x512->256x256 256 4021.8 2252.1 1.8x parallel 1 512x512->1024x1024 1 35.7 34.0 1.1x parallel 3 512x512->1024x1024 1 53.6 67.9 0.8x original 8 512x512->1024x1024 1 94.6 149.9 0.6x original 16 512x512->1024x1024 1 179.6 286.6 0.6x original 32 512x512->1024x1024 1 359.0 559.3 0.6x original 64 512x512->1024x1024 1 679.4 1097.0 0.6x original 128 512x512->1024x1024 1 1321.4 2170.6 0.6x original 256 512x512->1024x1024 1 2643.4 4316.6 0.6x original 512 512x512->1024x1024 1 5237.4 8606.1 0.6x original 768 512x512->1024x1024 1 7829.8 12897.1 0.6x original 1 512x512->1024x1024 4 63.6 85.0 0.7x original 3 512x512->1024x1024 4 129.3 217.5 0.6x original 8 512x512->1024x1024 4 361.7 559.5 0.6x original 16 512x512->1024x1024 4 681.3 1094.9 0.6x original 32 512x512->1024x1024 4 1322.3 2169.9 0.6x original 64 512x512->1024x1024 4 2639.3 4313.6 0.6x original 128 512x512->1024x1024 4 5236.8 8603.7 0.6x original 1 512x512->1024x1024 64 723.2 1098.0 0.7x original 3 512x512->1024x1024 64 2027.8 3241.6 0.6x original 8 512x512->1024x1024 64 5275.7 8605.9 0.6x original 1 512x512->1024x1024 256 2800.8 4315.7 0.6x original 3 512x512->1024x1024 256 7989.1 12901.7 0.6x original 1 1024x1024->512x512 1 23.2 22.4 1.0x parallel 3 1024x1024->512x512 1 27.8 31.8 0.9x original 8 1024x1024->512x512 1 40.0 50.0 0.8x original 16 1024x1024->512x512 1 63.6 86.6 0.7x original 32 1024x1024->512x512 1 124.4 160.0 0.8x original 64 1024x1024->512x512 1 231.1 302.9 0.8x original 128 1024x1024->512x512 1 437.5 590.8 0.7x original 256 1024x1024->512x512 1 854.4 1156.5 0.7x original 512 1024x1024->512x512 1 1727.5 2287.3 0.8x original 768 1024x1024->512x512 1 2672.1 3419.7 0.8x original 1 1024x1024->512x512 4 30.9 34.7 0.9x original 3 1024x1024->512x512 4 50.5 68.1 0.7x original 8 1024x1024->512x512 4 125.3 160.1 0.8x original 16 1024x1024->512x512 4 231.5 303.0 0.8x original 32 1024x1024->512x512 4 439.1 590.5 0.7x original 64 1024x1024->512x512 4 855.8 1158.3 0.7x original 128 1024x1024->512x512 4 1727.6 2287.2 0.8x original 1 1024x1024->512x512 64 242.2 302.8 0.8x original 3 1024x1024->512x512 64 656.6 873.5 0.8x original 8 1024x1024->512x512 64 1739.7 2286.9 0.8x original 1 1024x1024->512x512 256 897.3 1157.8 0.8x original 3 1024x1024->512x512 256 2711.2 3419.3 0.8x original 1 1024x1024->2048x2048 1 95.3 84.9 1.1x parallel 3 1024x1024->2048x2048 1 167.1 217.8 0.8x original 8 1024x1024->2048x2048 1 403.5 559.1 0.7x original 16 1024x1024->2048x2048 1 711.1 1098.0 0.6x original 32 1024x1024->2048x2048 1 1341.4 2171.8 0.6x original 64 1024x1024->2048x2048 1 2587.8 4320.2 0.6x original 128 1024x1024->2048x2048 1 5178.0 8614.5 0.6x original 1 1024x1024->2048x2048 4 231.1 288.0 0.8x original 3 1024x1024->2048x2048 4 566.0 829.7 0.7x original 8 1024x1024->2048x2048 4 1348.3 2172.1 0.6x original 16 1024x1024->2048x2048 4 2605.9 4319.6 0.6x original 32 1024x1024->2048x2048 4 5200.4 8614.9 0.6x original 1 1024x1024->2048x2048 64 2757.4 4321.3 0.6x original 3 1024x1024->2048x2048 64 7963.1 12915.0 0.6x original 1 2048x2048->1024x1024 1 38.0 34.8 1.1x parallel 3 2048x2048->1024x1024 1 56.0 68.4 0.8x original 8 2048x2048->1024x1024 1 125.2 162.2 0.8x original 16 2048x2048->1024x1024 1 224.9 308.9 0.7x original 32 2048x2048->1024x1024 1 423.6 600.7 0.7x original 64 2048x2048->1024x1024 1 818.1 1174.9 0.7x original 128 2048x2048->1024x1024 1 1601.7 2325.9 0.7x original 1 2048x2048->1024x1024 4 72.6 87.8 0.8x original 3 2048x2048->1024x1024 4 177.1 235.8 0.8x original 8 2048x2048->1024x1024 4 426.4 600.7 0.7x original 16 2048x2048->1024x1024 4 819.7 1174.9 0.7x original 32 2048x2048->1024x1024 4 1603.0 2327.6 0.7x original 1 2048x2048->1024x1024 64 856.9 1176.4 0.7x original 3 2048x2048->1024x1024 64 2452.0 3479.2 0.7x original 1 4096x4096->2048x2048 1 109.3 89.8 1.2x parallel 3 4096x4096->2048x2048 1 213.3 237.9 0.9x original 8 4096x4096->2048x2048 1 463.4 607.2 0.8x original 16 4096x4096->2048x2048 1 868.9 1189.8 0.7x original 32 4096x4096->2048x2048 1 1667.6 2356.0 0.7x original 1 4096x4096->2048x2048 4 270.8 313.3 0.9x original 3 4096x4096->2048x2048 4 674.3 899.5 0.7x original 8 4096x4096->2048x2048 4 1674.1 2357.3 0.7x original 1 640x480->320x240 1 19.5 19.3 1.0x parallel 3 640x480->320x240 1 22.8 22.5 1.0x parallel 8 640x480->320x240 1 27.9 31.1 0.9x original 16 640x480->320x240 1 37.4 37.6 1.0x original 32 640x480->320x240 1 56.0 57.9 1.0x original 64 640x480->320x240 1 108.4 104.1 1.0x parallel 128 640x480->320x240 1 227.7 192.6 1.2x parallel 256 640x480->320x240 1 441.5 367.1 1.2x parallel 512 640x480->320x240 1 857.8 717.5 1.2x parallel 768 640x480->320x240 1 1274.1 1063.3 1.2x parallel 1024 640x480->320x240 1 1687.9 1409.8 1.2x parallel 1152 640x480->320x240 1 1894.7 1580.4 1.2x parallel 1 640x480->320x240 4 23.2 24.7 0.9x original 3 640x480->320x240 4 33.1 33.0 1.0x parallel 8 640x480->320x240 4 55.9 58.0 1.0x original 16 640x480->320x240 4 108.5 104.1 1.0x parallel 32 640x480->320x240 4 228.8 192.7 1.2x parallel 64 640x480->320x240 4 441.2 367.2 1.2x parallel 128 640x480->320x240 4 856.8 715.7 1.2x parallel 256 640x480->320x240 4 1687.1 1409.6 1.2x parallel 512 640x480->320x240 4 3345.4 2790.4 1.2x parallel 768 640x480->320x240 4 5003.8 4174.5 1.2x parallel 1 640x480->320x240 64 113.0 104.2 1.1x parallel 3 640x480->320x240 64 340.4 280.1 1.2x parallel 8 640x480->320x240 64 862.8 716.9 1.2x parallel 16 640x480->320x240 64 1692.3 1408.3 1.2x parallel 32 640x480->320x240 64 3349.6 2789.7 1.2x parallel 1 640x480->320x240 256 459.7 368.6 1.2x parallel 3 640x480->320x240 256 1292.6 1062.0 1.2x parallel 8 640x480->320x240 256 3365.7 2791.7 1.2x parallel 1 1920x1080->960x540 1 30.3 30.1 1.0x parallel 3 1920x1080->960x540 1 40.5 43.9 0.9x original 8 1920x1080->960x540 1 68.5 87.7 0.8x original 16 1920x1080->960x540 1 133.2 164.0 0.8x original 32 1920x1080->960x540 1 245.0 313.9 0.8x original 64 1920x1080->960x540 1 465.6 610.3 0.8x original 128 1920x1080->960x540 1 900.3 1195.1 0.8x original 256 1920x1080->960x540 1 1770.1 2366.2 0.7x original 1 1920x1080->960x540 4 46.0 50.6 0.9x original 3 1920x1080->960x540 4 106.6 127.7 0.8x original 8 1920x1080->960x540 4 245.5 313.2 0.8x original 16 1920x1080->960x540 4 465.7 610.4 0.8x original 32 1920x1080->960x540 4 900.4 1194.8 0.8x original 64 1920x1080->960x540 4 1770.2 2365.8 0.7x original 1 1920x1080->960x540 64 485.2 610.9 0.8x original 3 1920x1080->960x540 64 1353.0 1780.9 0.8x original 1 1920x1080->960x540 256 1850.6 2367.3 0.8x original 1 1920x1080->3840x2160 1 170.6 150.1 1.1x parallel 3 1920x1080->3840x2160 1 362.7 420.4 0.9x original 8 1920x1080->3840x2160 1 773.0 1088.7 0.7x original 16 1920x1080->3840x2160 1 1368.9 2156.7 0.6x original 32 1920x1080->3840x2160 1 2579.6 4286.9 0.6x original 64 1920x1080->3840x2160 1 4976.8 8549.2 0.6x original 1 1920x1080->3840x2160 4 471.0 555.8 0.8x original 3 1920x1080->3840x2160 4 1085.0 1622.0 0.7x original 8 1920x1080->3840x2160 4 2593.5 4287.2 0.6x original 16 1920x1080->3840x2160 4 4991.1 8551.4 0.6x original 1 1920x1080->3840x2160 64 5305.3 8550.9 0.6x original 1 800x1200->400x600 1 22.9 22.7 1.0x parallel 3 800x1200->400x600 1 27.6 29.3 0.9x original 8 800x1200->400x600 1 39.9 48.3 0.8x original 16 800x1200->400x600 1 62.5 80.8 0.8x original 32 800x1200->400x600 1 128.8 155.6 0.8x original 64 800x1200->400x600 1 238.2 293.5 0.8x original 128 800x1200->400x600 1 457.0 569.8 0.8x original 256 800x1200->400x600 1 891.4 1116.7 0.8x original 512 800x1200->400x600 1 1760.5 2207.5 0.8x original 768 800x1200->400x600 1 2634.0 3298.7 0.8x original 1024 800x1200->400x600 1 3527.2 4390.7 0.8x original 1 800x1200->400x600 4 31.1 34.9 0.9x original 3 800x1200->400x600 4 51.6 66.1 0.8x original 8 800x1200->400x600 4 130.6 157.0 0.8x original 16 800x1200->400x600 4 239.9 293.7 0.8x original 32 800x1200->400x600 4 456.4 569.4 0.8x original 64 800x1200->400x600 4 891.7 1115.3 0.8x original 128 800x1200->400x600 4 1761.4 2208.9 0.8x original 256 800x1200->400x600 4 3527.4 4389.0 0.8x original 1 800x1200->400x600 64 248.4 293.5 0.8x original 3 800x1200->400x600 64 683.8 842.6 0.8x original 8 800x1200->400x600 64 1771.7 2208.5 0.8x original 16 800x1200->400x600 64 3535.7 4389.9 0.8x original 1 800x1200->400x600 256 931.6 1115.2 0.8x original 3 800x1200->400x600 256 2675.1 3298.7 0.8x original 1 224x224->518x518 1 21.8 21.6 1.0x parallel 3 224x224->518x518 1 27.1 31.3 0.9x original 8 224x224->518x518 1 39.3 52.1 0.8x original 16 224x224->518x518 1 58.5 87.0 0.7x original 32 224x224->518x518 1 94.9 155.3 0.6x original 64 224x224->518x518 1 178.6 292.8 0.6x original 128 224x224->518x518 1 372.4 572.4 0.7x original 256 224x224->518x518 1 734.1 1124.6 0.7x original 512 224x224->518x518 1 1448.6 2222.0 0.7x original 768 224x224->518x518 1 2174.2 3321.2 0.7x original 1024 224x224->518x518 1 2877.5 4420.9 0.7x original 1152 224x224->518x518 1 3235.2 4970.5 0.7x original 1 224x224->518x518 4 29.7 35.3 0.8x original 3 224x224->518x518 4 49.5 69.5 0.7x original 8 224x224->518x518 4 95.2 154.5 0.6x original 16 224x224->518x518 4 179.6 292.4 0.6x original 32 224x224->518x518 4 372.6 572.2 0.7x original 64 224x224->518x518 4 734.4 1126.3 0.7x original 128 224x224->518x518 4 1449.5 2223.1 0.7x original 256 224x224->518x518 4 2877.1 4418.8 0.7x original 512 224x224->518x518 4 5736.4 8820.2 0.7x original 768 224x224->518x518 4 8633.4 13215.4 0.7x original 1 224x224->518x518 64 193.5 293.3 0.7x original 3 224x224->518x518 64 564.2 849.3 0.7x original 8 224x224->518x518 64 1460.8 2223.4 0.7x original 16 224x224->518x518 64 2890.8 4419.7 0.7x original 32 224x224->518x518 64 5747.3 8819.7 0.7x original 1 224x224->518x518 256 783.7 1123.3 0.7x original 3 224x224->518x518 256 2222.3 3320.4 0.7x original 8 224x224->518x518 256 5783.1 8820.8 0.7x original ``` </details> ## Motivation Vision-language models (Kimi K2.5, Llama Vision, etc.) use `F.interpolate(mode='bicubic')` to resize position embeddings at inference time. These workloads have many channels (768–1152) but small spatial grids (e.g. Kimi K2.5 interpolates from 64x64). The existing kernel loops over batch and channels sequentially within each thread, launching only `output_height * output_width` threads total. For example in Kimi K2.5 case: 64x64 → 74x74 = 5,476 output elements → 5,476 threads → 171 warps → across 148 SMs w/ 64 warps per SM -> this is <2% occupancy ## Changes Added a parallel kernel variant (`upsample_bicubic2d_out_frame_parallel`) that spreads batch×channel work across `blockIdx.z`. The launcher dispatches to this kernel when `output_height * output_width <= 18432`, otherwise uses the original kernel unchanged. **Files (3 files, +156/-13):** - `aten/src/ATen/native/cuda/UpSampleBicubic2d.cu` — new parallel forward kernel + dispatch heuristic - `benchmarks/operator_benchmark/pt/interpolate_test.py` — added `device` param + CUDA high-channel configs - `test/test_nn.py` — new `test_upsamplingBicubic2d_many_channels` correctness test ## Testing ```bash python test/test_nn.py TestNNDeviceTypeCUDA.test_upsamplingBicubic2d_correctness python test/test_nn.py TestNNDeviceTypeCUDA.test_upsamplingBicubic2d_many_channels ``` Numerics validated: parallel kernel produces **bit-identical** output to the original across float32 and bfloat16 for all tested configurations. ## BC-breaking? No. The interpolation math is unchanged. The parallel kernel produces identical results. The dispatch is purely a performance optimization — same API, same numerics. Pull Request resolved: pytorch#174578 Approved by: https://github.com/eqy, https://github.com/Skylion007
This PR modifies the skips for FP8 to a proper way, as the tests were not skipped. Not sure why `self.skipTest` and `unittest.SkipTest` do not work, but the macros are more appropriate anyways. Pull Request resolved: pytorch#170528 Approved by: https://github.com/albanD
…torch#174990) Summary: This converts NanCheck into an op so it can be used from outside of ProcessGroupNCCL. This can be used from torchcomms. Misc changes: * add CPU implementation * use CUDA_KERNEL_ASSERT macro so it logs a more helpful message when nancheck fires Reland since pytorch#174736 was reverted Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/ffe476d1bfe408e9b34fd73d214aff790c9dea09 Test plan from GitHub: CI ``` $ python -c "import torch; torch.ops.c10d.check_for_nan(torch.tensor(float('nan'), device='cuda')); torch.cuda.synchronize()" (pytorch-3.12) /home/tristanr/pytorch/torch/csrc/distributed/c10d/NanCheck.cu:217: checkForNaN: block: [0,0,0], thread: [0,0,0] Assertion `!isnan(tailPtr[threadIdx.x])` failed. Traceback (most recent call last): File "<string>", line 1, in <module> File "/home/tristanr/pytorch/torch/cuda/__init__.py", line 1165, in synchronize return torch._C._cuda_synchronize() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ torch.AcceleratorError: CUDA error: device-side assert triggered Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information. CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. ``` Differential Revision: D93267474 Pull Request resolved: pytorch#174990 Approved by: https://github.com/dcci, https://github.com/kapilsh
Addresses issue raised in pytorch#175193 Pull Request resolved: pytorch#175237 Approved by: https://github.com/malfet
This reverts commit 54603b1. Reverted pytorch#174793 on behalf of https://github.com/atalman due to Need to revert pytorch#174846 please reland once the signal is clear ([comment](pytorch#174793 (comment)))
This reverts commit fc79a04. Reverted pytorch#174846 on behalf of https://github.com/atalman due to test/test_dynamic_shapes.py::TestUbackedOps::test_unbacked_norm_no_dde [GH job link](https://github.com/pytorch/pytorch/actions/runs/22125637092/job/63956164686) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/fc79a048f03c5749b1e8d975e5ad5f5c6e79d5aa) ([comment](pytorch#174846 (comment)))
Before overlap preserving bucketing was trying to merge all_reduce and reduce_scatters as it used bucket_key that had the same type for all_reduce and reduce_scatter. This resulted in failures. Pull Request resolved: pytorch#175150 Approved by: https://github.com/eellison
…und triton bug (pytorch#175168)" This reverts commit 9eabb4d. Reverted pytorch#175168 on behalf of https://github.com/drisspg due to I wanna measure the impact on flash first where we care more about numbers of instructions ([comment](pytorch#175168 (comment)))
`test_pdl_template_and_delay` is failing on Spark and Thor because they don't have enough SMs for max_autotune/templates. Added skip decorator for small GPUs. Pull Request resolved: pytorch#174597 Approved by: https://github.com/eqy
|
Jenkins build for 3fb1b1cd9a042d2a282485bd3dbd9531427e2b13 commit finished as FAILURE |
|
Jenkins build for 3fb1b1cd9a042d2a282485bd3dbd9531427e2b13 commit finished as FAILURE |
|
Jenkins build for 3fb1b1cd9a042d2a282485bd3dbd9531427e2b13 commit finished as FAILURE |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
rocm_base: cc3acaf
latest upstream commit in this IFU: 7984635