diff --git a/csrc/options.cpp b/csrc/options.cpp index e2ad832a46e..453c1394cec 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -185,6 +185,7 @@ const std::unordered_map& getEnableOptions() { {"p2p_protocol", EnableOption::P2pProtocol}, {"multicast_protocol", EnableOption::MulticastProtocol}, {"parallel_serde", EnableOption::ParallelSerde}, + {"infer_contiguity", EnableOption::InferContiguity}, }; return available_options; } diff --git a/csrc/options.h b/csrc/options.h index 3f21c3d9392..3f70e8ad992 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -130,6 +130,7 @@ enum class EnableOption { MulticastProtocol, //! Prescribe multicast protocol: //! memcpy|multimem|batch_memcpy ParallelSerde, //! Enable deserializing FusionExecutorCache in parallel + InferContiguity, //! Enable contiguity inference EndOfOption //! Placeholder for counting the number of elements }; diff --git a/csrc/runtime/allocations.cpp b/csrc/runtime/allocations.cpp index a83a70cf15e..ad6d75b4b02 100644 --- a/csrc/runtime/allocations.cpp +++ b/csrc/runtime/allocations.cpp @@ -21,12 +21,12 @@ namespace nvfuser { -KernelArgumentHolder inferOutputShapeAndContiguousStrides( +KernelArgumentHolder inferContiguousOutputMetaTensor( Fusion* fusion, const KernelArgumentHolder& args, PrecomputedValues* evaluator_precomputed_values) { FUSER_PERF_SCOPE( - "fusion_executor::allocations::inferOutputShapeAndContiguousStrides"); + "fusion_executor::allocations::inferContiguousOutputMetaTensor"); ExpressionEvaluator expr_eval; std::unique_ptr evaluator_precomputed_values_up = nullptr; diff --git a/csrc/runtime/allocations.h b/csrc/runtime/allocations.h index 981d3a6bc17..5e9b2da0ae5 100644 --- a/csrc/runtime/allocations.h +++ b/csrc/runtime/allocations.h @@ -46,7 +46,7 @@ struct GlobalBufferInfo { //! pushing scalar int 0 as a place-holder. //! 2. This API does not allocate output in memory, but only returns the //! inferred output sizes. Used in runtime/fusion_executor_cache.cpp. -KernelArgumentHolder inferOutputShapeAndContiguousStrides( +KernelArgumentHolder inferContiguousOutputMetaTensor( Fusion* fusion, const KernelArgumentHolder& args, PrecomputedValues* evaluator_precomputed_values = nullptr); diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp index f3ea7866e58..d9c0c277cdf 100644 --- a/csrc/runtime/fusion_kernel_runtime.cpp +++ b/csrc/runtime/fusion_kernel_runtime.cpp @@ -337,6 +337,49 @@ KernelArgumentHolder FusionKernelRuntime::runWithInputs( return fusion_outputs; } +KernelArgumentHolder FusionKernelRuntime::inferOutputMetaTensor( + HeuristicParamsList* heuristics, + SegmentedGroup* group_to_run, + const KernelArgumentHolder& group_runtime_inputs, + PrecomputedValues* evaluator_precomputed_values) const { + FUSER_PERF_SCOPE("FusionKernelRuntime::inferOutputMetaTensor"); + NVF_ERROR(heuristics != nullptr); + Fusion* fusion_to_run = group_to_run->getFusion(); + const auto& heuristic_params = heuristics->at(group_to_run->groupId()); + const bool is_expr_eval = + heuristic_params->scheduler_type == SchedulerType::ExprEval; + if (!(is_expr_eval && isOptionEnabled(EnableOption::InferContiguity))) { + return inferContiguousOutputMetaTensor( + fusion_to_run, group_runtime_inputs, evaluator_precomputed_values); + } + + // For expr evaluated fusion, the striding rules follow that of ATen. + ExpressionEvaluator eval_fusion; + for (const auto& [in, tensor_pv] : + zip(fusion_to_run->inputs(), group_runtime_inputs)) { + if (tensor_pv.is()) { + const auto& t = tensor_pv.as(); + if (t.defined()) { + const auto meta_t = at::empty_strided( + t.sizes(), + t.strides(), + at::TensorOptions().device(at::kMeta).dtype(t.dtype())); + eval_fusion.bind(in, meta_t); + } else { + eval_fusion.bind(in, t); + } + } else { + eval_fusion.bind(in, tensor_pv); + } + } + KernelArgumentHolder group_runtime_outputs; + for (Val* v : fusion_to_run->outputs()) { + auto result = eval_fusion.evaluate(v); + group_runtime_outputs.push(result); + } + return group_runtime_outputs; +} + std::vector FusionKernelRuntime::prepareInputs( const KernelArgumentHolder& args) const { std::vector all_runtime_inputs; @@ -362,12 +405,8 @@ std::vector FusionKernelRuntime::prepareInputs( group_runtime_inputs.setCacheId(group_cache_id.value()); } - // TODO: inferOutputShapeAndContiguousStrides doesn't seem to strictly - // require a Fusion for each segment. Consider using the complete fusion - // instead. - auto fusion_to_run = segmented_fusion_->makeFusion(group_to_run).second; - auto group_runtime_outputs = inferOutputShapeAndContiguousStrides( - fusion_to_run.get(), group_runtime_inputs); + auto group_runtime_outputs = inferOutputMetaTensor( + heuristics_.get(), group_to_run, group_runtime_inputs); // map output args to tensor map args_manager.updateWithSegmentOutputs( @@ -599,8 +638,9 @@ std::optional> FusionKernelRuntime:: } // Generate metadata for the fusion's outputs - auto group_runtime_outputs = inferOutputShapeAndContiguousStrides( - fusion_to_run, + auto group_runtime_outputs = inferOutputMetaTensor( + heuristics.get(), + group_to_run, group_runtime_inputs, evaluator_precomputed_values.get()); diff --git a/csrc/runtime/fusion_kernel_runtime.h b/csrc/runtime/fusion_kernel_runtime.h index 31965df07c2..e8b9e7bb0ac 100644 --- a/csrc/runtime/fusion_kernel_runtime.h +++ b/csrc/runtime/fusion_kernel_runtime.h @@ -173,6 +173,16 @@ class FusionKernelRuntime { //! Access the list of schedulers maintained in this runtime instance const std::vector>& schedulers() const; + //! Infer the output shape and stride of the fusion as tensors on Meta device + //! If the group is scheduled to be evaluated using ExprEval, the output + //! tensors are inferred using the ExprEval on meta device. Otherwise, the + //! output tensors are inferred assuming they are contiguous. + KernelArgumentHolder inferOutputMetaTensor( + HeuristicParamsList* heuristics, + SegmentedGroup* group_to_run, + const KernelArgumentHolder& group_runtime_inputs, + PrecomputedValues* evaluator_precomputed_values = nullptr) const; + // Create KernelArgumentHolders for all of the segments. Sorted in // the run order. std::vector prepareInputs( diff --git a/tests/cpp/test_alias.cpp b/tests/cpp/test_alias.cpp index 0f419f53d97..3fc9d68d1b7 100644 --- a/tests/cpp/test_alias.cpp +++ b/tests/cpp/test_alias.cpp @@ -497,6 +497,9 @@ TEST_F(AliasTest, Issue1452) { } TEST_F(AliasTest, AliasOutputBeforeNonAliasOutput) { + EnableOptionsGuard opt_guard; + EnableOptionsGuard::getCurOptions().unset(EnableOption::InferContiguity); + auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); diff --git a/tests/cpp/test_indexing_advanced.cpp b/tests/cpp/test_indexing_advanced.cpp index 92cbbe49783..a0614fcb426 100644 --- a/tests/cpp/test_indexing_advanced.cpp +++ b/tests/cpp/test_indexing_advanced.cpp @@ -26,6 +26,7 @@ class AdvancedIndexingTest : public NVFuserFixtureParamTest { } else { EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); } + EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity); } }; @@ -33,6 +34,7 @@ class AdvancedIndexingIdModelTest : public NVFuserTest { protected: void SetUp() override { EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel); + EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity); } }; diff --git a/tests/cpp/test_layout_op.cpp b/tests/cpp/test_layout_op.cpp index 8ac5e9df436..1942b43d7ee 100644 --- a/tests/cpp/test_layout_op.cpp +++ b/tests/cpp/test_layout_op.cpp @@ -88,6 +88,7 @@ class LayoutOpTest : public NVFuserTest { void SetUp() override { NVFuserTest::SetUp(); EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel); + EnableOptionsGuard::getCurOptions().unset(EnableOption::InferContiguity); } }; diff --git a/tests/cpp/test_loop_domain_scheduling.cpp b/tests/cpp/test_loop_domain_scheduling.cpp index b3b775aa776..bd34b31556a 100644 --- a/tests/cpp/test_loop_domain_scheduling.cpp +++ b/tests/cpp/test_loop_domain_scheduling.cpp @@ -41,6 +41,7 @@ class LoopDomainSchedulingTest : public NVFuserTest { protected: void SetUp() override { EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel); + EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity); } }; diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 68a2f1d0f9b..3c2b4a1f058 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -972,7 +972,13 @@ TEST_F(BlockQuantizationValidationTest, MergesMustBeContiguous) { class BlockQuantizationSchedulingTest : public BlackwellBase, public ::testing::WithParamInterface< - std::tuple, bool, bool>> {}; + std::tuple, bool, bool>> { + protected: + void SetUp() override { + BlackwellBase::SetUp(); + EnableOptionsGuard::getCurOptions().unset(EnableOption::InferContiguity); + } +}; TEST_P(BlockQuantizationSchedulingTest, AutoScheduleSingleOp) { const auto data_type = std::get<0>(GetParam()); diff --git a/tests/cpp/test_matmul_aten_evaluation.cpp b/tests/cpp/test_matmul_aten_evaluation.cpp index 150646ff697..68d574b5c0c 100644 --- a/tests/cpp/test_matmul_aten_evaluation.cpp +++ b/tests/cpp/test_matmul_aten_evaluation.cpp @@ -371,37 +371,4 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(Sizes({n, 1})), testing::Values(Sizes({n})))); -using MatmulNodeTest = NVFuserTest; - -TEST_F(MatmulNodeTest, OutputStrides) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - TensorView* x = makeSymbolicTensor(2, DataType::Half); - TensorView* y = makeSymbolicTensor(2, DataType::Half); - TensorView* z = matmul(x, y); - - fusion->addInput(x); - fusion->addInput(y); - fusion->addOutput(z); - - z->setAllocationDomain({z->axis(1), z->axis(0), z->axis(2)}, true); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); - at::Tensor x_tensor = at::randn({2, 3}, options); - at::Tensor y_tensor = at::randn({3, 5}, options); - - FusionExecutorCache executor_cache(std::move(fusion)); - auto outs = executor_cache.runFusionWithInputs({x_tensor, y_tensor}); - at::Tensor z_tensor = outs[0].as(); - testValidate( - executor_cache.fusion(), - {z_tensor}, - {x_tensor, y_tensor}, - __LINE__, - __FILE__); - - EXPECT_THAT(z_tensor.strides(), ElementsAre(1, 2)); -} - } // namespace nvfuser diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index c9d85072960..277c9c8335b 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -2802,6 +2802,7 @@ class MatmulFusionTest EnableOptionsGuard::getCurOptions().set( EnableOption::FuseMultipleMatmuls); } + EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity); } bool fusion_enabled = GetParam().first; diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 1b60405dbed..bc181a07718 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -26,6 +26,7 @@ class PointwiseTest : public NVFuserTest { protected: void SetUp() override { EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel); + EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity); } }; diff --git a/tests/cpp/test_rng.cpp b/tests/cpp/test_rng.cpp index d0e7e56218a..d7c32ced22e 100644 --- a/tests/cpp/test_rng.cpp +++ b/tests/cpp/test_rng.cpp @@ -69,6 +69,7 @@ at::Tensor generate_normal(int64_t size, at::ScalarType dtype) { class RNGTest : public NVFuserTest { void SetUp() override { EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel); + EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity); } }; diff --git a/tests/cpp/utils.cpp b/tests/cpp/utils.cpp index e09614e9278..a74950e703b 100644 --- a/tests/cpp/utils.cpp +++ b/tests/cpp/utils.cpp @@ -59,6 +59,7 @@ void NVFuserTest::SetUp() { GTEST_SKIP() << "skipping tests on pre-PASCAL GPUs"; } EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel); + EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity); } NVFuserTest::~NVFuserTest() { diff --git a/tests/python/direct/conftest.py b/tests/python/direct/conftest.py index 1e8465beef5..5c9e812cb48 100644 --- a/tests/python/direct/conftest.py +++ b/tests/python/direct/conftest.py @@ -28,6 +28,8 @@ def exec_nvfuser( new_fusion_expected=True, expected_fd_str=None, device=None, + enable_options=None, + disable_options=None, validate_results=False, ): # Copy inputs because aliased outputs can modify inputs when running @@ -64,12 +66,20 @@ def exec_nvfuser( if validate_results: out = fd.validate(inputs) else: + if enable_options is None: + enable_options = [] + if disable_options is None: + disable_options = [] out = fd.execute( inputs, device=device, + _enable_options=enable_options, + _disable_options=disable_options, ) - assert check_captured_python_definition(out, fd, inputs_captured, device) + assert check_captured_python_definition( + out, fd, inputs_captured, device, enable_options, disable_options + ) assert expected_fd_str is None or expected_fd_str in repr(fd) return out, fd diff --git a/tests/python/direct/test_python_frontend.py b/tests/python/direct/test_python_frontend.py index c576ac98053..01c692123c8 100644 --- a/tests/python/direct/test_python_frontend.py +++ b/tests/python/direct/test_python_frontend.py @@ -2763,3 +2763,101 @@ def fusion_func(fd: FusionDefinition): out, _ = nvfuser_direct_test.exec_nvfuser(fusion_func, inputs) nvfuser_direct_test.assertEqual(out[0], inputs[0]) + + +def test_issue4888(nvfuser_direct_test): + # https://github.com/NVIDIA/Fuser/issues/4888 + def nvfuser_fusion_id2(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[4096, 4097], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.define_tensor( + shape=[4096, 4097], + contiguity=[True, True], + dtype=DataType.Bool, + is_cpu=False, + stride_order=[1, 0], + ) + T2 = fd.define_tensor( + shape=[4096, 4097], + contiguity=[True, True], + dtype=DataType.Bool, + is_cpu=False, + stride_order=[1, 0], + ) + T3 = fd.define_tensor( + shape=[1, 32, 4096, 4096], + contiguity=[None, True, True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[3, 2, 1, 0], + ) + T4 = fd.ops.cast(T0, dtype=DataType.Float) + T5 = fd.ops.bitwise_or(T1, T2) + T6 = fd.ops.set(T5) + fd.add_output(T6, T1) + T7 = fd.ops.cast(T6, dtype=DataType.Float) + T8 = fd.ops.mul(T4, T7) + T9 = fd.ops.cast(T8, dtype=DataType.BFloat16) + T10 = fd.ops.set(T9) + fd.add_output(T10, T0) + T15 = fd.ops.broadcast_in_dim(T10, shape=[1, 4096, 4097], broadcast_dims=[1, 2]) + T21 = fd.ops.broadcast_in_dim( + T15, shape=[1, 1, 4096, 4097], broadcast_dims=[0, 2, 3] + ) + T27 = fd.ops.broadcast_in_dim( + T21, shape=[1, 1, 4096, 4097], broadcast_dims=[0, 1, 2, 3] + ) + T43 = fd.ops.slice( + T27, + start_indices=[0, 0, 0, 0], + end_indices=[1, 1, 4096, 4096], + strides=[1, 1, 1, 1], + manual_normalization=0, + ) + T49 = fd.ops.broadcast_in_dim( + T43, shape=[1, 32, 4096, 4096], broadcast_dims=[0, 1, 2, 3] + ) + T50 = fd.ops.cast(T49, dtype=DataType.Float) + T51 = fd.ops.cast(T3, dtype=DataType.Float) + S52 = fd.define_scalar(0.0883883, dtype=DataType.Double) + T53 = fd.ops.mul(T51, S52) + T54 = fd.ops.add(T53, T50) + T55 = fd.ops.max(T54, dims=[3], keepdim=False, dtype=DataType.Null) + T61 = fd.ops.broadcast_in_dim( + T55, shape=[1, 32, 4096, 1], broadcast_dims=[0, 1, 2] + ) + T67 = fd.ops.broadcast_in_dim( + T61, shape=[1, 32, 4096, 4096], broadcast_dims=[0, 1, 2, 3] + ) + T68 = fd.ops.sub(T54, T67) + T69 = fd.ops.exp(T68) + T70 = fd.ops.sum(T69, dims=[3], keepdim=False, dtype=DataType.Null) + T76 = fd.ops.broadcast_in_dim( + T70, shape=[1, 32, 4096, 1], broadcast_dims=[0, 1, 2] + ) + T82 = fd.ops.broadcast_in_dim( + T76, shape=[1, 32, 4096, 4096], broadcast_dims=[0, 1, 2, 3] + ) + T83 = fd.ops.reciprocal(T82) + T84 = fd.ops.mul(T69, T83) + T85 = fd.ops.cast(T84, dtype=DataType.BFloat16) + fd.add_output(T49) + fd.add_output(T84) + fd.add_output(T85) + + inputs = [ + torch.testing.make_tensor((4096, 4097), dtype=torch.bfloat16, device="cuda:0"), + torch.testing.make_tensor((4096, 4097), dtype=torch.bool, device="cuda:0"), + torch.testing.make_tensor((4096, 4097), dtype=torch.bool, device="cuda:0"), + torch.testing.make_tensor( + (1, 32, 4096, 4096), dtype=torch.bfloat16, device="cuda:0" + ), + ] + nvfuser_direct_test.exec_nvfuser( + nvfuser_fusion_id2, inputs, enable_options=["infer_contiguity"] + ) diff --git a/tests/python/direct_utils/utils.py b/tests/python/direct_utils/utils.py index f5eb652a39e..6ee05e13b46 100644 --- a/tests/python/direct_utils/utils.py +++ b/tests/python/direct_utils/utils.py @@ -38,7 +38,14 @@ def is_pre_blackwell(): # Get string representation for FusionDefinition # Run captured python definition # Check that the result of captured python definition matches original results -def check_captured_python_definition(reference_outputs, fd, inputs, device=None): +def check_captured_python_definition( + reference_outputs, + fd, + inputs, + device=None, + enable_options=None, + disable_options=None, +): try: fd_str = fd.__repr__() func_name = "nvfuser_fusion" @@ -49,7 +56,16 @@ def check_captured_python_definition(reference_outputs, fd, inputs, device=None) eval(func_name)(fd_cap) torch.manual_seed(0) - captured_outputs = fd_cap.execute(inputs, device=device) + if enable_options is None: + enable_options = [] + if disable_options is None: + disable_options = [] + captured_outputs = fd_cap.execute( + inputs, + device=device, + _enable_options=enable_options, + _disable_options=disable_options, + ) if len(reference_outputs) != len(captured_outputs): return False