-
Notifications
You must be signed in to change notification settings - Fork 54
Description
Anything you want to discuss about ucm.
Problem Description
The AscendDevice::H2DBatchSync() and AscendDevice::D2HBatchSync() methods execute batch memory transfers sequentially. This results in much worse performance compared to CudaDevice, which uses parallel CUDA kernels for batch operations.
Root Cause
The current AscendDevice::H2DBatchSync() sequentially call AscendDevice::H2DSync function, which is blocking operation.
Current NPU Implementation (Sequential)
Status H2DSync(std::byte* dst, const std::byte* src, const size_t count) override
{
return ASCEND_API(aclrtMemcpy, dst, count, src, count, ACL_MEMCPY_HOST_TO_DEVICE);
}
Status H2DBatchSync(std::byte* dArr[], const std::byte* hArr[], const size_t number, const size_t count) override
{
for (size_t i = 0; i < number; i++) {
auto status = this->H2DSync(dArr[i], hArr[i], count);
if (status.Failure()) { return status; }
}
return Status::OK();
}Comparision: Current CUDA Implementation (Parallel)
__global__ void H2DKernel(uintptr_t* dst, const volatile uintptr_t* src, size_t num, size_t size)
{
auto length = num * size;
auto offset = (blockIdx.x * blockDim.x + threadIdx.x) * CUDA_TRANS_UNIT_SIZE;
while (offset + CUDA_TRANS_UNIT_SIZE <= length) {
auto idx = offset / size;
auto off = offset % size;
H2DUnit(((uint8_t*)dst[idx]) + off, ((const uint8_t*)src[idx]) + off);
offset += CUDA_TRANS_THREAD_NUMBER * CUDA_TRANS_UNIT_SIZE;
}
}
inline __host__ void H2DBatch(uintptr_t* dst, const volatile uintptr_t* src, size_t num, size_t size, cudaStream_t stream)
{
H2DKernel<<<CUDA_TRANS_BLOCK_NUMBER, CUDA_TRANS_BLOCK_SIZE, 0, stream>>>(dst, src, num, size);
}
Status H2DBatchSync(std::byte* dArr[], const std::byte* hArr[], const size_t number, const size_t count) override
{
auto src = MakeDeviceArray((const void**)hArr, number);
if (!src) { return Status::OutOfMemory(); }
auto dst = MakeDeviceArray((const void**)dArr, number);
if (!dst) {
ReleaseDeviceArray(src);
return Status::OutOfMemory();
}
H2DBatch((uintptr_t*)dst, (const volatile uintptr_t*)src, number, count, this->stream_);
auto status = this->Synchronized();
ReleaseDeviceArray(src);
ReleaseDeviceArray(dst);
return status;
}Impact
Performance Comparison
| Device | Batch Method | Execution Model | Performance |
|---|---|---|---|
CudaDevice |
Custom CUDA kernels | Parallel (8192 threads) | ✅ Optimal |
AscendDevice (current) |
sync | Sequential | ❌ Suboptimal |
AscendDevice (simple but not work) |
Single stream async | Sequential | ❌ Suboptimal |
AscendDevice (idea but not feasible) |
Custom NPU Kernels | Parallel | ❌ Infeasible for NPU now |
AscendDevice (fixed) |
Multi-stream async | Parallel (N streams) | ✅ Improved |
Example Scenario
For a batch of 32 transfers:
- Current: 32 sequential operations = ~32× single transfer time
- Fixed: 32 operations across N=8 streams = ~4× single transfer time (theoretical)
Solution
Simple Async Implementation Does not Work
A simple approach is to call AscendDevice::H2DAsync() in AscendDevice::H2DBatchSync(), as shown below
Status H2DBatchSync(std::byte* dArr[], const std::byte* hArr[], const size_t number,
const size_t count) override
{
for (size_t i = 0; i < number; i++) {
auto status = this->H2DAsync(dArr[i], hArr[i], count); // All use same stream
if (status.Failure()) { return status; }
}
return this->Synchronized();
}However, since all H2DAsync() belongs to the same stream this->stream_, all batch transfer operations are queued to a single stream (this->stream_) and must be executed sequentally. Even though aclrtMemcpyAsync() is non-blocking, operations within the same stream still execute sequentially. This means:
- Transfer j must wait for Transfer j-1 to complete
- No parallelism is achieved despite using async API
- Performance degrades linearly with batch size
WARNING: This problem also occurs when we call
AscendDevice::H2DAsync()multiple times. Even though such callings are non-blocking, all such memory transfers still run sequencelly because they belong to the same stream. We should callH2DBatchSync()in the application side for bothAscendDeviceandCUDADevice.
Proposed Implementation
Instead, we should implement a multi-stream approach using a pool of streams for batch operations:
- Create a pool of streams (e.g., N=8 streams) during
Setup() - Distribute batch transfers across streams using round-robin
- Synchronize all streams at completion
The solution is exemplified as follows
class AscendDevice : public IBufferedDevice {
private:
std::vector<void*> batchStreams_; // Pool of streams for parallel batch operations
public:
AscendDevice(...) {
batchStreams_.resize(8, nullptr); // Initialize stream pool
}
Status Setup() override {
// ... existing setup ...
// Create batch streams for parallel operations
for (auto& bs : this->batchStreams_) {
if ((status = ASCEND_API(aclrtCreateStream, &bs)).Failure()) {
return status;
}
}
// ... rest of setup ...
}
Status H2DBatchSync(std::byte* dArr[], const std::byte* hArr[],
const size_t number, const size_t count) override
{
// Distribute transfers across multiple streams for true parallelism
const size_t numStreams = this->batchStreams_.size();
for (size_t i = 0; i < number; i++) {
void* stream = this->batchStreams_[i % numStreams]; // Round-robin
auto status = ASCEND_API(aclrtMemcpyAsync, dArr[i], count, hArr[i], count,
ACL_MEMCPY_HOST_TO_DEVICE, stream);
if (status.Failure()) { return status; }
}
// Synchronize all batch streams
for (auto bs : this->batchStreams_) {
auto status = ASCEND_API(aclrtSynchronizeStream, bs);
if (status.Failure()) { return status; }
}
return Status::OK();
}
~AscendDevice() override {
// ... existing cleanup ...
// Destroy batch streams
for (auto& bs : this->batchStreams_) {
if (bs) {
(void)aclrtDestroyStream(bs);
bs = nullptr;
}
}
}
};Why This Works
- Multiple streams: Each stream can execute operations independently
- Round-robin distribution: Load balances transfers across streams
- True parallelism: Transfers on different streams execute concurrently
- Single synchronization: Wait for all streams to complete
Related Code
- File:
ucm/store/device/ascend/ascend_device.cc - Methods:
H2DBatchSync(),D2HBatchSync() - Comparison:
ucm/store/device/cuda/cuda_device.cu(uses parallel kernels)