Skip to content

Moving to CUDA after loading is extremely slow #12599

@francois-rozet

Description

@francois-rozet

Describe the bug

Moving a freshly loaded module to CUDA is extremely slow. Much slower than it should. This seems to be related to the way diffusers loads weights because cloning the weights before moving them is much faster.

Reproduction

import time
import torch

from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    variant="fp16",
    torch_dtype=torch.float16,
    low_cpu_memory_usage=False,
)

tic = time.time()

pipe.unet.to("cuda")

tac = time.time()

print(tac - tic, flush=True)  # 20s

If I manually clone the parameters and buffers first, the transfer is much faster.

import time
import torch

from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    variant="fp16",
    torch_dtype=torch.float16,
    low_cpu_memory_usage=False,
)

tic = time.time()

for p in (*pipe.unet.parameters(), *pipe.unet.buffers()):
    p.data = p.data.clone()

pipe.unet.to("cuda")

tac = time.time()

print(tac - tic, flush=True)  # 1s

Logs

System Info

- 🤗 Diffusers version: 0.35.2
- Platform: Linux-4.18.0-193.6.3.el8_2.x86_64-x86_64-with-glibc2.28
- Running on Google Colab?: No
- Python version: 3.11.11
- PyTorch version (GPU?): 2.6.0+cu118 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.36.0
- Transformers version: 4.57.1
- Accelerate version: 1.11.0
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.6.2
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 2080 Ti, 11264 MiB
NVIDIA GeForce RTX 2080 Ti, 11264 MiB
NVIDIA GeForce RTX 2080 Ti, 11264 MiB
NVIDIA GeForce RTX 2080 Ti, 11264 MiB
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no

Who can help?

@sayakpaul @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestperformanceAnything related to performance improvements, profiling and benchmarking

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions