Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions deepspeed/runtime/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# DeepSpeed Team

import torch
import contextlib
import functools
from deepspeed.utils.torch import required_torch_version
from deepspeed.accelerator import get_accelerator

try:
from torch.compiler import is_compiling as torch_is_compiling
Expand All @@ -16,6 +18,11 @@
# Torch does not have compiler support
torch_is_compiling = lambda: False

if required_torch_version(min_version="2.6.0a"):
from torch._dynamo.compiled_autograd import _enable as compiled_autograd_enable
else:
from torch._dynamo.compiled_autograd import enable as compiled_autograd_enable


def is_compile_supported():
return required_torch_version(min_version=2.1)
Expand Down Expand Up @@ -71,3 +78,15 @@ def wrapper(*args, **kwargs):

def is_compiling():
return torch_is_compiling()


@contextlib.contextmanager
def compiled_autograd(enabled, kwargs):
try:
if enabled:
with compiled_autograd_enable(torch.compile(backend=get_accelerator().get_compile_backend(), **kwargs)):
yield
else:
yield
finally:
pass
23 changes: 19 additions & 4 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@

from .pipe.module import PipelineModule
from .utils import get_ma_status
from .compiler import is_compile_supported
from .compiler import is_compile_supported, compiled_autograd
from ..ops.adam import FusedAdam
from ..moe.sharded_moe import TopKGate, MOELayer
from ..moe.layer import MoE
Expand Down Expand Up @@ -420,6 +420,9 @@ def __init__(self,
self.register_compile_pass(selective_gather.NAME, selective_gather.selective_gather)
self.register_compile_pass(offload_adam_states.NAME, offload_adam_states.move_opt_states)

self._is_compiled_autograd_enabled = False
self._compile_kwargs = {}

def _optimized_linear_offload_setup(self):
self.optimized_linear_base_weight_sharding = False
self.optimized_linear_lora_enabled = False
Expand Down Expand Up @@ -2359,8 +2362,9 @@ def backward(self, loss, retain_graph=False, scale_wrt_gas=True):

self._start_timers(self.engine_timers.backward_timers)
loss = self._backward_prologue(loss, scale_wrt_gas)
self._do_optimizer_backward(loss, retain_graph)
self._backward_epilogue()
with compiled_autograd(self._is_compiled_autograd_enabled, self._compile_kwargs):
self._do_optimizer_backward(loss, retain_graph)
self._backward_epilogue()
self._stop_timers(self.engine_timers.backward_timers)

return loss
Expand Down Expand Up @@ -4078,7 +4082,11 @@ def empty_partition_cache(self):
gc.collect()
get_accelerator().empty_cache()

def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}, schedule=None) -> None:
def compile(self,
backend=get_accelerator().get_compile_backend(),
compile_kwargs={},
schedule=None,
compiled_autograd_enabled=False) -> None:
"""Compile the module using the specified backend and kwargs.
If a compiler_fn is set, it will be used instead of torch.compile().
"""
Expand Down Expand Up @@ -4144,6 +4152,13 @@ def passes_name_to_fn(passes):
raise

self._is_compiled = True
self._compile_kwargs = compile_kwargs
if compiled_autograd_enabled:
if not self._deepcompile_active:
self._is_compiled_autograd_enabled = compiled_autograd_enabled
else:
logger.warning("Compiled autograd is not compatible with DeepCompile, disabling compiled autograd.")
self._is_compiled_autograd_enabled = False

def _set_deepcompile_active(self, active: bool) -> None:
"""Toggle DeepCompile runtime state and manage forward hooks accordingly."""
Expand Down