-
Notifications
You must be signed in to change notification settings - Fork 77
Use ATen ops on meta tensor to compute output shapes and strides for ExprEval segments #5082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ExprEval segments
|
Review updated until commit d558cca Description
|
| Relevant files | |||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|
| Bug fix | 1 files
| ||||||||||
| Enhancement | |||||||||||
| Tests | 5 files
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Complex contiguity logic
resetAllocationDomainAndContiguity function contains complex logic for computing contiguity from tensor strides. This logic handles reductions, broadcasts, and multi-dimensional tensors. The reviewer should verify that this algorithm correctly handles all edge cases and matches PyTorch's contiguity semantics. |
Test failures
-
(Medium, 12)
NVFuser evaluator assertion failures in stream/multidevice matmul & linear testsTest Name A100 A100 (dist.) GB200 GB200 (dist.) H100 H100 (dist.) Source tests.python.direct.test_stream.test_two_matmuls_inlinable[nvfuser_direct_test=eager] ❌ ❌ ❌ tests.python.direct.test_stream.test_two_matmuls_inlinable[nvfuser_direct_test=lru_cache] ❌ ❌ ❌ tests.python.multidevice.test_overlap.test_row_parallel_linear_forward ❌ ❌ ❌ ❌ ❌ ❌ -
(Medium, 6)
Profiler event count mismatch in test_stream.test_matmul (nvFuser stream scheduling)Test Name A100 GB200 H100 Source tests.python.direct.test_stream.test_matmul[nvfuser_direct_test=eager] ❌ ❌ ❌ tests.python.direct.test_stream.test_matmul[nvfuser_direct_test=lru_cache] ❌ ❌ ❌ -
(Medium, 1)
CUDA out-of-memory in nvFuser TmaPointwiseTestF.SplitGridDim2D (runner: H100)Test Name H100 Source TmaPointwiseTestF.SplitGridDim2D ❌ Link
In preparing for the fix of #4888, I am working on #5082, which requires the use of `Expr::evaluate` on meta tensors to infer shape and strides of fusion segments selected to be scheduled by the `ExprEval` scheduler. As a consequence of this change, all the `Expr::evaluate` functions should support meta device, and the returned output tensor's shape and stride must match that on device type CUDA. According to https://docs.pytorch.org/docs/stable/meta.html > In some cases, not all device types (e.g., CPU and CUDA) have exactly the same output metadata for an operation; we typically prefer representing the CUDA behavior faithfully in this situation. It is generally safe to assume that we can use device type meta to infer shapes and strides of device type CUDA. But unfortunately, not all operators implement meta device, and `at::_scaled_dot_product_flash_attention` is such an example. In this PR, I am adding my own `at::_scaled_dot_product_flash_attention` implementation on meta devices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
13 files reviewed, 1 comment
| if (is_expr_eval) { | ||
| // For expr evaluated fusion, the striding rules follow that of ATen. | ||
| ExpressionEvaluator eval_fusion; | ||
| for (auto [i, v] : enumerate(group_to_run->inputs())) { | ||
| auto tensor_pv = args_manager.checkTensorMap(v); | ||
| if (tensor_pv.is<at::Tensor>()) { | ||
| const auto t = tensor_pv.as<at::Tensor>(); | ||
| if (t.defined()) { | ||
| const auto meta_t = at::empty_strided( | ||
| t.sizes(), | ||
| t.strides(), | ||
| at::TensorOptions().device(at::kMeta).dtype(t.dtype())); | ||
| eval_fusion.bind(fusion_to_run->inputs()[i], meta_t); | ||
| } else { | ||
| eval_fusion.bind(fusion_to_run->inputs()[i], t); | ||
| } | ||
| } else { | ||
| eval_fusion.bind(fusion_to_run->inputs()[i], tensor_pv); | ||
| } | ||
| } | ||
| for (auto v : fusion_to_run->outputs()) { | ||
| auto result = eval_fusion.evaluate(v); | ||
| group_runtime_outputs.push(result); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: duplicated meta tensor evaluation logic between prepareInputs (lines 370-393) and getMaybeHeuristicsFor (lines 640-663). extract to helper like evaluateExprEvalSegmentOutputs(fusion_to_run, group_inputs, args_manager) to reduce duplication
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
13 files reviewed, no comments
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, no comments
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, no comments
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, no comments
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 1 comment
| } else { | ||
| // TODO: inferOutputSizes doesn't seem to strictly require a Fusion for | ||
| // each segment. Consider using the complete fusion instead. | ||
| auto fusion_to_run = segmented_fusion_->makeFusion(group_to_run).second; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: variable shadowing: fusion_to_run declared on line 365 as Fusion*, then redeclared here as std::unique_ptr<Fusion>. while technically valid (different scopes), this is confusing
| auto fusion_to_run = segmented_fusion_->makeFusion(group_to_run).second; | |
| auto fusion_unique_ptr = segmented_fusion_->makeFusion(group_to_run).second; | |
| group_runtime_outputs = | |
| inferOutputSizes(fusion_unique_ptr.get(), group_runtime_inputs); |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, 2 comments
| if (is_expr_eval) { | ||
| // For expr evaluated fusion, the striding rules follow that of ATen. | ||
| ExpressionEvaluator eval_fusion; | ||
| for (auto [i, v] : enumerate(group_to_run->inputs())) { | ||
| auto tensor_pv = args_manager.checkTensorMap(v); | ||
| if (tensor_pv.is<at::Tensor>()) { | ||
| const auto t = tensor_pv.as<at::Tensor>(); | ||
| if (t.defined()) { | ||
| const auto meta_t = at::empty_strided( | ||
| t.sizes(), | ||
| t.strides(), | ||
| at::TensorOptions().device(at::kMeta).dtype(t.dtype())); | ||
| eval_fusion.bind(fusion_to_run->inputs()[i], meta_t); | ||
| } else { | ||
| eval_fusion.bind(fusion_to_run->inputs()[i], t); | ||
| } | ||
| } else { | ||
| eval_fusion.bind(fusion_to_run->inputs()[i], tensor_pv); | ||
| } | ||
| } | ||
| for (auto v : fusion_to_run->outputs()) { | ||
| auto result = eval_fusion.evaluate(v); | ||
| group_runtime_outputs.push(result); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: duplicated between prepareInputs and getMaybeHeuristicsFor. extract to helper like evaluateExprEvalSegmentOnMeta(fusion, inputs, args_manager) to eliminate duplication
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| auto fusion_to_run = segmented_fusion_->makeFusion(group_to_run).second; | ||
| auto group_runtime_outputs = | ||
| inferOutputSizes(fusion_to_run.get(), group_runtime_inputs); | ||
| Fusion* fusion_to_run = group_to_run->getFusion(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: variable fusion_to_run shadows line 397 where it's redeclared as std::unique_ptr<Fusion>. rename one (e.g., this to fusion_ptr or line 397 to fusion_to_run_unique) to avoid confusion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, 1 comment
| } else { | ||
| // TODO: inferOutputSizes doesn't seem to strictly require a Fusion for | ||
| // each segment. Consider using the complete fusion instead. | ||
| auto fusion_to_run = segmented_fusion_->makeFusion(group_to_run).second; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: fusion_to_run redeclared here as std::unique_ptr<Fusion> shadows the Fusion* on line 365. rename to avoid confusion (e.g., fusion_to_run_ptr)
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
!test |
|
!test |
) Fixes #4888 Stacked on #5766 I used to work on #5082 for the fix, but I hit too many blockers, because this PR could interact with many new assumptions/hacks/unfinalized designs on things like allocation domain, stream-sharded tensor, multidevice, etc., and we keep having new things committed to the main branch that break #5082. This situation delayed the PR for a very long time. So I recreated this PR that is more friendly to incremental development. Today, in the main branch, in `FusionExecutorCache`, we were assuming fusion segments always generate contiguous tensors. This is not true for `ExpressionEvaluator` segments. For example, ATen's slice op returns non-contiguous tensors. It is worth mentioning that, because segmentation and scheduler selection depend on inputs, the contiguity of intermediate results also depends on inputs. This PR adds `FusionKernelRuntime::inferOutputMetaTensor(`, which replaces `inferOutputShapeAndContiguousStrides` to infer the output shape and stride of each segment. Both `FusionKernelRuntime::inferOutputMetaTensor(` and `inferOutputShapeAndContiguousStrides` store their result as a tensor on the meta device. The difference is, `FusionKernelRuntime::inferOutputMetaTensor(` will actually run the segment on device type meta if this segment is scheduled to run by `ExpressionEvaluator`, while `inferOutputShapeAndContiguousStrides` just assumes the output to be contiguous. Because `FusionKernelRuntime::inferOutputMetaTensor(` will run the segment on device type meta, related op's `MyOp::evaluate` should work for device type meta. There is good and bad news for this design. The good news is, most `MyOp::evaluate` just calls `at::` ops, which usually already support meta device, and [PyTorch designed meta device to try to make its behavior on par with CUDA](https://docs.pytorch.org/docs/stable/meta.html). The bad news is, because many op's meta device implementation is on Python, running `at::op` on these kinds of ops would hang due to the inability to grab Python's GIL (Thanks @naoyam for help debugging!). If this is the case, the corresponding `MyOp::evaluate` must manually compute the shape and stride and use `at::empty_strided(device=meta)` to create the result. Besides `FusionKernelRuntime::inferOutputMetaTensor(`, this PR also adds `FusionKernelRuntime::updateContiguityOfSegmentOutputs(`. Which updates the segment output `TensorView`s' contiguity based on the inferred shape and stride. This PR adds an enable option "infer-contiguity" to incrementally enable this feature. When "infer-contiguity" is disabled, `FusionKernelRuntime::inferOutputMetaTensor(` will fallback to the behavior of `inferOutputShapeAndContiguousStrides`, and `FusionKernelRuntime::updateContiguityOfSegmentOutputs(` will be no-op. The plan is, we merge this PR and not set "infer-contiguity" for the currently failed tests. I will write new PRs fixing the failed tests one by one. --------- Co-authored-by: Jingyue Wu <[email protected]>
|
Superseded by #5772 |
Fixes #4888
In
FusionExecutorCache, we were assuming fusion segments always generate contiguous tensors. This is not true forExpressionEvaluatorsegments. For such segments, we need to actually execute the segment to know the shape and stride we will get. This execution is on device type meta, instead of using a real CUDA tensor. According to https://docs.pytorch.org/docs/stable/meta.htmlThis execution should lead to the same shape and stride as we will see when using real tensors.