From d0553fdd945cc09e004855494600943d58d505cd Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Fri, 24 Apr 2026 17:03:13 -0700 Subject: [PATCH 01/26] feat: support SD1.5/SDXL/Dreamshaper alongside SD-Turbo Five fixes that together let LCM-LoRA'd SD1.5, SDXL, SDXL-Turbo, and Dreamshaper variants produce sharp deterministic txt2img output: 1. Auto-fuse the matching LCM LoRA (lcm-lora-sdv1-5 / lcm-lora-sdxl) for non-Turbo bases. The schema declared use_lcm_lora but nothing wired it up, so non-Turbo models were running LCMScheduler with un-distilled UNet weights and outputting yellow/black blobs. 2. Swap SDXL's stock VAE for madebyollin/sdxl-vae-fp16-fix on load. The stock VAE decodes NaN in fp16, so every SDXL frame was pure black. 3. SDXL conditioning (add_text_embeds, add_time_ids) now broadcasts to the current batch size when t_index_list has multiple entries. 4. Per-family default num_inference_steps: 1 for sd-turbo proper, 4 for everything else. Single-step at t=999 only converges for the model distilled for that exact regime; SDXL-Turbo / Dreamshaper-XL-Turbo / non-Turbo + LCM LoRA are blurry at 1 step and sharp at 4. Exposed as the "Inference Steps" UI slider with an "Auto Inference Steps" toggle to defer to per-family suggestion. 5. Two text-mode bugs in __call__: - Image-loopback was implicit ("video missing AND prev_image_result exists"), making each frame feed its previous output back as input and drift to over-saturated abstract patterns. Now opt-in only. - Input latent used unseeded torch.randn each call, so seed=42 still produced a different scene per frame. Now reuses the seeded init_noise[0:1] for stable, deterministic output. Verified across sd-turbo, SD1.5, Dreamshaper-8, SDXL-Turbo, SDXL-Base, and Dreamshaper-XL-v2-Turbo at 512 / 1024. --- src/scope_streamdiffusion/pipeline.py | 133 ++++++++++++++++++++------ src/scope_streamdiffusion/schema.py | 12 ++- 2 files changed, 112 insertions(+), 33 deletions(-) diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index 769c43b..99bdebb 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -80,10 +80,26 @@ def __init__( # Load the base model print(f"Loading model: {model_id}") + self.model_id = model_id + self.sd_turbo = "turbo" in model_id.lower() self.pipe = self._load_model(model_id) print(f"Model loaded: {self.pipe.__class__.__name__}") - # Model components + # Check if SDXL (needed before LCM LoRA selection) + self.sdxl: bool = type(self.pipe) is StableDiffusionXLPipeline + + # SDXL's default VAE overflows in fp16 and decodes NaN. Swap to the + # community fp16-fix VAE so the full-quality decode path works without + # forcing TAESD. + if self.sdxl and self.dtype == torch.float16: + self._install_sdxl_fp16_vae() + + # Non-turbo models need LCM LoRA to denoise correctly with LCMScheduler + # at 1–4 steps. Turbo/Lightning models are already distilled for this. + if not self.sd_turbo: + self._attach_lcm_lora() + + # Model components (grabbed after LoRA fuse so fused weights are live) self.text_encoder = self.pipe.text_encoder self.unet = self.pipe.unet self.vae = self.pipe.vae @@ -91,9 +107,6 @@ def __init__( self._taesd_vae = None self._using_taesd = False - # Check if SDXL - self.sdxl: bool = type(self.pipe) is StableDiffusionXLPipeline - # Setup scheduler self.scheduler: LCMScheduler = LCMScheduler.from_config( self.pipe.scheduler.config @@ -185,6 +198,49 @@ def _load_model(self, model_id: str) -> DiffusionPipeline: print(f"Failed to load model {model_id}: {e}") raise + def _install_sdxl_fp16_vae(self) -> None: + """Swap SDXL's default VAE for madebyollin/sdxl-vae-fp16-fix. + + Stability AI's SDXL VAE overflows on certain inputs in fp16 and decodes + to NaN — even from a perfectly valid UNet prediction. The community + fp16-fix VAE is a drop-in replacement with the same architecture and + quality, retuned to be numerically stable in fp16. + """ + from diffusers import AutoencoderKL + + try: + print("[StreamDiffusion] Installing madebyollin/sdxl-vae-fp16-fix") + new_vae = AutoencoderKL.from_pretrained( + "madebyollin/sdxl-vae-fp16-fix", torch_dtype=self.dtype + ).to(self.device) + self.pipe.vae = new_vae + print("[StreamDiffusion] SDXL fp16-fix VAE installed") + except Exception as e: + print(f"[StreamDiffusion] Failed to install fp16-fix VAE: {e}") + + def _attach_lcm_lora(self) -> None: + """Load and fuse the appropriate LCM LoRA for a non-turbo SD/SDXL base. + + LCMScheduler at 1–4 steps only produces usable output on models that + have been distilled for low-step inference — Turbo, Lightning, or LCM. + For plain SD 1.5 / SDXL bases, we attach the matching LCM LoRA so the + scheduler path works the same as it does for Turbo. + """ + lcm_lora_id = ( + "latent-consistency/lcm-lora-sdxl" + if self.sdxl + else "latent-consistency/lcm-lora-sdv1-5" + ) + print(f"[StreamDiffusion] Loading LCM LoRA: {lcm_lora_id}") + try: + self.pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm") + self.pipe.fuse_lora(lora_scale=1.0, adapter_names=["lcm"]) + self.pipe.unload_lora_weights() + print("[StreamDiffusion] LCM LoRA fused") + except Exception as e: + print(f"[StreamDiffusion] Failed to load LCM LoRA {lcm_lora_id}: {e}") + raise + def _set_taesd(self, enabled: bool) -> None: """Switch between TAESD (fast) and full VAE decoder.""" if enabled == self._using_taesd: @@ -274,6 +330,14 @@ def _prepare_runtime_state( self.latent_height = int(height // self.pipe.vae_scale_factor) self.latent_width = int(width // self.pipe.vae_scale_factor) + # --- Scheduler defaults --- + # `num_inference_steps` is the user-facing sharpness lever: more steps = + # sharper detail (SD-Turbo proper is the exception — it's distilled for + # 1 step). When the caller didn't pin a `t_index_list`, walk every step + # in the schedule so the UNet sees the full LCM timestep range. + if t_index_list is None: + t_index_list = list(range(num_inference_steps)) + # --- Cheap scalar assignments --- self.strength = strength self.guidance_scale = guidance_scale @@ -723,16 +787,15 @@ def _set_timesteps(self, num_inference_steps: int, strength: float): # Calculate alpha/beta values alpha_prod_t_sqrt_list = [] beta_prod_t_sqrt_list = [] + ac = self.scheduler.alphas_cumprod + last_idx = len(ac) - 1 for timestep in self.sub_timesteps: - if timestep >= len(self.scheduler.alphas_cumprod): - print( - f"Warning: timestep {timestep} is greater than the number of timesteps {len(self.scheduler.alphas_cumprod)}" - ) - continue - alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt() - beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt() - alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt) - beta_prod_t_sqrt_list.append(beta_prod_t_sqrt) + # Clamp into range instead of skipping — skipping would make the + # downstream .view(len(t_list), 1, 1, 1) reshape fail when any + # timestep happened to land out of range for this scheduler. + idx = min(int(timestep), last_idx) + alpha_prod_t_sqrt_list.append(ac[idx].sqrt()) + beta_prod_t_sqrt_list.append((1 - ac[idx]).sqrt()) alpha_prod_t_sqrt = ( torch.stack(alpha_prod_t_sqrt_list) @@ -945,10 +1008,14 @@ def _predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: (self.init_noise[0:1], self.stock_noise[:-1]), dim=0 ) if self.sdxl: - added_cond_kwargs = { - "text_embeds": self.add_text_embeds.to(self.device), - "time_ids": self.add_time_ids.to(self.device), - } + batch = x_t_latent.shape[0] + te = self.add_text_embeds.to(self.device) + ti = self.add_time_ids.to(self.device) + if te.shape[0] != batch: + te = te[:1].expand(batch, -1) + if ti.shape[0] != batch: + ti = ti[:1].expand(batch, -1) + added_cond_kwargs = {"text_embeds": te, "time_ids": ti} x_t_latent = x_t_latent.to(self.device) t_list = t_list.to(self.device) @@ -1086,9 +1153,13 @@ def get_param(key, default): prompt_interpolation_method = get_param("prompt_interpolation_method", "linear") guidance_scale = get_param("guidance_scale", 0.0) - # SD Turbo: Use single timestep (t_index_list=[0]) but set schedule length - # This matches your working project setup - num_inference_steps = get_param("num_inference_steps", 3) + num_inference_steps = get_param("num_inference_steps", 4) + use_suggested_steps = get_param("use_suggested_num_inference_steps", True) + if use_suggested_steps: + # Per-family suggestion: SD-Turbo proper is 1-step distilled; every + # other family (SDXL-Turbo, SDXL-Turbo fine-tunes, SD1.5/SDXL + + # LCM LoRA) needs 4 steps to converge to sharp output. + num_inference_steps = 1 if (self.sd_turbo and not self.sdxl) else 4 # For img2img with SD Turbo, need higher strength for visible changes # 0.5-0.7 = moderate, 0.8-0.95 = heavy transformation @@ -1170,10 +1241,12 @@ def get_param(key, default): frame = None - # Process input - if image_loopback or ( - (video is None or len(video) == 0) and self.prev_image_result is not None - ): + # Process input. Note: image_loopback must be opt-in. Falling back to + # `prev_image_result` whenever video is missing turns text mode into a + # recursive feedback loop (each frame uses the previous frame's + # output as input), which drifts to over-saturated/abstract patterns + # within a few frames. + if image_loopback and self.prev_image_result is not None: frame = self.prev_image_result elif video is not None and len(video) > 0: # Convert Scope tensor format to pipeline format @@ -1222,12 +1295,12 @@ def get_param(key, default): input_latent = self._encode_image(input_tensor) else: - # Text-to-image mode - input_latent = torch.randn( - (1, 4, self.latent_height, self.latent_width), - device=self.device, - dtype=self.dtype, - ) + # Text-to-image mode — use the seeded `init_noise` instead of a + # fresh unseeded randn. With a fresh randn per call, every frame + # would generate a different scene; the seeded buffer keeps the + # output stable across frames for the same seed (and lets the + # user reseed deterministically by changing `seed`). + input_latent = self.init_noise[0:1].clone() # Run diffusion x_0_pred_out = self._predict_x0_batch(input_latent) diff --git a/src/scope_streamdiffusion/schema.py b/src/scope_streamdiffusion/schema.py index c937709..12612d1 100644 --- a/src/scope_streamdiffusion/schema.py +++ b/src/scope_streamdiffusion/schema.py @@ -171,11 +171,17 @@ class StreamDiffusionConfig(BasePipelineConfig): ) num_inference_steps: int = Field( - default=2, + default=4, ge=1, le=50, - description="Number of denoising steps", - # json_schema_extra=ui_field_config(order=21, label="Inference Steps"), + description="Number of LCM denoising steps. Main sharpness lever: more steps = sharper detail. SD-Turbo (sd-turbo) is distilled for 1 step; SDXL-Turbo / fine-tunes / non-turbo + LCM LoRA all want 4–8.", + json_schema_extra=ui_field_config(order=21, label="Inference Steps"), + ) + + use_suggested_num_inference_steps: bool = Field( + default=True, + description="When ON, the pipeline picks the inference-step count per model family (1 for SD-Turbo, 4 for everything else) and ignores the slider. Toggle OFF to drive the slider yourself.", + json_schema_extra=ui_field_config(order=22, label="Auto Inference Steps"), ) strength: float = Field( From 633f427345b1a0c4f01001f7881985e602e3baf7 Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Fri, 24 Apr 2026 17:16:12 -0700 Subject: [PATCH 02/26] feat: model_id_or_path as enum dropdown of tested models Limits the UI to the six model IDs verified to produce sharp output via this pipeline's auto-LCM-LoRA + fp16-fix-VAE plumbing. Also adds the field to the UI surface (was previously schema-only). --- src/scope_streamdiffusion/schema.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/scope_streamdiffusion/schema.py b/src/scope_streamdiffusion/schema.py index 12612d1..552d158 100644 --- a/src/scope_streamdiffusion/schema.py +++ b/src/scope_streamdiffusion/schema.py @@ -52,9 +52,17 @@ class StreamDiffusionConfig(BasePipelineConfig): # Model Configuration # ======================================== - model_id_or_path: str = Field( + model_id_or_path: Literal[ + "stabilityai/sd-turbo", + "stabilityai/sdxl-turbo", + "stable-diffusion-v1-5/stable-diffusion-v1-5", + "stabilityai/stable-diffusion-xl-base-1.0", + "Lykon/dreamshaper-8", + "Lykon/dreamshaper-xl-v2-turbo", + ] = Field( default="stabilityai/sd-turbo", - description="Model ID from HuggingFace or local path to model", + description="HuggingFace model ID. Tested set; non-Turbo entries auto-attach the matching LCM LoRA, SDXL entries auto-swap to madebyollin/sdxl-vae-fp16-fix.", + json_schema_extra=ui_field_config(order=8, label="Model"), ) acceleration: Literal["none", "xformers", "tensorrt"] = Field( From 43a1d8f57919b74a5114841f0b9dc11e3e2590ce Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Fri, 24 Apr 2026 17:18:11 -0700 Subject: [PATCH 03/26] fix: accept model_id_or_path kwarg so the UI dropdown actually loads MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The schema field is named ``model_id_or_path`` and Scope's pipeline_manager merges schema defaults into __init__ kwargs by their declared name, but __init__ only read ``model_id`` — so picking a model in the UI was silently ignored and the default reloaded every time. --- src/scope_streamdiffusion/pipeline.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index 99bdebb..fbdb8a8 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -56,7 +56,8 @@ def get_config_class(cls) -> type["BasePipelineConfig"]: def __init__( self, device: Optional[torch.device] = None, - model_id: str = "stabilityai/sd-turbo", + model_id: Optional[str] = None, + model_id_or_path: Optional[str] = None, torch_dtype: torch.dtype = torch.float16, **kwargs, # noqa: ARG002 ) -> None: @@ -64,7 +65,9 @@ def __init__( Args: device: Torch device to use - model_id: Model ID or path to load + model_id / model_id_or_path: Model ID or path to load. The schema + field is ``model_id_or_path``; ``model_id`` is accepted as an + alias so older callers keep working. torch_dtype: Data type for tensors """ self.device = ( @@ -78,6 +81,13 @@ def __init__( self.config = kwargs.get("config") or kwargs.get("pipeline_config") print(f"Init - Config object: {self.config}") + # The schema's field is ``model_id_or_path``. Scope's pipeline_manager + # merges schema defaults into the init kwargs by their declared name, + # so we have to accept that spelling — accepting only ``model_id`` + # silently drops the user's selection and reloads the default every + # time. Resolve in order: explicit model_id > model_id_or_path > default. + model_id = model_id or model_id_or_path or "stabilityai/sd-turbo" + # Load the base model print(f"Loading model: {model_id}") self.model_id = model_id From 788f7a548cc3f6a4dfbc0def84f345c69988d489 Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Fri, 24 Apr 2026 17:22:00 -0700 Subject: [PATCH 04/26] feat: hot-swap model when model_id_or_path changes at runtime MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Scope routes model_id_or_path through setNodeParams (the runtime/kwargs path), not through pipeline/load. Previously __call__ ignored the incoming value, so picking a different model in the UI updated logs but left the original weights loaded. Detect a mismatch against self.model_id and reload the weights in place — re-attaching the LCM LoRA / fp16-fix VAE per family, freeing the old pipe first to avoid 2x VRAM, and invalidating prompt / timestep / noise caches so the next frame rebuilds against the new model. --- src/scope_streamdiffusion/pipeline.py | 66 +++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index fbdb8a8..1167a2d 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -208,6 +208,64 @@ def _load_model(self, model_id: str) -> DiffusionPipeline: print(f"Failed to load model {model_id}: {e}") raise + def _swap_model(self, new_model_id: str) -> None: + """Replace the loaded model in place. + + Scope routes ``model_id_or_path`` through both the load-time path + (which would reinit the pipeline cleanly) and the runtime + ``setNodeParams`` path (which only updates kwargs and never touches + ``__init__``). When the runtime kwarg disagrees with what we loaded, + rebuild the model parts here so picking a model in the UI actually + swaps it. Stalls the frame loop while loading — same as a fresh load. + """ + print(f"[StreamDiffusion] Swapping model: {self.model_id} -> {new_model_id}") + # Free the current model's GPU memory before bringing the next one in + # so we don't peak at 2x weights. + old = getattr(self, "pipe", None) + if old is not None: + del old + self.pipe = None + self._taesd_vae = None + self._full_vae = None + torch.cuda.empty_cache() + + self.model_id = new_model_id + self.sd_turbo = "turbo" in new_model_id.lower() + self.pipe = self._load_model(new_model_id) + print(f"[StreamDiffusion] Model loaded: {self.pipe.__class__.__name__}") + self.sdxl = type(self.pipe) is StableDiffusionXLPipeline + if self.sdxl and self.dtype == torch.float16: + self._install_sdxl_fp16_vae() + if not self.sd_turbo: + self._attach_lcm_lora() + + self.text_encoder = self.pipe.text_encoder + self.unet = self.pipe.unet + self.vae = self.pipe.vae + self._full_vae = self.vae + self._using_taesd = False + self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) + self.image_processor = VaeImageProcessor(self.pipe.vae_scale_factor) + + # Invalidate runtime caches so the next __call__ rebuilds prompt + # embeddings, timestep schedule, and noise buffers against the new + # model — text encoder + UNet config differ between SD1.5 and SDXL. + self._schedule_key = None + self._noise_shape = None + self._prompts_key = None + self._cached_base_embed = None + self._previous_prompt_embeddings = None + self.prev_image_result = None + self._last_transition_id = None + self._pooled_source = None + self._pooled_target = None + self._transition_total_steps = 0 + if hasattr(self, "embedding_blender"): + try: + self.embedding_blender.cancel_transition() + except Exception: + pass + def _install_sdxl_fp16_vae(self) -> None: """Swap SDXL's default VAE for madebyollin/sdxl-vae-fp16-fix. @@ -1103,6 +1161,14 @@ def __call__(self, **kwargs) -> dict: # Extract parameters - handle Scope's parameter format video = kwargs.get("video", None) + # Hot-swap the model when the runtime kwarg disagrees with what's + # loaded. Scope sends model_id_or_path through setNodeParams, not + # through pipeline/load, so this is the only place a UI-driven model + # change actually takes effect. + requested_model = kwargs.get("model_id_or_path") or kwargs.get("model_id") + if requested_model and requested_model != self.model_id: + self._swap_model(requested_model) + # Bypass: pass input through unchanged when disabled enabled = kwargs.get("enabled", True) if not enabled: From c4151c838f1deeb7cf279a4da12555143464add0 Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Fri, 24 Apr 2026 17:29:33 -0700 Subject: [PATCH 05/26] refactor: read model_id_or_path via get_param like every other runtime field Use the same config/kwargs lookup path that strength, seed, etc. use, instead of a hand-rolled kwargs.get() ahead of the rest of __call__. --- src/scope_streamdiffusion/pipeline.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index 1167a2d..6642a59 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -1161,14 +1161,6 @@ def __call__(self, **kwargs) -> dict: # Extract parameters - handle Scope's parameter format video = kwargs.get("video", None) - # Hot-swap the model when the runtime kwarg disagrees with what's - # loaded. Scope sends model_id_or_path through setNodeParams, not - # through pipeline/load, so this is the only place a UI-driven model - # change actually takes effect. - requested_model = kwargs.get("model_id_or_path") or kwargs.get("model_id") - if requested_model and requested_model != self.model_id: - self._swap_model(requested_model) - # Bypass: pass input through unchanged when disabled enabled = kwargs.get("enabled", True) if not enabled: @@ -1225,6 +1217,14 @@ def get_param(key, default): # Finally use default return default + # Hot-swap when the model selection changes at runtime. Scope routes + # model_id_or_path through setNodeParams (kwargs / config), not + # through pipeline/load — so this is the only spot where a UI-driven + # change actually takes effect. + requested_model = get_param("model_id_or_path", None) + if requested_model and requested_model != self.model_id: + self._swap_model(requested_model) + # Extract all parameters with config fallback prompt_interpolation_method = get_param("prompt_interpolation_method", "linear") guidance_scale = get_param("guidance_scale", 0.0) From 8315c82009c7ed85561d1912c2ec6d4e707add3f Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Fri, 24 Apr 2026 18:54:15 -0700 Subject: [PATCH 06/26] fix: sequential denoising in txt2img mode so multi-step doesn't flash channels StreamDiffusion's batch denoising emits one frame per __call__ but each frame is at a different t_index in the cycle (frame i -> t_index i mod N). Across video that smooths out; for a steady text prompt it shows up as N different denoising stages flashing one after another. Switch to sequential denoising (all N steps inside one __call__) when there's no video input and the schedule has >1 step, so each frame is one fully denoised image. --- src/scope_streamdiffusion/pipeline.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index 6642a59..f98e177 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -1158,6 +1158,7 @@ def __call__(self, **kwargs) -> dict: Returns: dict: {"video": output_tensor} where output_tensor is (T, H, W, C) in [0, 1] """ + print('.....') # Extract parameters - handle Scope's parameter format video = kwargs.get("video", None) @@ -1198,6 +1199,9 @@ def __call__(self, **kwargs) -> dict: # Get config instance - Scope should pass this # Try different ways Scope might pass config config = kwargs.get("config") or kwargs.get("pipeline_config") + + print(config) + print(kwargs) # If no config found, try to get it from the pipeline if config is None: @@ -1222,6 +1226,7 @@ def get_param(key, default): # through pipeline/load — so this is the only spot where a UI-driven # change actually takes effect. requested_model = get_param("model_id_or_path", None) + print(f"Requested model: {requested_model}, current model: {self.model_id}") if requested_model and requested_model != self.model_id: self._swap_model(requested_model) @@ -1247,6 +1252,16 @@ def get_param(key, default): height = get_param("height", 512) use_denoising_batch = get_param("use_denoising_batch", True) do_add_noise = get_param("do_add_noise", True) + + # Batch denoising is StreamDiffusion's streaming-video trick: each + # __call__ returns one frame at a different denoising stage in the + # cycle (frame i at t_index i mod N). Across consecutive video frames + # that smooths out; in steady-prompt txt2img it shows up as N + # distinct outputs flashing in sequence. Run sequential denoising + # (all N steps inside one __call__ -> one fully denoised frame) when + # there's no video input and the schedule has more than one step. + if (video is None or len(video) == 0) and num_inference_steps > 1: + use_denoising_batch = False similar_image_filter_enabled = get_param("similar_image_filter_enabled", False) image_loopback = get_param("image_loopback", False) controlnet_mode = get_param("controlnet_mode", "none") From 9fdbb7e910b32e999137f060849dbe2862a94afb Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Fri, 24 Apr 2026 19:06:56 -0700 Subject: [PATCH 07/26] Revert "fix: sequential denoising in txt2img mode so multi-step doesn't flash channels" This reverts commit 8315c82009c7ed85561d1912c2ec6d4e707add3f. --- src/scope_streamdiffusion/pipeline.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index f98e177..6642a59 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -1158,7 +1158,6 @@ def __call__(self, **kwargs) -> dict: Returns: dict: {"video": output_tensor} where output_tensor is (T, H, W, C) in [0, 1] """ - print('.....') # Extract parameters - handle Scope's parameter format video = kwargs.get("video", None) @@ -1199,9 +1198,6 @@ def __call__(self, **kwargs) -> dict: # Get config instance - Scope should pass this # Try different ways Scope might pass config config = kwargs.get("config") or kwargs.get("pipeline_config") - - print(config) - print(kwargs) # If no config found, try to get it from the pipeline if config is None: @@ -1226,7 +1222,6 @@ def get_param(key, default): # through pipeline/load — so this is the only spot where a UI-driven # change actually takes effect. requested_model = get_param("model_id_or_path", None) - print(f"Requested model: {requested_model}, current model: {self.model_id}") if requested_model and requested_model != self.model_id: self._swap_model(requested_model) @@ -1252,16 +1247,6 @@ def get_param(key, default): height = get_param("height", 512) use_denoising_batch = get_param("use_denoising_batch", True) do_add_noise = get_param("do_add_noise", True) - - # Batch denoising is StreamDiffusion's streaming-video trick: each - # __call__ returns one frame at a different denoising stage in the - # cycle (frame i at t_index i mod N). Across consecutive video frames - # that smooths out; in steady-prompt txt2img it shows up as N - # distinct outputs flashing in sequence. Run sequential denoising - # (all N steps inside one __call__ -> one fully denoised frame) when - # there's no video input and the schedule has more than one step. - if (video is None or len(video) == 0) and num_inference_steps > 1: - use_denoising_batch = False similar_image_filter_enabled = get_param("similar_image_filter_enabled", False) image_loopback = get_param("image_loopback", False) controlnet_mode = get_param("controlnet_mode", "none") From 7872f9186bf0e23a75e7991a43811054a96ce614 Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Fri, 24 Apr 2026 19:27:39 -0700 Subject: [PATCH 08/26] feat: serial denoise path for txt2img and image-loopback modes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds _predict_x0_serial as a sibling of _predict_x0_batch and routes to it when num_inference_steps > 1 in steady-prompt modes (no video input, or explicit image_loopback) with ControlNet off. Walks the full N-step LCM schedule inside one __call__, so each emitted frame is one fully denoised image instead of one slot of the rolling N-track buffer cycle that otherwise flashes N different attractors at the camera. The batch path still owns: - num_inference_steps == 1 (degenerates to one UNet call anyway, and it's the path SD-Turbo and the depth/scribble ControlNet pre-passes expect) - video input / v2v streams (where the buffer reuse trick actually amortises across consecutive related frames — its design point) - ControlNet streams (same reasoning) Routing decision is a single boolean (`use_serial`) computed alongside the other extracted params; the rest of __call__ branches on it exactly twice — once to skip auto-noising the encoded image (serial adds its own noise based on `strength`) and once to pick the predict function. Batch path is untouched. --- src/scope_streamdiffusion/pipeline.py | 96 ++++++++++++++++++++++++++- 1 file changed, 93 insertions(+), 3 deletions(-) diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index 6642a59..8680653 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -1136,6 +1136,63 @@ def _predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: return x_0_pred_out + @torch.no_grad() + def _predict_x0_serial( + self, + latent: torch.Tensor, + num_inference_steps: int, + strength: float = 1.0, + is_img2img: bool = False, + ) -> torch.Tensor: + """Run a clean N-step LCM denoise loop on a single latent. + + Sibling of :meth:`_predict_x0_batch` for modes where the streaming + rolling-buffer trick gets in the way (steady-prompt txt2img and + image-loopback) — runs all N timesteps inside one call so the output + is a single fully-denoised frame, not one slice of a 4-track buffer + cycle. + + For txt2img, ``latent`` is pure noise and we walk every timestep in + the schedule. For img2img / loopback, ``latent`` is the cleanly- + encoded input image; we add fresh noise at the first timestep we + actually run, controlled by ``strength`` (1.0 = full repaint, lower + = preserve more of the input). + """ + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps = self.scheduler.timesteps + + if is_img2img: + skip = max(0, int(round(num_inference_steps * (1.0 - strength)))) + timesteps = timesteps[skip:] + if len(timesteps) == 0: + return latent + noise = torch.randn( + latent.shape, generator=self.generator, + device=self.device, dtype=self.dtype, + ) + latent = self.scheduler.add_noise(latent, noise, timesteps[:1]) + + added_cond_kwargs = {} + if self.sdxl: + added_cond_kwargs = { + "text_embeds": self.add_text_embeds.to(self.device), + "time_ids": self.add_time_ids.to(self.device), + } + + prompt_embeds = self.prompt_embeds[:1] # serial works one frame at a time + for t in timesteps: + noise_pred = self.unet( + latent, + t, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + latent = self.scheduler.step( + noise_pred, t, latent, return_dict=False + )[0] + return latent + @torch.no_grad() def __call__(self, **kwargs) -> dict: """Process input video frame(s) and return generated output. @@ -1258,6 +1315,30 @@ def get_param(key, default): depth_skip_interval = get_param("depth_skip_interval", 3) use_taesd = get_param("use_taesd", False) + # --- Pick denoise path ------------------------------------------------- + # The batch path (StreamDiffusion's rolling-buffer trick) amortises N + # denoising stages across N consecutive video frames. That's a real + # win when the input stream changes every frame (webcam, v2v, moving + # ControlNet) but in steady-prompt txt2img / image-loopback the per- + # slot init_noise drift makes the N buffer slots crystallise into N + # different attractors that flash one after another. + # Use the serial path for those cases; keep the batch path everywhere + # else and at num_inference_steps=1 (where it degenerates to one + # UNet call per frame anyway). + has_video_input_eval = video is not None and len(video) > 0 + is_steady_prompt_mode = (not has_video_input_eval) or image_loopback + use_serial = ( + num_inference_steps > 1 + and is_steady_prompt_mode + and controlnet_mode == "none" + ) + if use_serial: + # Serial denoise wants prompt_embeds sized for batch=1 and doesn't + # care about the LCM coefficient pre-compute. Force the runtime + # state into a 1-track configuration so _prepare_runtime_state's + # caches match what _predict_x0_serial will read. + use_denoising_batch = False + # --- Safeguard: prevent invalid strength / num_inference_steps combos --- # LCM scheduler requires: floor(original_steps * strength) >= num_inference_steps # original_steps defaults to 50 in the scheduler. @@ -1367,8 +1448,9 @@ def get_param(key, default): return {"video": output.permute(0, 2, 3, 1).clamp(0, 1)} input_tensor = filtered - # Encode to latent space - input_latent = self._encode_image(input_tensor) + # Encode to latent space. Serial img2img adds its own noise based + # on the requested strength, so don't double-noise here. + input_latent = self._encode_image(input_tensor, add_noise=not use_serial) else: # Text-to-image mode — use the seeded `init_noise` instead of a @@ -1379,7 +1461,15 @@ def get_param(key, default): input_latent = self.init_noise[0:1].clone() # Run diffusion - x_0_pred_out = self._predict_x0_batch(input_latent) + if use_serial: + x_0_pred_out = self._predict_x0_serial( + input_latent, + num_inference_steps=num_inference_steps, + strength=strength, + is_img2img=frame is not None, + ) + else: + x_0_pred_out = self._predict_x0_batch(input_latent) # Decode to image space x_output = self._decode_image(x_0_pred_out).detach().clone() # Normalize from [-1, 1] to [0, 1] (VAE outputs in range [-1, 1]) From 25478b759612b91de387a7a5807343b93b0d0f73 Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Sun, 3 May 2026 09:14:07 -0700 Subject: [PATCH 09/26] feat(trt): node-id-keyed adapter cache to survive graph edits MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Scope rebuilds the plugin instance on every graph edit, which clears the in-memory `_trt_*_built` flags and forces a per-engine deserialize/bind cycle (visible stalls of hundreds of ms to seconds, plus the rare full ONNX→TRT compile). Hold the built adapters at module scope keyed by the graph node id so the new instance can swap them straight back in. - New `_trt_cache.py`: `CachedTRTState` (cuda_stream, unet_adapter, unet_has_controlnet, cn_adapters dict, taesd_adapter) keyed by `node:`, with signature `(model_id, height, width)` so a real config change still triggers a clean rebuild. - `pipeline.py`: read `node_id` from kwargs (Scope must pass it through; until that lands, falls back to `_anon_` — correct for the single-SD-node case). At first `_ensure_trt_*` call, look up the cache; on hit, swap `self.unet` / `self.controlnet` / `self.vae` to the cached adapter and skip the build. On miss, build then write back. --- src/scope_streamdiffusion/_trt_cache.py | 83 +++++++++++++++++++++++++ src/scope_streamdiffusion/pipeline.py | 73 ++++++++++++++++++++++ 2 files changed, 156 insertions(+) create mode 100644 src/scope_streamdiffusion/_trt_cache.py diff --git a/src/scope_streamdiffusion/_trt_cache.py b/src/scope_streamdiffusion/_trt_cache.py new file mode 100644 index 0000000..bbcfd3b --- /dev/null +++ b/src/scope_streamdiffusion/_trt_cache.py @@ -0,0 +1,83 @@ +"""Process-wide cache of built TRT adapters keyed by graph node id. + +Scope rebuilds plugin instances on every graph edit (see +`scope/src/scope/server/graph_executor.py`); a fresh `StreamDiffusionPipeline` +loses its in-memory `_trt_*_built` flags and rebuilds engines on first call, +even when the on-disk engine cache hits. Loading and binding a TRT engine +context costs ~hundreds of ms per engine, and ONNX→TRT compile costs minutes +when the disk cache misses. Both are visible stalls during graph edits. + +This module holds the built adapters at module scope so a new pipeline +instance for the same logical node can swap them straight back in without +touching the engine builder. + +Cache key: the user-supplied graph node id when Scope passes it through +`__init__` kwargs. Until that upstream change lands the plugin falls back to +`_anon_`, which is correct for the common single-SD-node setup but +collides if two SD nodes ever coexist with different engine signatures. + +Engines are tied to (model_id, image_height, image_width); changing any of +those invalidates the cached state and forces a clean rebuild. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class CachedTRTState: + signature: tuple # (model_id, height, width) + cuda_stream: Any | None = None + unet_adapter: Any | None = None + unet_has_controlnet: bool = False + cn_adapters: dict[str, Any] = field(default_factory=dict) + taesd_adapter: Any | None = None + + +_CACHE: dict[str, CachedTRTState] = {} + + +def cache_key(node_id: str | None, model_id: str) -> str: + """Return the cache key for this pipeline instance. + + Prefers the node id (stable across graph edits, unique per logical node); + falls back to a model-id-scoped anon key for compatibility with Scope + versions that don't yet pass node_id to plugin __init__. + """ + if node_id: + return f"node:{node_id}" + return f"_anon_{model_id}" + + +def get_or_create(key: str, signature: tuple) -> tuple[CachedTRTState, bool]: + """Look up an entry; return (state, restored). + + `restored=True` means the cached signature matched and the caller should + reuse `state.*_adapter`. `restored=False` means either no entry existed or + the signature changed (engines built for different dims/model); the entry + is reset to a fresh state so callers can populate it after building. + """ + existing = _CACHE.get(key) + if existing is not None and existing.signature == signature: + return existing, True + fresh = CachedTRTState(signature=signature) + _CACHE[key] = fresh + return fresh, False + + +def peek(key: str) -> CachedTRTState | None: + return _CACHE.get(key) + + +def clear(key: str | None = None) -> None: + """Drop one entry, or the whole cache when key is None. + + Adapters hold CUDA memory; clearing here releases the only strong ref + once the previous pipeline instance is also gone. + """ + if key is None: + _CACHE.clear() + else: + _CACHE.pop(key, None) diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index e59b4f6..aa9f822 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -17,6 +17,7 @@ from scope.core.pipelines.interface import Pipeline, Requirements from scope.core.pipelines.blending import EmbeddingBlender, parse_transition_config +from . import _trt_cache from .controlnet import ControlNetHandler from .schema import StreamDiffusionConfig @@ -126,6 +127,14 @@ def __init__( if self._acceleration_mode == "trt": print(f"[TRT] acceleration_mode='trt' detected at init") + # Identify this pipeline instance for the cross-instance TRT adapter + # cache. node_id is the user-supplied graph node id from Scope and is + # stable across graph edits; until upstream Scope passes it through, + # we fall back to a model-scoped anon key (correct for a single SD + # node, the common case). + self._node_id: str | None = kwargs.get("node_id") + self._trt_cache_key: str = _trt_cache.cache_key(self._node_id, model_id) + # Check if SDXL self.sdxl: bool = type(self.pipe) is StableDiffusionXLPipeline @@ -208,6 +217,23 @@ def _ensure_trt_taesd(self) -> None: return if self._taesd_vae is None: return + + signature = (self._model_id_for_trt, int(self.height), int(self.width)) + cache_state, restored = _trt_cache.get_or_create(self._trt_cache_key, signature) + if self._trt_cuda_stream is None and cache_state.cuda_stream is not None: + self._trt_cuda_stream = cache_state.cuda_stream + if restored and cache_state.taesd_adapter is not None: + self._trt_eager_taesd = self._taesd_vae + self._taesd_vae = cache_state.taesd_adapter + if self._using_taesd: + self.vae = cache_state.taesd_adapter + self._trt_taesd_built = True + print( + f"[TRT] TAESD adapter restored from cache (key={self._trt_cache_key})", + flush=True, + ) + return + self._trt_taesd_built = True # prevent retry on failure from .trt_engines import ( @@ -217,6 +243,7 @@ def _ensure_trt_taesd(self) -> None: ) if self._trt_cuda_stream is None: self._trt_cuda_stream = make_cuda_stream() + cache_state.cuda_stream = self._trt_cuda_stream print( "[TRT] Preparing TAESD engines — first build takes ~1 min, cached after", flush=True, @@ -245,6 +272,7 @@ def _ensure_trt_taesd(self) -> None: self._taesd_vae = adapter if self._using_taesd: self.vae = adapter + cache_state.taesd_adapter = adapter print(f"[TRT] TAESD engines active: enc={enc_path.name}, dec={dec_path.name}", flush=True) def _ensure_trt_controlnet(self, mode: str) -> None: @@ -265,6 +293,24 @@ def _ensure_trt_controlnet(self, mode: str) -> None: self._trt_eager_controlnets[mode] = self._cn.model self.controlnet = adapter return + + signature = (self._model_id_for_trt, int(self.height), int(self.width)) + cache_state, restored = _trt_cache.get_or_create(self._trt_cache_key, signature) + if self._trt_cuda_stream is None and cache_state.cuda_stream is not None: + self._trt_cuda_stream = cache_state.cuda_stream + cached_cn = cache_state.cn_adapters.get(mode) if restored else None + if cached_cn is not None: + self._trt_eager_controlnets[mode] = self._cn.model + self._trt_cn_engines[mode] = cached_cn + self._trt_cn_built_modes.add(mode) + self.controlnet = cached_cn + print( + f"[TRT] ControlNet adapter restored from cache " + f"(mode={mode}, key={self._trt_cache_key})", + flush=True, + ) + return + self._trt_cn_built_modes.add(mode) # mark before build to prevent retry storm from .trt_engines import ( @@ -275,6 +321,7 @@ def _ensure_trt_controlnet(self, mode: str) -> None: if self._trt_cuda_stream is None: self._trt_cuda_stream = make_cuda_stream() + cache_state.cuda_stream = self._trt_cuda_stream # ControlNet ONNX export needs default attention too (same xformers # issue as the UNet path). @@ -304,6 +351,7 @@ def _ensure_trt_controlnet(self, mode: str) -> None: self._trt_eager_controlnets[mode] = self._cn.model self._trt_cn_engines[mode] = adapter self.controlnet = adapter + cache_state.cn_adapters[mode] = adapter print(f"[TRT] ControlNet engine active ({mode}): {engine_path}", flush=True) def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: @@ -331,6 +379,26 @@ def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: f"(had={self._trt_unet_has_controlnet}, want={want_control}); rebuilding" ) + signature = (self._model_id_for_trt, int(self.height), int(self.width)) + cache_state, restored = _trt_cache.get_or_create(self._trt_cache_key, signature) + if self._trt_cuda_stream is None and cache_state.cuda_stream is not None: + self._trt_cuda_stream = cache_state.cuda_stream + if ( + restored + and cache_state.unet_adapter is not None + and cache_state.unet_has_controlnet == want_control + ): + self._trt_eager_unet = self.pipe.unet + self.unet = cache_state.unet_adapter + self._trt_unet_built = True + self._trt_unet_has_controlnet = want_control + print( + f"[TRT] UNet adapter restored from cache " + f"(want_control={want_control}, key={self._trt_cache_key})", + flush=True, + ) + return + # Set sticky flags before the build so failures don't retry every frame. self._trt_unet_built = True self._trt_unet_has_controlnet = want_control @@ -361,6 +429,7 @@ def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: if self._trt_cuda_stream is None: self._trt_cuda_stream = make_cuda_stream() + cache_state.cuda_stream = self._trt_cuda_stream if want_control: print( @@ -380,6 +449,8 @@ def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: self.unet = TRTUNetWithControlAdapter( engine_path, self._trt_cuda_stream, num_down_residuals=12, ) + cache_state.unet_adapter = self.unet + cache_state.unet_has_controlnet = True print(f"[TRT] UNet+ctrl engine active: {engine_path}", flush=True) else: print( @@ -396,6 +467,8 @@ def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: ) self._trt_eager_unet = self.pipe.unet self.unet = TRTUNetAdapter(engine_path, self._trt_cuda_stream) + cache_state.unet_adapter = self.unet + cache_state.unet_has_controlnet = False print(f"[TRT] UNet engine active: {engine_path}", flush=True) def _load_model(self, model_id: str) -> DiffusionPipeline: From 7c89f3d83b4d22c53214dffabe87fa21d1da8949 Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Mon, 4 May 2026 21:59:28 -0700 Subject: [PATCH 10/26] feat(schema): width/height as Resolution IntEnum with field_validator Replaces Literal[256, 320, ...] tuple on width/height with a Resolution IntEnum and a `mode='before'` field_validator that coerces ints into enum members and raises a clear error listing all allowed values otherwise. Pipeline code already wraps width/height in `int()`, so behavior is unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/scope_streamdiffusion/schema.py | 97 +++++++++++++++++++---------- 1 file changed, 63 insertions(+), 34 deletions(-) diff --git a/src/scope_streamdiffusion/schema.py b/src/scope_streamdiffusion/schema.py index d872037..6284c85 100644 --- a/src/scope_streamdiffusion/schema.py +++ b/src/scope_streamdiffusion/schema.py @@ -1,8 +1,9 @@ """Configuration schema for StreamDiffusion pipeline.""" +from enum import IntEnum from typing import Literal -from pydantic import Field +from pydantic import Field, field_validator from scope.core.pipelines.base_schema import ( BasePipelineConfig, InputMode, @@ -11,6 +12,29 @@ ) +class Resolution(IntEnum): + """Allowed pixel dimensions for width/height. + + Multiples of 64 in [256, 1024]. UNet downsamples latents 3x and ControlNet + residuals land at /8 in latent space, so pixel dims must divide by 64. TRT + engines are built for this dynamic range. + """ + + R256 = 256 + R320 = 320 + R384 = 384 + R448 = 448 + R512 = 512 + R576 = 576 + R640 = 640 + R704 = 704 + R768 = 768 + R832 = 832 + R896 = 896 + R960 = 960 + R1024 = 1024 + + class StreamDiffusionConfig(BasePipelineConfig): """Configuration for the StreamDiffusion pipeline.""" @@ -44,13 +68,13 @@ class StreamDiffusionConfig(BasePipelineConfig): enabled: bool = Field( default=True, description="Enable pipeline processing. When disabled, input video is passed through unchanged.", - json_schema_extra=ui_field_config(order=0, label="Enabled"), + #json_schema_extra=ui_field_config(order=0, label="Enabled"), ) input_mode: InputMode = Field( default="text", description="Input mode: 'text' generates from prompts only, 'video' transforms input frames", - json_schema_extra=ui_field_config(order=1, label="Input Mode"), + #json_schema_extra=ui_field_config(order=1, label="Input Mode"), ) # ======================================== @@ -71,13 +95,13 @@ class StreamDiffusionConfig(BasePipelineConfig): "requires pipeline reload. Engines support dynamic resolution 256-1024 " "and batch 1-4." ), - json_schema_extra=ui_field_config(order=2, label="Acceleration"), + #json_schema_extra=ui_field_config(order=2, label="Acceleration"), ) use_taesd: bool = Field( default=True, description="Use Tiny AutoEncoder (TAESD) for ~10x faster VAE decoding at slight quality cost", - json_schema_extra=ui_field_config(order=2, label="Use TAESD"), + #json_schema_extra=ui_field_config(order=2, label="Use TAESD"), ) controlnet_mode: Literal["none", "depth", "scribble"] = Field( @@ -99,7 +123,7 @@ class StreamDiffusionConfig(BasePipelineConfig): ge=0.0, le=1, description="Minimum depth value for ControlNet", - json_schema_extra=ui_field_config(order=5, label="Depth Min"), + #json_schema_extra=ui_field_config(order=5, label="Depth Min"), ) depth_max: float = Field( @@ -107,27 +131,27 @@ class StreamDiffusionConfig(BasePipelineConfig): ge=0.0, le=1, description="Maximum depth value for ControlNet", - json_schema_extra=ui_field_config(order=6, label="Depth Max"), + #json_schema_extra=ui_field_config(order=6, label="Depth Max"), ) depth_skip_interval: int = Field( - default=3, + default=2, ge=1, le=10, description="Run depth model every Nth frame; reuse cached depth map on intermediate frames. Higher = less GPU cost, more temporal lag.", - json_schema_extra=ui_field_config(order=7, label="Depth Skip Interval"), + #json_schema_extra=ui_field_config(order=7, label="Depth Skip Interval"), ) depth_input_size: Literal[252, 364, 518] = Field( - default=518, + default=252, description="Resolution the depth model runs at (must be multiple of 14). Lower = faster but coarser depth. 252 ≈ 4× faster than 518; the depth map is bilinear-upsampled to controlnet resolution either way.", - json_schema_extra=ui_field_config(order=8, label="Depth Input Size"), + #json_schema_extra=ui_field_config(order=8, label="Depth Input Size"), ) depth_temporal_cache: bool = Field( default=True, description="Use the video model's temporal hidden-state cache for inter-frame consistency. Disabling skips the temporal motion modules entirely (faster, slightly more flicker). Combined with skip interval > 1 the cache buys little, so toggle off for speed.", - json_schema_extra=ui_field_config(order=9, label="Depth Temporal Cache"), + #json_schema_extra=ui_field_config(order=9, label="Depth Temporal Cache"), ) controlnet_temporal_smoothing: float = Field( @@ -135,7 +159,7 @@ class StreamDiffusionConfig(BasePipelineConfig): ge=0.0, le=1.0, description="Temporal blending of the ControlNet conditioning map. 0.0 = fully smoothed (previous frame only), 1.0 = no smoothing (current frame only). Lower values reduce flicker; higher values reduce latency.", - json_schema_extra=ui_field_config(order=5, label="ControlNet Smoothing"), + #json_schema_extra=ui_field_config(order=5, label="ControlNet Smoothing"), ) # ======================================== @@ -146,7 +170,7 @@ class StreamDiffusionConfig(BasePipelineConfig): negative_prompt: str = Field( default="", description="Negative prompt — what to avoid in the generated image", - json_schema_extra=ui_field_config(order=11, label="Negative Prompt"), + #json_schema_extra=ui_field_config(order=11, label="Negative Prompt"), ) negative_prompt_scale: float = Field( @@ -154,7 +178,7 @@ class StreamDiffusionConfig(BasePipelineConfig): ge=0.0, le=2.0, description="Strength of embedding-space negative guidance (used when guidance_scale=0). Subtracts the negative prompt embedding from the positive. 0 = disabled, 1 = full subtraction.", - json_schema_extra=ui_field_config(order=12, label="Negative Scale"), + #json_schema_extra=ui_field_config(order=12, label="Negative Scale"), ) prompt_interpolation_method: Literal["linear", "slerp"] = Field( @@ -171,7 +195,7 @@ class StreamDiffusionConfig(BasePipelineConfig): "0 = hard cut (can cause garbage frames); 8-30 is typical for smooth " "prompt morphs. Ignored when an explicit transition dict is sent." ), - json_schema_extra=ui_field_config(order=10, label="Transition Steps"), + #json_schema_extra=ui_field_config(order=10, label="Transition Steps"), ) seed: int = Field( @@ -195,7 +219,7 @@ class StreamDiffusionConfig(BasePipelineConfig): ) num_inference_steps: int = Field( - default=2, + default=1, ge=1, le=50, description="Number of denoising steps", @@ -265,7 +289,7 @@ class StreamDiffusionConfig(BasePipelineConfig): image_loopback: bool = Field( default=False, description="Use last frame as input for the next generation", - json_schema_extra=ui_field_config(order=49, label="Image Loopback"), + #json_schema_extra=ui_field_config(order=49, label="Image Loopback"), ) # ======================================== @@ -279,7 +303,7 @@ class StreamDiffusionConfig(BasePipelineConfig): "mask. SD output goes where mask=1, original goes where mask=0. " "Flip directions by toggling the upstream segmenter's Invert Mask." ), - json_schema_extra=ui_field_config(order=55, label="Mask Compositing"), + #json_schema_extra=ui_field_config(order=55, label="Mask Compositing"), ) mask_feather: float = Field( @@ -290,7 +314,7 @@ class StreamDiffusionConfig(BasePipelineConfig): "Soft mask edges (pixels). 0 = hard edge. Cheap box-blur applied " "to the mask before compositing." ), - json_schema_extra=ui_field_config(order=56, label="Mask Feather"), + #json_schema_extra=ui_field_config(order=56, label="Mask Feather"), ) mask_strength: float = Field( @@ -301,24 +325,29 @@ class StreamDiffusionConfig(BasePipelineConfig): "Overall mask blend strength. 0 disables compositing, 1 is full effect. " "Use intermediate values to ghost the original through the SD output." ), - json_schema_extra=ui_field_config(order=57, label="Mask Strength"), + #json_schema_extra=ui_field_config(order=57, label="Mask Strength"), ) - # Resolution settings — must be a multiple of 64 (UNet downsamples latents - # 3x; ControlNet residuals go to /8 in latent space, so pixel dim has to - # divide by 64). TRT engines are built for the 256-1024 dynamic range. - width: Literal[ - 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024 - ] = Field( - default=512, + width: Resolution = Field( + default=Resolution.R512, description="Output width (multiple of 64, 256-1024)", - json_schema_extra=ui_field_config(order=60, label="Width"), + #json_schema_extra=ui_field_config(order=60, label="Width"), ) - height: Literal[ - 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024 - ] = Field( - default=512, + height: Resolution = Field( + default=Resolution.R512, description="Output height (multiple of 64, 256-1024)", - json_schema_extra=ui_field_config(order=61, label="Height"), + #json_schema_extra=ui_field_config(order=61, label="Height"), ) + + @field_validator("width", "height", mode="before") + @classmethod + def _validate_resolution(cls, v: object) -> Resolution: + try: + return Resolution(int(v)) # type: ignore[arg-type] + except (ValueError, TypeError) as e: + allowed = ", ".join(str(r.value) for r in Resolution) + raise ValueError( + f"Resolution must be one of: {allowed} (multiples of 64 in [256, 1024]); got {v!r}" + ) from e + From 29f1ea7c125ca965f1b23d786e32fa3446426f7f Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Tue, 5 May 2026 12:35:53 -0700 Subject: [PATCH 11/26] feat(schema): restrict to Turbo-only, drop num_inference_steps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Trim model_id_or_path enum to stabilityai/sd-turbo and stabilityai/sdxl-turbo — both 1-step distillations. Drops Dreamshaper, SD 1.5 base, SDXL base, and the Dreamshaper SDXL Turbo variant: keeping the multi-step models meant carrying LCM LoRA fusion + a serial denoise path that we no longer need. Removes num_inference_steps and use_suggested_num_inference_steps fields: both are dead now that step count is fixed at 1 for every supported model. LoRA-based step distillation (Hyper-SD / Lightning) on arbitrary checkpoints is the better path forward — tracked separately, not in this change. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/scope_streamdiffusion/schema.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/src/scope_streamdiffusion/schema.py b/src/scope_streamdiffusion/schema.py index 00a07f1..a811e92 100644 --- a/src/scope_streamdiffusion/schema.py +++ b/src/scope_streamdiffusion/schema.py @@ -84,13 +84,9 @@ class StreamDiffusionConfig(BasePipelineConfig): model_id_or_path: Literal[ "stabilityai/sd-turbo", "stabilityai/sdxl-turbo", - "stable-diffusion-v1-5/stable-diffusion-v1-5", - "stabilityai/stable-diffusion-xl-base-1.0", - "Lykon/dreamshaper-8", - "Lykon/dreamshaper-xl-v2-turbo", ] = Field( default="stabilityai/sd-turbo", - description="HuggingFace model ID. Tested set; non-Turbo entries auto-attach the matching LCM LoRA, SDXL entries auto-swap to madebyollin/sdxl-vae-fp16-fix.", + description="HuggingFace model ID. Both entries are 1-step distillations; SDXL-Turbo additionally swaps in madebyollin/sdxl-vae-fp16-fix.", json_schema_extra=ui_field_config(order=8, label="Model"), ) @@ -226,20 +222,6 @@ class StreamDiffusionConfig(BasePipelineConfig): # json_schema_extra=ui_field_config(order=20, label="Guidance Scale"), ) - num_inference_steps: int = Field( - default=4, - ge=1, - le=50, - description="Number of LCM denoising steps. Main sharpness lever: more steps = sharper detail. SD-Turbo (sd-turbo) is distilled for 1 step; SDXL-Turbo / fine-tunes / non-turbo + LCM LoRA all want 4–8.", - json_schema_extra=ui_field_config(order=21, label="Inference Steps"), - ) - - use_suggested_num_inference_steps: bool = Field( - default=True, - description="When ON, the pipeline picks the inference-step count per model family (1 for SD-Turbo, 4 for everything else) and ignores the slider. Toggle OFF to drive the slider yourself.", - json_schema_extra=ui_field_config(order=22, label="Auto Inference Steps"), - ) - strength: float = Field( default=0.99, ge=0.0, From 1015f4be0f88f33b023c4791d055772f07b9b6cb Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Tue, 5 May 2026 12:36:06 -0700 Subject: [PATCH 12/26] refactor(pipeline): drop non-Turbo paths, hardcode 1-step denoise MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Now that the schema only allows SD-Turbo and SDXL-Turbo, the runtime can shed everything that existed to make non-Turbo models usable at low step counts: - self.sd_turbo flag (everything is Turbo now) and the per-family step-count branch in __call__ - _attach_lcm_lora() and its call sites in __init__ / _swap_model (LCM LoRA was only fused for non-Turbo SD 1.5 / SDXL bases) - _predict_x0_serial() and the use_serial branch in __call__ — serial denoise was added for steady-prompt txt2img / image-loopback on multi-step models; with 1-step Turbo it never fires - denoising_steps_num > 1 dead branches in _prepare_runtime_state and _predict_x0_batch (always 1 now) - num_inference_steps plumbing — pinned at 1 in __call__ Untouched: TRT engine swap, ControlNet handling, prompt transitions, RCFG, mask compositing, hot-swap between sd-turbo and sdxl-turbo, and the SDXL fp16-fix VAE swap. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/scope_streamdiffusion/pipeline.py | 177 ++------------------------ 1 file changed, 8 insertions(+), 169 deletions(-) diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index 9144f4c..e684237 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -92,11 +92,9 @@ def __init__( # Load the base model print(f"Loading model: {model_id}") self.model_id = model_id - self.sd_turbo = "turbo" in model_id.lower() self.pipe = self._load_model(model_id) print(f"Model loaded: {self.pipe.__class__.__name__}") - # Check if SDXL (needed before LCM LoRA selection) self.sdxl: bool = type(self.pipe) is StableDiffusionXLPipeline # SDXL's default VAE overflows in fp16 and decodes NaN. Swap to the @@ -105,12 +103,7 @@ def __init__( if self.sdxl and self.dtype == torch.float16: self._install_sdxl_fp16_vae() - # Non-turbo models need LCM LoRA to denoise correctly with LCMScheduler - # at 1–4 steps. Turbo/Lightning models are already distilled for this. - if not self.sd_turbo: - self._attach_lcm_lora() - - # Model components (grabbed after LoRA fuse so fused weights are live) + # Model components self.text_encoder = self.pipe.text_encoder self.unet = self.pipe.unet self.vae = self.pipe.vae @@ -537,14 +530,11 @@ def _swap_model(self, new_model_id: str) -> None: torch.cuda.empty_cache() self.model_id = new_model_id - self.sd_turbo = "turbo" in new_model_id.lower() self.pipe = self._load_model(new_model_id) print(f"[StreamDiffusion] Model loaded: {self.pipe.__class__.__name__}") self.sdxl = type(self.pipe) is StableDiffusionXLPipeline if self.sdxl and self.dtype == torch.float16: self._install_sdxl_fp16_vae() - if not self.sd_turbo: - self._attach_lcm_lora() self.text_encoder = self.pipe.text_encoder self.unet = self.pipe.unet @@ -593,29 +583,6 @@ def _install_sdxl_fp16_vae(self) -> None: except Exception as e: print(f"[StreamDiffusion] Failed to install fp16-fix VAE: {e}") - def _attach_lcm_lora(self) -> None: - """Load and fuse the appropriate LCM LoRA for a non-turbo SD/SDXL base. - - LCMScheduler at 1–4 steps only produces usable output on models that - have been distilled for low-step inference — Turbo, Lightning, or LCM. - For plain SD 1.5 / SDXL bases, we attach the matching LCM LoRA so the - scheduler path works the same as it does for Turbo. - """ - lcm_lora_id = ( - "latent-consistency/lcm-lora-sdxl" - if self.sdxl - else "latent-consistency/lcm-lora-sdv1-5" - ) - print(f"[StreamDiffusion] Loading LCM LoRA: {lcm_lora_id}") - try: - self.pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm") - self.pipe.fuse_lora(lora_scale=1.0, adapter_names=["lcm"]) - self.pipe.unload_lora_weights() - print("[StreamDiffusion] LCM LoRA fused") - except Exception as e: - print(f"[StreamDiffusion] Failed to load LCM LoRA {lcm_lora_id}: {e}") - raise - def _set_taesd(self, enabled: bool) -> None: """Switch between TAESD (fast) and full VAE decoder.""" if enabled == self._using_taesd: @@ -750,19 +717,7 @@ def _prepare_runtime_state( self._last_seed = seed if seed_changed or shape_changed: - if self.denoising_steps_num > 1: - self.x_t_latent_buffer = torch.zeros( - ( - (self.denoising_steps_num - 1) * self.frame_bff_size, - 4, - self.latent_height, - self.latent_width, - ), - dtype=self.dtype, - device=self.device, - ) - else: - self.x_t_latent_buffer = None + self.x_t_latent_buffer = None self._initialize_noise() self._noise_shape = noise_shape @@ -1376,11 +1331,6 @@ def _predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: if self.use_denoising_batch: t_list = self.sub_timesteps_tensor - if self.denoising_steps_num > 1: - x_t_latent = torch.cat((x_t_latent, prev_latent_batch), dim=0) - self.stock_noise = torch.cat( - (self.init_noise[0:1], self.stock_noise[:-1]), dim=0 - ) if self.sdxl: batch = x_t_latent.shape[0] te = self.add_text_embeds.to(self.device) @@ -1397,20 +1347,8 @@ def _predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: x_t_latent, t_list, added_cond_kwargs=added_cond_kwargs ) - if self.denoising_steps_num > 1: - x_0_pred_out = x_0_pred_batch[-1].unsqueeze(0) - if self.do_add_noise: - self.x_t_latent_buffer = ( - self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] - + self.beta_prod_t_sqrt[1:] * self.init_noise[1:] - ) - else: - self.x_t_latent_buffer = ( - self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] - ) - else: - x_0_pred_out = x_0_pred_batch - self.x_t_latent_buffer = None + x_0_pred_out = x_0_pred_batch + self.x_t_latent_buffer = None else: self.init_noise = x_t_latent for idx, t in enumerate(self.sub_timesteps_tensor): @@ -1442,63 +1380,6 @@ def _predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: return x_0_pred_out - @torch.no_grad() - def _predict_x0_serial( - self, - latent: torch.Tensor, - num_inference_steps: int, - strength: float = 1.0, - is_img2img: bool = False, - ) -> torch.Tensor: - """Run a clean N-step LCM denoise loop on a single latent. - - Sibling of :meth:`_predict_x0_batch` for modes where the streaming - rolling-buffer trick gets in the way (steady-prompt txt2img and - image-loopback) — runs all N timesteps inside one call so the output - is a single fully-denoised frame, not one slice of a 4-track buffer - cycle. - - For txt2img, ``latent`` is pure noise and we walk every timestep in - the schedule. For img2img / loopback, ``latent`` is the cleanly- - encoded input image; we add fresh noise at the first timestep we - actually run, controlled by ``strength`` (1.0 = full repaint, lower - = preserve more of the input). - """ - self.scheduler.set_timesteps(num_inference_steps, device=self.device) - timesteps = self.scheduler.timesteps - - if is_img2img: - skip = max(0, int(round(num_inference_steps * (1.0 - strength)))) - timesteps = timesteps[skip:] - if len(timesteps) == 0: - return latent - noise = torch.randn( - latent.shape, generator=self.generator, - device=self.device, dtype=self.dtype, - ) - latent = self.scheduler.add_noise(latent, noise, timesteps[:1]) - - added_cond_kwargs = {} - if self.sdxl: - added_cond_kwargs = { - "text_embeds": self.add_text_embeds.to(self.device), - "time_ids": self.add_time_ids.to(self.device), - } - - prompt_embeds = self.prompt_embeds[:1] # serial works one frame at a time - for t in timesteps: - noise_pred = self.unet( - latent, - t, - encoder_hidden_states=prompt_embeds, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] - latent = self.scheduler.step( - noise_pred, t, latent, return_dict=False - )[0] - return latent - @torch.no_grad() def __call__(self, **kwargs) -> dict: """Process input video frame(s) and return generated output. @@ -1589,13 +1470,8 @@ def get_param(key, default): prompt_interpolation_method = get_param("prompt_interpolation_method", "linear") guidance_scale = get_param("guidance_scale", 0.0) - num_inference_steps = get_param("num_inference_steps", 4) - use_suggested_steps = get_param("use_suggested_num_inference_steps", True) - if use_suggested_steps: - # Per-family suggestion: SD-Turbo proper is 1-step distilled; every - # other family (SDXL-Turbo, SDXL-Turbo fine-tunes, SD1.5/SDXL + - # LCM LoRA) needs 4 steps to converge to sharp output. - num_inference_steps = 1 if (self.sd_turbo and not self.sdxl) else 4 + # SD-Turbo and SDXL-Turbo are both 1-step distillations. + num_inference_steps = 1 # For img2img with SD Turbo, need higher strength for visible changes # 0.5-0.7 = moderate, 0.8-0.95 = heavy transformation @@ -1626,30 +1502,6 @@ def get_param(key, default): # don't change it because TRT engines can't be hot-swapped. acceleration_mode = self._acceleration_mode - # --- Pick denoise path ------------------------------------------------- - # The batch path (StreamDiffusion's rolling-buffer trick) amortises N - # denoising stages across N consecutive video frames. That's a real - # win when the input stream changes every frame (webcam, v2v, moving - # ControlNet) but in steady-prompt txt2img / image-loopback the per- - # slot init_noise drift makes the N buffer slots crystallise into N - # different attractors that flash one after another. - # Use the serial path for those cases; keep the batch path everywhere - # else and at num_inference_steps=1 (where it degenerates to one - # UNet call per frame anyway). - has_video_input_eval = video is not None and len(video) > 0 - is_steady_prompt_mode = (not has_video_input_eval) or image_loopback - use_serial = ( - num_inference_steps > 1 - and is_steady_prompt_mode - and controlnet_mode == "none" - ) - if use_serial: - # Serial denoise wants prompt_embeds sized for batch=1 and doesn't - # care about the LCM coefficient pre-compute. Force the runtime - # state into a 1-track configuration so _prepare_runtime_state's - # caches match what _predict_x0_serial will read. - use_denoising_batch = False - # --- Safeguard: prevent invalid strength / num_inference_steps combos --- # LCM scheduler requires: floor(original_steps * strength) >= num_inference_steps # original_steps defaults to 50 in the scheduler. @@ -1790,9 +1642,7 @@ def get_param(key, default): return {"video": output.permute(0, 2, 3, 1).clamp(0, 1)} input_tensor = filtered - # Encode to latent space. Serial img2img adds its own noise based - # on the requested strength, so don't double-noise here. - input_latent = self._encode_image(input_tensor, add_noise=not use_serial) + input_latent = self._encode_image(input_tensor, add_noise=True) else: # Text-to-image mode — use the seeded `init_noise` instead of a @@ -1802,16 +1652,7 @@ def get_param(key, default): # user reseed deterministically by changing `seed`). input_latent = self.init_noise[0:1].clone() - # Run diffusion - if use_serial: - x_0_pred_out = self._predict_x0_serial( - input_latent, - num_inference_steps=num_inference_steps, - strength=strength, - is_img2img=frame is not None, - ) - else: - x_0_pred_out = self._predict_x0_batch(input_latent) + x_0_pred_out = self._predict_x0_batch(input_latent) # Decode to image space x_output = self._decode_image(x_0_pred_out).detach().clone() # Normalize from [-1, 1] to [0, 1] (VAE outputs in range [-1, 1]) @@ -1868,7 +1709,6 @@ def main(): test_params = { "prompt": "A beautiful sunset over mountains", "negative_prompt": "ugly, blurry, low quality", - "num_inference_steps": 4, "guidance_scale": 0.0, "strength": 0.99, "seed": 42, @@ -1881,7 +1721,6 @@ def main(): print("\nTest parameters:") print(f" Prompt: {test_params['prompt']}") - print(f" Steps: {test_params['num_inference_steps']}") print(f" Size: {test_params['width']}x{test_params['height']}") print("\nRunning pipeline 10 times...\n") From cc82a98abdbe6af47a7463d88e8c458793588148 Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Tue, 5 May 2026 12:37:18 -0700 Subject: [PATCH 13/26] docs: handoff notes for the Turbo-only simplification Co-Authored-By: Claude Opus 4.7 (1M context) --- HANDOFF_TURBO_ONLY.md | 66 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 HANDOFF_TURBO_ONLY.md diff --git a/HANDOFF_TURBO_ONLY.md b/HANDOFF_TURBO_ONLY.md new file mode 100644 index 0000000..f612337 --- /dev/null +++ b/HANDOFF_TURBO_ONLY.md @@ -0,0 +1,66 @@ +# Handoff — Turbo-only simplification + +PR: https://github.com/happyFish/scope-stream_diffusion_v1/pull/2 (`sd-multi-model` → `main`). + +## What landed + +1. **Merge of `main` into `sd-multi-model`** (commit `d8d2fbd`) — brings the + full TRT subsystem (refittable engines, node-id-keyed adapter cache, fp16 + VAE TRT path) under the multi-model dropdown. +2. **`feat(schema)` (`29f1ea7`)** — `model_id_or_path` enum trimmed to + `stabilityai/sd-turbo` + `stabilityai/sdxl-turbo`. `num_inference_steps` + and `use_suggested_num_inference_steps` fields removed. +3. **`refactor(pipeline)` (`1015f4b`)** — drops `self.sd_turbo`, + `_attach_lcm_lora`, `_predict_x0_serial`, the `use_serial` branch in + `__call__`, and the dead `denoising_steps_num > 1` branches. Step count + is hardcoded to 1 in `__call__`. + +Net: +9 / -188 LoC across `pipeline.py` and `schema.py`. + +## Verified locally + +- `python -m py_compile` clean on all .py files. +- Module imports cleanly from this worktree (`PYTHONPATH=src`) under Scope's + venv (`~/Projects/daydreamlive-scope/.venv`). +- Schema reflects the trim: `Literal['stabilityai/sd-turbo', + 'stabilityai/sdxl-turbo']`, no `num_inference_steps` field. + +## Not verified — please run before merging + +Scope was not running on `localhost:8000` during this session, so the +hot-reload smoke test from the original handoff was skipped. To validate: + +```bash +curl -X POST http://localhost:8000/api/v1/plugins/scope_streamdiffusion/reload \ + -H 'Content-Type: application/json' -d '{"force":true}' +``` + +Then in the UI: + +- Render at `acceleration_mode=none` with `model_id_or_path=stabilityai/sd-turbo`. +- Swap to `stabilityai/sdxl-turbo` via the dropdown — confirm hot-swap path + and SDXL fp16-fix VAE install. +- Render at `acceleration_mode=trt` for both variants. +- Run a moth dev session: scenes trigger, oscillators drive params, + ControlNet (depth) and mask compositing still work. + +## Out of scope (explicit) + +- LoRA hot-swap (separate spec at `~/Projects/moth/docs/specs/lora-support.md`, + Phase 4 of `streamdiffusion-trt.md`). +- Hyper-SD / Lightning step-LoRA fusion at load — the better path to + 1-step inference on arbitrary SD 1.5 / SDXL checkpoints; future PR. +- SD 3 / 3.5 — MMDiT, not UNet, incompatible with the current TRT path. +- Moth-side UI changes — none needed; the dropdown is schema-driven. + +## Notes for next agent + +- Editable install lives in the **main worktree path** + (`~/Projects/moth-scope/plugins/scope-stream_dffusion_v1`), not this + worktree. Once this PR merges into `main`, pulling main in the main + worktree will pick the changes up automatically; until then, force- + importing from `src/` is the only way to verify the worktree's code in + Python. +- The 2-commit split (schema then pipeline) is intentional. The handoff + asked for 2-3 commits; bundling pipeline.py into one was cleaner than + trying to interleave the LCM LoRA / serial / dead-branch removals. From 206300f1685d4a6953334f630688b06b55019b7a Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Tue, 5 May 2026 14:55:15 -0700 Subject: [PATCH 14/26] feat: DMD2-SDXL-1step preset via extensible MODEL_PRESETS MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a curated 1-step model option that isn't a direct HuggingFace repo: SDXL-base with the DMD2-distilled UNet (tianweiy/DMD2) swapped in. DMD2 generally outperforms SDXL-Turbo on FID/CLIP per the paper, while staying on the same LCMScheduler at 1 step that all our existing TRT/runtime infra is built around. Introduces a MODEL_PRESETS dict at module scope as the extension point for future Turbo-class additions: - 'unet_swap' shape — base pipeline + distilled UNet checkpoint. Used here for DMD2; DMD2 retrained the UNet via distribution matching, so it ships as a UNet, not a LoRA. - Future shapes documented inline: 'lora' (Hyper-SD / SDXL-Lightning step-distillation LoRAs), 'scheduler' override, 'timesteps_override'. Hyper-SD-1step / Lightning-1step both need TCD / Euler schedulers, which require a `_set_timesteps` refactor (the current path calls LCM-specific `get_scalings_for_boundary_condition_discrete` and reads `scheduler.alphas_cumprod` directly). That refactor is out of scope for this PR. The fp16-fix VAE swap, TRT cache keying, hot-swap, and rolling-buffer denoise math are all untouched — DMD2's UNet is architecturally an SDXL UNet, so everything downstream of `_load_model` is identical. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/scope_streamdiffusion/pipeline.py | 80 ++++++++++++++++++++++++--- src/scope_streamdiffusion/schema.py | 8 ++- 2 files changed, 80 insertions(+), 8 deletions(-) diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index e684237..25234b5 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -9,6 +9,7 @@ DiffusionPipeline, LCMScheduler, StableDiffusionXLPipeline, + UNet2DConditionModel, ) from diffusers.image_processor import VaeImageProcessor from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( @@ -25,6 +26,31 @@ from scope.core.pipelines.base_schema import BasePipelineConfig +# Curated presets — model_id strings that aren't direct HuggingFace repos but +# describe a (base, distillation) recipe. Extending this dict is how we add +# new 1-step / few-step models to the dropdown without exposing the user to +# the underlying repo plumbing. +# +# Schema currently exposes the `unet_swap` shape. Future shapes: +# "lora": (lora_repo, lora_filename) — fuse a step-LoRA at scale=1.0 onto +# the base. Works for Hyper-SD-1step / SDXL-Lightning-1step ONLY +# after `_set_timesteps` is taught about TCD / Euler schedulers +# (it currently calls LCM-specific +# `scheduler.get_scalings_for_boundary_condition_discrete`). +# "scheduler": SchedulerClass — override the LCMScheduler default in +# _swap_model. Same caveat as above re: `_set_timesteps`. +# "timesteps_override": [int, ...] — pin specific timesteps (Hyper-SD-1step +# wants [800] with TCD). +MODEL_PRESETS: Dict[str, dict] = { + "dmd2-sdxl-1step": { + "base": "stabilityai/stable-diffusion-xl-base-1.0", + # tianweiy/DMD2 ships several distilled UNet checkpoints; the + # 1-step fp16 variant is the SDXL-Turbo equivalent. + "unet_swap": ("tianweiy/DMD2", "dmd2_sdxl_1step_unet_fp16.bin"), + }, +} + + # Import or inline the helper utilities class SimilarImageFilter: """Simple similar image filter implementation.""" @@ -486,17 +512,25 @@ def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: print(f"[TRT] UNet engine active: {engine_path}", flush=True) def _load_model(self, model_id: str) -> DiffusionPipeline: - """Load the diffusion model.""" + """Load the diffusion model. + + For HuggingFace model IDs, loads via DiffusionPipeline.from_pretrained + directly. For curated presets in MODEL_PRESETS, follows the preset's + recipe (base load + UNet swap, etc.). + """ try: - pipe = DiffusionPipeline.from_pretrained( - model_id, - torch_dtype=self.dtype, - variant="fp16" if self.dtype == torch.float16 else None, - ) + preset = MODEL_PRESETS.get(model_id) + if preset is not None: + pipe = self._load_preset(preset) + else: + pipe = DiffusionPipeline.from_pretrained( + model_id, + torch_dtype=self.dtype, + variant="fp16" if self.dtype == torch.float16 else None, + ) pipe = pipe.to(self.device) # Enable xformers memory-efficient attention if available. - # The schema declares acceleration="xformers" but this was never called. try: pipe.enable_xformers_memory_efficient_attention() print("[StreamDiffusion] xformers memory-efficient attention enabled") @@ -508,6 +542,38 @@ def _load_model(self, model_id: str) -> DiffusionPipeline: print(f"Failed to load model {model_id}: {e}") raise + def _load_preset(self, preset: dict) -> DiffusionPipeline: + """Build a DiffusionPipeline from a MODEL_PRESETS recipe. + + Currently supports the ``unet_swap`` shape — load the base pipeline, + then replace its UNet with a distilled checkpoint. Other recipe + shapes (LoRA fuse, scheduler override, timesteps_override) will land + alongside the `_set_timesteps` refactor needed to support + non-LCM schedulers. + """ + base = preset["base"] + print(f"[StreamDiffusion] Loading preset: base={base}") + + unet_swap = preset.get("unet_swap") + if unet_swap is not None: + unet_repo, unet_file = unet_swap + print(f"[StreamDiffusion] Loading distilled UNet: {unet_repo}/{unet_file}") + unet = UNet2DConditionModel.from_pretrained( + unet_repo, weight_name=unet_file, torch_dtype=self.dtype + ) + pipe = DiffusionPipeline.from_pretrained( + base, + unet=unet, + torch_dtype=self.dtype, + variant="fp16" if self.dtype == torch.float16 else None, + ) + return pipe + + # Other preset shapes (LoRA fuse, etc.) land here once supported. + raise NotImplementedError( + f"MODEL_PRESETS recipe shape not yet implemented: {preset}" + ) + def _swap_model(self, new_model_id: str) -> None: """Replace the loaded model in place. diff --git a/src/scope_streamdiffusion/schema.py b/src/scope_streamdiffusion/schema.py index a811e92..a34f9db 100644 --- a/src/scope_streamdiffusion/schema.py +++ b/src/scope_streamdiffusion/schema.py @@ -84,9 +84,15 @@ class StreamDiffusionConfig(BasePipelineConfig): model_id_or_path: Literal[ "stabilityai/sd-turbo", "stabilityai/sdxl-turbo", + "dmd2-sdxl-1step", ] = Field( default="stabilityai/sd-turbo", - description="HuggingFace model ID. Both entries are 1-step distillations; SDXL-Turbo additionally swaps in madebyollin/sdxl-vae-fp16-fix.", + description=( + "Model selection. All entries are 1-step distillations. " + "'dmd2-sdxl-1step' is SDXL-base with the DMD2 distilled UNet " + "(tianweiy/DMD2) swapped in — quality bump over SDXL-Turbo per " + "the DMD2 paper. SDXL-derived entries auto-install the fp16-fix VAE." + ), json_schema_extra=ui_field_config(order=8, label="Model"), ) From 819ec3db5d4cd82a9bc0a6e661059b991691fa37 Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Tue, 5 May 2026 15:22:28 -0700 Subject: [PATCH 15/26] fix: load DMD2 UNet via state_dict override, not from_pretrained MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Distilled-UNet repos like tianweiy/DMD2 ship weights only — no config.json — because the architecture is identical to the base UNet. UNet2DConditionModel.from_pretrained needs a config and bails with 'tianweiy/DMD2 does not appear to have a file named config.json'. Switch to: load the base SDXL pipeline (gets a correctly-configured UNet module), download the DMD2 checkpoint via hf_hub_download, then override the UNet's state_dict in place. Verified end-to-end with a 300-frame sun→moon morph render at fp16, no acceleration: 6 fps eager, output matches expected DMD2 quality. Same pattern works for SDXL-Lightning's 1-step UNet variant once the scheduler refactor lands. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/scope_streamdiffusion/pipeline.py | 43 ++++++++++++++++----------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index 25234b5..e9e0bfa 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -546,33 +546,40 @@ def _load_preset(self, preset: dict) -> DiffusionPipeline: """Build a DiffusionPipeline from a MODEL_PRESETS recipe. Currently supports the ``unet_swap`` shape — load the base pipeline, - then replace its UNet with a distilled checkpoint. Other recipe - shapes (LoRA fuse, scheduler override, timesteps_override) will land - alongside the `_set_timesteps` refactor needed to support + then override its UNet weights from a distilled checkpoint. Other + recipe shapes (LoRA fuse, scheduler override, timesteps_override) + will land alongside the `_set_timesteps` refactor needed to support non-LCM schedulers. """ base = preset["base"] - print(f"[StreamDiffusion] Loading preset: base={base}") + print(f"[StreamDiffusion] Loading preset base: {base}") + pipe = DiffusionPipeline.from_pretrained( + base, + torch_dtype=self.dtype, + variant="fp16" if self.dtype == torch.float16 else None, + ) unet_swap = preset.get("unet_swap") if unet_swap is not None: + from huggingface_hub import hf_hub_download + unet_repo, unet_file = unet_swap - print(f"[StreamDiffusion] Loading distilled UNet: {unet_repo}/{unet_file}") - unet = UNet2DConditionModel.from_pretrained( - unet_repo, weight_name=unet_file, torch_dtype=self.dtype - ) - pipe = DiffusionPipeline.from_pretrained( - base, - unet=unet, - torch_dtype=self.dtype, - variant="fp16" if self.dtype == torch.float16 else None, - ) + print(f"[StreamDiffusion] Downloading distilled UNet: {unet_repo}/{unet_file}") + ckpt_path = hf_hub_download(unet_repo, unet_file) + # Distilled-UNet repos (DMD2, SDXL-Lightning, etc.) often ship + # weights only — no config.json — because the architecture is + # identical to the base UNet. Reuse the base pipeline's UNet + # module and override its state_dict. + if unet_file.endswith(".safetensors"): + from safetensors.torch import load_file + state_dict = load_file(ckpt_path) + else: + state_dict = torch.load(ckpt_path, map_location="cpu") + pipe.unet.load_state_dict(state_dict) + print("[StreamDiffusion] Distilled UNet weights loaded") return pipe - # Other preset shapes (LoRA fuse, etc.) land here once supported. - raise NotImplementedError( - f"MODEL_PRESETS recipe shape not yet implemented: {preset}" - ) + return pipe def _swap_model(self, new_model_id: str) -> None: """Replace the loaded model in place. From 1f676a4bddc5a6b66079cb480fb4df4be78eaa76 Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Tue, 5 May 2026 16:24:41 -0700 Subject: [PATCH 16/26] fix(dmd2): pin training timestep [399] via timesteps_override MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DMD2-1step is distilled at a single specific timestep. Letting LCMScheduler pick the default 1-step (~979, near pure-noise endpoint) feeds the model a timestep it was never trained on and produces garbage — visually a blurry monochrome blob with no recognizable features. Add a `timesteps_override` field to MODEL_PRESETS and have `_set_timesteps` honor it when present. With the override pinned at [399] (the DMD2 paper's documented training timestep for SDXL 1-step), the model produces clean photographic output: a recognizable sun / moon with proper composition, contrast, and detail. Same mechanism will land Hyper-SDXL-1step (timesteps=[800]) once the broader scheduler-class refactor on feat/scheduler-refactor catches up; this commit just gets DMD2 to a usable state. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/scope_streamdiffusion/pipeline.py | 33 +++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index e9e0bfa..f0d5255 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -47,6 +47,9 @@ # tianweiy/DMD2 ships several distilled UNet checkpoints; the # 1-step fp16 variant is the SDXL-Turbo equivalent. "unet_swap": ("tianweiy/DMD2", "dmd2_sdxl_1step_unet_fp16.bin"), + # DMD2 was distilled at this specific timestep — feeding it + # LCMScheduler's default 1-step pick (~979) produces noise. + "timesteps_override": [399], }, } @@ -118,6 +121,7 @@ def __init__( # Load the base model print(f"Loading model: {model_id}") self.model_id = model_id + self._timesteps_override = MODEL_PRESETS.get(model_id, {}).get("timesteps_override") self.pipe = self._load_model(model_id) print(f"Model loaded: {self.pipe.__class__.__name__}") @@ -603,6 +607,7 @@ def _swap_model(self, new_model_id: str) -> None: torch.cuda.empty_cache() self.model_id = new_model_id + self._timesteps_override = MODEL_PRESETS.get(new_model_id, {}).get("timesteps_override") self.pipe = self._load_model(new_model_id) print(f"[StreamDiffusion] Model loaded: {self.pipe.__class__.__name__}") self.sdxl = type(self.pipe) is StableDiffusionXLPipeline @@ -1143,11 +1148,29 @@ def _encode_prompts_array( ), blended_pooled_embeds def _set_timesteps(self, num_inference_steps: int, strength: float): - """Set the timesteps for the diffusion process.""" - self.scheduler.set_timesteps( - num_inference_steps, self.device, strength=strength - ) - self.timesteps = self.scheduler.timesteps.to(self.device) + """Set the timesteps for the diffusion process. + + Honors `MODEL_PRESETS[...]["timesteps_override"]` when present. + Distilled 1-step models (DMD2, Hyper-SD, Lightning) are trained at + a specific timestep and produce garbage at any other one — letting + LCMScheduler pick the default would feed them ~t=979 (near max + noise) where they were never trained. + """ + if self._timesteps_override is not None: + # Pin the override; still call set_timesteps so the scheduler + # internals (timestep_scaling, etc.) are populated for any + # downstream lookups. + self.scheduler.set_timesteps( + num_inference_steps, self.device, strength=strength + ) + self.timesteps = torch.tensor( + self._timesteps_override, device=self.device, dtype=torch.long + ) + else: + self.scheduler.set_timesteps( + num_inference_steps, self.device, strength=strength + ) + self.timesteps = self.scheduler.timesteps.to(self.device) # Make sub timesteps list self.sub_timesteps = [] From 8763277294e8b9484090e7b1db677a9cddce9173 Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Tue, 5 May 2026 22:07:28 -0700 Subject: [PATCH 17/26] feat(trt): SDXL UNet engine support (Turbo + DMD2 verified) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the missing SDXL-shaped TRT path so acceleration_mode='trt' works on SDXL-Turbo and DMD2-distilled UNets. Eager-only on SDXL was a pre-existing limitation: ONNX export crashed in get_aug_embed because the export wrapper passed added_cond_kwargs=None instead of the SDXL {text_embeds, time_ids} dict. End-to-end this commit: * UNetSDXL I/O spec — 5 inputs (sample, timestep, encoder_hidden_states, text_embeds dim=1280, time_ids dim=6) instead of SD 1.5's 3. * UNetSDXLExportWrapper — wraps the diffusers UNet so text_embeds/time_ids are positional args for ONNX trace, reconstructed into added_cond_kwargs at the inner forward. * UNet2DConditionModelSDXLEngine — runtime engine wrapper feeding all 5 named inputs to the TRT context. * compile_unet_sdxl — same shape as compile_unet but routes through the SDXL wrapper. Skips the polygraphy ONNX optimizer (passes the same path twice for raw + "opt") because polygraphy's optimizer OOMs on the ~5 GB SDXL ONNX; TRT's builder does its own graph optimization. * export_onnx — adds use_external_data flag (torch 2.9 param `external_data`) so SDXL UNet's >2 GB ONNX serializes correctly. Post-processes the raw export to consolidate ~1500 per-tensor sidecar files into one weights.bin: pytorch's per-tensor location-only entries trip TRT's WeightsContextMemoryMap on certain initializers ("Failed to open file"). * build_unet_sdxl_engine + TRTUNetSDXLAdapter — build/load. Engine is static-shape (build_dynamic_shape=False) and static-batch (max=1). SDXL's tactic exploration over a dynamic shape envelope OOMs even on 24 GB VRAM; static-shape collapses the search space enough to fit. Engine is only valid at the (h,w) it was built for — resolution changes will rebuild. * _ConfigShim — gains an `sdxl=True` mode returning the SDXL cross_attention_dim=2048 and addition_time_embed_dim=256 the pipeline reads to size add_time_ids. TRTUNetSDXLAdapter also fakes an `add_embedding.linear_1.in_features=2816` shim because the SDXL pipeline introspects that attribute on UNet. * pipeline._ensure_trt_unet — accepts explicit image_height/width args. Static engines need the *real* runtime dims at build time; self.height/self.width are still init defaults (512x512) when this method runs because _prepare_runtime_state hasn't executed yet. Pre-emptively setting self.{height,width} would block dims_changed in _prepare_runtime_state and leave self.latent_{height,width} at init defaults — engine and inference would mismatch in the other direction. * SDXL build flow moves VAE + text encoders to CPU during the TRT build to free VRAM for the builder's TACTIC_DRAM allocation, then moves them back. UNet stays on GPU (the ONNX tracer needs it there). Verified end-to-end on a 4090: - SDXL-Turbo @ 1024x1024: 91 ms/frame eager → 11 ms/frame TRT (8.3x) - DMD2-SDXL-1step @ 1024x1024: 91 ms/frame eager → 11 ms/frame TRT (8.3x) - Output is byte-different but visually equivalent to eager, confirming correct numerical behavior. Build-time prerequisites the wheel install model alone doesn't satisfy (documented in trt_engines.py header): - LD_LIBRARY_PATH must include the venv's tensorrt_libs at process exec time. Loader's lazy dlopen of the per-SM kernel library (libnvinfer_builder_resource_smXX.so.10.x) bypasses ldconfig because those libs have a do_not_link_against_* SONAME, so cache lookup by filename fails. ctypes preload from inside Python is too late — the dynamic linker reads LD_LIBRARY_PATH at exec time only. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/scope_streamdiffusion/_trt/__init__.py | 42 ++++++ src/scope_streamdiffusion/_trt/builder.py | 4 +- src/scope_streamdiffusion/_trt/engine.py | 54 ++++++++ src/scope_streamdiffusion/_trt/models.py | 135 ++++++++++++++++++ src/scope_streamdiffusion/_trt/utilities.py | 49 +++++++ src/scope_streamdiffusion/pipeline.py | 95 ++++++++++++- src/scope_streamdiffusion/trt_engines.py | 145 +++++++++++++++++++- 7 files changed, 512 insertions(+), 12 deletions(-) diff --git a/src/scope_streamdiffusion/_trt/__init__.py b/src/scope_streamdiffusion/_trt/__init__.py index 16c7d3e..725b128 100644 --- a/src/scope_streamdiffusion/_trt/__init__.py +++ b/src/scope_streamdiffusion/_trt/__init__.py @@ -24,6 +24,7 @@ AutoencoderKLEngine, ControlNetEngine, UNet2DConditionModelEngine, + UNet2DConditionModelSDXLEngine, UNet2DConditionModelWithControlEngine, ) from .models import ( @@ -33,6 +34,8 @@ ControlNetExportWrapper, UNet, UNetExportWrapperWithControl, + UNetSDXL, + UNetSDXLExportWrapper, UNetWithControlInputs, VAEEncoder, ) @@ -109,14 +112,18 @@ def compile_unet( "TorchVAEEncoder", "UNet", "UNet2DConditionModelEngine", + "UNet2DConditionModelSDXLEngine", "UNet2DConditionModelWithControlEngine", "UNetExportWrapperWithControl", + "UNetSDXL", + "UNetSDXLExportWrapper", "UNetWithControlInputs", "VAE", "VAEEncoder", "build_engine", "compile_controlnet", "compile_unet", + "compile_unet_sdxl", "compile_unet_with_control", "compile_vae_decoder", "compile_vae_encoder", @@ -126,6 +133,41 @@ def compile_unet( ] +def compile_unet_sdxl( + unet: UNet2DConditionModel, + model_data: BaseModel, + onnx_path: str, + onnx_opt_path: str, # noqa: ARG001 — kept for API symmetry; SDXL skips the polygraphy optimizer + engine_path: str, + opt_batch_size: int = 1, + engine_build_options: dict = {}, +): + """Compile an SDXL UNet to TRT — wraps text_embeds/time_ids as positional inputs. + + Differs from `compile_unet` in one important way: **the polygraphy + ONNX optimization pass is skipped**. SDXL's UNet exports to a ~5 GB + ONNX file, and the polygraphy `optimize_onnx` step runs onnxruntime + shape-inference + Unsqueeze elimination passes that load the entire + graph into RAM with ~3-5× overhead — peaks at 20-25 GB and OOMs on + a 32 GB host. TensorRT's builder does its own graph optimization + during engine construction, so the pre-pass is double-work anyway. + + Mechanism: pass the same path for both raw and "optimized" ONNX. + After export the file exists at `onnx_opt_path`, so the EngineBuilder + skips the optimize step and feeds the raw ONNX directly to TRT. + """ + wrapped = UNetSDXLExportWrapper(unet).to( + torch.device("cuda"), dtype=torch.float16 + ).eval() + builder = EngineBuilder(model_data, wrapped, device=torch.device("cuda")) + builder.build( + onnx_path, onnx_path, engine_path, # same path twice: skip polygraphy optimizer + opt_batch_size=opt_batch_size, + use_external_data=True, # SDXL UNet fp16 is ~2.6 GB → must use external-data format + **engine_build_options, + ) + + def compile_unet_with_control( unet, model_data: BaseModel, diff --git a/src/scope_streamdiffusion/_trt/builder.py b/src/scope_streamdiffusion/_trt/builder.py index b2e7154..5c78c33 100644 --- a/src/scope_streamdiffusion/_trt/builder.py +++ b/src/scope_streamdiffusion/_trt/builder.py @@ -47,6 +47,7 @@ def build( force_engine_build: bool = False, force_onnx_export: bool = False, force_onnx_optimize: bool = False, + use_external_data: bool = False, ): if not force_onnx_export and os.path.exists(onnx_path): print(f"Found cached model: {onnx_path}") @@ -58,7 +59,7 @@ def build( self.network, self.controlnet_model ) - + export_onnx( self.network, onnx_path=onnx_path, @@ -67,6 +68,7 @@ def build( opt_image_width=opt_image_width, opt_batch_size=opt_batch_size, onnx_opset=onnx_opset, + use_external_data=use_external_data, ) del self.network gc.collect() diff --git a/src/scope_streamdiffusion/_trt/engine.py b/src/scope_streamdiffusion/_trt/engine.py index 20a4f6d..375c6cf 100644 --- a/src/scope_streamdiffusion/_trt/engine.py +++ b/src/scope_streamdiffusion/_trt/engine.py @@ -56,6 +56,60 @@ def forward(self, *args, **kwargs): pass +class UNet2DConditionModelSDXLEngine: + """SDXL UNet engine — adds text_embeds + time_ids inputs to the plain UNet engine.""" + + def __init__(self, filepath: str, stream: cuda.Stream, use_cuda_graph: bool = False): + self.engine = Engine(filepath) + self.stream = stream + self.use_cuda_graph = use_cuda_graph + self.engine.load() + self.engine.activate() + + def __call__( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + text_embeds: torch.Tensor, + time_ids: torch.Tensor, + **kwargs, + ) -> Any: + if timestep.dtype != torch.float32: + timestep = timestep.float() + + self.engine.allocate_buffers( + shape_dict={ + "sample": latent_model_input.shape, + "timestep": timestep.shape, + "encoder_hidden_states": encoder_hidden_states.shape, + "text_embeds": text_embeds.shape, + "time_ids": time_ids.shape, + "latent": latent_model_input.shape, + }, + device=latent_model_input.device, + ) + + noise_pred = self.engine.infer( + { + "sample": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "text_embeds": text_embeds, + "time_ids": time_ids, + }, + self.stream, + use_cuda_graph=self.use_cuda_graph, + )["latent"] + return UNet2DConditionOutput(sample=noise_pred) + + def to(self, *args, **kwargs): + pass + + def forward(self, *args, **kwargs): + pass + + class UNet2DConditionModelWithControlEngine: """UNet engine variant that accepts ControlNet residuals as runtime inputs. diff --git a/src/scope_streamdiffusion/_trt/models.py b/src/scope_streamdiffusion/_trt/models.py index 32a8692..693f15f 100644 --- a/src/scope_streamdiffusion/_trt/models.py +++ b/src/scope_streamdiffusion/_trt/models.py @@ -484,6 +484,141 @@ def get_sample_input(self, batch_size, image_height, image_width): ) +class UNetSDXL(BaseModel): + """TRT I/O spec for the SDXL UNet. + + Adds the SDXL-specific aug-conditioning inputs that the SD1.5 UNet + doesn't have: + - text_embeds: pooled output of the 2nd text encoder (dim=1280) + - time_ids: resolution conditioning (dim=6 = orig_size[2] + + crops_top_left[2] + target_size[2]) + + Without these, SDXL UNet's `get_aug_embed` raises `TypeError: argument + of type 'NoneType' is not iterable` because it expects a dict at + `added_cond_kwargs`. + + SDXL standard config: + cross_attention_dim = 2048 + addition_time_embed_dim = 256 + projection_class_embeddings_input_dim = 2816 (= 1280 + 256*6) + """ + + def __init__( + self, + fp16=False, + device="cuda", + max_batch_size=16, + min_batch_size=1, + embedding_dim=2048, + text_maxlen=77, + unet_dim=4, + text_embeds_dim=1280, + time_ids_dim=6, + ): + super(UNetSDXL, self).__init__( + fp16=fp16, + device=device, + max_batch_size=max_batch_size, + min_batch_size=min_batch_size, + embedding_dim=embedding_dim, + text_maxlen=text_maxlen, + ) + self.unet_dim = unet_dim + self.text_embeds_dim = text_embeds_dim + self.time_ids_dim = time_ids_dim + self.name = "UNetSDXL" + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states", "text_embeds", "time_ids"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "timestep": {0: "2B"}, + "encoder_hidden_states": {0: "2B"}, + "text_embeds": {0: "2B"}, + "time_ids": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, max_batch, _, _, _, _, + min_latent_height, max_latent_height, min_latent_width, max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + return { + "sample": [ + (min_batch, self.unet_dim, min_latent_height, min_latent_width), + (batch_size, self.unet_dim, latent_height, latent_width), + (max_batch, self.unet_dim, max_latent_height, max_latent_width), + ], + "timestep": [(min_batch,), (batch_size,), (max_batch,)], + "encoder_hidden_states": [ + (min_batch, self.text_maxlen, self.embedding_dim), + (batch_size, self.text_maxlen, self.embedding_dim), + (max_batch, self.text_maxlen, self.embedding_dim), + ], + "text_embeds": [ + (min_batch, self.text_embeds_dim), + (batch_size, self.text_embeds_dim), + (max_batch, self.text_embeds_dim), + ], + "time_ids": [ + (min_batch, self.time_ids_dim), + (batch_size, self.time_ids_dim), + (max_batch, self.time_ids_dim), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "timestep": (2 * batch_size,), + "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), + "text_embeds": (2 * batch_size, self.text_embeds_dim), + "time_ids": (2 * batch_size, self.time_ids_dim), + "latent": (2 * batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn(2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device), + torch.ones((2 * batch_size,), dtype=torch.float32, device=self.device), + torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + torch.randn(2 * batch_size, self.text_embeds_dim, dtype=dtype, device=self.device), + torch.randn(2 * batch_size, self.time_ids_dim, dtype=dtype, device=self.device), + ) + + +class UNetSDXLExportWrapper(torch.nn.Module): + """Wraps the SDXL UNet so text_embeds/time_ids are positional args. + + Diffusers' UNet expects them inside an `added_cond_kwargs` dict, but + ONNX export prefers positional inputs. Reconstruct the dict here. + """ + + def __init__(self, unet): + super().__init__() + self.unet = unet + + def forward(self, sample, timestep, encoder_hidden_states, text_embeds, time_ids): + out = self.unet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids}, + return_dict=False, + ) + return out[0] + + class UNetExportWrapperWithControl(torch.nn.Module): """Wraps the diffusers UNet so ControlNet residuals are positional inputs. diff --git a/src/scope_streamdiffusion/_trt/utilities.py b/src/scope_streamdiffusion/_trt/utilities.py index 6e82bbc..04997e4 100644 --- a/src/scope_streamdiffusion/_trt/utilities.py +++ b/src/scope_streamdiffusion/_trt/utilities.py @@ -409,12 +409,18 @@ def export_onnx( opt_image_width: int, opt_batch_size: int, onnx_opset: int, + use_external_data: bool = False, ): with torch.inference_mode(), torch.autocast("cuda"): inputs = model_data.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) print('exporting onnx') print(model_data.get_input_names()) print(model_data.get_dynamic_axes()) + # use_external_data: SDXL UNet fp16 is ~2.6 GB, exceeding protobuf's + # 2 GB single-message limit. Without external-data format, the export + # produces malformed ONNX that TRT's parser rejects with + # `Invalid Engine`. SD 1.5 fits in 2 GB so we leave the default off + # for that path to keep the existing single-file cache layout. torch.onnx.export( model, inputs, @@ -428,11 +434,54 @@ def export_onnx( dynamo=False, # force legacy trace-based exporter — dynamo path # produces ONNX with op variants that polygraphy's version_converter # can't migrate (e.g. Resize). Legacy is what prism was tested on. + external_data=use_external_data, # torch 2.9+ name (was use_external_data_format) ) del model gc.collect() torch.cuda.empty_cache() + if use_external_data: + # PyTorch's exporter writes one external file per tensor with only a + # `location` field — no explicit `offset` / `length`. Some tensors + # then trip TRT's parser: + # [E] WeightsContextMemoryMap.cpp:124: Failed to open file: ... + # Consolidate into a single weights.bin with explicit offsets so + # TRT's mmap path sees the canonical layout. Bonus: cache dir goes + # from ~1500 files to 2. + import os, shutil + onnx_dir = os.path.dirname(onnx_path) + weights_name = os.path.basename(onnx_path) + ".weights" + m = onnx.load(onnx_path, load_external_data=True) + # Stage rewrite into a sibling temp dir so we can clear the original + # sidecars cleanly without colliding with the in-progress write. + staging = os.path.join(onnx_dir, "_consolidate_staging") + if os.path.isdir(staging): + shutil.rmtree(staging) + os.makedirs(staging) + staged_onnx = os.path.join(staging, os.path.basename(onnx_path)) + onnx.save_model( + m, staged_onnx, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=weights_name, + size_threshold=1024, + convert_attribute=False, + ) + # Remove the per-tensor sidecar files (everything in onnx_dir except + # the master .onnx and the new weights file we're about to move in). + for entry in os.listdir(onnx_dir): + full = os.path.join(onnx_dir, entry) + if full == staging or full == onnx_path: + continue + if os.path.isfile(full): + os.unlink(full) + # Move staged master + weights into place, then drop staging. + os.replace(staged_onnx, onnx_path) + os.replace(os.path.join(staging, weights_name), os.path.join(onnx_dir, weights_name)) + shutil.rmtree(staging) + del m + gc.collect() + def optimize_onnx( onnx_path: str, diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index f0d5255..94445e8 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -398,7 +398,12 @@ def _ensure_trt_controlnet(self, mode: str) -> None: cache_state.cn_adapters[mode] = adapter print(f"[TRT] ControlNet engine active ({mode}): {engine_path}", flush=True) - def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: + def _ensure_trt_unet( + self, + controlnet_mode: str = "none", + image_height: int | None = None, + image_width: int | None = None, + ) -> None: """Build TRT engine for the UNet and swap self.unet to the adapter. Two variants depending on controlnet_mode: @@ -409,6 +414,12 @@ def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: residuals get silently dropped and ControlNet conditioning has no effect on the output. + ``image_height`` / ``image_width`` should be the runtime spatial + dims for this build. Falls back to ``self.height`` / ``self.width`` + when omitted, but caller should pass them explicitly because + ``_prepare_runtime_state`` (which sets ``self.{height,width}``) + normally runs *after* this method. + Engines are cached separately on disk because they have different signatures. Switching modes mid-process may trigger a rebuild. """ @@ -423,7 +434,9 @@ def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: f"(had={self._trt_unet_has_controlnet}, want={want_control}); rebuilding" ) - signature = (self.model_id, int(self.height), int(self.width)) + eff_h = int(image_height if image_height is not None else self.height) + eff_w = int(image_width if image_width is not None else self.width) + signature = (self.model_id, eff_h, eff_w) cache_state, restored = _trt_cache.get_or_create(self._trt_cache_key, signature) if self._trt_cuda_stream is None and cache_state.cuda_stream is not None: self._trt_cuda_stream = cache_state.cuda_stream @@ -465,8 +478,10 @@ def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: from .trt_engines import ( TRTUNetAdapter, + TRTUNetSDXLAdapter, TRTUNetWithControlAdapter, build_unet_engine, + build_unet_sdxl_engine, build_unet_with_control_engine, make_cuda_stream, ) @@ -476,6 +491,17 @@ def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: cache_state.cuda_stream = self._trt_cuda_stream if want_control: + if self.sdxl: + # SDXL + ControlNet + TRT: not yet wired. The ControlNet + # path uses UNetWithControlInputs which assumes SD1.5 + # signature (no text_embeds/time_ids). Falling through to + # eager keeps SDXL+ControlNet working until that variant + # gets the same SDXL aug-conditioning treatment as + # build_unet_sdxl_engine. + raise NotImplementedError( + "SDXL + ControlNet + TRT not yet supported. Use " + "acceleration_mode='none' with controlnet on SDXL models." + ) print( "[TRT] Preparing UNet+ctrl engine — first build takes 5-10 min, cached after", flush=True, @@ -483,8 +509,8 @@ def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: engine_path = build_unet_with_control_engine( self.pipe.unet, model_id=self.model_id, - image_height=int(self.height), - image_width=int(self.width), + image_height=eff_h, + image_width=eff_w, min_batch_size=1, max_batch_size=4, num_down_residuals=12, @@ -496,6 +522,50 @@ def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: cache_state.unet_adapter = self.unet cache_state.unet_has_controlnet = True print(f"[TRT] UNet+ctrl engine active: {engine_path}", flush=True) + elif self.sdxl: + print( + "[TRT] Preparing SDXL UNet engine — first build takes 5-10 min, cached after", + flush=True, + ) + # TRT's builder TACTIC_DRAM allocator can race our resident + # pipeline allocations during engine build. Free what we can + # (VAE + text encoders, ~5 GB combined) without disturbing + # the UNet — the ONNX tracer needs it on GPU. Restore after + # the build completes. + print("[TRT] Moving VAE + text encoders to CPU during build", flush=True) + cpu_components = [] + for attr in ("vae", "text_encoder", "text_encoder_2"): + comp = getattr(self.pipe, attr, None) + if comp is not None and hasattr(comp, "to"): + try: + comp.to("cpu") + cpu_components.append((attr, comp)) + except Exception as e: + print(f"[TRT] could not move {attr} to CPU: {e}", flush=True) + torch.cuda.empty_cache() + # batch=1 + static shape (set in build_unet_sdxl_engine) make + # TRT's tactic search bounded enough to fit on a 24 GB card. + # Engine is only valid at the (height, width, batch=1) profile + # it was built for; resolution changes will trigger a rebuild. + engine_path = build_unet_sdxl_engine( + self.pipe.unet, + model_id=self.model_id, + image_height=eff_h, + image_width=eff_w, + min_batch_size=1, + max_batch_size=1, + ) + print("[TRT] Restoring VAE + text encoders to GPU", flush=True) + for attr, comp in cpu_components: + try: + comp.to(self.device) + except Exception as e: + print(f"[TRT] could not restore {attr} to {self.device}: {e}", flush=True) + self._trt_eager_unet = self.pipe.unet + self.unet = TRTUNetSDXLAdapter(engine_path, self._trt_cuda_stream) + cache_state.unet_adapter = self.unet + cache_state.unet_has_controlnet = False + print(f"[TRT] SDXL UNet engine active: {engine_path}", flush=True) else: print( "[TRT] Preparing UNet engine — first build takes 5-10 min, cached after", @@ -504,8 +574,8 @@ def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: engine_path = build_unet_engine( self.pipe.unet, model_id=self.model_id, - image_height=int(self.height), - image_width=int(self.width), + image_height=eff_h, + image_width=eff_w, min_batch_size=1, max_batch_size=4, ) @@ -1639,9 +1709,20 @@ def get_param(key, default): # TRT engine swap — UNet always, ControlNet additionally when active. # Two separate engines (each <2 GB ONNX) instead of a single combined # graph that hits TRT's cask-convolution bug. + # Pass runtime dims explicitly: _prepare_runtime_state (which sets + # self.height/self.width and the latent dims) runs *after* this + # block. Static-shape SDXL engines need the real runtime dims at + # build time, otherwise they're sized for the __init__ defaults + # (512x512) and mismatch at inference. Setting self.{height,width} + # preemptively here is wrong — it'd block dims_changed in + # _prepare_runtime_state, leaving self.latent_{height,width} stale. if acceleration_mode == "trt": try: - self._ensure_trt_unet(controlnet_mode) + self._ensure_trt_unet( + controlnet_mode, + image_height=int(height), + image_width=int(width), + ) except Exception as e: print(f"[TRT] UNet engine swap failed, falling back to eager: {e}") import traceback diff --git a/src/scope_streamdiffusion/trt_engines.py b/src/scope_streamdiffusion/trt_engines.py index 74d81d7..aedf8ff 100644 --- a/src/scope_streamdiffusion/trt_engines.py +++ b/src/scope_streamdiffusion/trt_engines.py @@ -369,6 +369,133 @@ def build_unet_engine( return engine_path +def build_unet_sdxl_engine( + unet: UNet2DConditionModel, + *, + model_id: str, + image_height: int = 1024, + image_width: int = 1024, + min_batch_size: int = 1, + max_batch_size: int = 4, +) -> Path: + """Build (or reuse) a TRT engine for an SDXL UNet. + + Differs from `build_unet_engine` only in the I/O spec — adds + `text_embeds` and `time_ids` as engine inputs so SDXL's `get_aug_embed` + has the kwargs it expects. Without these the ONNX export crashes with + `TypeError: argument of type 'NoneType' is not iterable`. + """ + from ._trt import UNetSDXL, compile_unet_sdxl, create_onnx_path + + suffix = f"unet_sdxl_b{min_batch_size}-{max_batch_size}_h{image_height}_w{image_width}" + cache_dir = _model_cache_dir(model_id, suffix) + onnx_dir = cache_dir / "onnx" + onnx_dir.mkdir(parents=True, exist_ok=True) + engine_path = cache_dir / "unet_sdxl.engine" + + if engine_path.exists(): + logger.info(f"[TRT] Reusing cached SDXL UNet engine: {engine_path}") + return engine_path + + logger.info(f"[TRT] Building SDXL UNet engine -> {engine_path} (5-10 min on first build)") + + unet_model = UNetSDXL( + fp16=True, + device=str(unet.device) if unet.device.type != "meta" else "cuda", + max_batch_size=max_batch_size, + min_batch_size=min_batch_size, + embedding_dim=unet.config.cross_attention_dim, + unet_dim=unet.config.in_channels, + ) + # SDXL UNet build is memory-heavy; locking to static shape and static + # batch keeps TRT's tactic exploration bounded and the build fits in + # VRAM. Loses dynamic-shape flexibility (engine only valid at one + # resolution) but unblocks getting an engine produced at all. Dynamic + # shapes are a follow-up once the static path is verified. + compile_unet_sdxl( + unet, unet_model, + str(create_onnx_path("unet_sdxl", str(onnx_dir), opt=False)), + str(create_onnx_path("unet_sdxl", str(onnx_dir), opt=True)), + str(engine_path), + opt_batch_size=min_batch_size, + engine_build_options={ + "build_dynamic_shape": False, + "build_static_batch": True, + # Static shape engines are valid only at the (h,w) they were + # built for. Pass the actual runtime dims so the engine + # matches what inference will request. + "opt_image_height": image_height, + "opt_image_width": image_width, + }, + ) + import shutil + if onnx_dir.exists(): + shutil.rmtree(onnx_dir, ignore_errors=True) + logger.info(f"[TRT] SDXL UNet engine built: {engine_path}") + return engine_path + + +class TRTUNetSDXLAdapter: + """Drop-in for diffusers SDXL UNet — accepts added_cond_kwargs.""" + + def __init__(self, engine_path: Path, cuda_stream, *, use_cuda_graph: bool = False): + from ._trt import UNet2DConditionModelSDXLEngine + self.engine = UNet2DConditionModelSDXLEngine( + str(engine_path), cuda_stream, use_cuda_graph=use_cuda_graph, + ) + self._use_cuda_graph = use_cuda_graph + self.config = _ConfigShim(sdxl=True) + # SDXL pipelines read `unet.add_embedding.linear_1.in_features` to + # size the add_time_ids tensor (must equal the original UNet's + # projection_class_embeddings_input_dim — 2816 for stock SDXL = + # 1280 text_embeds + 6 * 256 addition_time_embed_dim). + class _AddEmbeddingShim: + class _Linear1Shim: + in_features = 2816 + linear_1 = _Linear1Shim() + self.add_embedding = _AddEmbeddingShim() + + def __call__( + self, + sample: torch.Tensor, + timestep, + encoder_hidden_states: torch.Tensor, + added_cond_kwargs: dict | None = None, + return_dict: bool = True, + **kwargs, + ): + if not isinstance(timestep, torch.Tensor): + timestep = torch.tensor(timestep, device=sample.device) + if timestep.ndim == 0: + timestep = timestep.unsqueeze(0) + + if ( + added_cond_kwargs is None + or "text_embeds" not in added_cond_kwargs + or "time_ids" not in added_cond_kwargs + ): + raise RuntimeError( + "TRTUNetSDXLAdapter requires added_cond_kwargs with 'text_embeds' and 'time_ids'." + ) + + out = self.engine( + latent_model_input=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + text_embeds=added_cond_kwargs["text_embeds"], + time_ids=added_cond_kwargs["time_ids"], + ) + if return_dict: + return out + return (out.sample,) + + def to(self, *args, **kwargs): + return self + + def eval(self): + return self + + class TRTUNetAdapter: """Thin wrapper for the vendored UNet engine. @@ -435,13 +562,23 @@ def eval(self): class _ConfigShim: - """Diffusers config object surface — read by pipeline._prepare_runtime_state.""" + """Diffusers config object surface — read by pipeline._prepare_runtime_state. + + Defaults match SD 1.5/2.1: addition_time_embed_dim=None (no aug + conditioning), cross_attention_dim=1024. Pass `sdxl=True` for SDXL + where the pipeline reads addition_time_embed_dim to size add_time_ids + and cross_attention_dim for the encoder hidden state. + """ - def __init__(self): - self.addition_time_embed_dim = None + def __init__(self, sdxl: bool = False): + if sdxl: + self.addition_time_embed_dim = 256 # SDXL standard + self.cross_attention_dim = 2048 + else: + self.addition_time_embed_dim = None + self.cross_attention_dim = 1024 self.in_channels = 4 self.out_channels = 4 - self.cross_attention_dim = 1024 def build_taesd_engines( From c56a5817702d1e80e999cfcd349f1c4e77e7a31d Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Wed, 6 May 2026 07:41:43 -0700 Subject: [PATCH 18/26] feat(trt): SDXL UNet dynamic shape over 512-1024 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Static-shape engines locked the build to a single (h, w) — any resolution or aspect-ratio change required a 5–10 min rebuild. Replaced with a dynamic-shape build over the [512, 1024] envelope on both axes. Same cached engine now serves any in-range resolution. Verified end-to-end: a single cached engine handles 1024x1024, 1024x768 (landscape), and 768x1024 (portrait) without a rebuild. Composition adapts to the aspect (wide horizon vs. tall cloud column). Trade-offs vs. static-shape: - Steady-state at the opt point (1024x1024): 11 ms/frame → 14 ms/frame. ~27% slowdown for the flexibility, expected. - Build memory: 512-1024 envelope on a 24 GB card with VAE+text-encoders on CPU during build → fits cleanly. Wider envelopes (256-1024) blew past the budget; 512-1024 is the practical sweet spot. - Engine size: ~5.2 GB on disk (similar to static). Cache key now encodes the resolution range (`unet_sdxl_b1-1_h512-1024_w512-1024`) instead of the opt point, so engines don't collide across resolution choices and any in-range run hits the same cached file. Static batch (max=1) is kept — guidance_scale=0 is the only mode for Turbo / DMD2, so dynamic batch would just double workspace cost for no inference benefit. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/scope_streamdiffusion/trt_engines.py | 36 ++++++++++++++++-------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/src/scope_streamdiffusion/trt_engines.py b/src/scope_streamdiffusion/trt_engines.py index aedf8ff..5ed74db 100644 --- a/src/scope_streamdiffusion/trt_engines.py +++ b/src/scope_streamdiffusion/trt_engines.py @@ -376,7 +376,9 @@ def build_unet_sdxl_engine( image_height: int = 1024, image_width: int = 1024, min_batch_size: int = 1, - max_batch_size: int = 4, + max_batch_size: int = 1, + min_image_resolution: int = 512, + max_image_resolution: int = 1024, ) -> Path: """Build (or reuse) a TRT engine for an SDXL UNet. @@ -384,10 +386,26 @@ def build_unet_sdxl_engine( `text_embeds` and `time_ids` as engine inputs so SDXL's `get_aug_embed` has the kwargs it expects. Without these the ONNX export crashes with `TypeError: argument of type 'NoneType' is not iterable`. + + Dynamic-shape build over [min_image_resolution, max_image_resolution] + on both axes — runtime can pick any resolution in that range without + triggering a rebuild. The opt point (image_height, image_width) is + where TRT's tactic selection is centered; runs at the opt size are + fastest, runs at min/max get slightly suboptimal tactics. + + Default range 512–1024 is the tightest envelope that covers SDXL's + sweet spot. Wider ranges (256–1024) blow past the builder's memory + budget on 24 GB cards. Static batch (max=1) is kept because + guidance_scale=0 (default for Turbo / DMD2) means inference never + uses batch>1; allowing batch>1 doubles the workspace. """ from ._trt import UNetSDXL, compile_unet_sdxl, create_onnx_path - suffix = f"unet_sdxl_b{min_batch_size}-{max_batch_size}_h{image_height}_w{image_width}" + suffix = ( + f"unet_sdxl_b{min_batch_size}-{max_batch_size}_" + f"h{min_image_resolution}-{max_image_resolution}_" + f"w{min_image_resolution}-{max_image_resolution}" + ) cache_dir = _model_cache_dir(model_id, suffix) onnx_dir = cache_dir / "onnx" onnx_dir.mkdir(parents=True, exist_ok=True) @@ -407,11 +425,6 @@ def build_unet_sdxl_engine( embedding_dim=unet.config.cross_attention_dim, unet_dim=unet.config.in_channels, ) - # SDXL UNet build is memory-heavy; locking to static shape and static - # batch keeps TRT's tactic exploration bounded and the build fits in - # VRAM. Loses dynamic-shape flexibility (engine only valid at one - # resolution) but unblocks getting an engine produced at all. Dynamic - # shapes are a follow-up once the static path is verified. compile_unet_sdxl( unet, unet_model, str(create_onnx_path("unet_sdxl", str(onnx_dir), opt=False)), @@ -419,13 +432,14 @@ def build_unet_sdxl_engine( str(engine_path), opt_batch_size=min_batch_size, engine_build_options={ - "build_dynamic_shape": False, + "build_dynamic_shape": True, "build_static_batch": True, - # Static shape engines are valid only at the (h,w) they were - # built for. Pass the actual runtime dims so the engine - # matches what inference will request. + # opt point — TRT's tactic selection is centered here. "opt_image_height": image_height, "opt_image_width": image_width, + # Min/max bounds for the dynamic shape envelope. + "min_image_resolution": min_image_resolution, + "max_image_resolution": max_image_resolution, }, ) import shutil From 12609ddba0b8746d14a0a86ebd3056daf7899618 Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Wed, 6 May 2026 08:16:25 -0700 Subject: [PATCH 19/26] chore: refresh PR docs to reflect expanded scope - Update acceleration_mode description: actual measured speedup is 2-8x (was "~2-3x"), and SDXL engines have a different envelope (512-1024, batch=1) than SD 1.5 (256-1024, batch 1-4) due to the 24 GB build budget. Also call out the SDXL + ControlNet + TRT NotImplementedError so users hit it via doc rather than runtime surprise. - Remove HANDOFF_TURBO_ONLY.md. The PR scope expanded well past "Turbo-only simplification": now covers DMD2 preset, scheduler timestep override, full SDXL TRT path with dynamic shape. Earlier handoff text is misleading. Co-Authored-By: Claude Opus 4.7 (1M context) --- HANDOFF_TURBO_ONLY.md | 66 ----------------------------- src/scope_streamdiffusion/schema.py | 15 ++++--- 2 files changed, 10 insertions(+), 71 deletions(-) delete mode 100644 HANDOFF_TURBO_ONLY.md diff --git a/HANDOFF_TURBO_ONLY.md b/HANDOFF_TURBO_ONLY.md deleted file mode 100644 index f612337..0000000 --- a/HANDOFF_TURBO_ONLY.md +++ /dev/null @@ -1,66 +0,0 @@ -# Handoff — Turbo-only simplification - -PR: https://github.com/happyFish/scope-stream_diffusion_v1/pull/2 (`sd-multi-model` → `main`). - -## What landed - -1. **Merge of `main` into `sd-multi-model`** (commit `d8d2fbd`) — brings the - full TRT subsystem (refittable engines, node-id-keyed adapter cache, fp16 - VAE TRT path) under the multi-model dropdown. -2. **`feat(schema)` (`29f1ea7`)** — `model_id_or_path` enum trimmed to - `stabilityai/sd-turbo` + `stabilityai/sdxl-turbo`. `num_inference_steps` - and `use_suggested_num_inference_steps` fields removed. -3. **`refactor(pipeline)` (`1015f4b`)** — drops `self.sd_turbo`, - `_attach_lcm_lora`, `_predict_x0_serial`, the `use_serial` branch in - `__call__`, and the dead `denoising_steps_num > 1` branches. Step count - is hardcoded to 1 in `__call__`. - -Net: +9 / -188 LoC across `pipeline.py` and `schema.py`. - -## Verified locally - -- `python -m py_compile` clean on all .py files. -- Module imports cleanly from this worktree (`PYTHONPATH=src`) under Scope's - venv (`~/Projects/daydreamlive-scope/.venv`). -- Schema reflects the trim: `Literal['stabilityai/sd-turbo', - 'stabilityai/sdxl-turbo']`, no `num_inference_steps` field. - -## Not verified — please run before merging - -Scope was not running on `localhost:8000` during this session, so the -hot-reload smoke test from the original handoff was skipped. To validate: - -```bash -curl -X POST http://localhost:8000/api/v1/plugins/scope_streamdiffusion/reload \ - -H 'Content-Type: application/json' -d '{"force":true}' -``` - -Then in the UI: - -- Render at `acceleration_mode=none` with `model_id_or_path=stabilityai/sd-turbo`. -- Swap to `stabilityai/sdxl-turbo` via the dropdown — confirm hot-swap path - and SDXL fp16-fix VAE install. -- Render at `acceleration_mode=trt` for both variants. -- Run a moth dev session: scenes trigger, oscillators drive params, - ControlNet (depth) and mask compositing still work. - -## Out of scope (explicit) - -- LoRA hot-swap (separate spec at `~/Projects/moth/docs/specs/lora-support.md`, - Phase 4 of `streamdiffusion-trt.md`). -- Hyper-SD / Lightning step-LoRA fusion at load — the better path to - 1-step inference on arbitrary SD 1.5 / SDXL checkpoints; future PR. -- SD 3 / 3.5 — MMDiT, not UNet, incompatible with the current TRT path. -- Moth-side UI changes — none needed; the dropdown is schema-driven. - -## Notes for next agent - -- Editable install lives in the **main worktree path** - (`~/Projects/moth-scope/plugins/scope-stream_dffusion_v1`), not this - worktree. Once this PR merges into `main`, pulling main in the main - worktree will pick the changes up automatically; until then, force- - importing from `src/` is the only way to verify the worktree's code in - Python. -- The 2-commit split (schema then pipeline) is intentional. The handoff - asked for 2-3 commits; bundling pipeline.py into one was cleaner than - trying to interleave the LCM LoRA / serial / dead-branch removals. diff --git a/src/scope_streamdiffusion/schema.py b/src/scope_streamdiffusion/schema.py index a34f9db..7f57f64 100644 --- a/src/scope_streamdiffusion/schema.py +++ b/src/scope_streamdiffusion/schema.py @@ -99,11 +99,16 @@ class StreamDiffusionConfig(BasePipelineConfig): acceleration_mode: Literal["none", "trt"] = Field( default="trt", description=( - "TRT-compile UNet (and ControlNet) for ~2-3x denoising speedup. " - "First build per (model, batch range) takes 5-10 min and caches to " - "~/.cache/scope-streamdiffusion-trt/. Set at session start; changing " - "requires pipeline reload. Engines support dynamic resolution 256-1024 " - "and batch 1-4." + "TRT-compile UNet (and ControlNet on SD 1.5) for 2-8x denoising " + "speedup. First build per model takes 5-10 min and caches to " + "~/.cache/scope-streamdiffusion-trt/. Set at session start; " + "changing requires pipeline reload. SD 1.5 engines support " + "dynamic resolution 256-1024 and batch 1-4. SDXL engines " + "(sdxl-turbo, dmd2-sdxl-1step) support dynamic resolution " + "512-1024 with static batch=1 — different envelope to fit a " + "24 GB VRAM build budget. SDXL + ControlNet + TRT is not yet " + "supported (raises NotImplementedError); use acceleration='none' " + "with controlnet on SDXL until that lands." ), #json_schema_extra=ui_field_config(order=2, label="Acceleration"), ) From 54c4ad6c77b4577c5b542182b4b9b945646ba74f Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Wed, 6 May 2026 14:59:42 -0700 Subject: [PATCH 20/26] fix(schema): ModelId as StrEnum so HF URL formatting uses values `class ModelId(str, Enum)` stringifies to 'ModelId.SDXL_TURBO' under Python 3.12, which leaks through f-string formatting in `DiffusionPipeline.from_pretrained` and produces a 404 against huggingface.co/api/models/ModelId.SDXL_TURBO. StrEnum's __str__/__format__ return the enum value, so HF gets the correct repo path. Also picks the SDXL TAESD variant (madebyollin/taesdxl) when the loaded model is SDXL, refreshes the plugin's CLAUDE.md to match the current multi-model layout. Co-Authored-By: Claude Opus 4.7 (1M context) --- CLAUDE.md | 377 +++++++++++++++++--------- src/scope_streamdiffusion/pipeline.py | 3 +- src/scope_streamdiffusion/schema.py | 18 +- 3 files changed, 261 insertions(+), 137 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index fde9b1b..326a4ec 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,182 +1,301 @@ -# CLAUDE.md +# scope-streamdiffusion Plugin — Claude Code Guide -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. +Real-time Stable Diffusion pipeline for Daydream Scope using StreamDiffusion. Supports SD 1.5, SDXL, Turbo models with LCM scheduling, ControlNet, TensorRT acceleration, and multi-model orchestration. -## Scope Plugin Expertise +## Design -You are an expert at building Daydream Scope plugins/nodes/pipelines. Reference documentation: -- **https://docs.daydream.live/scope/tutorials/build-video-effects-plugin** -- **https://docs.daydream.live/scope/tutorials/vibe-code-a-scope-plugin** +This is a Scope `Pipeline` subclass that wraps diffusion inference. The plugin is **entry-point discovered** via `pyproject.toml` and loads into Scope's pipeline selector automatically. -## Project Overview +### Core Principles -This is a **Daydream Scope plugin** that integrates StreamDiffusion for real-time Stable Diffusion video generation. It's not a standalone application—it's designed to be installed and discovered by the Daydream Scope framework via Python entry points. +- **Init/Runtime separation.** `__init__()` loads models once; `__call__(**kwargs)` handles per-frame params (prompt, seed, guidance scale, strength). Parameters can change frame-to-frame without reloading. +- **Tensor format aware.** Scope uses `(T, H, W, C)` in [0, 1]; diffusion expects `(B, C, H, W)`. Conversions happen in `__call__()`. +- **Schema-driven config.** All parameters defined in `schema.py` using Pydantic. UI fields auto-generated via `ui_field_config(order=N, label="...")`. +- **Lazy model loading.** Models load on first init; subsequent calls reuse weights. Model changes trigger full reinitialization. -## Installation & Development +## Project Structure -### Install Plugin (Development Mode) -```bash -pip install -e . -``` -Development mode allows code changes to take effect immediately without reinstalling. - -### Verify Plugin Registration -```bash -python -c "import scope_streamdiffusion; print('Plugin loaded successfully')" -``` - -### Testing in Scope -The plugin is automatically discovered by Scope once installed. Start Scope and look for "StreamDiffusion" in the pipeline selector. - -## Architecture - -### Plugin Structure ``` -src/scope_streamdiffusion/ -├── __init__.py # Plugin registration via @hookimpl -├── schema.py # Configuration schema (UI fields + validation) -└── pipeline.py # Pipeline implementation (model + inference) +. +├── CLAUDE.md # This file +├── README.md # User-facing features and usage +├── ADAPTATION_NOTES.md # How StreamDiffusion was adapted to Scope +├── INSTALL.md # Quick install guide +├── pyproject.toml # Package config, entry point, deps +│ +├── src/scope_streamdiffusion/ +│ ├── __init__.py # Plugin registration via hookimpl +│ ├── schema.py # StreamDiffusionConfig (Pydantic + UI) +│ ├── pipeline.py # StreamDiffusionPipeline (main logic) +│ ├── controlnet.py # ControlNet handler for multi-ControlNet support +│ ├── trt_engines.py # TensorRT engine discovery/caching +│ ├── _trt_cache.py # TensorRT compile cache management +│ │ +│ └── _trt/ # TensorRT backend utilities +│ ├── __init__.py +│ ├── models.py # TRT model configs +│ ├── engine.py # Engine compilation and inference +│ ├── builder.py # ONNX → TRT conversion +│ └── utilities.py # Device/precision helpers +│ +└── (no tests directory yet) ``` -### Entry Point System -The plugin is discovered via `pyproject.toml` entry point: -```toml -[project.entry-points."scope"] -scope_streamdiffusion = "scope_streamdiffusion" -``` -Scope automatically loads all registered plugins at startup. +## Key Files -## Critical Architectural Patterns +**Schema & Configuration:** +- `schema.py` — `StreamDiffusionConfig`: Pydantic model with 50+ fields defining model, scheduler, sampler, guidance, seed, ControlNet setup, TensorRT flags. Fields use `ui_field_config()` for Scope UI auto-generation. -### 1. Initialization vs Runtime Separation +**Pipeline Implementation:** +- `pipeline.py` — `StreamDiffusionPipeline`: implements `Pipeline` interface. Methods: + - `get_config_class()`: returns `StreamDiffusionConfig` + - `prepare(**kwargs) → Requirements`: returns resource hints + - `__call__(**kwargs) → dict`: main inference loop; returns `{"video": tensor}` -**This is the most important pattern in the codebase.** +**ControlNet:** +- `controlnet.py` — `ControlNetHandler`: manages multi-ControlNet attachment, caching, and inference integration. Supports Canny, pose, depth, etc. -- **`__init__()`**: One-time model loading, GPU setup, component initialization - - Loads diffusion model from HuggingFace/local path - - Sets up VAE, UNet, text encoder, scheduler - - Initializes Compel for prompt weighting - - NO runtime parameters here +**TensorRT:** +- `trt_engines.py` — discovers cached engines, auto-selects by device/precision +- `_trt/engine.py` — compiles ONNX models to TensorRT `.engine` format with dynamic shapes +- `_trt_cache.py` — caches compiled engines locally for rapid reuse -- **`__call__(**kwargs)`**: Per-frame processing with runtime parameters - - Receives all generation params (prompt, seed, strength, etc.) from kwargs - - Calls `_prepare_runtime_state()` to set up state from kwargs - - Processes frame and returns `{"video": tensor}` - - Parameters can change between frames without reloading model +**Entry Point:** +- `__init__.py` — `@hookimpl` function that Scope's plugin loader calls at discovery -**Why:** Enables efficient real-time streaming where the model stays loaded but parameters can change dynamically. +## Architecture -### 2. Configuration Schema Pattern +### Inference Flow -All pipeline parameters are defined in `schema.py` using: -```python -class StreamDiffusionConfig(BasePipelineConfig): - param_name: type = Field( - default=value, - description="...", - json_schema_extra=ui_field_config(order=N, label="Display Name") - ) +``` +Input frame (from Scope) + ↓ +[Tensor format conversion] (T,H,W,C) → (B,C,H,W) + ↓ +[Load/prepare model] (on first call; cached after) + ↓ +[Encode prompt] (Compel for weighting; cache embeddings) + ↓ +[VAE encode] frame → latent + ↓ +[ControlNet encode] (if enabled; pre-compute for all conditions) + ↓ +[Denoising loop] + for each step in scheduler: + - Add noise if img2img + - Denoise with UNet + - Apply ControlNet + - Apply guidance + ↓ +[VAE decode] latent → image + ↓ +[Tensor format conversion] (B,C,H,W) → (T,H,W,C) + ↓ +Output (PIL.Image or tensor dict) ``` -- Inherits from `BasePipelineConfig` (provided by Scope) -- Each field gets `ui_field_config()` for UI generation -- `order` determines UI layout order -- Validation happens automatically via Pydantic +### Model Loading Lifecycle -### 3. Tensor Format Conversions +1. **First `__init__`:** Load diffusion model from HuggingFace or local path. Setup VAE, UNet, text encoder, scheduler. Warm up GPU. +2. **Subsequent `__init__` calls:** Reuse loaded weights (unless model_id changed). +3. **Model change:** Trigger full reload (detected via signature comparison). +4. **ControlNet attach:** Load and fuse ControlNet weights; cache encoders. -**Scope's tensor format:** `(T, H, W, C)` normalized to [0, 1] -**Diffusion model format:** `(B, C, H, W)` for processing +### Parameter Handling -Conversions happen in `__call__()`: -```python -# Input: Scope format → Model format -frame = video[0] # (H, W, C) -input_tensor = frame.permute(2, 0, 1).unsqueeze(0) # (1, C, H, W) +**Initialization-time (requires model reload):** +- `model_id`: changes which model to load +- `torch_dtype`: precision (float16 vs float32) +- `acceleration`: xformers vs none -# Output: Model format → Scope format -output = result.permute(0, 2, 3, 1).clamp(0, 1) # (T, H, W, C) -``` +**Runtime (can change per-frame):** +- `prompt`: text input (re-encoded each frame or cached if unchanged) +- `seed`: random seed +- `guidance_scale`: classifier-free guidance strength +- `strength`: how much to denoise (img2img) +- `num_inference_steps`: denoising steps +- `scheduler`: LCM, DDPM, etc. (some require reinit) -### 4. Pipeline Interface Methods +### TensorRT Compilation -Must implement: -- `get_config_class()`: Returns config schema class -- `prepare()`: Returns `Requirements` (e.g., input size) -- `__call__(**kwargs)`: Main processing method +- Disabled by default (requires additional setup). +- When enabled: UNet compiled to device-specific `.engine` file. +- Compilation happens on first inference (slow; ~1-5 min depending on model). +- Cached engines reused on subsequent runs (instant load). +- Cache dir: `~/.cache/scope-streamdiffusion/trt/` -## Key Implementation Details +### ControlNet Support -### StreamDiffusion Specifics -- Uses LCM (Latent Consistency Models) scheduler for fast inference -- Supports batch denoising for better performance -- Single-step denoising with `t_index_list = [0]` for Turbo models -- Delta parameter controls temporal consistency in streams +- Multi-ControlNet: attach multiple conditions (e.g., Canny + pose). +- Conditions pre-computed once per prompt. +- Inference: scales applied per denoising step. +- Encoder caching to avoid re-encoding images. -### Model Support -- SD 1.5, SDXL, SD Turbo, SDXL Turbo -- Auto-detects SDXL vs SD 1.5 for proper prompt encoding -- Supports LoRA loading via `load_lora()` and `fuse_lora()` +## Development Workflow -### Prompt Encoding -- Uses Compel library for advanced prompt weighting -- SDXL requires pooled embeddings + add_time_ids -- SD 1.5 uses standard CLIP embeddings +### Before Starting -### ControlNet -- Basic support implemented but not exposed in UI -- Set via `self.controlnet` and `self.controlnet_pipeline` -- To expose: add fields to `StreamDiffusionConfig` +1. **Read `ADAPTATION_NOTES.md`** — explains how the original StreamDiffusion code was adapted to Scope's architecture. +2. **Understand init/runtime separation** — this is the foundation of how parameters flow. +3. **Check `schema.py`** for existing fields — don't add duplicate params. -## Adding New Parameters +### Adding a New Parameter 1. **Add to schema** (`schema.py`): ```python new_param: float = Field( default=1.0, ge=0.0, - le=2.0, - description="Parameter description", - json_schema_extra=ui_field_config(order=99, label="New Parameter"), + le=10.0, + description="What this does", + json_schema_extra=ui_field_config(order=50, label="New Param"), ) ``` 2. **Use in pipeline** (`pipeline.py`): + - If runtime-safe (doesn't require model reload): read from `kwargs` in `__call__()` + - If initialization-time (e.g., model architecture): pass to `__init__()` and track via signature + +3. **Test in Scope:** + - Run Scope (`SCOPE_REPL` or `scope serve`) + - Select StreamDiffusion pipeline + - Parameter should appear in UI with label and order + +### Adding ControlNet Support + +1. **Update schema** to expose ControlNet config: ```python - def __call__(self, **kwargs) -> dict: - new_param = kwargs.get("new_param", 1.0) - # Use new_param in processing... + controlnet_id: Optional[str] = Field( + default=None, + description="ControlNet model ID", + json_schema_extra=ui_field_config(order=40, label="ControlNet"), + ) + controlnet_conditioning: Optional[str] = Field( + default=None, + description="Encoded ControlNet condition", + ) ``` -3. **(Optional) Add to `_prepare_runtime_state()`** if it affects state initialization +2. **Update pipeline** to attach ControlNet: + ```python + if kwargs.get("controlnet_id"): + controlnet = ControlNetHandler(kwargs["controlnet_id"], device=self.device) + # Attach to diffusion pipeline + ``` -## Important Files Referenced +3. **Reference `controlnet.py`** for handler patterns. -- `ADAPTATION_NOTES.md`: Detailed explanation of how original StreamDiffusion code was adapted to Scope's architecture -- `README.md`: User-facing documentation with features and usage -- `INSTALL.md`: Quick installation guide +### TensorRT Integration -## Dependencies +Only attempt if you have CUDA 12+ and understand TensorRT compilation: -Core dependencies defined in `pyproject.toml`: -- `torch`: PyTorch (requires CUDA for GPU) -- `diffusers`: Stable Diffusion models and pipelines -- `compel`: Advanced prompt weighting -- `logfire`: Logging (Scope requirement) -- `numpy`, `pillow`: Image processing +1. Update `_trt/builder.py` if you need new precision/shape configs. +2. Call `_trt_cache.get_engine()` to auto-compile and cache. +3. Swap UNet for TRT engine in inference loop. +4. See `trt_engines.py` for caching logic. -## Debugging +## Testing -Common issues: -- **Plugin not appearing in Scope**: Check entry point registration -- **Model loading fails**: Verify model path and GPU availability -- **Import errors**: Ensure Scope framework is installed -- **Performance issues**: Enable xformers acceleration, reduce inference steps, or use Turbo models +No test suite exists yet. Manual testing approach: -## Development Workflow +```bash +# 1. Install in dev mode +pip install -e . + +# 2. Start Scope +SCOPE_REPL # or: scope serve + +# 3. In Scope, select StreamDiffusion pipeline + +# 4. Set parameters and verify: +# - Prompt changes take effect immediately +# - Model changes trigger reload (check logs) +# - Output looks reasonable +# - Performance is acceptable +``` + +### Debugging + +```bash +# Check plugin is discovered: +python -c "import scope_streamdiffusion; print('OK')" + +# Check config loads: +python -c "from scope_streamdiffusion import StreamDiffusionConfig; print(StreamDiffusionConfig.__fields__.keys())" + +# Test pipeline init: +from scope_streamdiffusion.pipeline import StreamDiffusionPipeline +p = StreamDiffusionPipeline() +print("Pipeline initialized") +``` + +## Important Constraints + +- **Model reloads are expensive.** Changing `model_id`, `torch_dtype`, or `acceleration` causes full reload (10-30s). +- **VRAM is limited.** Default to float16 and xformers acceleration. SDXL needs 8GB+ VRAM. +- **Scheduler matters.** LCM is fast (1-4 steps); DDPM is slow but more flexible (20-50 steps). Some model + scheduler combos don't work well. +- **TensorRT engines are device-specific.** Moving to a different GPU requires recompilation. +- **Prompt encoding is cached.** If prompt doesn't change, embeddings are reused (fast). If it does, encoding happens every frame (slower). + +## Dependencies & xformers + +**Core deps** (in `pyproject.toml`): +- `torch` — deep learning framework +- `diffusers` — HuggingFace diffusion models +- `logfire` — Scope logging integration +- `numpy`, `pillow` — image processing + +**Optional (xformers acceleration):** +xformers is NOT in dependencies because it ships with strict (often wrong) torch pins that break Scope's GPU stack. + +Install manually after setup, choosing the version for your torch: +```bash +torch 2.9.x → uv pip install --no-deps xformers==0.0.33.post2 +torch 2.10.x → uv pip install --no-deps xformers==0.0.34 +``` + +Use `--no-deps` to skip xformers' bogus torch pin. + +## Scope Integration Points + +**Entry point discovery:** +```toml +[project.entry-points."scope"] +scope_streamdiffusion = "scope_streamdiffusion" +``` +Scope calls `hookimpl()` function in `__init__.py` to register the pipeline. + +**Config schema:** +Schema fields with `ui_field_config()` are discovered by Scope and rendered in the pipeline UI. Changes to schema are reflected on next Scope restart. + +**Requirements:** +`prepare()` returns `Requirements` (e.g., minimum VRAM, input resolution). Scope uses this for validation. + +**Tensor I/O:** +`__call__()` receives `video` tensor from Scope in `(T, H, W, C)` format; must return same format. + +## Common Issues + +| Problem | Cause | Solution | +|---------|-------|----------| +| Plugin doesn't appear in Scope | Entry point not registered | Run `pip install -e .` again; restart Scope | +| Model loading fails | HF auth needed or model not found | Check HF cache; verify internet; login to HuggingFace if needed | +| OOM errors | Model too big for GPU | Use SDXL Turbo instead of base SDXL; reduce batch size; enable xformers | +| Slow inference | No GPU acceleration | Install xformers; check `torch.cuda.is_available()` returns True | +| ControlNet not working | Handler not attached properly | Review `controlnet.py` logic; check config passes condition tensor | +| TensorRT compile fails | CUDA version mismatch | Ensure CUDA 12+; check triton compatibility | + +## Code Style & Conventions + +- Type hints required on all public functions. +- Docstrings on classes and complex methods. +- Config validation via Pydantic (no manual validation). +- Logging via `logfire` (not print). +- No magic constants — all tunable params go in schema. + +## References -1. Make code changes in `src/scope_streamdiffusion/` -2. Changes are immediately available (development mode) -3. Restart Scope to reload plugin -4. Test in Scope UI with various parameters -5. Check Scope logs for errors/warnings +- **Scope plugin tutorials:** https://docs.daydream.live/scope/tutorials/build-video-effects-plugin +- **Diffusers docs:** https://huggingface.co/docs/diffusers +- **StreamDiffusion:** https://github.com/cumulo-autumn/StreamDiffusion +- **Scope Pydantic patterns:** Check other Scope pipelines in `daydreamlive-scope` repo diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index 94445e8..ab868e8 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -293,9 +293,10 @@ def _ensure_trt_taesd(self) -> None: flush=True, ) try: + taesd_model_id = "madebyollin/taesdxl" if self.sdxl else "madebyollin/taesd" enc_path, dec_path = build_taesd_engines( self._taesd_vae, - model_id="madebyollin/taesd", + model_id=taesd_model_id, image_height=int(self.height), image_width=int(self.width), min_batch_size=1, diff --git a/src/scope_streamdiffusion/schema.py b/src/scope_streamdiffusion/schema.py index 7f57f64..26e39f3 100644 --- a/src/scope_streamdiffusion/schema.py +++ b/src/scope_streamdiffusion/schema.py @@ -1,6 +1,6 @@ """Configuration schema for StreamDiffusion pipeline.""" -from enum import IntEnum +from enum import IntEnum, StrEnum from typing import Literal from pydantic import Field, field_validator @@ -12,6 +12,14 @@ ) +class ModelId(StrEnum): + """Supported StreamDiffusion models (all 1-step distillations).""" + + SD_TURBO = "stabilityai/sd-turbo" + SDXL_TURBO = "stabilityai/sdxl-turbo" + DMD2_SDXL_1STEP = "dmd2-sdxl-1step" + + class Resolution(IntEnum): """Allowed pixel dimensions for width/height. @@ -81,12 +89,8 @@ class StreamDiffusionConfig(BasePipelineConfig): # Model Configuration # ======================================== - model_id_or_path: Literal[ - "stabilityai/sd-turbo", - "stabilityai/sdxl-turbo", - "dmd2-sdxl-1step", - ] = Field( - default="stabilityai/sd-turbo", + model_id_or_path: ModelId = Field( + default=ModelId.SD_TURBO, description=( "Model selection. All entries are 1-step distillations. " "'dmd2-sdxl-1step' is SDXL-base with the DMD2 distilled UNet " From 393fcabe04dcaafc3e117fe6b3abd7a71c651413 Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Fri, 8 May 2026 12:09:28 -0700 Subject: [PATCH 21/26] Testing sd multi --- .claude/scheduled_tasks.lock | 1 + libndi-get.sh | 103 +++++ src/scope_streamdiffusion/controlnet.py | 22 + src/scope_streamdiffusion/pipeline.py | 531 +++++++++++++++++++----- src/scope_streamdiffusion/schema.py | 33 +- 5 files changed, 585 insertions(+), 105 deletions(-) create mode 100644 .claude/scheduled_tasks.lock create mode 100755 libndi-get.sh diff --git a/.claude/scheduled_tasks.lock b/.claude/scheduled_tasks.lock new file mode 100644 index 0000000..d72653b --- /dev/null +++ b/.claude/scheduled_tasks.lock @@ -0,0 +1 @@ +{"sessionId":"d8c05a9f-ecd4-4918-ba19-035cdd531a1a","pid":805816,"acquiredAt":1778002935354} \ No newline at end of file diff --git a/libndi-get.sh b/libndi-get.sh new file mode 100755 index 0000000..6362887 --- /dev/null +++ b/libndi-get.sh @@ -0,0 +1,103 @@ +#!/bin/bash +set -e + +# This script downloads and installs the NDI SDK for Linux. +# By default it downloads the NDI SDK v6 for Linux and extracts it to a temporary directory. +# +# Add argument "install" to install the library files to your system. +# Usage: ./libndi-get.sh install + + +LIBNDI_INSTALLER_NAME="Install_NDI_SDK_v6_Linux" +LIBNDI_INSTALLER="$LIBNDI_INSTALLER_NAME.tar.gz" +LIBNDI_INSTALLER_URL="https://downloads.ndi.tv/SDK/NDI_SDK_Linux/$LIBNDI_INSTALLER" + +# Use temporary directory +LIBNDI_TMP=$(mktemp --tmpdir -d ndidisk.XXXXXXX) + +# Check if the temp directory exists and is a directory. +if [[ -d "$LIBNDI_TMP" ]]; then + echo "Temporary directory created at $LIBNDI_TMP" +else + echo "Failed to create a temporary directory." + exit 1 +fi + +# While most of the command are with the folder path, this is needed for the libndi install script to run properly +pushd "$LIBNDI_TMP" + +# Download LIBNDI +# The follwoing should work with tmp folder in the user home directory - but not always... So we do not use it. +# curl -o "$LIBNDI_TMP/$LIBNDI_INSTALLER" $LIBNDI_INSTALLER_URL -f --retry 5 + +# The following is required if the temp directory is not in the user home directory. +curl -L "$LIBNDI_INSTALLER_URL" -f --retry 5 > "$LIBNDI_TMP/$LIBNDI_INSTALLER" + + +# Check if download was successful +if [ $? -ne 0 ]; then + echo "Download failed." + exit 1 +fi + +echo "Download complete." + +# Step 3: Uncompress the file. +echo "Uncompressing..." +tar -xzvf "$LIBNDI_TMP/$LIBNDI_INSTALLER" -C "$LIBNDI_TMP" + +# Check if uncompression was successful +if [ $? -ne 0 ]; then + echo "Uncompression failed." + exit 1 +fi + +echo "Uncompression complete." + + +yes | PAGER="cat" sh "$LIBNDI_INSTALLER_NAME.sh" + + +rm -rf "$LIBNDI_TMP/ndisdk" +echo "Moving things to a folder with no space" +mv "$LIBNDI_TMP/NDI SDK for Linux" "$LIBNDI_TMP/ndisdk" +echo +echo "Contents of $LIBNDI_TMP/ndisdk/lib:" +ls -la "$LIBNDI_TMP/ndisdk/lib" +echo +echo "Contents of $LIBNDI_TMP/ndisdk/lib/x86_64-linux-gnu:" +ls -la "$LIBNDI_TMP/ndisdk/lib/x86_64-linux-gnu" +echo + +popd + +if [ "$1" == "install" ]; then + echo "Copying the library files to the long-term location. You might be prompted for authentication." + sudo cp -P "$LIBNDI_TMP/ndisdk/lib/x86_64-linux-gnu/"* /usr/local/lib/ + sudo ldconfig + + echo "libndi installed to /usr/local/lib" + ls -la "/usr/local/lib/"libndi* + + echo "Adding backward compatibility tweaks for older plugins version to work with NDI v6" + sudo ln -s /usr/local/lib/libndi.so.6 /usr/local/lib/libndi.so.5 + + echo "Clean-up : Removing temporary folder" + rm -rf "$LIBNDI_TMP" + if [[ ! -d "$LIBNDI_TMP" ]]; then + echo "Temporary directory $LIBNDI_TMP does not exist anymore (good!)" + else + echo "Failed to clean-up temporary directory." + echo "Please clean this up manually - All should be in $LIBNDI_TMP" + exit 1 + fi + echo "Installation complete." +else + # Allow to keep the temporary files (to use with libndi-package.sh) + echo "No installation requested. The library files are in $LIBNDI_TMP/ndisdk/lib/x86_64-linux-gnu/" + echo "You can copy them manually to your system if needed." + ls -la "$LIBNDI_TMP/ndisdk/lib/x86_64-linux-gnu/libndi"* +fi + +echo "Script execution Complete." +exit 0 diff --git a/src/scope_streamdiffusion/controlnet.py b/src/scope_streamdiffusion/controlnet.py index a3b9de2..0303e68 100644 --- a/src/scope_streamdiffusion/controlnet.py +++ b/src/scope_streamdiffusion/controlnet.py @@ -85,6 +85,28 @@ def __init__(self, device: torch.device, dtype: torch.dtype): self.input: Optional[torch.Tensor] = None self.scale: float = 1.0 + def release(self) -> None: + """Drop all GPU-resident models and tensors held by this handler. + + Call before swapping the diffusion model — otherwise SD1.5 ControlNets, + depth-anything, and scribble weights stay resident across the swap and + contend with the new model's allocation. Caller is expected to run + ``torch.cuda.empty_cache()`` after this returns. + """ + self._controlnet_cache.clear() + self._depth_model = None + self._depth_hidden_state = None + self._last_depth_shape = None + self._depth_min_ema = None + self._depth_max_ema = None + self._prev_depth_input = None + self._scribble_model = None + self._prev_scribble_input = None + self._depth_norm_mean = None + self._depth_norm_std = None + self.model = None + self.input = None + def update( self, mode: str, diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index ab868e8..5c9c678 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -113,33 +113,30 @@ def __init__( # The schema's field is ``model_id_or_path``. Scope's pipeline_manager # merges schema defaults into the init kwargs by their declared name, - # so we have to accept that spelling — accepting only ``model_id`` - # silently drops the user's selection and reloads the default every - # time. Resolve in order: explicit model_id > model_id_or_path > default. - model_id = model_id or model_id_or_path or "stabilityai/sd-turbo" - - # Load the base model - print(f"Loading model: {model_id}") + # so what we see at __init__ is the *schema default*, not the user's + # UI selection — that only arrives via runtime kwargs/config on the + # first __call__. To avoid a spurious "load SD-Turbo, then immediately + # swap to the user's pick" on every startup, defer the actual model + # load to ``_ensure_pipe_loaded`` (called from __call__ once we have + # the runtime selection). The init-time arg is just a tentative + # default in case nothing more authoritative shows up at runtime. + config_model = getattr(self.config, "model_id_or_path", None) if self.config else None + model_id = model_id or config_model or model_id_or_path or "stabilityai/sd-turbo" self.model_id = model_id self._timesteps_override = MODEL_PRESETS.get(model_id, {}).get("timesteps_override") - self.pipe = self._load_model(model_id) - print(f"Model loaded: {self.pipe.__class__.__name__}") - - self.sdxl: bool = type(self.pipe) is StableDiffusionXLPipeline + print(f"[StreamDiffusion] Tentative model: {model_id} (load deferred to first __call__)") - # SDXL's default VAE overflows in fp16 and decodes NaN. Swap to the - # community fp16-fix VAE so the full-quality decode path works without - # forcing TAESD. - if self.sdxl and self.dtype == torch.float16: - self._install_sdxl_fp16_vae() - - # Model components - self.text_encoder = self.pipe.text_encoder - self.unet = self.pipe.unet - self.vae = self.pipe.vae - self._full_vae = self.vae # keep reference for toggling + # Model-dependent attrs are populated by ``_ensure_pipe_loaded``. + self.pipe = None + self.sdxl: bool = False + self.text_encoder = None + self.unet = None + self.vae = None + self._full_vae = None # populated on load self._taesd_vae = None self._using_taesd = False + self.scheduler = None + self.image_processor = None # legacy torch.compile flag — kept so other code paths that read # `_unet_compiled` (e.g. _ensure_trt_unet's "restore eager" branch) @@ -163,6 +160,10 @@ def __init__( self._trt_eager_controlnets: dict[str, Any] = {} # mode -> diffusers ControlNetModel (fallback) self._trt_cuda_stream = None self._trt_eager_unet = None # original; kept for fallback + # (height, width, controlnet_mode, use_taesd) of the last _setup_trt call. + # __call__ compares the current values against this and re-runs setup + # only on real divergence — otherwise the per-frame TRT block is a no-op. + self._trt_setup_signature: tuple | None = None # Read acceleration_mode at init from schema defaults / load_params. # The runtime kwargs path is unreliable because moth's 30fps param flood @@ -182,15 +183,8 @@ def __init__( self._node_id: str | None = kwargs.get("node_id") self._trt_cache_key: str = _trt_cache.cache_key(self._node_id, model_id) - # Setup scheduler - self.scheduler: LCMScheduler = LCMScheduler.from_config( - self.pipe.scheduler.config - ) - - # Setup image processor - self.image_processor: VaeImageProcessor = VaeImageProcessor( - self.pipe.vae_scale_factor - ) + # Scheduler / image_processor are model-dependent — populated by + # ``_ensure_pipe_loaded`` on the first __call__. # Setup embedding blender for prompt weighting and interpolation self.embedding_blender = EmbeddingBlender( @@ -245,10 +239,60 @@ def __init__( self._pooled_target: torch.Tensor | None = None self._transition_total_steps: int = 0 + # Seed transition state — when seed_transition_steps > 0, lerp + # `init_noise` from the previous seed's tensor to the new seed's + # tensor over N frames instead of hard-swapping. SDXL-Turbo / + # DMD2-1step have weaker stock_noise feedback than SD-Turbo, so + # without this seed changes read as hard cuts. + self._seed_transition_source: torch.Tensor | None = None + self._seed_transition_target: torch.Tensor | None = None + self._seed_transition_progress: int = 0 + self._seed_transition_total: int = 0 + # Mode-transition tracking — detect video↔text switches without a pipeline reload self._last_mode: str | None = None - print("StreamDiffusion pipeline initialized") + # TRT setup is deferred along with the model load — engines need + # ``self.pipe.unet`` to exist. ``_ensure_pipe_loaded`` runs + # ``_setup_trt`` immediately after loading when acceleration_mode + # is 'trt', so the first frame still pays the build cost up-front + # rather than mid-stream. + + print("StreamDiffusion pipeline initialized (model load deferred)") + + def _ensure_pipe_loaded(self, model_id: str) -> None: + """Load the diffusion model and populate model-dependent state. + + Called once from the first ``__call__`` with the user's actual + ``model_id_or_path`` from runtime kwargs/config. Doing the load here + instead of in ``__init__`` avoids a wasted "load schema default, + immediately swap to user's pick" cycle, since Scope's + pipeline_manager only forwards schema defaults at __init__ time. + Subsequent runtime model changes go through ``_swap_model``. + """ + if self.pipe is not None: + return + print(f"[StreamDiffusion] Loading model: {model_id}") + self.model_id = model_id + self._timesteps_override = MODEL_PRESETS.get(model_id, {}).get("timesteps_override") + self._trt_cache_key = _trt_cache.cache_key(self._node_id, model_id) + self.pipe = self._load_model(model_id) + print(f"[StreamDiffusion] Model loaded: {self.pipe.__class__.__name__}") + + self.sdxl = type(self.pipe) is StableDiffusionXLPipeline + if self.sdxl and self.dtype == torch.float16: + self._install_sdxl_fp16_vae() + + self.text_encoder = self.pipe.text_encoder + self.unet = self.pipe.unet + self.vae = self.pipe.vae + self._full_vae = self.vae + self._using_taesd = False + self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) + self.image_processor = VaeImageProcessor(self.pipe.vae_scale_factor) + + if self._acceleration_mode == "trt": + self._setup_trt(**self._trt_setup_args_from_config()) def _ensure_trt_taesd(self) -> None: """Build TRT engines for the TAESD encoder + decoder, swap self.vae. @@ -586,6 +630,139 @@ def _ensure_trt_unet( cache_state.unet_has_controlnet = False print(f"[TRT] UNet engine active: {engine_path}", flush=True) + def _setup_trt( + self, + *, + height: int, + width: int, + controlnet_mode: str, + use_taesd: bool, + ) -> None: + """Build or attach TRT engines for the current model. + + Called at load time (``__init__`` and ``_swap_model``) so the first + frame doesn't stall on a 5-10 minute compile, and again from + ``__call__`` only when ``(height, width, controlnet_mode, use_taesd)`` + diverges from the last setup. The inner ``_ensure_trt_*`` methods + short-circuit when nothing needs to change. + """ + if self._acceleration_mode != "trt": + return + try: + self._ensure_trt_unet( + controlnet_mode, + image_height=int(height), + image_width=int(width), + ) + except Exception as e: + print(f"[TRT] UNet engine swap failed, falling back to eager: {e}") + import traceback + traceback.print_exc() + if self._trt_eager_unet is not None: + self.unet = self._trt_eager_unet + if controlnet_mode in ("depth", "scribble"): + try: + self._ensure_trt_controlnet(controlnet_mode) + except Exception as e: + print( + f"[TRT] ControlNet engine swap failed for {controlnet_mode}, using eager: {e}" + ) + import traceback + traceback.print_exc() + if use_taesd: + try: + self._ensure_trt_taesd() + except Exception as e: + print(f"[TRT] TAESD engine swap failed, using eager: {e}") + import traceback + traceback.print_exc() + self._trt_setup_signature = ( + int(height), + int(width), + controlnet_mode, + bool(use_taesd), + ) + + def _reset_trt_state(self, new_model_id: str) -> None: + """Invalidate TRT sticky state so the next ``_setup_trt`` rebuilds. + + Called from ``_swap_model`` before loading the new model. Without + this, the sticky ``_trt_unet_built`` / ``_trt_taesd_built`` flags + cause subsequent ``_ensure_trt_*`` calls to short-circuit and the + new model runs eager regardless of ``acceleration_mode``. + """ + # Drop the module-scope cache entry for the previous model. Without + # this its ``unet_adapter`` / ``cn_adapters`` / ``taesd_adapter`` + # references stay live in ``_trt_cache._CACHE`` and pin engine + # memory across the swap — direct cause of OOM on a 24 GB card + # when going SD1.5 → SDXL with TRT on. + old_key = getattr(self, "_trt_cache_key", None) + if old_key: + _trt_cache.clear(old_key) + self._trt_unet_built = False + self._trt_unet_has_controlnet = False + self._trt_taesd_built = False + self._trt_eager_unet = None + self._trt_eager_taesd = None + self._trt_cn_built_modes.clear() + self._trt_cn_engines.clear() + self._trt_eager_controlnets.clear() + self._trt_cache_key = _trt_cache.cache_key(self._node_id, new_model_id) + self._trt_setup_signature = None + + def _set_acceleration_mode(self, mode: str) -> None: + """Swap between TRT-accelerated and eager modules at runtime. + + TRT engines themselves are immutable after build, but the choice of + which UNet / ControlNet / TAESD module ``self.*`` points at *can* be + flipped per frame. Cached adapters (in ``_trt_cache._CACHE`` and on + the instance) stay alive across the swap so toggling back to 'trt' + is instant after the first build. + """ + if mode not in ("none", "trt") or mode == self._acceleration_mode: + return + print( + f"[StreamDiffusion] acceleration_mode swap: " + f"{self._acceleration_mode} -> {mode}" + ) + if mode == "none": + self._deactivate_trt() + self._acceleration_mode = "none" + else: + self._acceleration_mode = "trt" + self._setup_trt(**self._trt_setup_args_from_config()) + + def _deactivate_trt(self) -> None: + """Restore eager UNet / ControlNet / TAESD; keep adapters cached. + + Resets the sticky ``_trt_*_built`` flags so a future ``_setup_trt`` + re-enters the cache-restore path and re-attaches the same adapters + without rebuilding. + """ + if self._trt_eager_unet is not None and self.unet is not self._trt_eager_unet: + self.unet = self._trt_eager_unet + if self._trt_eager_taesd is not None: + self._taesd_vae = self._trt_eager_taesd + if self._using_taesd: + self.vae = self._taesd_vae + if self._cn.model is not None: + self.controlnet = self._cn.model + self._trt_unet_built = False + self._trt_unet_has_controlnet = False + self._trt_taesd_built = False + self._trt_cn_built_modes.clear() + self._trt_setup_signature = None + + def _trt_setup_args_from_config(self) -> dict: + """Resolve _setup_trt args from self.config, with schema-default fallbacks.""" + cfg = self.config + return { + "height": int(getattr(cfg, "height", 512)) if cfg else 512, + "width": int(getattr(cfg, "width", 512)) if cfg else 512, + "controlnet_mode": getattr(cfg, "controlnet_mode", "none") if cfg else "none", + "use_taesd": bool(getattr(cfg, "use_taesd", True)) if cfg else True, + } + def _load_model(self, model_id: str) -> DiffusionPipeline: """Load the diffusion model. @@ -637,10 +814,17 @@ def _load_preset(self, preset: dict) -> DiffusionPipeline: unet_swap = preset.get("unet_swap") if unet_swap is not None: from huggingface_hub import hf_hub_download + from huggingface_hub.utils import LocalEntryNotFoundError unet_repo, unet_file = unet_swap - print(f"[StreamDiffusion] Downloading distilled UNet: {unet_repo}/{unet_file}") - ckpt_path = hf_hub_download(unet_repo, unet_file) + # Probe the local cache first so we can log accurately. The unconditional + # "Downloading" print was misleading on every cached load. + try: + ckpt_path = hf_hub_download(unet_repo, unet_file, local_files_only=True) + print(f"[StreamDiffusion] Loading cached distilled UNet: {unet_repo}/{unet_file}") + except LocalEntryNotFoundError: + print(f"[StreamDiffusion] Downloading distilled UNet: {unet_repo}/{unet_file}") + ckpt_path = hf_hub_download(unet_repo, unet_file) # Distilled-UNet repos (DMD2, SDXL-Lightning, etc.) often ship # weights only — no config.json — because the architecture is # identical to the base UNet. Reuse the base pipeline's UNet @@ -656,6 +840,74 @@ def _load_preset(self, preset: dict) -> DiffusionPipeline: return pipe + def _release_pipe_state(self) -> None: + """Drop every GPU-resident reference owned by the pipeline. + + Called from :meth:`_swap_model` before loading the new model. + Clears module references (``unet`` / ``vae`` / ``text_encoder`` / + any TRT adapter still pinned in ``self.unet``), per-step cached + tensors, prompt-embedding caches, and the ControlNet handler's + sub-models. Caller (``_swap_model``) is expected to have already + run :meth:`_reset_trt_state` so the cache-state-held adapter + references are gone too. Finishes with a ``gc.collect`` + + ``torch.cuda.empty_cache`` so the next allocation starts clean. + """ + import gc + + # Module references — these are the big-ticket allocations. + # self.unet may be a TRT adapter that owns engine memory; nulling + # it here is what actually releases the engine. + self.unet = None + self.vae = None + self.text_encoder = None + self._taesd_vae = None + self._full_vae = None + self.controlnet = None + self.controlnet_input = None + if hasattr(self, "_cn") and self._cn is not None: + self._cn.release() + + # Cached per-step tensors. + for attr in ( + "init_noise", + "stock_noise", + "x_t_latent_buffer", + "prev_image_result", + "prompt_embeds", + "add_text_embeds", + "add_time_ids", + "alpha_prod_t_sqrt", + "beta_prod_t_sqrt", + "c_skip", + "c_out", + "sub_timesteps_tensor", + "timesteps", + ): + if hasattr(self, attr): + setattr(self, attr, None) + + # Embedding / transition caches. + self._cached_base_embed = None + self._previous_prompt_embeddings = None + self._pooled_source = None + self._pooled_target = None + self._seed_transition_source = None + self._seed_transition_target = None + if hasattr(self, "embedding_blender"): + try: + self.embedding_blender.cancel_transition() + except Exception: + pass + + # Drop the pipeline last so any of the above that aliased its + # submodules have already been nulled. + self.pipe = None + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + def _swap_model(self, new_model_id: str) -> None: """Replace the loaded model in place. @@ -667,15 +919,16 @@ def _swap_model(self, new_model_id: str) -> None: swaps it. Stalls the frame loop while loading — same as a fresh load. """ print(f"[StreamDiffusion] Swapping model: {self.model_id} -> {new_model_id}") - # Free the current model's GPU memory before bringing the next one in - # so we don't peak at 2x weights. - old = getattr(self, "pipe", None) - if old is not None: - del old - self.pipe = None - self._taesd_vae = None - self._full_vae = None - torch.cuda.empty_cache() + # Reset TRT sticky state for the new model. Without this the + # ``_trt_unet_built`` / ``_trt_taesd_built`` flags from the previous + # model cause the next ``_setup_trt`` to short-circuit, leaving the + # new model running eager regardless of acceleration_mode. + self._reset_trt_state(new_model_id) + # Tear down everything the old model holds on the GPU before loading + # the new one — without this we peak at 2x model weights + engines + # and OOM on large models (SDXL UNet alone is ~5 GB fp16, plus a + # 2 GB+ TRT engine, plus VAE / text encoders / cached tensors). + self._release_pipe_state() self.model_id = new_model_id self._timesteps_override = MODEL_PRESETS.get(new_model_id, {}).get("timesteps_override") @@ -711,6 +964,11 @@ def _swap_model(self, new_model_id: str) -> None: self.embedding_blender.cancel_transition() except Exception: pass + self._cancel_seed_transition() + + # Build TRT engines for the new model now so the next frame doesn't stall. + if self._acceleration_mode == "trt": + self._setup_trt(**self._trt_setup_args_from_config()) def _install_sdxl_fp16_vae(self) -> None: """Swap SDXL's default VAE for madebyollin/sdxl-vae-fp16-fix. @@ -803,6 +1061,7 @@ def _prepare_runtime_state( do_add_noise: bool, transition: Optional[dict] = None, transition_steps: int = 0, + seed_transition_steps: int = 0, cfg_type: Literal["none", "full", "self", "initialize"] = "self", t_index_list: Optional[List[int]] = None, ): @@ -861,14 +1120,21 @@ def _prepare_runtime_state( seed_changed = seed != self._last_seed shape_changed = noise_shape != self._noise_shape or dims_changed - if seed_changed: + if shape_changed: + # Different latent shape can't be lerped against the old buffer; + # hard-reset and cancel any in-flight seed transition. self.generator.manual_seed(seed) self._last_seed = seed - - if seed_changed or shape_changed: + self._cancel_seed_transition() self.x_t_latent_buffer = None self._initialize_noise() self._noise_shape = noise_shape + elif seed_changed: + # Hard cut when seed_transition_steps == 0; multi-frame lerp otherwise. + self._setup_seed_transition(seed, seed_transition_steps) + + # Advance any in-flight seed transition by one frame. No-op when idle. + self._advance_seed_transition() # --- Prompt embeddings & transitions --- # The key includes spatial dims for SDXL because add_time_ids depend on them. @@ -1323,6 +1589,86 @@ def _initialize_noise(self): self.stock_noise = torch.zeros_like(self.init_noise) + def _setup_seed_transition(self, new_seed: int, total_steps: int) -> None: + """Begin a multi-frame lerp from the current init_noise to the new seed. + + Falls back to a hard cut (re-seed + regenerate immediately) when + ``total_steps <= 0`` or no prior ``init_noise`` exists. The first + frame after this runs at the source noise; subsequent frames lerp + toward the target via :meth:`_advance_seed_transition`. + """ + self._cancel_seed_transition() + if total_steps <= 0 or self.init_noise is None: + self.generator.manual_seed(new_seed) + self._last_seed = new_seed + self.x_t_latent_buffer = None + self._initialize_noise() + return + + self._seed_transition_source = self.init_noise.detach().clone() + self.generator.manual_seed(new_seed) + self._seed_transition_target = torch.randn( + self.init_noise.shape, + generator=self.generator, + device=self.device, + dtype=self.dtype, + ) + self._seed_transition_progress = 0 + self._seed_transition_total = total_steps + self._last_seed = new_seed + # Match the hard-cut path's stock_noise reset so the StreamDiffusion + # feedback term doesn't carry the previous seed's accumulator. + self.stock_noise = torch.zeros_like(self.init_noise) + self.x_t_latent_buffer = None + + @staticmethod + def _slerp_noise(a: torch.Tensor, b: torch.Tensor, t: float) -> torch.Tensor: + """Spherical interpolation between two noise tensors. + + Linear interpolation drops the variance of standard-normal noise to + ``(1-t)² + t²`` mid-blend (0.5 at t=0.5), which the diffusion model + renders as washed-out / blurry output. Slerp keeps the result on the + same hypersphere as the endpoints, preserving variance and producing + a perceptually smooth crossfade between scenes. + """ + a_flat = a.flatten().float() + b_flat = b.flatten().float() + a_norm = a_flat.norm() + b_norm = b_flat.norm() + cos_omega = (a_flat @ b_flat) / (a_norm * b_norm + 1e-8) + cos_omega = cos_omega.clamp(-1.0, 1.0) + omega = torch.acos(cos_omega) + sin_omega = torch.sin(omega) + # Collinear endpoints — degenerate to lerp to avoid divide-by-zero. + if sin_omega.abs() < 1e-6: + return torch.lerp(a, b, t) + w_a = torch.sin((1.0 - t) * omega) / sin_omega + w_b = torch.sin(t * omega) / sin_omega + return (w_a * a + w_b * b).to(dtype=a.dtype) + + def _advance_seed_transition(self) -> None: + """Slerp ``init_noise`` one step toward the target. No-op when idle.""" + if self._seed_transition_total <= 0: + return + self._seed_transition_progress += 1 + if self._seed_transition_progress >= self._seed_transition_total: + self.init_noise = self._seed_transition_target.clone() + self._cancel_seed_transition() + return + t = self._seed_transition_progress / self._seed_transition_total + self.init_noise = self._slerp_noise( + self._seed_transition_source, + self._seed_transition_target, + t, + ) + + def _cancel_seed_transition(self) -> None: + """Drop any in-flight seed transition without snapping init_noise.""" + self._seed_transition_source = None + self._seed_transition_target = None + self._seed_transition_progress = 0 + self._seed_transition_total = 0 + def _get_add_time_ids( self, original_size, @@ -1625,12 +1971,16 @@ def get_param(key, default): # Finally use default return default - # Hot-swap when the model selection changes at runtime. Scope routes - # model_id_or_path through setNodeParams (kwargs / config), not - # through pipeline/load — so this is the only spot where a UI-driven - # change actually takes effect. - requested_model = get_param("model_id_or_path", None) - if requested_model and requested_model != self.model_id: + # Resolve the user's model selection from runtime kwargs/config. + # On the first call the pipe isn't loaded yet — __init__ defers the + # load specifically so we can pick the *real* selection here instead + # of the schema default that pipeline_manager hands us at __init__. + # On subsequent calls a runtime change (UI swap) routes through + # _swap_model. + requested_model = get_param("model_id_or_path", None) or self.model_id + if self.pipe is None: + self._ensure_pipe_loaded(requested_model) + elif requested_model and requested_model != self.model_id: self._swap_model(requested_model) # Extract all parameters with config fallback @@ -1645,6 +1995,7 @@ def get_param(key, default): strength = get_param("strength", 0.9) seed = get_param("seed", 42) + seed_transition_steps = get_param("seed_transition_steps", 0) delta = get_param("delta", 1.0) width = get_param("width", 512) height = get_param("height", 512) @@ -1665,8 +2016,14 @@ def get_param(key, default): # at ~40 ms/call vs TAESD's ~5 ms. Big perf cliff if the param isn't # propagated from moth (e.g. queue-drop or absent from project file). use_taesd = get_param("use_taesd", True) - # acceleration_mode is locked at init (see __init__) — runtime updates - # don't change it because TRT engines can't be hot-swapped. + # acceleration_mode is hot-swappable: the engines themselves can't be + # rebuilt at runtime, but the module references (self.unet etc.) can + # flip between TRT adapters and eager modules. _set_acceleration_mode + # swaps; first 'trt' activation builds (slow), subsequent ones hit + # the cached adapters (instant). + requested_mode = get_param("acceleration_mode", self._acceleration_mode) + if requested_mode != self._acceleration_mode: + self._set_acceleration_mode(requested_mode) acceleration_mode = self._acceleration_mode # --- Safeguard: prevent invalid strength / num_inference_steps combos --- @@ -1707,44 +2064,20 @@ def get_param(key, default): self.controlnet = self._cn.model self.controlnet_input = self._cn.input - # TRT engine swap — UNet always, ControlNet additionally when active. - # Two separate engines (each <2 GB ONNX) instead of a single combined - # graph that hits TRT's cask-convolution bug. - # Pass runtime dims explicitly: _prepare_runtime_state (which sets - # self.height/self.width and the latent dims) runs *after* this - # block. Static-shape SDXL engines need the real runtime dims at - # build time, otherwise they're sized for the __init__ defaults - # (512x512) and mismatch at inference. Setting self.{height,width} - # preemptively here is wrong — it'd block dims_changed in - # _prepare_runtime_state, leaving self.latent_{height,width} stale. + # TRT engines are normally built at load time (in __init__ / + # _swap_model). This guard catches the residual cases where runtime + # values diverge from what was used at load — e.g. the user changes + # resolution, toggles controlnet on/off, or flips use_taesd in the + # UI. Fast no-op when nothing changed. if acceleration_mode == "trt": - try: - self._ensure_trt_unet( - controlnet_mode, - image_height=int(height), - image_width=int(width), + sig = (int(height), int(width), controlnet_mode, bool(use_taesd)) + if sig != self._trt_setup_signature: + self._setup_trt( + height=int(height), + width=int(width), + controlnet_mode=controlnet_mode, + use_taesd=bool(use_taesd), ) - except Exception as e: - print(f"[TRT] UNet engine swap failed, falling back to eager: {e}") - import traceback - traceback.print_exc() - if self._trt_eager_unet is not None: - self.unet = self._trt_eager_unet - if controlnet_mode in ("depth", "scribble"): - try: - self._ensure_trt_controlnet(controlnet_mode) - except Exception as e: - print(f"[TRT] ControlNet engine swap failed for {controlnet_mode}, using eager: {e}") - import traceback - traceback.print_exc() - # TAESD TRT — saves ~3-5 ms vs eager TAESD (which is already fast) - if use_taesd: - try: - self._ensure_trt_taesd() - except Exception as e: - print(f"[TRT] TAESD engine swap failed, using eager: {e}") - import traceback - traceback.print_exc() self.controlnet_conditioning_scale = self._cn.scale # Extract transition (explicit transition overrides auto-transition) @@ -1766,16 +2099,22 @@ def get_param(key, default): do_add_noise=do_add_noise, transition=transition, transition_steps=transition_steps, + seed_transition_steps=seed_transition_steps, ) frame = None - # Process input. Note: image_loopback must be opt-in. Falling back to - # `prev_image_result` whenever video is missing turns text mode into a - # recursive feedback loop (each frame uses the previous frame's - # output as input), which drifts to over-saturated/abstract patterns - # within a few frames. - if image_loopback and self.prev_image_result is not None: + # Process input. In text-only mode (no video stream) we fall back to + # the previous frame's output as input — the implicit-loopback path. + # This is what gives txt2img its iterative refinement: frame 1 is a + # cold t2i pass and frames 2+ are img2img on the previous output, so + # SD-Turbo's single-step recovery sharpens detail across frames. + # Removing the fallback (strict opt-in via image_loopback) loses the + # refinement and txt2img output stays at single-shot quality forever. + # The drift this can cause on long runs is the cost of admission. + if image_loopback or ( + (video is None or len(video) == 0) and self.prev_image_result is not None + ): frame = self.prev_image_result elif video is not None and len(video) > 0: # Convert Scope tensor format to pipeline format diff --git a/src/scope_streamdiffusion/schema.py b/src/scope_streamdiffusion/schema.py index 26e39f3..c5284cb 100644 --- a/src/scope_streamdiffusion/schema.py +++ b/src/scope_streamdiffusion/schema.py @@ -105,16 +105,18 @@ class StreamDiffusionConfig(BasePipelineConfig): description=( "TRT-compile UNet (and ControlNet on SD 1.5) for 2-8x denoising " "speedup. First build per model takes 5-10 min and caches to " - "~/.cache/scope-streamdiffusion-trt/. Set at session start; " - "changing requires pipeline reload. SD 1.5 engines support " - "dynamic resolution 256-1024 and batch 1-4. SDXL engines " - "(sdxl-turbo, dmd2-sdxl-1step) support dynamic resolution " - "512-1024 with static batch=1 — different envelope to fit a " - "24 GB VRAM build budget. SDXL + ControlNet + TRT is not yet " - "supported (raises NotImplementedError); use acceleration='none' " - "with controlnet on SDXL until that lands." + "~/.cache/scope-streamdiffusion-trt/. Hot-swappable at runtime: " + "toggling restores cached engines from process-scope cache " + "(instant) or builds them on first activation (stalls the " + "stream). SD 1.5 engines support dynamic resolution 256-1024 " + "and batch 1-4. SDXL engines (sdxl-turbo, dmd2-sdxl-1step) " + "support dynamic resolution 512-1024 with static batch=1 — " + "different envelope to fit a 24 GB VRAM build budget. SDXL + " + "ControlNet + TRT is not yet supported (raises " + "NotImplementedError); use acceleration_mode='none' with " + "controlnet on SDXL until that lands." ), - #json_schema_extra=ui_field_config(order=2, label="Acceleration"), + json_schema_extra=ui_field_config(order=2, label="Acceleration"), ) use_taesd: bool = Field( @@ -225,6 +227,19 @@ class StreamDiffusionConfig(BasePipelineConfig): json_schema_extra=ui_field_config(order=13, label="Seed"), ) + seed_transition_steps: int = Field( + default=0, + ge=0, + le=240, + description=( + "Lerp the seed noise toward the new seed over N frames on each " + "seed change. 0 = hard cut. SDXL-Turbo and DMD2-1step have less " + "natural frame-to-frame correlation than SD-Turbo; this gives a " + "deterministic settle independent of the model." + ), + json_schema_extra=ui_field_config(order=14, label="Seed Transition Steps"), + ) + # ======================================== # Diffusion Parameters # ======================================== From 51d00839e1cf03b504a4aacaf384a05bbab30f6b Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Fri, 8 May 2026 23:39:23 -0700 Subject: [PATCH 22/26] feat(loopback): per-model implicit_loopback flag for CFG-distilled models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The implicit txt2img→img2img fallback (frame 1 is t2i, frames 2+ are img2img on the previous output) gives non-CFG-distilled Turbo models their iterative-refinement look. Re-running it on DMD2 — which has high-CFG behavior baked into its weights — applies CFG-shape every frame and the feedback loop blows up within a handful of iterations. Adds an `implicit_loopback` key to MODEL_PRESETS, defaults True for unknown models, sets False for `dmd2-sdxl-1step`. Read at __init__, _ensure_pipe_loaded, and _swap_model so runtime model swaps pick up the right behavior. Explicit image_loopback=True still overrides for manual testing of the DMD2 runaway look. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/scope_streamdiffusion/pipeline.py | 38 ++++++++++++++++++++------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index 5c9c678..bf831ed 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -50,6 +50,12 @@ # DMD2 was distilled at this specific timestep — feeding it # LCMScheduler's default 1-step pick (~979) produces noise. "timesteps_override": [399], + # DMD2 has CFG distilled into its weights, so its single-shot + # output already looks like a guidance-shaped result. Implicit + # txt2img→img2img loopback re-applies that CFG-shape every frame + # and the chain blows up within a few iterations. Skip the + # implicit fallback; explicit image_loopback=True still works. + "implicit_loopback": False, }, } @@ -123,7 +129,14 @@ def __init__( config_model = getattr(self.config, "model_id_or_path", None) if self.config else None model_id = model_id or config_model or model_id_or_path or "stabilityai/sd-turbo" self.model_id = model_id - self._timesteps_override = MODEL_PRESETS.get(model_id, {}).get("timesteps_override") + preset = MODEL_PRESETS.get(model_id, {}) + self._timesteps_override = preset.get("timesteps_override") + # CFG-distilled models (DMD2, future Hyper-SD / Lightning) explode in + # the implicit txt2img→img2img loopback because each iteration re- + # applies the model's baked-in guidance shaping. Default True for + # everything else (SD-Turbo, SDXL-Turbo) since the iterative + # refinement is what gives those models their polished t2i look. + self._implicit_loopback: bool = preset.get("implicit_loopback", True) print(f"[StreamDiffusion] Tentative model: {model_id} (load deferred to first __call__)") # Model-dependent attrs are populated by ``_ensure_pipe_loaded``. @@ -274,7 +287,9 @@ def _ensure_pipe_loaded(self, model_id: str) -> None: return print(f"[StreamDiffusion] Loading model: {model_id}") self.model_id = model_id - self._timesteps_override = MODEL_PRESETS.get(model_id, {}).get("timesteps_override") + preset = MODEL_PRESETS.get(model_id, {}) + self._timesteps_override = preset.get("timesteps_override") + self._implicit_loopback = preset.get("implicit_loopback", True) self._trt_cache_key = _trt_cache.cache_key(self._node_id, model_id) self.pipe = self._load_model(model_id) print(f"[StreamDiffusion] Model loaded: {self.pipe.__class__.__name__}") @@ -931,7 +946,9 @@ def _swap_model(self, new_model_id: str) -> None: self._release_pipe_state() self.model_id = new_model_id - self._timesteps_override = MODEL_PRESETS.get(new_model_id, {}).get("timesteps_override") + preset = MODEL_PRESETS.get(new_model_id, {}) + self._timesteps_override = preset.get("timesteps_override") + self._implicit_loopback = preset.get("implicit_loopback", True) self.pipe = self._load_model(new_model_id) print(f"[StreamDiffusion] Model loaded: {self.pipe.__class__.__name__}") self.sdxl = type(self.pipe) is StableDiffusionXLPipeline @@ -2109,12 +2126,15 @@ def get_param(key, default): # This is what gives txt2img its iterative refinement: frame 1 is a # cold t2i pass and frames 2+ are img2img on the previous output, so # SD-Turbo's single-step recovery sharpens detail across frames. - # Removing the fallback (strict opt-in via image_loopback) loses the - # refinement and txt2img output stays at single-shot quality forever. - # The drift this can cause on long runs is the cost of admission. - if image_loopback or ( - (video is None or len(video) == 0) and self.prev_image_result is not None - ): + # Disabled per-model for CFG-distilled checkpoints (DMD2) where the + # baked-in guidance shaping compounds catastrophically across the + # feedback loop. Explicit image_loopback=True still wins regardless, + # so the user can force loopback on DMD2 if they want the stylized + # divergence (or for testing). + implicit_ok = self._implicit_loopback and ( + video is None or len(video) == 0 + ) and self.prev_image_result is not None + if image_loopback or implicit_ok: frame = self.prev_image_result elif video is not None and len(video) > 0: # Convert Scope tensor format to pipeline format From 2c9142761ebeba602dba3bcd51c1c81285446199 Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Fri, 8 May 2026 23:41:45 -0700 Subject: [PATCH 23/26] feat(negative): embedding-space negative subtraction for single-pass models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wires up the inert negative_prompt / negative_prompt_scale schema fields so they actually do something on Turbo / DMD2, where standard CFG isn't viable (would require a second UNet pass on a model designed for one). Approach: encode the negative prompt once, cache it, and on each frame shift prompt_embeds in the direction (pos - scale * neg). Renormalize each token back to its original L2 norm so the conditioning vector stays inside the UNet's training distribution. Raw subtraction without norm-preservation blows up the magnitude and the UNet predicts pure noise — that was the first failure mode. Same shift applied to SDXL's pooled add_text_embeds. add_time_ids stay put (size-derived, not text-derived). Cache invalidates on model swap and pipeline release because text-encoder dim changes between SD1.5 and SDXL. Schema default dropped from 1.0 to 0.5 with a description that flags 0.3-0.7 as the typical range and >1.0 as the danger zone. Effective on DMD2 (where it dampens the baked-in CFG look). Weak to ineffective on Turbo models — their distillation flattens the conditioning space, leaving little directional flex for embedding shifts to land. Behavior, not bug. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/scope_streamdiffusion/pipeline.py | 90 +++++++++++++++++++++++++++ src/scope_streamdiffusion/schema.py | 14 ++++- 2 files changed, 101 insertions(+), 3 deletions(-) diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index bf831ed..95b6236 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -262,6 +262,16 @@ def __init__( self._seed_transition_progress: int = 0 self._seed_transition_total: int = 0 + # Negative-prompt embedding cache. We don't run a second UNet pass + # for CFG (DMD2 / SD-Turbo are CFG-distilled), so "negative prompt" + # here means embedding-space subtraction: prompt_embeds -= scale * + # neg_embed before the single UNet call. Cache the encoded negative + # so we don't re-encode every frame; invalidate when the text or + # the model changes. + self._cached_negative_text: str | None = None + self._cached_negative_embed: torch.Tensor | None = None + self._cached_negative_pooled: torch.Tensor | None = None + # Mode-transition tracking — detect video↔text switches without a pipeline reload self._last_mode: str | None = None @@ -908,6 +918,10 @@ def _release_pipe_state(self) -> None: self._pooled_target = None self._seed_transition_source = None self._seed_transition_target = None + # Negative-embed cache is text-encoder-specific; swap invalidates it. + self._cached_negative_text = None + self._cached_negative_embed = None + self._cached_negative_pooled = None if hasattr(self, "embedding_blender"): try: self.embedding_blender.cancel_transition() @@ -971,6 +985,9 @@ def _swap_model(self, new_model_id: str) -> None: self._prompts_key = None self._cached_base_embed = None self._previous_prompt_embeddings = None + self._cached_negative_text = None + self._cached_negative_embed = None + self._cached_negative_pooled = None self.prev_image_result = None self._last_transition_id = None self._pooled_source = None @@ -1428,6 +1445,71 @@ def _encode_single_prompt( return prompt_embeds, pooled_embeds + def _apply_negative_subtraction( + self, + negative_prompt: str, + negative_prompt_scale: float, + ) -> None: + """Norm-preserving negative subtraction in embedding space. + + Single-pass models (Turbo, DMD2) can't use standard CFG without + doubling UNet cost. Embedding subtraction is the cheap alternative, + but raw ``pos - scale * neg`` blows up the L2 norm of each token, + knocking the conditioning out of the training distribution and the + UNet predicts pure noise. + + We do the subtraction directionally and then renormalize each + token's embedding back to the original L2 norm. Result: direction + shifts away from the negative concept, magnitude is preserved. + Same treatment applied to SDXL's pooled ``add_text_embeds``. + + ``add_time_ids`` are positional / size-derived, not text-derived, + so they stay put. + + Encoded negative is cached on text; empty text or scale 0 is a + no-op. Cache invalidates on model swap (text-encoder dim changes). + """ + if negative_prompt_scale <= 0 or not negative_prompt: + return + if ( + self._cached_negative_text != negative_prompt + or self._cached_negative_embed is None + ): + neg_embed, neg_pooled = self._encode_single_prompt(negative_prompt) + self._cached_negative_text = negative_prompt + self._cached_negative_embed = neg_embed.detach() + self._cached_negative_pooled = ( + neg_pooled.detach() if neg_pooled is not None else None + ) + + self.prompt_embeds = self._norm_preserving_subtract( + self.prompt_embeds, self._cached_negative_embed, negative_prompt_scale + ) + if self.sdxl and self._cached_negative_pooled is not None: + self.add_text_embeds = self._norm_preserving_subtract( + self.add_text_embeds, + self._cached_negative_pooled, + negative_prompt_scale, + ) + + @staticmethod + def _norm_preserving_subtract( + positive: torch.Tensor, negative: torch.Tensor, scale: float + ) -> torch.Tensor: + """Subtract ``scale * negative`` from ``positive`` then rescale to + match positive's original per-row L2 norm. Direction shifts, + magnitude is preserved, UNet stays inside training distribution. + """ + neg = negative.to(device=positive.device, dtype=positive.dtype) + if neg.shape[0] != positive.shape[0]: + neg = neg[:1].expand_as(positive) + # Per-token (or per-row for pooled) norm preservation: keep an + # epsilon to avoid /0 for any zero-magnitude rows. + orig_norm = positive.norm(dim=-1, keepdim=True) + shifted = positive - scale * neg + new_norm = shifted.norm(dim=-1, keepdim=True).clamp(min=1e-6) + return shifted * (orig_norm / new_norm) + def _encode_prompts_array( self, prompt_items: list[dict], @@ -2020,6 +2102,8 @@ def get_param(key, default): do_add_noise = get_param("do_add_noise", True) similar_image_filter_enabled = get_param("similar_image_filter_enabled", False) image_loopback = get_param("image_loopback", False) + negative_prompt = get_param("negative_prompt", "") + negative_prompt_scale = float(get_param("negative_prompt_scale", 1.0)) controlnet_mode = get_param("controlnet_mode", "none") controlnet_scale = get_param("controlnet_scale", 1.0) controlnet_temporal_smoothing = get_param("controlnet_temporal_smoothing", 0.5) @@ -2119,6 +2203,12 @@ def get_param(key, default): seed_transition_steps=seed_transition_steps, ) + # Apply embedding-space negative subtraction *after* prompt embeds + # are settled (including any prompt transition / SDXL pooled + # update). Acts on whatever this frame's conditioning happens to + # be, which is the right thing during transitions too. + self._apply_negative_subtraction(negative_prompt, negative_prompt_scale) + frame = None # Process input. In text-only mode (no video stream) we fall back to diff --git a/src/scope_streamdiffusion/schema.py b/src/scope_streamdiffusion/schema.py index c5284cb..9051f0a 100644 --- a/src/scope_streamdiffusion/schema.py +++ b/src/scope_streamdiffusion/schema.py @@ -195,11 +195,19 @@ class StreamDiffusionConfig(BasePipelineConfig): ) negative_prompt_scale: float = Field( - default=1.0, + default=0.5, ge=0.0, le=2.0, - description="Strength of embedding-space negative guidance (used when guidance_scale=0). Subtracts the negative prompt embedding from the positive. 0 = disabled, 1 = full subtraction.", - #json_schema_extra=ui_field_config(order=12, label="Negative Scale"), + description=( + "Strength of embedding-space negative guidance for single-pass " + "models (Turbo, DMD2) that can't use standard CFG. The negative " + "embedding is subtracted from the positive then rescaled to " + "preserve magnitude — direction shifts but the result stays in " + "the UNet's training distribution. 0 = disabled. 0.3-0.7 is " + "typical; >1.0 starts to push out of distribution and outputs " + "may degrade or go to noise." + ), + json_schema_extra=ui_field_config(order=12, label="Negative Scale"), ) prompt_interpolation_method: Literal["linear", "slerp"] = Field( From 90d1c1f31050e130232e855bc06007f72a6ddfaa Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Fri, 8 May 2026 23:41:58 -0700 Subject: [PATCH 24/26] fix(trt): defensive re-activate of TAESD engines if context is lost Hit a runtime crash in the TAESD decoder where Engine.context was None mid-stream, killing 10 consecutive frames and escalating to fatal. Engine.context is set by activate() and only initialized to None in __init__, so the only way it's None at decode time is if activate() was never run (or its effect was torn down by a sibling teardown path while the cached adapter survived in module-scope _trt_cache). Adds an _ensure_activated check to AutoencoderKLEngine.encode/decode: if context is None, re-load and re-activate the engine before allocate_buffers. Keeps the streaming loop alive instead of dying on an AttributeError. Band-aid, not a root-cause fix. The underlying lifecycle bug is likely a race in acceleration_mode / model swap teardown vs the process-scope adapter cache. Worth investigating if it recurs in a way this band-aid doesn't catch. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/scope_streamdiffusion/_trt/engine.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/scope_streamdiffusion/_trt/engine.py b/src/scope_streamdiffusion/_trt/engine.py index 375c6cf..640c29b 100644 --- a/src/scope_streamdiffusion/_trt/engine.py +++ b/src/scope_streamdiffusion/_trt/engine.py @@ -287,7 +287,24 @@ def __init__( self.encoder.activate() self.decoder.activate() + def _ensure_activated(self, engine): + """Re-activate the TRT engine if its execution context was lost. + + The cached adapter survives across plugin reinit (module-scope + cache in ``_trt_cache``), but rapid acceleration_mode / model + swaps can leave a previously-activated engine with context=None + if its activation was torn down by a sibling teardown path. + Repairing here keeps the streaming loop alive instead of dying + on `AttributeError: 'NoneType' object has no attribute + 'set_input_shape'`. + """ + if engine.context is None: + if engine.engine is None: + engine.load() + engine.activate() + def encode(self, images: torch.Tensor, **kwargs): + self._ensure_activated(self.encoder) self.encoder.allocate_buffers( shape_dict={ "images": images.shape, @@ -308,6 +325,7 @@ def encode(self, images: torch.Tensor, **kwargs): return AutoencoderTinyOutput(latents=latents) def decode(self, latent: torch.Tensor, **kwargs): + self._ensure_activated(self.decoder) self.decoder.allocate_buffers( shape_dict={ "latent": latent.shape, From b1b5478d2b44af58c0f9a496bbd854f3825f6536 Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Sat, 9 May 2026 00:00:11 -0700 Subject: [PATCH 25/26] refactor: extract PromptEncoder into its own module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit First step in breaking up pipeline.py (2393 lines and growing). The prompt-encoding concern is the most self-contained: it has clear inputs (prompts, dims, transition state), clear outputs (prompt_embeds plus SDXL add_text_embeds / add_time_ids), and only depends on the loaded pipe and a small set of caches. Pattern follows ControlNetHandler — a real helper class the pipeline holds as a member, no mixins, no inheritance gymnastics. Pipeline calls self.prompts.encode_for_frame(...) once per __call__ and reads the resulting embeds via self.prompts.{prompt_embeds, add_text_embeds, add_time_ids}. Inference paths updated for the new access path. Moved out of pipeline.py: - EmbeddingBlender ownership and lifecycle - _normalize_prompts, _encode_single_prompt, _encode_prompts_array - _make_prompts_key, _hash_transition, _begin_transition, _advance_pooled_transition, _finish_pooled_transition - _apply_negative_subtraction, _norm_preserving_subtract - _get_add_time_ids (SDXL aug-conditioning, only consumed by the encoder path) - All associated cache state: _cached_base_embed, _previous_prompt_embeddings, _prompts_key, _last_transition_id, _pooled_source/_target, _transition_total_steps, _cached_negative_* PromptEncoder.attach(pipe, sdxl) is the single seam that gets called from _ensure_pipe_loaded and _swap_model after the new pipe is available; it handles cache invalidation since text-encoder hidden dim changes between SD 1.5 and SDXL. _release_pipe_state calls prompts.reset_caches() instead of nulling individual fields. pipeline.py: 2393 -> 1926 lines (20% smaller). prompt_encoder.py: 527. No behavior change — same encode flow, same transition semantics, same negative-subtraction math. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/scope_streamdiffusion/pipeline.py | 543 ++------------------ src/scope_streamdiffusion/prompt_encoder.py | 527 +++++++++++++++++++ 2 files changed, 565 insertions(+), 505 deletions(-) create mode 100644 src/scope_streamdiffusion/prompt_encoder.py diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index 95b6236..848333f 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -16,10 +16,10 @@ retrieve_latents, ) from scope.core.pipelines.interface import Pipeline, Requirements -from scope.core.pipelines.blending import EmbeddingBlender, parse_transition_config from . import _trt_cache from .controlnet import ControlNetHandler +from .prompt_encoder import PromptEncoder, normalize_prompts from .schema import StreamDiffusionConfig if TYPE_CHECKING: @@ -199,15 +199,15 @@ def __init__( # Scheduler / image_processor are model-dependent — populated by # ``_ensure_pipe_loaded`` on the first __call__. - # Setup embedding blender for prompt weighting and interpolation - self.embedding_blender = EmbeddingBlender( - device=self.device, - dtype=self.dtype, - ) + # Prompt encoding (text-encode, blending, transitions, negative + # subtraction) lives on its own helper. ``attach()`` wires it to + # the live pipe at load time and on every model swap. Inference + # reads ``self.prompts.prompt_embeds`` / ``add_text_embeds`` / + # ``add_time_ids`` directly. + self.prompts = PromptEncoder(self.device, self.dtype) # State that will be set during runtime self.generator = torch.Generator(device=self.device) - self._previous_prompt_embeddings = None self.similar_filter = SimilarImageFilter() self.prev_image_result = None self.inference_time_ema = 0 @@ -241,16 +241,6 @@ def __init__( ) self._last_seed: int | None = None self._noise_shape: tuple | None = None # (batch_size, latent_h, latent_w) - self._prompts_key: tuple | None = None - self._cached_base_embed: torch.Tensor | None = None # (1, seq_len, hidden_dim) - - # Transition state — the main embedding queue lives inside - # EmbeddingBlender; the pooled embedding (SDXL only) is interpolated - # linearly in lockstep here so `add_text_embeds` tracks the morph. - self._last_transition_id: str | None = None - self._pooled_source: torch.Tensor | None = None - self._pooled_target: torch.Tensor | None = None - self._transition_total_steps: int = 0 # Seed transition state — when seed_transition_steps > 0, lerp # `init_noise` from the previous seed's tensor to the new seed's @@ -262,16 +252,6 @@ def __init__( self._seed_transition_progress: int = 0 self._seed_transition_total: int = 0 - # Negative-prompt embedding cache. We don't run a second UNet pass - # for CFG (DMD2 / SD-Turbo are CFG-distilled), so "negative prompt" - # here means embedding-space subtraction: prompt_embeds -= scale * - # neg_embed before the single UNet call. Cache the encoded negative - # so we don't re-encode every frame; invalidate when the text or - # the model changes. - self._cached_negative_text: str | None = None - self._cached_negative_embed: torch.Tensor | None = None - self._cached_negative_pooled: torch.Tensor | None = None - # Mode-transition tracking — detect video↔text switches without a pipeline reload self._last_mode: str | None = None @@ -315,6 +295,7 @@ def _ensure_pipe_loaded(self, model_id: str) -> None: self._using_taesd = False self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) self.image_processor = VaeImageProcessor(self.pipe.vae_scale_factor) + self.prompts.attach(self.pipe, self.sdxl) if self._acceleration_mode == "trt": self._setup_trt(**self._trt_setup_args_from_config()) @@ -898,9 +879,6 @@ def _release_pipe_state(self) -> None: "stock_noise", "x_t_latent_buffer", "prev_image_result", - "prompt_embeds", - "add_text_embeds", - "add_time_ids", "alpha_prod_t_sqrt", "beta_prod_t_sqrt", "c_skip", @@ -911,22 +889,14 @@ def _release_pipe_state(self) -> None: if hasattr(self, attr): setattr(self, attr, None) - # Embedding / transition caches. - self._cached_base_embed = None - self._previous_prompt_embeddings = None - self._pooled_source = None - self._pooled_target = None + # Reset prompt-encoder caches (text-encoder-specific; the new + # model will have a different text encoder). + if hasattr(self, "prompts"): + self.prompts.reset_caches() + + # Seed-transition state. self._seed_transition_source = None self._seed_transition_target = None - # Negative-embed cache is text-encoder-specific; swap invalidates it. - self._cached_negative_text = None - self._cached_negative_embed = None - self._cached_negative_pooled = None - if hasattr(self, "embedding_blender"): - try: - self.embedding_blender.cancel_transition() - except Exception: - pass # Drop the pipeline last so any of the above that aliased its # submodules have already been nulled. @@ -976,28 +946,14 @@ def _swap_model(self, new_model_id: str) -> None: self._using_taesd = False self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) self.image_processor = VaeImageProcessor(self.pipe.vae_scale_factor) + self.prompts.attach(self.pipe, self.sdxl) - # Invalidate runtime caches so the next __call__ rebuilds prompt - # embeddings, timestep schedule, and noise buffers against the new - # model — text encoder + UNet config differ between SD1.5 and SDXL. + # Invalidate runtime caches so the next __call__ rebuilds the + # timestep schedule and noise buffers against the new model. + # Prompt-encoder caches are reset by ``prompts.attach()`` above. self._schedule_key = None self._noise_shape = None - self._prompts_key = None - self._cached_base_embed = None - self._previous_prompt_embeddings = None - self._cached_negative_text = None - self._cached_negative_embed = None - self._cached_negative_pooled = None self.prev_image_result = None - self._last_transition_id = None - self._pooled_source = None - self._pooled_target = None - self._transition_total_steps = 0 - if hasattr(self, "embedding_blender"): - try: - self.embedding_blender.cancel_transition() - except Exception: - pass self._cancel_seed_transition() # Build TRT engines for the new model now so the next frame doesn't stall. @@ -1171,418 +1127,20 @@ def _prepare_runtime_state( self._advance_seed_transition() # --- Prompt embeddings & transitions --- - # The key includes spatial dims for SDXL because add_time_ids depend on them. - # When an explicit transition dict is present, its target_prompts is the - # authoritative destination; keying against the incoming source prompts - # would make prompts_changed flap during/after the transition and snap - # steady state back to the source. - key_prompts = prompts - if transition is not None: - target_raw = transition.get("target_prompts") - if target_raw: - key_prompts = self._normalize_prompts(target_raw) - new_prompts_key = self._make_prompts_key( - key_prompts, prompt_interpolation_method, width, height - ) - prompts_changed = new_prompts_key != self._prompts_key - - # Hash the explicit transition dict so repeated sends don't restart it. - transition_id = self._hash_transition(transition) if transition else None - new_explicit_transition = ( - transition_id is not None and transition_id != self._last_transition_id - ) - - started_transition = False - - # Cancel any in-flight transition if a new target has arrived so we - # redirect from the current interpolated position rather than snapping - # after the old transition drains. - if self.embedding_blender.is_transitioning() and ( - new_explicit_transition - or (transition is None and transition_steps > 0 and prompts_changed) - ): - self.embedding_blender.cancel_transition() - self._finish_pooled_transition() - - # 1) Explicit transition (transition dict with target_prompts). - if new_explicit_transition and not self.embedding_blender.is_transitioning(): - transition_config = parse_transition_config(transition) - target_prompts_raw = transition.get("target_prompts", []) - if transition_config.num_steps > 0 and target_prompts_raw: - target_prompts = self._normalize_prompts(target_prompts_raw) - started_transition = self._begin_transition( - target_prompts=target_prompts, - interpolation_method=prompt_interpolation_method, - num_steps=transition_config.num_steps, - temporal_method=transition_config.temporal_interpolation_method, - width=width, - height=height, - ) - self._last_transition_id = transition_id - - # 2) Auto-transition when `prompts` changes with transition_steps > 0. - elif ( - transition is None - and transition_steps > 0 - and prompts_changed - and self._previous_prompt_embeddings is not None - and not self.embedding_blender.is_transitioning() - ): - started_transition = self._begin_transition( - target_prompts=prompts, - interpolation_method=prompt_interpolation_method, - num_steps=transition_steps, - temporal_method=prompt_interpolation_method, - width=width, - height=height, - ) - - # --- Produce prompt_embeds for this frame --- - if self.embedding_blender.is_transitioning(): - next_embedding = self.embedding_blender.get_next_embedding() - if next_embedding is not None: - self.prompt_embeds = next_embedding.repeat(self.batch_size, 1, 1) - self._advance_pooled_transition() - else: - self.prompt_embeds = self._cached_base_embed.repeat( - self.batch_size, 1, 1 - ) - self._finish_pooled_transition() - else: - # Steady state — re-encode if prompts changed and we didn't start a - # transition for it (hard cut path, e.g. transition_steps == 0). - if prompts_changed and not started_transition: - raw_embeds, _ = self._encode_prompts_array( - key_prompts, prompt_interpolation_method - ) - self._cached_base_embed = raw_embeds[0:1] - self._prompts_key = new_prompts_key - # Drop the transition-id guard once the explicit dict is gone so a - # later identical dict is treated as a fresh request. - if transition is None: - self._last_transition_id = None - self._finish_pooled_transition() - self.prompt_embeds = self._cached_base_embed.repeat(self.batch_size, 1, 1) - - # Cache embedding as source for the next transition. - self._previous_prompt_embeddings = self.prompt_embeds[0:1].detach() - - def _make_prompts_key( - self, - prompts: list[dict], - interpolation_method: str, - width: int, - height: int, - ) -> tuple: - """Identity key for a prompts payload; SDXL includes dims for add_time_ids.""" - return ( - tuple((p.get("text", ""), p.get("weight", 1.0)) for p in prompts), - interpolation_method, - (width, height) if self.sdxl else (), - ) - - @staticmethod - def _hash_transition(transition: dict) -> str: - """Stable identity for a transition dict so repeated sends don't restart it.""" - import hashlib - import json - - payload = { - "num_steps": int(transition.get("num_steps", 0) or 0), - "method": transition.get("temporal_interpolation_method", "linear"), - "target": [ - { - "text": p.get("text", "") if isinstance(p, dict) else str(p), - "weight": float(p.get("weight", 1.0)) if isinstance(p, dict) else 1.0, - } - for p in (transition.get("target_prompts") or []) - ], - } - encoded = json.dumps(payload, sort_keys=True).encode("utf-8") - return hashlib.sha1(encoded).hexdigest() - - def _begin_transition( - self, - target_prompts: list[dict], - interpolation_method: str, - num_steps: int, - temporal_method: str, - width: int, - height: int, - ) -> bool: - """Start a temporal transition from the last emitted embedding toward - the target prompts. Eagerly advances `_cached_base_embed` and - `_prompts_key` to the target so steady state lands there when the queue - drains. Returns True if a transition was actually started. - """ - source_embedding = self._previous_prompt_embeddings - if source_embedding is None: - return False - - # Encode and blend target in main embedding space + pooled (SDXL). - target_embed, target_pooled = self._encode_prompts_array( - target_prompts, interpolation_method, apply_sdxl_conditioning=False - ) - target_embed_single = target_embed[0:1] - - # Eagerly move the steady-state cache to the target so once the queue - # drains we land on the target prompts with no bounce-back. - self._cached_base_embed = target_embed_single - self._prompts_key = self._make_prompts_key( - target_prompts, interpolation_method, width, height - ) - - # Slerp is not supported here: upstream EmbeddingBlender.slerp runs - # torch.acos on the native dtype; at fp16 the [-1, 1] clamp isn't - # enough to prevent acos(1.0) → NaN at certain token positions, which - # nukes the whole conditioning tensor. Until that's fixed upstream, - # fall back to linear and warn once. - if temporal_method == "slerp": - if not getattr(self, "_slerp_fallback_warned", False): - print( - "[StreamDiffusion] slerp temporal interpolation is not " - "supported (fp16 NaN in upstream blender); falling back " - "to linear." - ) - self._slerp_fallback_warned = True - temporal_method = "linear" - - self.embedding_blender.start_transition( - source_embedding=source_embedding, - target_embedding=target_embed_single, - num_steps=num_steps, - temporal_interpolation_method=temporal_method, - ) - - # Pooled interpolation runs in lockstep with the main queue for SDXL. - if self.sdxl and target_pooled is not None: - self._pooled_source = ( - self.add_text_embeds.detach().clone() - if hasattr(self, "add_text_embeds") and self.add_text_embeds is not None - else target_pooled.clone() - ) - self._pooled_target = target_pooled.clone() - self._transition_total_steps = max(1, num_steps) - else: - self._pooled_source = None - self._pooled_target = None - self._transition_total_steps = 0 - - # start_transition short-circuits when source ≈ target - # (MIN_EMBEDDING_DIFF_THRESHOLD); report accurately so the caller falls - # to steady state instead of assuming a transition is live. - if not self.embedding_blender.is_transitioning(): - self._finish_pooled_transition() - return False - return True - - def _advance_pooled_transition(self) -> None: - """Linearly interpolate `add_text_embeds` toward the target pooled. - - Uses the blender's remaining queue length to compute progress so - pooled and main embeds stay in lockstep even if start_transition - short-circuited. - """ - if not self.sdxl or self._pooled_target is None: - return - if self._transition_total_steps <= 0: - return - remaining = len(self.embedding_blender._transition_queue) - done_steps = self._transition_total_steps - remaining - t = min(1.0, max(0.0, done_steps / self._transition_total_steps)) - source = ( - self._pooled_source - if self._pooled_source is not None - else self._pooled_target - ) - self.add_text_embeds = torch.lerp(source, self._pooled_target, t).to( - dtype=self.dtype, device=self.device - ) - - def _finish_pooled_transition(self) -> None: - """Snap pooled to the target and clear transition state.""" - if self.sdxl and self._pooled_target is not None: - self.add_text_embeds = self._pooled_target.to( - dtype=self.dtype, device=self.device - ) - self._pooled_source = None - self._pooled_target = None - self._transition_total_steps = 0 - - @staticmethod - def _normalize_prompts(prompts: str | list[str] | list[dict]) -> list[dict]: - """Normalize prompts to list[dict] format.""" - if isinstance(prompts, str): - return [{"text": prompts, "weight": 1.0}] - if isinstance(prompts, list): - if len(prompts) == 0: - return [{"text": "", "weight": 1.0}] - # Check if it's a list of strings - if isinstance(prompts[0], str): - return [{"text": text, "weight": 1.0} for text in prompts] - # Already list[dict] - return prompts - return [{"text": str(prompts), "weight": 1.0}] - - def _encode_single_prompt( - self, prompt_text: str - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """Encode a single prompt string to embeddings. - - Returns: - (prompt_embeds, pooled_embeds) tuple - """ - # Use diffusers' built-in encoding - encoder_output = self.pipe.encode_prompt( - prompt=prompt_text, - device=self.device, - num_images_per_prompt=1, - do_classifier_free_guidance=False, - negative_prompt=None, - ) - prompt_embeds = encoder_output[0] # [1, seq_len, hidden_dim] - pooled_embeds = encoder_output[2] if self.sdxl else None - - return prompt_embeds, pooled_embeds - - def _apply_negative_subtraction( - self, - negative_prompt: str, - negative_prompt_scale: float, - ) -> None: - """Norm-preserving negative subtraction in embedding space. - - Single-pass models (Turbo, DMD2) can't use standard CFG without - doubling UNet cost. Embedding subtraction is the cheap alternative, - but raw ``pos - scale * neg`` blows up the L2 norm of each token, - knocking the conditioning out of the training distribution and the - UNet predicts pure noise. - - We do the subtraction directionally and then renormalize each - token's embedding back to the original L2 norm. Result: direction - shifts away from the negative concept, magnitude is preserved. - Same treatment applied to SDXL's pooled ``add_text_embeds``. - - ``add_time_ids`` are positional / size-derived, not text-derived, - so they stay put. - - Encoded negative is cached on text; empty text or scale 0 is a - no-op. Cache invalidates on model swap (text-encoder dim changes). - """ - if negative_prompt_scale <= 0 or not negative_prompt: - return - if ( - self._cached_negative_text != negative_prompt - or self._cached_negative_embed is None - ): - neg_embed, neg_pooled = self._encode_single_prompt(negative_prompt) - self._cached_negative_text = negative_prompt - self._cached_negative_embed = neg_embed.detach() - self._cached_negative_pooled = ( - neg_pooled.detach() if neg_pooled is not None else None - ) - - self.prompt_embeds = self._norm_preserving_subtract( - self.prompt_embeds, self._cached_negative_embed, negative_prompt_scale - ) - if self.sdxl and self._cached_negative_pooled is not None: - self.add_text_embeds = self._norm_preserving_subtract( - self.add_text_embeds, - self._cached_negative_pooled, - negative_prompt_scale, - ) - - @staticmethod - def _norm_preserving_subtract( - positive: torch.Tensor, negative: torch.Tensor, scale: float - ) -> torch.Tensor: - """Subtract ``scale * negative`` from ``positive`` then rescale to - match positive's original per-row L2 norm. Direction shifts, - magnitude is preserved, UNet stays inside training distribution. - """ - neg = negative.to(device=positive.device, dtype=positive.dtype) - if neg.shape[0] != positive.shape[0]: - neg = neg[:1].expand_as(positive) - # Per-token (or per-row for pooled) norm preservation: keep an - # epsilon to avoid /0 for any zero-magnitude rows. - orig_norm = positive.norm(dim=-1, keepdim=True) - shifted = positive - scale * neg - new_norm = shifted.norm(dim=-1, keepdim=True).clamp(min=1e-6) - return shifted * (orig_norm / new_norm) - - def _encode_prompts_array( - self, - prompt_items: list[dict], - interpolation_method: str = "linear", - apply_sdxl_conditioning: bool = True, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """Encode multiple weighted prompts and blend them. - - Args: - prompt_items: List of {"text": str, "weight": float} - interpolation_method: "linear" or "slerp" - apply_sdxl_conditioning: When True (default, steady-state encode), - also updates `self.add_text_embeds` and `self.add_time_ids` - for SDXL. Set False when encoding a transition target so the - in-flight pooled/time_ids aren't overwritten mid-morph. - - Returns: - (blended_prompt_embeds, blended_pooled_embeds) tuple - """ - if not prompt_items: - prompt_items = [{"text": "", "weight": 1.0}] - - # Extract texts and weights - texts = [item.get("text", "") for item in prompt_items] - weights = [item.get("weight", 1.0) for item in prompt_items] - - # Encode each prompt - all_prompt_embeds = [] - all_pooled_embeds = [] if self.sdxl else None - - for text in texts: - prompt_embeds, pooled_embeds = self._encode_single_prompt(text) - all_prompt_embeds.append(prompt_embeds) - if self.sdxl and pooled_embeds is not None: - all_pooled_embeds.append(pooled_embeds) - - # Blend embeddings - blended_prompt_embeds = self.embedding_blender.blend( - all_prompt_embeds, - weights, - interpolation_method, - cache_result=True, + # All prompt encoding, blending, transition handling, and SDXL aug- + # conditioning lives in the PromptEncoder helper. After this call, + # ``self.prompts.prompt_embeds`` (and add_text_embeds / add_time_ids + # for SDXL) holds the conditioning for this frame. + self.prompts.encode_for_frame( + prompts=prompts, + interpolation_method=prompt_interpolation_method, + width=width, + height=height, + batch_size=self.batch_size, + transition=transition, + transition_steps=transition_steps, ) - blended_pooled_embeds = None - if self.sdxl and all_pooled_embeds: - blended_pooled_embeds = self.embedding_blender.blend( - all_pooled_embeds, - weights, - interpolation_method, - cache_result=False, - ) - - # Handle SDXL additional embeddings (skipped for transition-target - # encoding so the live pooled/time_ids aren't overwritten mid-morph). - if apply_sdxl_conditioning and self.sdxl and blended_pooled_embeds is not None: - self.add_text_embeds = blended_pooled_embeds - original_size = (self.height, self.width) - crops_coords_top_left = (0, 0) - target_size = (self.height, self.width) - text_encoder_projection_dim = int(self.add_text_embeds.shape[-1]) - self.add_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - dtype=self.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - - return blended_prompt_embeds.repeat( - self.batch_size, 1, 1 - ), blended_pooled_embeds - def _set_timesteps(self, num_inference_steps: int, strength: float): """Set the timesteps for the diffusion process. @@ -1768,31 +1326,6 @@ def _cancel_seed_transition(self) -> None: self._seed_transition_progress = 0 self._seed_transition_total = 0 - def _get_add_time_ids( - self, - original_size, - crops_coords_top_left, - target_size, - dtype, - text_encoder_projection_dim=None, - ): - """Get additional time IDs for SDXL.""" - add_time_ids = list(original_size + crops_coords_top_left + target_size) - - passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) - + text_encoder_projection_dim - ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features - - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, " - f"but a vector of {passed_add_embed_dim} was created." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids def _encode_image( self, image_tensors: torch.Tensor, add_noise: bool = True @@ -1883,7 +1416,7 @@ def _unet_step( down_block_res_samples, mid_block_res_sample = self.controlnet( x_t_latent_plus_uc, t_list, - encoder_hidden_states=self.prompt_embeds, + encoder_hidden_states=self.prompts.prompt_embeds, controlnet_cond=cond_image, conditioning_scale=self.controlnet_conditioning_scale, return_dict=False, @@ -1892,7 +1425,7 @@ def _unet_step( model_pred = self.unet( x_t_latent_plus_uc, t_list, - encoder_hidden_states=self.prompt_embeds, + encoder_hidden_states=self.prompts.prompt_embeds, added_cond_kwargs=added_cond_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, @@ -1945,8 +1478,8 @@ def _predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: t_list = self.sub_timesteps_tensor if self.sdxl: batch = x_t_latent.shape[0] - te = self.add_text_embeds.to(self.device) - ti = self.add_time_ids.to(self.device) + te = self.prompts.add_text_embeds.to(self.device) + ti = self.prompts.add_time_ids.to(self.device) if te.shape[0] != batch: te = te[:1].expand(batch, -1) if ti.shape[0] != batch: @@ -1971,8 +1504,8 @@ def _predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: ) if self.sdxl: added_cond_kwargs = { - "text_embeds": self.add_text_embeds.to(self.device), - "time_ids": self.add_time_ids.to(self.device), + "text_embeds": self.prompts.add_text_embeds.to(self.device), + "time_ids": self.prompts.add_time_ids.to(self.device), } x_0_pred, _model_pred = self._unet_step( x_t_latent, t, idx=idx, added_cond_kwargs=added_cond_kwargs @@ -2043,7 +1576,7 @@ def __call__(self, **kwargs) -> dict: prompts = kwargs.get("prompts", []) # Normalize to list[dict] format prompts = ( - self._normalize_prompts(prompts) + normalize_prompts(prompts) if prompts else [{"text": "", "weight": 1.0}] ) @@ -2207,7 +1740,7 @@ def get_param(key, default): # are settled (including any prompt transition / SDXL pooled # update). Acts on whatever this frame's conditioning happens to # be, which is the right thing during transitions too. - self._apply_negative_subtraction(negative_prompt, negative_prompt_scale) + self.prompts.apply_negative_subtraction(negative_prompt, negative_prompt_scale) frame = None diff --git a/src/scope_streamdiffusion/prompt_encoder.py b/src/scope_streamdiffusion/prompt_encoder.py new file mode 100644 index 0000000..0ff3b1d --- /dev/null +++ b/src/scope_streamdiffusion/prompt_encoder.py @@ -0,0 +1,527 @@ +"""Prompt encoding, blending, transitions, and negative subtraction. + +Owns everything text-encoder-related so the main pipeline doesn't have to. +The pipeline holds an instance as ``self.prompts`` and calls +``encode_for_frame()`` once per ``__call__``, then optionally +``apply_negative_subtraction()``. Inference reads the produced embeds via +``self.prompts.prompt_embeds`` / ``add_text_embeds`` / ``add_time_ids``. + +Lifecycle: ``attach(pipe, sdxl)`` after a model load (or model swap) wires +us to the live pipeline and resets all text-encoder-dependent caches. +``reset_caches()`` is the lighter version called during teardown without +re-attaching. +""" + +from __future__ import annotations + +import hashlib +import json +from typing import Any, Optional + +import torch + +from scope.core.pipelines.blending import EmbeddingBlender, parse_transition_config + + +def normalize_prompts(prompts: str | list[str] | list[dict]) -> list[dict]: + """Coerce a prompts payload into ``list[{"text": str, "weight": float}]``. + + Module-level so the pipeline can call it on raw kwargs before any + PromptEncoder method that expects normalized input. + """ + if isinstance(prompts, str): + return [{"text": prompts, "weight": 1.0}] + if isinstance(prompts, list): + if len(prompts) == 0: + return [{"text": "", "weight": 1.0}] + if isinstance(prompts[0], str): + return [{"text": text, "weight": 1.0} for text in prompts] + return prompts + return [{"text": str(prompts), "weight": 1.0}] + + +class PromptEncoder: + """Per-frame prompt encoding with transitions, caching, and negative + subtraction. + + Attach to a loaded ``DiffusionPipeline`` via ``attach(pipe, sdxl)``; + re-attach on every model swap because the text encoder identity (and + hidden dim, for SDXL) changes between SD 1.5 and SDXL. + """ + + def __init__(self, device: torch.device, dtype: torch.dtype) -> None: + self.device = device + self.dtype = dtype + + # Live pipe references — set via ``attach()``. Until then the encoder + # is inert; calling encode_for_frame would raise. + self.pipe: Any = None + self.sdxl: bool = False + + self.embedding_blender = EmbeddingBlender(device=device, dtype=dtype) + + # Current-frame outputs the inference path reads. Inference accesses + # ``self.prompts.prompt_embeds`` / ``add_text_embeds`` / ``add_time_ids``. + self.prompt_embeds: Optional[torch.Tensor] = None + self.add_text_embeds: Optional[torch.Tensor] = None + self.add_time_ids: Optional[torch.Tensor] = None + + # Per-text-encoder caches. All invalidate on attach() and reset_caches(). + self._cached_base_embed: Optional[torch.Tensor] = None + self._previous_prompt_embeddings: Optional[torch.Tensor] = None + self._prompts_key: Optional[tuple] = None + + # Negative-prompt cache. + self._cached_negative_text: Optional[str] = None + self._cached_negative_embed: Optional[torch.Tensor] = None + self._cached_negative_pooled: Optional[torch.Tensor] = None + + # Pooled (SDXL) transition state — main embedding queue lives in + # ``embedding_blender``; pooled is interpolated linearly in lockstep. + self._pooled_source: Optional[torch.Tensor] = None + self._pooled_target: Optional[torch.Tensor] = None + self._transition_total_steps: int = 0 + + # Transition-id guard so repeated identical explicit transition dicts + # don't restart the transition every frame. + self._last_transition_id: Optional[str] = None + + # One-shot warning when slerp is requested for temporal interpolation + # (fp16 NaN bug in upstream blender). We fall back to linear silently + # after the first warn. + self._slerp_fallback_warned: bool = False + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def attach(self, pipe: Any, sdxl: bool) -> None: + """Wire to a loaded pipeline and reset all caches. + + Call from ``_ensure_pipe_loaded`` and ``_swap_model`` after the new + pipe is available; SD 1.5 and SDXL have different text-encoder + hidden dims, so cached embeds from the prior model would mismatch. + """ + self.pipe = pipe + self.sdxl = sdxl + self.reset_caches() + + def reset_caches(self) -> None: + """Drop every cached tensor and cancel any in-flight transition.""" + self._cached_base_embed = None + self._previous_prompt_embeddings = None + self._prompts_key = None + self._cached_negative_text = None + self._cached_negative_embed = None + self._cached_negative_pooled = None + self._pooled_source = None + self._pooled_target = None + self._transition_total_steps = 0 + self._last_transition_id = None + try: + self.embedding_blender.cancel_transition() + except Exception: + pass + + # ------------------------------------------------------------------ + # Per-frame encode + # ------------------------------------------------------------------ + + def encode_for_frame( + self, + prompts: list[dict], + interpolation_method: str, + width: int, + height: int, + batch_size: int, + transition: Optional[dict] = None, + transition_steps: int = 0, + ) -> None: + """Update ``self.prompt_embeds`` (and SDXL extras) for this frame. + + Handles: prompts-changed re-encoding, explicit transition-dict + starts, auto-transitions on prompt change, blender advance for + in-flight transitions, and pooled (SDXL) lockstep lerp. + """ + # When an explicit transition dict is present, its target_prompts is + # the authoritative destination; keying against the source prompts + # would make prompts_changed flap during/after the transition and + # snap steady state back to the source. + key_prompts = prompts + if transition is not None: + target_raw = transition.get("target_prompts") + if target_raw: + key_prompts = normalize_prompts(target_raw) + new_prompts_key = self._make_prompts_key( + key_prompts, interpolation_method, width, height + ) + prompts_changed = new_prompts_key != self._prompts_key + + transition_id = self._hash_transition(transition) if transition else None + new_explicit_transition = ( + transition_id is not None and transition_id != self._last_transition_id + ) + + started_transition = False + + # Cancel any in-flight transition if a new target has arrived so we + # redirect from the current interpolated position rather than + # snapping after the old transition drains. + if self.embedding_blender.is_transitioning() and ( + new_explicit_transition + or (transition is None and transition_steps > 0 and prompts_changed) + ): + self.embedding_blender.cancel_transition() + self._finish_pooled_transition() + + if new_explicit_transition and not self.embedding_blender.is_transitioning(): + transition_config = parse_transition_config(transition) + target_prompts_raw = transition.get("target_prompts", []) + if transition_config.num_steps > 0 and target_prompts_raw: + target_prompts = normalize_prompts(target_prompts_raw) + started_transition = self._begin_transition( + target_prompts=target_prompts, + interpolation_method=interpolation_method, + num_steps=transition_config.num_steps, + temporal_method=transition_config.temporal_interpolation_method, + width=width, + height=height, + ) + self._last_transition_id = transition_id + elif ( + transition is None + and transition_steps > 0 + and prompts_changed + and self._previous_prompt_embeddings is not None + and not self.embedding_blender.is_transitioning() + ): + started_transition = self._begin_transition( + target_prompts=prompts, + interpolation_method=interpolation_method, + num_steps=transition_steps, + temporal_method=interpolation_method, + width=width, + height=height, + ) + + # --- Produce prompt_embeds for this frame --- + if self.embedding_blender.is_transitioning(): + next_embedding = self.embedding_blender.get_next_embedding() + if next_embedding is not None: + self.prompt_embeds = next_embedding.repeat(batch_size, 1, 1) + self._advance_pooled_transition() + else: + self.prompt_embeds = self._cached_base_embed.repeat(batch_size, 1, 1) + self._finish_pooled_transition() + else: + # Steady state — re-encode if prompts changed and we didn't start + # a transition for it (hard cut path, e.g. transition_steps == 0). + if prompts_changed and not started_transition: + raw_embeds, _ = self._encode_prompts_array( + key_prompts, + interpolation_method, + width=width, + height=height, + batch_size=batch_size, + ) + self._cached_base_embed = raw_embeds[0:1] + self._prompts_key = new_prompts_key + # Drop the transition-id guard once the explicit dict is gone so + # a later identical dict is treated as a fresh request. + if transition is None: + self._last_transition_id = None + self._finish_pooled_transition() + self.prompt_embeds = self._cached_base_embed.repeat(batch_size, 1, 1) + + # Cache embedding as source for the next transition. + self._previous_prompt_embeddings = self.prompt_embeds[0:1].detach() + + # ------------------------------------------------------------------ + # Negative-prompt subtraction (single-pass models) + # ------------------------------------------------------------------ + + def apply_negative_subtraction( + self, negative_prompt: str, negative_prompt_scale: float + ) -> None: + """Norm-preserving negative subtraction in embedding space. + + Single-pass models (Turbo, DMD2) can't use standard CFG without + doubling UNet cost. Embedding subtraction is the cheap alternative, + but raw ``pos - scale * neg`` blows up the L2 norm of each token, + knocking the conditioning out of the training distribution and + the UNet predicts pure noise. + + We do the subtraction directionally and then renormalize each + token's embedding back to the original L2 norm. Same treatment + applied to SDXL's pooled ``add_text_embeds``. ``add_time_ids`` + are positional / size-derived, not text-derived, so they stay put. + + Encoded negative is cached on text; empty text or scale 0 is a + no-op. Cache invalidates on model swap (text-encoder dim changes). + """ + if negative_prompt_scale <= 0 or not negative_prompt: + return + if self.prompt_embeds is None: + return + if ( + self._cached_negative_text != negative_prompt + or self._cached_negative_embed is None + ): + neg_embed, neg_pooled = self._encode_single_prompt(negative_prompt) + self._cached_negative_text = negative_prompt + self._cached_negative_embed = neg_embed.detach() + self._cached_negative_pooled = ( + neg_pooled.detach() if neg_pooled is not None else None + ) + + self.prompt_embeds = _norm_preserving_subtract( + self.prompt_embeds, self._cached_negative_embed, negative_prompt_scale + ) + if self.sdxl and self._cached_negative_pooled is not None and self.add_text_embeds is not None: + self.add_text_embeds = _norm_preserving_subtract( + self.add_text_embeds, + self._cached_negative_pooled, + negative_prompt_scale, + ) + + # ------------------------------------------------------------------ + # Internals + # ------------------------------------------------------------------ + + def _make_prompts_key( + self, + prompts: list[dict], + interpolation_method: str, + width: int, + height: int, + ) -> tuple: + return ( + tuple((p.get("text", ""), p.get("weight", 1.0)) for p in prompts), + interpolation_method, + (width, height) if self.sdxl else (), + ) + + @staticmethod + def _hash_transition(transition: dict) -> str: + payload = { + "num_steps": int(transition.get("num_steps", 0) or 0), + "method": transition.get("temporal_interpolation_method", "linear"), + "target": [ + { + "text": p.get("text", "") if isinstance(p, dict) else str(p), + "weight": float(p.get("weight", 1.0)) if isinstance(p, dict) else 1.0, + } + for p in (transition.get("target_prompts") or []) + ], + } + encoded = json.dumps(payload, sort_keys=True).encode("utf-8") + return hashlib.sha1(encoded).hexdigest() + + def _begin_transition( + self, + target_prompts: list[dict], + interpolation_method: str, + num_steps: int, + temporal_method: str, + width: int, + height: int, + ) -> bool: + source_embedding = self._previous_prompt_embeddings + if source_embedding is None: + return False + + target_embed, target_pooled = self._encode_prompts_array( + target_prompts, + interpolation_method, + apply_sdxl_conditioning=False, + width=width, + height=height, + batch_size=1, + ) + target_embed_single = target_embed[0:1] + + # Eagerly move steady-state cache to the target so once the queue + # drains we land on the target prompts with no bounce-back. + self._cached_base_embed = target_embed_single + self._prompts_key = self._make_prompts_key( + target_prompts, interpolation_method, width, height + ) + + # Slerp NaNs at fp16 in the upstream blender (acos at the [-1, 1] + # boundary) — fall back to linear with a one-shot warn. + if temporal_method == "slerp": + if not self._slerp_fallback_warned: + print( + "[StreamDiffusion] slerp temporal interpolation is not " + "supported (fp16 NaN in upstream blender); falling back " + "to linear." + ) + self._slerp_fallback_warned = True + temporal_method = "linear" + + self.embedding_blender.start_transition( + source_embedding=source_embedding, + target_embedding=target_embed_single, + num_steps=num_steps, + temporal_interpolation_method=temporal_method, + ) + + if self.sdxl and target_pooled is not None: + self._pooled_source = ( + self.add_text_embeds.detach().clone() + if self.add_text_embeds is not None + else target_pooled.clone() + ) + self._pooled_target = target_pooled.clone() + self._transition_total_steps = max(1, num_steps) + else: + self._pooled_source = None + self._pooled_target = None + self._transition_total_steps = 0 + + # start_transition short-circuits when source ≈ target + # (MIN_EMBEDDING_DIFF_THRESHOLD); report accurately so the caller + # falls to steady state instead of assuming a transition is live. + if not self.embedding_blender.is_transitioning(): + self._finish_pooled_transition() + return False + return True + + def _advance_pooled_transition(self) -> None: + """Linearly interpolate ``add_text_embeds`` toward the target pooled.""" + if not self.sdxl or self._pooled_target is None: + return + if self._transition_total_steps <= 0: + return + remaining = len(self.embedding_blender._transition_queue) + done_steps = self._transition_total_steps - remaining + t = min(1.0, max(0.0, done_steps / self._transition_total_steps)) + source = ( + self._pooled_source + if self._pooled_source is not None + else self._pooled_target + ) + self.add_text_embeds = torch.lerp(source, self._pooled_target, t).to( + dtype=self.dtype, device=self.device + ) + + def _finish_pooled_transition(self) -> None: + """Snap pooled to the target and clear transition state.""" + if self.sdxl and self._pooled_target is not None: + self.add_text_embeds = self._pooled_target.to( + dtype=self.dtype, device=self.device + ) + self._pooled_source = None + self._pooled_target = None + self._transition_total_steps = 0 + + def _encode_single_prompt( + self, prompt_text: str + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + encoder_output = self.pipe.encode_prompt( + prompt=prompt_text, + device=self.device, + num_images_per_prompt=1, + do_classifier_free_guidance=False, + negative_prompt=None, + ) + prompt_embeds = encoder_output[0] + pooled_embeds = encoder_output[2] if self.sdxl else None + return prompt_embeds, pooled_embeds + + def _encode_prompts_array( + self, + prompt_items: list[dict], + interpolation_method: str, + *, + width: int, + height: int, + batch_size: int, + apply_sdxl_conditioning: bool = True, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if not prompt_items: + prompt_items = [{"text": "", "weight": 1.0}] + + texts = [item.get("text", "") for item in prompt_items] + weights = [item.get("weight", 1.0) for item in prompt_items] + + all_prompt_embeds = [] + all_pooled_embeds = [] if self.sdxl else None + + for text in texts: + prompt_embeds, pooled_embeds = self._encode_single_prompt(text) + all_prompt_embeds.append(prompt_embeds) + if self.sdxl and pooled_embeds is not None: + all_pooled_embeds.append(pooled_embeds) + + blended_prompt_embeds = self.embedding_blender.blend( + all_prompt_embeds, + weights, + interpolation_method, + cache_result=True, + ) + + blended_pooled_embeds = None + if self.sdxl and all_pooled_embeds: + blended_pooled_embeds = self.embedding_blender.blend( + all_pooled_embeds, + weights, + interpolation_method, + cache_result=False, + ) + + # SDXL aug-conditioning: write add_text_embeds and add_time_ids for + # the steady-state encode. Skipped for transition-target encodes so + # the in-flight pooled / time_ids aren't overwritten mid-morph. + if apply_sdxl_conditioning and self.sdxl and blended_pooled_embeds is not None: + self.add_text_embeds = blended_pooled_embeds + self.add_time_ids = self._compute_add_time_ids( + width=width, height=height, dtype=self.dtype + ) + + return blended_prompt_embeds.repeat(batch_size, 1, 1), blended_pooled_embeds + + def _compute_add_time_ids( + self, width: int, height: int, dtype: torch.dtype + ) -> torch.Tensor: + """Build SDXL aug-conditioning time_ids from the current dims. + + Reads ``self.pipe.unet.config.addition_time_embed_dim`` and + ``self.pipe.unet.add_embedding.linear_1.in_features`` to validate + the vector length matches what the UNet expects. Raises if not. + """ + original_size = (height, width) + crops_coords_top_left = (0, 0) + target_size = (height, width) + text_encoder_projection_dim = int(self.add_text_embeds.shape[-1]) + + add_time_ids_list = list(original_size + crops_coords_top_left + target_size) + unet = self.pipe.unet + passed_add_embed_dim = ( + unet.config.addition_time_embed_dim * len(add_time_ids_list) + + text_encoder_projection_dim + ) + expected_add_embed_dim = unet.add_embedding.linear_1.in_features + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length " + f"{expected_add_embed_dim}, but a vector of " + f"{passed_add_embed_dim} was created." + ) + return torch.tensor([add_time_ids_list], dtype=dtype) + + +def _norm_preserving_subtract( + positive: torch.Tensor, negative: torch.Tensor, scale: float +) -> torch.Tensor: + """Subtract ``scale * negative`` then rescale to match positive's + original per-row L2 norm. Direction shifts, magnitude is preserved, + UNet stays inside training distribution. + """ + neg = negative.to(device=positive.device, dtype=positive.dtype) + if neg.shape[0] != positive.shape[0]: + neg = neg[:1].expand_as(positive) + orig_norm = positive.norm(dim=-1, keepdim=True) + shifted = positive - scale * neg + new_norm = shifted.norm(dim=-1, keepdim=True).clamp(min=1e-6) + return shifted * (orig_norm / new_norm) From 468e8378b7db41bbcb37e1cc2e26e08531f2b74e Mon Sep 17 00:00:00 2001 From: Chris Justiz Roush Date: Sat, 9 May 2026 08:40:55 -0700 Subject: [PATCH 26/26] docs(plans): hand-off plans for refactor, SDXL ControlNet, and LoRA Three self-contained plans another agent can execute without the originating conversation: pipeline.py decomposition (TRTLifecycle / ModelLoader / InferenceCore), SDXL ControlNet wiring (eager + TRT), and LoRA support including the TRT refit path for live weight swaps. Co-Authored-By: Claude Opus 4.7 --- docs/plans/LORA_PLAN.md | 155 +++++++++++++++++++++++++++++ docs/plans/README.md | 19 ++++ docs/plans/REFACTOR_PLAN.md | 116 +++++++++++++++++++++ docs/plans/SDXL_CONTROLNET_PLAN.md | 63 ++++++++++++ 4 files changed, 353 insertions(+) create mode 100644 docs/plans/LORA_PLAN.md create mode 100644 docs/plans/README.md create mode 100644 docs/plans/REFACTOR_PLAN.md create mode 100644 docs/plans/SDXL_CONTROLNET_PLAN.md diff --git a/docs/plans/LORA_PLAN.md b/docs/plans/LORA_PLAN.md new file mode 100644 index 0000000..34d9e34 --- /dev/null +++ b/docs/plans/LORA_PLAN.md @@ -0,0 +1,155 @@ +# Plan: LoRA Support + +## Context +`schema.py` has `supports_lora = True` already. `pipeline.py` has stub `load_lora` and `fuse_lora` methods that aren't called. Scope has a `download_lora` endpoint already — verify in the parent repo `daydreamlive-scope`. + +## Schema Changes (`src/scope_streamdiffusion/schema.py`) + +Add a `LoraSpec` model and a `loras` list field on `StreamDiffusionConfig`: +```python +class LoraSpec(BaseModel): + repo_id: str # HF repo or local path + weight_name: Optional[str] = None # for repos with multiple files + adapter_name: str # diffusers adapter name; required for stack/swap + scale: float = 1.0 # 0..2 typical + +class StreamDiffusionConfig(BaseModel): + ... + loras: list[LoraSpec] = Field( + default_factory=list, + json_schema_extra=ui_field_config(order=..., label="LoRAs"), + ) +``` + +Order field: place after model selection but before ControlNet config. Reuse Scope's existing LoRA picker UI if one exists in the parent repo's other pipelines. + +## Loader Wiring (`ModelLoader` post-refactor, or `pipeline.py` if pre-refactor) + +LoRAs attach via `pipe.load_lora_weights(repo_id, weight_name=..., adapter_name=...)`. After loading all requested adapters, call `pipe.set_adapters([names...], adapter_weights=[scales...])`. + +**Lifecycle order:** +1. `ModelLoader._load_model` loads the diffusers pipe. +2. SDXL fp16 VAE swap. +3. **LoRA attach.** Iterate `config.loras`, call `pipe.load_lora_weights` per spec. +4. `pipe.set_adapters(...)` with names + scales. +5. **Do NOT call `fuse_lora`.** Keep adapters live so scales/swaps work without reload. Only fuse before TRT compilation (next step). +6. PromptEncoder.attach, ControlNetHandler.attach. +7. TRTLifecycle.attach. **If TRT is enabled, fuse_lora here** before compiling — TRT bakes weights at compile time, so fused-then-compiled is the only correct path (unless using the refit path — see TRT Refit below). + +## Change Detection +Track a "LoRA signature" (sorted tuple of `(repo_id, weight_name, adapter_name, scale)`) on the model loader. On `_swap_model` / `_ensure_pipe_loaded`: +- Same model + same LoRA signature → no-op. +- Same model + different LoRA signature, **eager mode** → call `pipe.unload_lora_weights()`, then re-attach. Cheap, no reload needed. +- Same model + different LoRA signature, **TRT mode without refit** → full reload required. Treat this as a model swap. Surface the cost in the UI — recompiling SDXL UNet is 10+ minutes. +- Same model + different LoRA signature, **TRT mode with refit-capable engine** → refit (see below). 1–10s instead of 10+ min. +- Scale-only change with same adapters loaded, **eager mode** → `pipe.set_adapters(...)` with new weights. No reload. +- Scale-only change, **TRT non-refit** → full reload. **TRT refit** → refit. + +## Cache Coordination with TRT +The TRT cache key (in `_trt_cache.py` / `trt_engines.py`) must include the LoRA signature. Otherwise two different LoRA stacks will collide on the same cache slot and you'll silently load the wrong engine. Hash the sorted signature into the engine filename. + +When using refit, the cache key for the *engine* uses only the base model + refit-capable flag (LoRA signature does NOT affect the engine identity). The fused weights are applied at refit time. The LoRA signature is tracked separately as the "currently refit-applied state" and used only for change detection. + +## Scope Integration +The user mentioned Scope has a `download_lora` endpoint already. Find it in the parent repo (`daydreamlive-scope`) and confirm: +- Whether it returns a local path or a repo_id. +- Whether the UI already has a LoRA picker in other pipelines that we can match. +- Whether LoRA management is per-pipeline or global. + +Match the existing pattern. Don't invent a new one. + +## Testing +1. Eager SD-Turbo + a single style LoRA from CivitAI (download via Scope, attach via config). +2. Live scale change 0.0 → 1.0 → 1.5. Should update without reload. +3. Live LoRA swap (different adapter). Should be fast (unload + load), no model reload. +4. Toggle TRT on with LoRAs attached. Confirm fuse-then-compile path runs and engine is cached with LoRA-aware key. +5. Live LoRA change with TRT on (non-refit) — confirm full reload + recompile triggers and completes. +6. Stack 2 LoRAs simultaneously. Verify `set_adapters` with multiple names works and scales are independent. +7. SDXL + LoRA (eager and TRT). + +## Out of Scope (defer) +- Multi-LoRA blending UI beyond stack-with-scales. +- LoRA training or merging. + +--- + +# Addendum: TRT Refit Path for LoRAs + +The base plan above says "LoRA change with TRT → full reload" — correct but expensive (10+ min for SDXL). TensorRT's **refit** feature lets you update weights in a built engine without rebuilding it. This is the right answer for live LoRA swaps on TRT. + +## What Refit Buys You +- Engine structure (layers, shapes, fusions) stays compiled. +- Only the weight tensors get re-uploaded. +- Typical refit time: **1–10 seconds** for SDXL UNet vs. 10+ minutes for full rebuild. +- Works for scale changes AND adapter swaps, as long as the LoRA targets the same layers. + +## Build-Time Requirements +The engine must be compiled with refit enabled. Two flags in the TRT builder: +- `BuilderFlag.REFIT` — required. +- `BuilderFlag.STRIP_PLAN` (TRT 10+) — optional but recommended; strips weights from the engine file so you ship a smaller cache and refit at load. Trade-off: load is no longer instant — must refit before first inference. + +**Decision:** use `REFIT` only (not `STRIP_PLAN`). Cached engines stay self-sufficient; refit only runs when LoRAs change. The size penalty for `REFIT`-only is small (~5%) and inference perf is unchanged. + +## Implementation Sketch + +### Builder changes (`src/scope_streamdiffusion/_trt/builder.py` or wherever the network config lives) +Add `network_flags` / `builder_config.flags |= 1 << int(trt.BuilderFlag.REFIT)` to all UNet builders (`build_unet_engine`, `build_unet_sdxl_engine`, `build_unet_with_control_engine`, and the new SDXL+control variant). VAE/TAESD/ControlNet engines don't need it — LoRAs target UNet only (cross-attention layers). + +### Refit at runtime (new method on `TRTLifecycle`) +```python +def refit_lora(self, lora_signature): + # 1. Load base UNet weights into a temporary diffusers UNet (CPU OK). + # 2. Apply LoRA stack to that UNet (load_lora_weights + set_adapters + fuse_lora). + # 3. Use trt.Refitter to push the fused weights into the live engine. + # 4. Discard the temp UNet. +``` + +The refitter API: +```python +refitter = trt.Refitter(self._trt_unet_engine.engine, TRT_LOGGER) +for name in refitter.get_all_weights(): # or get_missing() + weights = fused_unet_state_dict[map_trt_name_to_torch(name)] + refitter.set_named_weights(name, weights) +assert refitter.refit_cuda_engine() +``` + +### Name mapping (the hard part) +TRT weight names come from the ONNX export and don't match diffusers' `state_dict` keys 1:1. You need a map. Two approaches: +1. **Build the map at compile time.** During ONNX export, record the `(torch_param_name → onnx_initializer_name)` mapping and persist it next to the engine in the cache. At refit time, load the map and translate. +2. **Reconstruct the map at refit time** by re-running ONNX export on a dummy UNet with the same architecture and reading the resulting initializer names. Slower but simpler. + +Recommend approach 1. Save the map as `.refit_map.json` alongside the engine file. The TRT cache key already covers architecture variants, so the map is valid for the engine. + +### Cache key change +Refit-capable engines and refit-incapable engines are different artifacts. Add `refit=True` to the cache key path component so old (non-refit) cached engines aren't reused. Old engines stay valid for non-LoRA streams; new ones get used when LoRAs are configured. + +## Updated LoRA Lifecycle (replaces "full reload" branch in the base plan) + +| Change | Eager | TRT (refit-capable engine) | TRT (legacy non-refit engine) | +|---|---|---|---| +| Scale only | `set_adapters` | refit | rebuild | +| Adapter swap, same layers | unload + load + `set_adapters` | refit | rebuild | +| Adapter swap, different layers | same | refit (zero out unused) | rebuild | +| Add ControlNet, etc. | rebuild pipeline state | rebuild engine | rebuild | + +"Different layers" case: if a new LoRA targets layers the previous one didn't, those original-weight slots need to be restored to the base model's weights during refit. The fused-state-dict approach handles this naturally since the temp UNet is built from base weights + new LoRA stack. + +## When to Skip Refit +- First time TRT is enabled with LoRAs configured → fuse first, then build (current plan). Refit only helps on subsequent changes. +- Engine compiled before this feature lands → fall back to rebuild. Detect via the cache-key version bump. +- Refitter reports missing weights → log and rebuild. Don't run a partially-refit engine. + +## Testing (Refit-specific) +1. Cold start with one LoRA + TRT. Confirm engine builds with `REFIT` flag (check `engine.refittable`). +2. Live scale change 0.0 → 1.5. Should complete in <10s, no recompile log. +3. Live adapter swap (different LoRA, same target layers). Same speed. +4. Live adapter swap to a LoRA that targets *additional* layers. Confirm refit covers all weights and output is correct. +5. Stress test: 20 rapid scale/adapter changes. Memory should stay stable (the temp UNet must actually free). +6. SDXL refit specifically — name-map size is larger; verify no missing weights. + +## Risk +- TRT refit name mapping is fiddly. Budget time for debugging the ONNX-name ↔ torch-name mapping. +- Some TRT optimizations bake constants. If a LoRA's effective rank changes the optimal kernel choice, refit produces correct but suboptimal output. Acceptable trade-off. +- `STRIP_PLAN` is tempting but adds first-inference latency. Skip it. + +This makes live LoRA swaps on TRT actually viable instead of "technically supported but never used." diff --git a/docs/plans/README.md b/docs/plans/README.md new file mode 100644 index 0000000..c718541 --- /dev/null +++ b/docs/plans/README.md @@ -0,0 +1,19 @@ +# Plans + +Hand-off plans for the next round of work on `sd-multi-model`. Each plan is self-contained and intended to be executed by another agent without needing the originating conversation. + +- [REFACTOR_PLAN.md](REFACTOR_PLAN.md) — decompose `pipeline.py` into helper classes (`TRTLifecycle`, `ModelLoader`, `InferenceCore`) following the `PromptEncoder` / `ControlNetHandler` pattern. +- [SDXL_CONTROLNET_PLAN.md](SDXL_CONTROLNET_PLAN.md) — wire SDXL ControlNet through the eager and TRT paths (currently raises `NotImplementedError` on TRT for SDXL). +- [LORA_PLAN.md](LORA_PLAN.md) — schema, loader wiring, change detection, and the TRT refit path for live LoRA swaps. + +## Recommended order +1. Refactor (lands first — the LoRA plan assumes the `ModelLoader` and `TRTLifecycle` helpers exist). +2. SDXL ControlNet (independent of LoRA). +3. LoRA (depends on refactor; benefits from but does not require ControlNet work). + +## Architectural pattern (read first) +All three plans assume the helper-class composition pattern. The canonical examples in the repo: +- `src/scope_streamdiffusion/prompt_encoder.py` +- `src/scope_streamdiffusion/controlnet.py` + +Helpers take `(device, dtype)` at construction, gain a pipe back-reference via `attach(pipe, sdxl)`, and expose runtime state as instance attributes. diff --git a/docs/plans/REFACTOR_PLAN.md b/docs/plans/REFACTOR_PLAN.md new file mode 100644 index 0000000..34b3382 --- /dev/null +++ b/docs/plans/REFACTOR_PLAN.md @@ -0,0 +1,116 @@ +# Refactor Plan: pipeline.py Decomposition + +## Goal +Reduce `pipeline.py` from ~1900 lines to a thin orchestrator (~400 lines) by extracting cohesive responsibilities into helper classes. Follow the pattern established by `PromptEncoder` (commit `b1b5478`) and the existing `ControlNetHandler`. + +## Architectural Pattern (non-negotiable — already established) +- Helper class lives in its own module under `src/scope_streamdiffusion/`. +- Constructor takes `(device, dtype)` and any static config. +- `attach(pipe, sdxl: bool)` lifecycle method called from `_ensure_pipe_loaded` and `_swap_model` after the diffusers pipeline is loaded. Helpers re-bind to the new pipe here. +- Helper owns its caches and exposes runtime state as instance attributes the pipeline reads through (e.g., `self.prompts.prompt_embeds`). +- Helper has explicit `reset_caches()` / `release()` methods called on model swap or teardown. +- **No mixins.** Composition only. The user explicitly rejected mixins. + +## Reference Files +- `src/scope_streamdiffusion/prompt_encoder.py` — the template. Read this first. +- `src/scope_streamdiffusion/controlnet.py` — second example of the pattern. +- `src/scope_streamdiffusion/pipeline.py` — the source to extract from. + +## Extraction Order (do them in this order, commit between each) + +### Extraction 1: `TRTLifecycle` → `src/scope_streamdiffusion/trt_lifecycle.py` +**Methods to move:** +- `_ensure_trt_taesd` +- `_ensure_trt_controlnet` +- `_ensure_trt_unet` +- `_setup_trt` +- `_reset_trt_state` +- `_set_acceleration_mode` +- `_deactivate_trt` +- `_trt_setup_args_from_config` + +**Compromise to accept:** these methods currently mutate `self.unet`, `self.controlnet`, `self.vae`, `self._taesd_vae` directly. Don't fight it — give the helper a back-reference to the pipeline (`self.pipe = pipe` set in `attach()`) and have it write through. The win is moving 500 lines of TRT-specific lifecycle code out of the orchestrator, not pretending TRT doesn't touch pipeline state. + +**Caches the helper owns:** `_trt_taesd_paths`, `_trt_controlnet_paths`, `_trt_unet_paths`, `_trt_unet_engine`, `_trt_controlnet_engine`, the `acceleration_mode` last-applied value, and the `_trt_cache` adapter handles. The module-scope `_trt_cache._CACHE` stays where it is — it must survive plugin reinit. + +**Pipeline-side after extraction:** +```python +self.trt = TRTLifecycle(device=self.device, dtype=self.dtype) +# in _ensure_pipe_loaded / _swap_model: +self.trt.attach(self, self.sdxl) +# in __call__'s pre-inference setup: +self.trt.ensure_engines(config, want_control=...) +``` + +**Testing checkpoint after this extraction:** +1. Cold-load each model with `acceleration_mode="trt"`: SD-Turbo, SDXL-Turbo, DMD2. +2. Live-swap from SD-Turbo → SDXL-Turbo → DMD2 → SD-Turbo. Confirm no `context=None` crashes (the band-aid `_ensure_activated` in `_trt/engine.py` should still cover this; if it triggers, that's a regression in the swap teardown path). +3. Toggle ControlNet on SD1.5 + SD-Turbo while running. +4. Switch `acceleration_mode` between `none` / `xformers` / `trt` mid-stream. + +--- + +### Extraction 2: `ModelLoader` → `src/scope_streamdiffusion/model_loader.py` +**Methods to move:** +- `_load_model` +- `_load_preset` +- `_release_pipe_state` +- `_swap_model` +- `_install_sdxl_fp16_vae` +- `_set_taesd` +- `load_lora` (currently a stub — leave as-is, the LoRA plan wires it up) +- `fuse_lora` (stub — same) + +**State the helper owns:** the `MODEL_PRESETS` dict (move it to this module), last-loaded `model_id`, last-loaded preset signature, the SDXL fp16 VAE replacement state, TAESD-installed flag. + +**Compromise:** like TRT, this writes through to `self.pipe`, `self.unet`, `self.vae`, `self.text_encoder`, `self.text_encoder_2`, `self.tokenizer`, `self.tokenizer_2`, `self.scheduler`, `self.sdxl`. Use the back-reference; the goal is consolidation, not purity. + +**Order matters in `attach`/swap flow:** ModelLoader runs first, then PromptEncoder.attach, then ControlNetHandler.attach, then TRTLifecycle.attach. Document this in a comment at the top of `pipeline._ensure_pipe_loaded`. + +**Testing checkpoint:** +1. Cold load each preset. +2. Swap each direction. Verify no double-loaded models in VRAM (`nvidia-smi` while swapping). +3. Verify SDXL fp16 VAE replacement still happens on SDXL-Turbo and DMD2. +4. Verify TAESD eager and TRT both still work. + +--- + +### Extraction 3: `InferenceCore` → `src/scope_streamdiffusion/inference_core.py` +**Methods to move:** +- `_set_timesteps` +- `_initialize_noise` +- `_setup_seed_transition` +- `_slerp_noise` +- `_advance_seed_transition` +- `_cancel_seed_transition` +- `_encode_image` +- `_decode_image` +- `_add_noise` +- `_scheduler_step_batch` +- `_unet_step` +- `_predict_x0_batch` + +**State the helper owns:** `alpha_prod_t_sqrt`, `beta_prod_t_sqrt`, `c_skip`, `c_out`, `sub_timesteps_tensor`, `init_noise`, `x_t_latent_buffer`, the seed-transition fields (`_pending_seed`, `_transition_remaining`, etc.). + +**Reads (not writes) from pipeline:** `self.pipe.prompts.prompt_embeds`, `self.pipe.unet`, `self.pipe.controlnet`, `self.pipe.controlnet_input`, `self.pipe.vae`, `self.pipe.scheduler`. Pass these through the back-reference. + +**`__call__` after this extraction shrinks to roughly:** +```python +def __call__(self, **kwargs): + config = self._validate_config(kwargs) + self._prepare_runtime_state(config) + self.prompts.encode_for_frame(...) + if self.controlnet_handler: + self.controlnet_handler.update(...) + latent = self.inference.run_step(video, config) + return {"video": self.inference.to_scope_format(latent)} +``` + +**Testing checkpoint:** full smoke test — every model × (txt2img / img2img / loopback) × (eager / xformers / TRT) × (with/without negative prompt) × seed transitions. + +## Cross-cutting Rules +- **Don't change behavior.** This is a pure move. If you find a bug, note it in a comment — fix it in a separate commit after the refactor lands. +- **Commit per extraction.** Three commits. Each must pass the testing checkpoint before moving to the next. +- **Don't extract `__init__`, `prepare`, `_prepare_runtime_state`, `__call__`, `get_config_class`, or the schema-driven setters.** These are the orchestrator's job. +- **Don't add abstract base classes or interfaces** for the helpers. Three concrete classes is fine. +- **Don't introduce a `BaseHelper` parent class.** They share a pattern, not behavior. diff --git a/docs/plans/SDXL_CONTROLNET_PLAN.md b/docs/plans/SDXL_CONTROLNET_PLAN.md new file mode 100644 index 0000000..7ace9ba --- /dev/null +++ b/docs/plans/SDXL_CONTROLNET_PLAN.md @@ -0,0 +1,63 @@ +# Plan: SDXL ControlNet Support + +## Context +Current state: `_ensure_trt_unet` has `if want_control: if self.sdxl: raise NotImplementedError(...)`. SD1.5 ControlNet (eager + TRT) works. SDXL ControlNet works in eager mode through diffusers but the TRT path is unimplemented. + +Test target model: `diffusers/controlnet-canny-sdxl-1.0` (paired with `stabilityai/stable-diffusion-xl-base-1.0` or SDXL-Turbo). The DMD2 1-step UNet is a swap — SDXL ControlNet against DMD2 is a stretch goal; verify with the base SDXL UNet first. + +## Eager Path (verify first, may already work) +1. Confirm `ControlNetHandler.update()` correctly produces residuals for SDXL-shape inputs. SDXL UNet expects `added_cond_kwargs={"text_embeds": ..., "time_ids": ...}` — the ControlNet model also needs these. Read `diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl` for the canonical wiring. +2. In `_unet_step` (post-extraction: `InferenceCore.unet_step`), when `self.sdxl and self.controlnet`, call ControlNet with the SDXL aug-conditioning, then pass the residuals into `self.unet` along with `added_cond_kwargs`. +3. If this produces correct output, eager SDXL ControlNet is done. Move to TRT. + +## TRT Path + +### Step 1: New ONNX export wrapper +File: `src/scope_streamdiffusion/_trt/models.py`. Add `UNetSDXLWithControlInputs` modeled on the existing `UNetWithControlInputs` (SD1.5) and `UNetSDXL` (SDXL no-control). + +Inputs (in order — must match adapter feed order): +- `sample` (B, 4, H/8, W/8) +- `timestep` (scalar or (B,)) +- `encoder_hidden_states` (B, 77, 2048) +- `text_embeds` (B, 1280) ← SDXL aug +- `time_ids` (B, 6) ← SDXL aug +- `input_control_00` … `input_control_{N-1}` (down residuals) +- `input_control_middle` (mid residual) + +Output: `latent` (same shape as `sample`). + +Forward should call `self.unet(sample, timestep, encoder_hidden_states, added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids}, down_block_additional_residuals=[...], mid_block_additional_residual=...)`. + +### Step 2: New builder +File: `src/scope_streamdiffusion/trt_engines.py`. Add `build_unet_sdxl_with_control_engine(...)` modeled on `build_unet_with_control_engine` + `build_unet_sdxl_engine`. Use the same dynamic-shape ranges as the SDXL UNet build (512–1024). + +**Known constraints:** +- ONNX export of SDXL UNet runs ~5 GB. Use `external_data` format. The existing `build_unet_sdxl_engine` already does this — copy its handling. +- ControlNet residuals are full-resolution feature maps; this multiplies the export size by ~10–20%. Expect 6 GB ONNX. +- Static shape recommended for first cut. Generalize to dynamic only after a static build runs. +- Compile time: 5–15 minutes on a 4090. Cache aggressively. + +### Step 3: Standalone SDXL ControlNet engine +The SD1.5 path uses a separate `ControlNetEngine` (`src/scope_streamdiffusion/_trt/engine.py`) that produces residuals consumed by `UNet2DConditionModelWithControlEngine`. Mirror this for SDXL: +- Add ONNX wrapper for SDXL ControlNet to `_trt/models.py` (it has the same SDXL aug-conditioning inputs as the UNet wrapper). +- Add builder `build_controlnet_sdxl_engine` in `trt_engines.py`. +- The existing `ControlNetEngine` class in `_trt/engine.py` has hard-coded `block_out_channels=(320, 640, 1280, 1280)` for SD1.5. SDXL ControlNet uses `(320, 640, 1280)` (one fewer block) and produces 9 down residuals + 1 mid (versus 12+1 for SD1.5). Add `ControlNetSDXLEngine` or parameterize `ControlNetEngine` by passing `chans` and `spec` at construction. + +### Step 4: New runtime adapter +File: `src/scope_streamdiffusion/trt_engines.py`. Add `TRTUNetSDXLWithControlAdapter` that exposes the diffusers UNet `__call__` signature and dispatches to `UNet2DConditionModelSDXLWithControlEngine` (also new in `_trt/engine.py`) plus the SDXL ControlNet engine. + +### Step 5: Wire into TRTLifecycle +In `_ensure_trt_unet`, replace the `raise NotImplementedError` with the SDXL+control branch. Path resolution and cache key must include both UNet and ControlNet model IDs. + +### Step 6: TAESD +SDXL TAESD (`madebyollin/taesdxl`) already works via the existing `_ensure_trt_taesd` path. No change needed. + +## Testing +1. Eager SDXL + Canny ControlNet on a webcam frame. Output should track edges. +2. TRT SDXL + Canny ControlNet, same prompt. Confirm visual parity within fp16 tolerance. +3. Live-toggle ControlNet on/off mid-stream on SDXL-Turbo. +4. Swap SD-Turbo (SD1.5) ↔ SDXL-Turbo with ControlNet attached. Confirm correct adapter is selected each time. + +## Out of Scope (defer) +- Multi-ControlNet on SDXL (do single first). +- DMD2 + ControlNet (the 1-step UNet swap may not respect ControlNet residuals correctly — needs separate investigation).