Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 89 additions & 2 deletions xtuner/v1/ray/config/worker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import os
import socket
from pathlib import Path
from typing import List, Literal, Optional, Union

from cyclopts import Group, Parameter
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from typing_extensions import Annotated


Expand Down Expand Up @@ -67,6 +70,8 @@ class RolloutConfig(BaseModel):
)
"""

model_config = ConfigDict(extra="forbid")

# base config
env: Annotated[
str,
Expand Down Expand Up @@ -167,15 +172,97 @@ class RolloutConfig(BaseModel):
help="Context length for the rollout worker.",
),
] = None
rollout_engine_launch_args: Annotated[
Optional[BaseModel],
Parameter(
group=infer_group,
help="Path to the rollout backend configuration file.",
),
] = None
extra_rollout_config: Annotated[
dict,
Parameter(
group=infer_group,
help='Extra configuration for different rollout worker. vllm parameters will start with prefix "vllm", etc.',
),
] = {"lmdeploy_log_level": "CRITICAL", "lmdeploy_uvicorn_log_level": "CRITICAL"}
] = {"lmdeploy_log_level": "CRITICAL", "lmdeploy_uvicorn_log_level": "CRITICAL", "lmdeploy_backend": "pytorch"}
worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir"

def __init__(self, **kwargs):
if "model_name" not in kwargs:
model_name_from_config = None
model_path = Path(kwargs["model_path"])
config_json_path = model_path / "config.json"
try:
with open(config_json_path, encoding="utf-8") as f:
config_data = json.load(f)
model_name_from_config = config_data.get("model_type")
except (json.JSONDecodeError, OSError):
pass

if model_name_from_config:
kwargs["model_name"] = model_name_from_config
else:
kwargs["model_name"] = model_path.name

if "tokenizer_path" not in kwargs:
kwargs["tokenizer_path"] = str(kwargs["model_path"])

port = kwargs.get("api_port", 8000)
while True:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(("localhost", port))
break
except OSError:
port += 1
kwargs["api_port"] = port

if "device" in kwargs and kwargs["device"] == "NPU":
kwargs["gpus_per_node"] = 16

rollout_backend = ""
if os.environ.get("XTUNER_USE_SGLANG", "0") == "1":
rollout_backend = "sglang"
elif os.environ.get("XTUNER_USE_VLLM", "0") == "1":
rollout_backend = "vllm"
elif os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "1":
rollout_backend = "lmdeploy"

assert rollout_backend in ["sglang", "vllm", "lmdeploy"], (
f"Unsupported rollout backend: {rollout_backend}. Please set XTUNER_USE_SGLANG, XTUNER_USE_VLLM, or XTUNER_USE_LMDEPLOY to 1."
)
if rollout_backend == "sglang":
kwargs["launch_server_method"] = "multiprocessing"
kwargs["rollout_cross_node_comm"] = False
if "rollout_engine_launch_args" not in kwargs or kwargs["rollout_engine_launch_args"] is None:
from xtuner.v1.ray.rollout.config.sglang_launch_config import SGLangDefaultServerArgs

kwargs["rollout_engine_launch_args"] = SGLangDefaultServerArgs()
elif rollout_backend == "lmdeploy":
kwargs["launch_server_method"] = "ray"
kwargs["rollout_cross_node_comm"] = True
if "rollout_engine_launch_args" not in kwargs or kwargs["rollout_engine_launch_args"] is None:
if (
"lmdeploy_backend" in kwargs.get("extra_rollout_config", {})
and kwargs["extra_rollout_config"]["lmdeploy_backend"] == "turbomind"
):
from xtuner.v1.ray.rollout.config.lmdeploy_launch_config import (
LMDeployDefaultTurbomindEngineConfig,
)

kwargs["rollout_engine_launch_args"] = LMDeployDefaultTurbomindEngineConfig()
else:
# default to pytorch backend
from xtuner.v1.ray.rollout.config.lmdeploy_launch_config import LMDeployDefaultPytorchEngineConfig

kwargs["rollout_engine_launch_args"] = LMDeployDefaultPytorchEngineConfig()
else:
kwargs["launch_server_method"] = "ray"
kwargs["rollout_cross_node_comm"] = True
super().__init__(**kwargs)
self.worker_log_dir.mkdir(parents=True, exist_ok=True)


if __name__ == "__main__":
from cyclopts import App, Group, Parameter
Expand Down
129 changes: 129 additions & 0 deletions xtuner/v1/ray/rollout/config/lmdeploy_launch_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import inspect
from typing import Any, Dict, List, Literal, Optional

from lmdeploy import PytorchEngineConfig, TurbomindEngineConfig
from pydantic import BaseModel, ConfigDict

from xtuner.v1.utils import get_logger


logger = get_logger()


class LMDeployDefaultPytorchEngineConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

dtype: str = "auto"
tp: int = 1
dp: int = 1
dp_rank: int = 0
ep: int = 1
session_len: int | None = None
max_batch_size: int | None = None
cache_max_entry_count: float = 0.8
prefill_interval: int = 16
block_size: int = 64
num_cpu_blocks: int = 0
num_gpu_blocks: int = 0
adapters: Dict[str, str] | None = None
max_prefill_token_num: int = 4096
thread_safe: bool = False
enable_prefix_caching: bool = False
device_type: str = "cuda"
eager_mode: bool = False
custom_module_map: Dict[str, str] | None = None
download_dir: str | None = None
revision: str | None = None
quant_policy: Literal[0, 4, 8] = 0
distributed_executor_backend: str | None = None
empty_init: bool = False
enable_microbatch: bool = False
enable_eplb: bool = False
enable_mp_engine: bool = False
mp_engine_backend: str = "mp"
model_format: str | None = None
enable_metrics: bool = False
hf_overrides: Optional[Dict[str, Any]] | None = None
disable_vision_encoder: bool = False
logprobs_mode: str | None = None

role: str = "Hybrid"
migration_backend: str = "DLSlime"

def to_lmdeploy_engine_config(self) -> PytorchEngineConfig:
server_args_params = set(inspect.signature(PytorchEngineConfig).parameters.keys())
default_server_args_fields = set(self.model_fields.keys())
missing_params = server_args_params - default_server_args_fields
if missing_params:
logger.warning("Parameters in SGLang ServerArgs but not initialized in Xtuner DefaultServerArgs:")
for param in sorted(missing_params):
logger.info(f"- {param}")

default_args_dict = self.model_dump()
filtered_args = {key: value for key, value in default_args_dict.items() if key in server_args_params}
from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend

if filtered_args.get("role") == "Hybrid":
filtered_args["role"] = EngineRole.Hybrid
elif filtered_args.get("role") == "Prefill":
filtered_args["role"] = EngineRole.Prefill
elif filtered_args.get("role") == "Decode":
filtered_args["role"] = EngineRole.Decode
else:
logger.warning(f"Unknown role {filtered_args.get('role')}, defaulting to Hybrid")
filtered_args["role"] = EngineRole.Hybrid
if filtered_args.get("migration_backend") == "DLSlime":
filtered_args["migration_backend"] = MigrationBackend.DLSlime
else:
logger.warning(
f"Unknown migration_backend {filtered_args.get('migration_backend')}, defaulting to DLSlime"
)
filtered_args["migration_backend"] = MigrationBackend.DLSlime
return PytorchEngineConfig(**filtered_args)


class LMDeployDefaultTurbomindEngineConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

dtype: str = "auto"
model_format: Optional[str] = None
tp: int = 1
dp: int = 1
device_num: int | None = None
attn_tp_size: int | None = None
attn_dp_size: int | None = None
mlp_tp_size: int | None = None
mlp_dp_size: int | None = None
outer_dp_size: int | None = None
session_len: Optional[int] | None = None
max_batch_size: int | None = None
cache_max_entry_count: float = 0.8
cache_chunk_size: int = -1
cache_block_seq_len: int = 64
enable_prefix_caching: bool = False
quant_policy: int = 0
rope_scaling_factor: float = 0.0
use_logn_attn: bool = False
download_dir: Optional[str] | None = None
revision: Optional[str] | None = None
max_prefill_token_num: int = 8192
num_tokens_per_iter: int = 0
max_prefill_iters: int = 1
devices: Optional[List[int]] | None = None
empty_init: bool = False
communicator: str = "nccl"
hf_overrides: Optional[Dict[str, Any]] | None = None
enable_metrics: bool = False

def to_lmdeploy_engine_config(self) -> TurbomindEngineConfig:
server_args_params = set(inspect.signature(TurbomindEngineConfig).parameters.keys())
default_server_args_fields = set(self.model_fields.keys())
missing_params = server_args_params - default_server_args_fields
if missing_params:
logger.info("Parameters in SGLang ServerArgs but not initialized in Xtuner DefaultServerArgs:")
for param in sorted(missing_params):
logger.info(f"- {param}")

default_args_dict = self.model_dump()
filtered_args = {key: value for key, value in default_args_dict.items() if key in server_args_params}
return TurbomindEngineConfig(**filtered_args)
Loading
Loading