Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions hackable_diffusion/lib/architecture/dit_blocks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,8 @@ def test_variable_shapes(self):
'kernel': (self.c, self.d),
'bias': (self.d,),
},
'ConditionalNorm': {
'Dense_0': {
'kernel': (self.c, self.d * 2),
'bias': (self.d * 2,),
},
},
# Single ConditionalNorm is shared between attention and MLP
# branches, so it appears only once in the param tree.
'ConditionalNorm': {
'Dense_0': {
'kernel': (self.c, self.d * 2),
Expand Down
8 changes: 5 additions & 3 deletions hackable_diffusion/lib/corruption/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,13 @@ def corrupt(
assert alpha_bcast.shape == x0.shape
# Get the mask of the corruption process. It is true if the token is not
# corrupted and False if it is corrupted.
is_not_corrupted = jax.random.bernoulli(key, p=alpha_bcast, mode=self.mode)
key, _ = jax.random.split(key)
mask_key, noise_key = jax.random.split(key)
is_not_corrupted = jax.random.bernoulli(
mask_key, p=alpha_bcast, mode=self.mode
)

# compute noise vector
noise = self.sample_from_invariant(key, data_spec=x0)
noise = self.sample_from_invariant(noise_key, data_spec=x0)

# noise x0 with probability alpha
xt = jnp.where(is_not_corrupted, x0, noise) # is_not_corrupted = (xt == x0)
Expand Down
4 changes: 2 additions & 2 deletions hackable_diffusion/lib/diffusion_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,11 +425,11 @@ def initialize_variables(
@kt.typechecked
def __call__(
self,
xt: DataTree,
time: TimeTree,
xt: DataTree,
conditioning: Conditioning | None,
is_training: bool,
):
) -> TargetInfoTree:
if self.time_rescaler is not None:
time_rescaled = utils.lenient_map(
lambda time, time_rescaler: time_rescaler(time)
Expand Down
12 changes: 6 additions & 6 deletions hackable_diffusion/lib/inference/guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ def __post_init__(self):
@kt.typechecked
def __call__(
self,
xt: DataArray,
xt: DataTree,
conditioning: Conditioning,
time: TimeArray,
cond_outputs: TargetInfo,
uncond_outputs: TargetInfo,
) -> TargetInfo:
"""Simple scalar guidance function."""
time: TimeTree,
cond_outputs: TargetInfoTree,
uncond_outputs: TargetInfoTree,
) -> TargetInfoTree:
"""Limited interval guidance function."""
del conditioning # unused
time = utils.bcast_right(time, xt.ndim)

Expand Down
11 changes: 9 additions & 2 deletions hackable_diffusion/lib/inference/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,16 @@ def __call__(

@dataclasses.dataclass(kw_only=True, frozen=True)
class FlaxNNXInferenceFn(InferenceFn):
"""Inference function protocol with a diffusion network given by nn.Module."""
"""Inference function protocol with a diffusion network given by nn.Module.

Note: ``inference_seed`` is used for any stochastic layers (e.g., dropout)
that remain active at inference time. Since ``is_training=False`` is always
passed, dropout layers are typically disabled and this seed has no effect.
If you need stochastic inference (e.g., MC dropout), provide different seeds.
"""

nnx_network: ConvertedNNXDiffusionNetwork
inference_seed: int = 0

@kt.typechecked
def __call__(
Expand All @@ -112,7 +119,7 @@ def __call__(
xt=xt,
conditioning=conditioning,
is_training=False,
rngs=nnx.Rngs(0),
rngs=nnx.Rngs(self.inference_seed),
)


Expand Down
33 changes: 28 additions & 5 deletions hackable_diffusion/lib/manifolds.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,34 @@ def safe_norm(
keepdims: bool = True,
eps: float = 1e-9,
) -> Array:
"""Computes norm safely to avoid NaN gradients at zero."""
is_zero = jnp.all(x == 0, axis=axis, keepdims=keepdims)
safe_x = jnp.where(is_zero, eps, x)
n = jnp.linalg.norm(safe_x, axis=axis, keepdims=keepdims)
return jnp.where(is_zero, 0.0, n)
"""Computes norm safely to avoid NaN gradients at zero.

Uses ``sqrt(max(sum(x^2), eps^2))`` instead of the standard ``norm``.

Why not ``jnp.where(is_zero, eps, norm(x))``? JAX traces *both* branches
of ``jnp.where`` and combines their gradients as ``cond * grad_a +
(1-cond) * grad_b``. At ``x=0``, ``grad(norm(x)) = x/||x|| = 0/0 = NaN``,
and ``0 * NaN = NaN`` under IEEE 754, poisoning the result even though the
"safe" branch was selected.

This implementation avoids the problem entirely — there is a single
differentiable code path with no branching:

- Forward: returns ``eps`` at ``x=0``, not ``0``.
- Backward: the gradient ``sum(x * dx) / sqrt(max(..., eps^2))`` evaluates
to ``0/eps = 0`` at ``x=0``. No NaN, no spurious non-zero gradient.

Args:
x: Input tensor.
axis: Axes over which to compute the norm.
keepdims: Whether to keep the original number of dimensions.
eps: Epsilon value to avoid NaN gradients at zero.

Returns:
The safe norm of x.
"""
sq = jnp.sum(jnp.square(x), axis=axis, keepdims=keepdims)
return jnp.sqrt(jnp.maximum(sq, eps * eps))


def transpose(x: DataArray) -> DataArray:
Expand Down
4 changes: 2 additions & 2 deletions hackable_diffusion/lib/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,14 +342,14 @@ def __call__(
time: TimeTree,
) -> LossOutputTree:
return jax.tree.map(
lambda loss, target, pred, t: loss(
lambda loss, pred, target, t: loss(
preds=pred,
targets=target,
time=t,
),
self.losses,
targets,
preds,
targets,
time,
)

Expand Down
5 changes: 0 additions & 5 deletions hackable_diffusion/lib/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,9 @@
is called to produce the final clean output sample.
"""

import dataclasses
from typing import Protocol
import flax.struct
from hackable_diffusion.lib import hd_typing
import jax
import kauldron.ktyping as kt


#################################################################################
# MARK: Type Aliases
Expand Down Expand Up @@ -116,7 +112,6 @@ class DiffusionStep:

Attributes:
xt: The noisy data at the current step.
conditioning: The conditioning data from the prediction model.
step_info: The `StepInfo` used to compute the current step.
aux: Additional data computed by the sampler.
"""
Expand Down
2 changes: 1 addition & 1 deletion hackable_diffusion/lib/sampling/discrete_step_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __post_init__(self):
)

schedule: DiscreteSchedule
rescale_factor: float = 1.0
rescale_factor: float
switch_min: float = 0.0
switch_max: float = 1.0

Expand Down
54 changes: 24 additions & 30 deletions hackable_diffusion/lib/sampling/gaussian_step_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,24 +574,21 @@ def first_step(
)["velocity"]
new_xt = xt - dt * velocity

aux = current_step.aux
# Update the internal counter and current_velocity_step_one.
# Create a new aux dict to avoid mutating the frozen dataclass's dict.
internal_counter = jnp.mod(
aux["internal_counter"] + 1, self.num_internal_steps
current_step.aux["internal_counter"] + 1, self.num_internal_steps
)
current_velocity_step_one = velocity
aux.update(
dict(
internal_counter=internal_counter,
current_velocity_step_one=current_velocity_step_one,
)
new_aux = dict(
internal_counter=internal_counter,
current_update=current_step.aux["current_update"],
current_velocity_step_one=velocity,
)

# Note that we output next_next_step_info and not next_step_info.
return DiffusionStep(
xt=new_xt,
step_info=next_next_step_info,
aux=aux,
aux=new_aux,
)

@kt.typechecked
Expand All @@ -601,7 +598,7 @@ def second_step(
current_step: DiffusionStep,
next_step_info: StepInfo,
) -> DiffusionStep:
"""First step of the Heun sampler.
"""Second step of the Heun sampler.

This is the second internal step. It should be called when the internal
counter is 1.
Expand All @@ -614,16 +611,15 @@ def second_step(
Returns:
The next step.
"""
aux = current_step.aux
xt = current_step.xt

current_update = aux["current_update"]
old_time = current_update.step_info.time
prev_update = current_step.aux["current_update"]
old_time = prev_update.step_info.time
next_time = next_step_info.time
old_time = utils.bcast_right(old_time, xt.ndim)
next_time = utils.bcast_right(next_time, xt.ndim)

old_velocity = aux["current_velocity_step_one"]
old_velocity = current_step.aux["current_velocity_step_one"]
intermediate_velocity = self.corruption_process.convert_predictions(
prediction=prediction,
xt=xt,
Expand All @@ -634,30 +630,28 @@ def second_step(

# Perform the final step.

old_xt = current_update.xt
old_xt = prev_update.xt
dt = old_time - next_time
new_xt = old_xt - dt * (old_velocity + intermediate_velocity) / 2

# Update the internal counter and the current_update
# Create a new aux dict to avoid mutating the frozen dataclass's dict.
internal_counter = jnp.mod(
aux["internal_counter"] + 1, self.num_internal_steps
)
current_update = DiffusionStep(
xt=new_xt,
step_info=next_step_info,
aux=dict(),
)
aux.update(
dict(
internal_counter=internal_counter,
current_update=current_update,
)
current_step.aux["internal_counter"] + 1, self.num_internal_steps
)
new_aux = dict(
internal_counter=internal_counter,
current_update=DiffusionStep(
xt=new_xt,
step_info=next_step_info,
aux=dict(),
),
current_velocity_step_one=current_step.aux["current_velocity_step_one"],
)

return DiffusionStep(
xt=new_xt,
step_info=next_step_info,
aux=aux,
aux=new_aux,
)

@kt.typechecked
Expand Down
42 changes: 33 additions & 9 deletions hackable_diffusion/lib/sampling/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __call__(
rng: PRNGKey,
initial_noise: DataTree,
conditioning: Conditioning,
) -> tuple[DiffusionStepTree, DiffusionStepTree]:
) -> tuple[DiffusionStepTree, DiffusionStepTree | None]:
...


Expand Down Expand Up @@ -130,11 +130,16 @@ class DiffusionSampler(SampleFn):
time_schedule: Defines the sequence of time steps for the process.
stepper: The sampling algorithm (e.g., DDIM) that updates the state.
num_steps: The total number of denoising steps.
store_trajectory: If True (default), returns the full trajectory of all
intermediate steps. If False, only returns the final step and None for the
trajectory, significantly reducing memory usage for high-resolution or
many-step sampling.
"""

time_schedule: TimeSchedule
stepper: SamplerStep
num_steps: int
store_trajectory: bool = True

@kt.typechecked
def __call__(
Expand All @@ -143,7 +148,7 @@ def __call__(
rng: PRNGKey,
initial_noise: DataTree,
conditioning: Conditioning | None = None,
) -> tuple[DiffusionStepTree, DiffusionStepTree]:
) -> tuple[DiffusionStepTree, DiffusionStepTree | None]:
"""Performs a full reverse diffusion sampling loop for a single sample.

This function orchestrates the denoising process, starting from an initial
Expand All @@ -159,7 +164,7 @@ def __call__(
A tuple containing:
- The final `DiffusionStepTree` of the sampling process.
- A `DiffusionStepTree` PyTree containing the full trajectory of all
steps.
steps, or None if `store_trajectory` is False.
"""
if self.num_steps < 2:
raise ValueError(
Expand All @@ -179,7 +184,7 @@ def __call__(
first_step_info,
)

def scan_body(step_carry: DiffusionStepTree, next_step_info: StepInfoTree):
def _step(step_carry: DiffusionStepTree, next_step_info: StepInfoTree):
xt, time = _get_input_inference_fn(step_carry)
prediction = inference_fn(
xt=xt,
Expand All @@ -191,11 +196,27 @@ def scan_body(step_carry: DiffusionStepTree, next_step_info: StepInfoTree):
step_carry,
next_step_info,
)
return next_step, next_step # ('carryover', 'accumulated')
return next_step

before_last_step, intermediate_steps = jax.lax.scan(
scan_body, first_step, next_step_infos
)
intermediate_steps = None
if self.store_trajectory:

def scan_body(step_carry, next_step_info):
next_step = _step(step_carry, next_step_info)
return next_step, next_step # ('carryover', 'accumulated')

before_last_step, intermediate_steps = jax.lax.scan(
scan_body, first_step, next_step_infos
)
else:

def scan_body_no_accum(step_carry, next_step_info):
next_step = _step(step_carry, next_step_info)
return next_step, None # no accumulation

before_last_step, _ = jax.lax.scan(
scan_body_no_accum, first_step, next_step_infos
)

xt, time = _get_input_inference_fn(before_last_step)
last_prediction = inference_fn(
Expand All @@ -210,5 +231,8 @@ def scan_body(step_carry: DiffusionStepTree, next_step_info: StepInfoTree):
last_step_info,
)

all_steps = _concat_pytree(first_step, intermediate_steps, last_step)
if self.store_trajectory:
all_steps = _concat_pytree(first_step, intermediate_steps, last_step)
else:
all_steps = None
return last_step, all_steps
7 changes: 7 additions & 0 deletions hackable_diffusion/lib/sampling/time_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,13 @@ class EDMTimeSchedule(TimeScheduleBaseClass):

rho: float = 1.0

def __post_init__(self):
if self.rho <= 0:
raise ValueError(
f"rho must be positive, got {self.rho}. rho=0 causes division by"
" zero in the schedule computation."
)

@kt.typechecked
def all_step_infos(
self, rng: PRNGKey, num_steps: int, data_spec: DataArray
Expand Down
Loading
Loading