Skip to content

Conversation

@wujingyue
Copy link
Collaborator

No description provided.

@wujingyue wujingyue requested a review from Priya2698 January 13, 2026 06:32
@wujingyue
Copy link
Collaborator Author

!test

@github-actions
Copy link

Description

  • Move torch.cuda.set_device calls from individual test functions to the fixture setup

  • Centralize device configuration in conftest.py fixture __init__ method

  • Remove redundant device setup calls from 5 test files

  • Improve code organization and reduce duplication across multidevice tests

Changes walkthrough

Relevant files
Enhancement
conftest.py
Add device setup to fixture                                                           

tests/python/multidevice/conftest.py

  • Add torch.cuda.set_device(self._communicator.local_rank()) call in
    fixture __init__ method
  • Centralize device setup logic for all multidevice tests
  • +2/-0     
    test_matmul.py
    Remove redundant device setup calls                                           

    tests/python/multidevice/test_matmul.py

  • Remove torch.cuda.set_device(multidevice_test.local_rank) calls from 7
    test functions
  • Clean up redundant device setup code
  • +0/-17   
    test_multidevice.py
    Remove redundant device setup calls                                           

    tests/python/multidevice/test_multidevice.py

  • Remove torch.cuda.set_device(multidevice_test.local_rank) calls from 3
    test functions
  • Clean up redundant device setup code
  • +0/-6     
    test_overlap.py
    Remove redundant device setup calls                                           

    tests/python/multidevice/test_overlap.py

  • Remove torch.cuda.set_device(multidevice_test.local_rank) calls from 2
    test functions
  • Clean up redundant device setup code
  • +0/-4     
    test_transformer.py
    Remove redundant device setup calls                                           

    tests/python/multidevice/test_transformer.py

  • Remove torch.cuda.set_device(multidevice_test.local_rank) calls from 2
    test functions
  • Clean up redundant device setup code
  • +0/-6     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ No major issues detected

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 13, 2026

    Greptile Overview

    Greptile Summary

    This PR refactors CUDA device initialization in multidevice tests by moving torch.cuda.set_device() from individual test functions to the MultideviceTest fixture's __init__ method.

    Changes Made

    • conftest.py: Added torch.cuda.set_device(self._communicator.local_rank()) in MultideviceTest.__init__, right after communicator initialization and before torch.manual_seed(0)
    • test_matmul.py: Removed 8 redundant torch.cuda.set_device() calls from test functions
    • test_multidevice.py: Removed 3 redundant calls
    • test_overlap.py: Removed 2 redundant calls
    • test_transformer.py: Removed 2 redundant calls

    Total: 15 redundant calls eliminated across all test files.

    Benefits

    1. Code Deduplication: Eliminates repetitive boilerplate code across test functions
    2. Consistency: Ensures all tests using the multidevice_test fixture have the device set correctly before execution
    3. Maintainability: Single source of truth for device initialization logic
    4. Correct Ordering: Device is set immediately after communicator initialization, ensuring all subsequent CUDA operations use the correct device

    Implementation Analysis

    The placement of torch.cuda.set_device() in the fixture is appropriate:

    • Called after Communicator.instance() (which provides the local_rank)
    • Called before torch.manual_seed(0) (which may use CUDA RNG)
    • Ensures proper device context for all test code

    The fixture's shard_tensor and shard_tensor_1d methods call .cuda(self.local_rank), which will now correctly use the already-set device.

    Confidence Score: 5/5

    • This PR is safe to merge with no concerns
    • This is a straightforward refactoring that consolidates device initialization into the test fixture. All removed calls were redundant since the fixture now handles device setting. The change improves code quality without altering test behavior. Tests that create CUDA tensors with device="cuda" will correctly use the device set by the fixture. No logic errors, edge cases, or breaking changes identified.
    • No files require special attention

    Important Files Changed

    File Analysis

    Filename Score Overview
    tests/python/multidevice/conftest.py 5/5 Added torch.cuda.set_device() call in MultideviceTest.init to centralize device setting for all tests using this fixture
    tests/python/multidevice/test_matmul.py 5/5 Removed 8 redundant torch.cuda.set_device() calls from test functions, now handled by fixture
    tests/python/multidevice/test_multidevice.py 5/5 Removed 3 redundant torch.cuda.set_device() calls from test functions, now handled by fixture
    tests/python/multidevice/test_overlap.py 5/5 Removed 2 redundant torch.cuda.set_device() calls from test functions, now handled by fixture
    tests/python/multidevice/test_transformer.py 5/5 Removed 2 redundant torch.cuda.set_device() calls from test functions, now handled by fixture

    Sequence Diagram

    sequenceDiagram
        participant PyTest
        participant multidevice_test as multidevice_test Fixture
        participant MultideviceTest
        participant Communicator
        participant CUDA
        participant TestFunction
    
        PyTest->>multidevice_test: Request fixture
        multidevice_test->>MultideviceTest: __init__()
        MultideviceTest->>Communicator: instance()
        Communicator-->>MultideviceTest: communicator
        MultideviceTest->>Communicator: local_rank()
        Communicator-->>MultideviceTest: rank_id
        MultideviceTest->>CUDA: set_device(rank_id)
        Note over CUDA: Device set for current process
        MultideviceTest->>CUDA: manual_seed(0)
        MultideviceTest-->>multidevice_test: instance
        multidevice_test-->>PyTest: fixture ready
        PyTest->>TestFunction: Execute test(multidevice_test)
        Note over TestFunction: All CUDA operations use<br/>correct device automatically
        TestFunction-->>PyTest: Test complete
        PyTest->>multidevice_test: Teardown
        multidevice_test->>Communicator: barrier()
        Note over multidevice_test: Fixture cleaned up
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    No files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @wujingyue wujingyue merged commit 5b41331 into main Jan 13, 2026
    61 of 62 checks passed
    @wujingyue wujingyue deleted the wjy/device branch January 13, 2026 21:22
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants