diff --git a/hackable_diffusion/lib/architecture/dit_blocks_test.py b/hackable_diffusion/lib/architecture/dit_blocks_test.py index f6604ff..d8b176a 100644 --- a/hackable_diffusion/lib/architecture/dit_blocks_test.py +++ b/hackable_diffusion/lib/architecture/dit_blocks_test.py @@ -75,12 +75,8 @@ def test_variable_shapes(self): 'kernel': (self.c, self.d), 'bias': (self.d,), }, - 'ConditionalNorm': { - 'Dense_0': { - 'kernel': (self.c, self.d * 2), - 'bias': (self.d * 2,), - }, - }, + # Single ConditionalNorm is shared between attention and MLP + # branches, so it appears only once in the param tree. 'ConditionalNorm': { 'Dense_0': { 'kernel': (self.c, self.d * 2), diff --git a/hackable_diffusion/lib/corruption/discrete.py b/hackable_diffusion/lib/corruption/discrete.py index 3c667f9..8d3c1ac 100644 --- a/hackable_diffusion/lib/corruption/discrete.py +++ b/hackable_diffusion/lib/corruption/discrete.py @@ -275,11 +275,13 @@ def corrupt( assert alpha_bcast.shape == x0.shape # Get the mask of the corruption process. It is true if the token is not # corrupted and False if it is corrupted. - is_not_corrupted = jax.random.bernoulli(key, p=alpha_bcast, mode=self.mode) - key, _ = jax.random.split(key) + mask_key, noise_key = jax.random.split(key) + is_not_corrupted = jax.random.bernoulli( + mask_key, p=alpha_bcast, mode=self.mode + ) # compute noise vector - noise = self.sample_from_invariant(key, data_spec=x0) + noise = self.sample_from_invariant(noise_key, data_spec=x0) # noise x0 with probability alpha xt = jnp.where(is_not_corrupted, x0, noise) # is_not_corrupted = (xt == x0) diff --git a/hackable_diffusion/lib/diffusion_network.py b/hackable_diffusion/lib/diffusion_network.py index a81d29d..7d0523d 100644 --- a/hackable_diffusion/lib/diffusion_network.py +++ b/hackable_diffusion/lib/diffusion_network.py @@ -425,11 +425,11 @@ def initialize_variables( @kt.typechecked def __call__( self, - xt: DataTree, time: TimeTree, + xt: DataTree, conditioning: Conditioning | None, is_training: bool, - ): + ) -> TargetInfoTree: if self.time_rescaler is not None: time_rescaled = utils.lenient_map( lambda time, time_rescaler: time_rescaler(time) diff --git a/hackable_diffusion/lib/inference/guidance.py b/hackable_diffusion/lib/inference/guidance.py index af1297d..0d312e8 100644 --- a/hackable_diffusion/lib/inference/guidance.py +++ b/hackable_diffusion/lib/inference/guidance.py @@ -104,13 +104,13 @@ def __post_init__(self): @kt.typechecked def __call__( self, - xt: DataArray, + xt: DataTree, conditioning: Conditioning, - time: TimeArray, - cond_outputs: TargetInfo, - uncond_outputs: TargetInfo, - ) -> TargetInfo: - """Simple scalar guidance function.""" + time: TimeTree, + cond_outputs: TargetInfoTree, + uncond_outputs: TargetInfoTree, + ) -> TargetInfoTree: + """Limited interval guidance function.""" del conditioning # unused time = utils.bcast_right(time, xt.ndim) diff --git a/hackable_diffusion/lib/inference/wrappers.py b/hackable_diffusion/lib/inference/wrappers.py index 0fa4721..e669ea2 100644 --- a/hackable_diffusion/lib/inference/wrappers.py +++ b/hackable_diffusion/lib/inference/wrappers.py @@ -95,9 +95,16 @@ def __call__( @dataclasses.dataclass(kw_only=True, frozen=True) class FlaxNNXInferenceFn(InferenceFn): - """Inference function protocol with a diffusion network given by nn.Module.""" + """Inference function protocol with a diffusion network given by nn.Module. + + Note: ``inference_seed`` is used for any stochastic layers (e.g., dropout) + that remain active at inference time. Since ``is_training=False`` is always + passed, dropout layers are typically disabled and this seed has no effect. + If you need stochastic inference (e.g., MC dropout), provide different seeds. + """ nnx_network: ConvertedNNXDiffusionNetwork + inference_seed: int = 0 @kt.typechecked def __call__( @@ -112,7 +119,7 @@ def __call__( xt=xt, conditioning=conditioning, is_training=False, - rngs=nnx.Rngs(0), + rngs=nnx.Rngs(self.inference_seed), ) diff --git a/hackable_diffusion/lib/manifolds.py b/hackable_diffusion/lib/manifolds.py index f5d0800..a08060b 100644 --- a/hackable_diffusion/lib/manifolds.py +++ b/hackable_diffusion/lib/manifolds.py @@ -71,11 +71,34 @@ def safe_norm( keepdims: bool = True, eps: float = 1e-9, ) -> Array: - """Computes norm safely to avoid NaN gradients at zero.""" - is_zero = jnp.all(x == 0, axis=axis, keepdims=keepdims) - safe_x = jnp.where(is_zero, eps, x) - n = jnp.linalg.norm(safe_x, axis=axis, keepdims=keepdims) - return jnp.where(is_zero, 0.0, n) + """Computes norm safely to avoid NaN gradients at zero. + + Uses ``sqrt(max(sum(x^2), eps^2))`` instead of the standard ``norm``. + + Why not ``jnp.where(is_zero, eps, norm(x))``? JAX traces *both* branches + of ``jnp.where`` and combines their gradients as ``cond * grad_a + + (1-cond) * grad_b``. At ``x=0``, ``grad(norm(x)) = x/||x|| = 0/0 = NaN``, + and ``0 * NaN = NaN`` under IEEE 754, poisoning the result even though the + "safe" branch was selected. + + This implementation avoids the problem entirely — there is a single + differentiable code path with no branching: + + - Forward: returns ``eps`` at ``x=0``, not ``0``. + - Backward: the gradient ``sum(x * dx) / sqrt(max(..., eps^2))`` evaluates + to ``0/eps = 0`` at ``x=0``. No NaN, no spurious non-zero gradient. + + Args: + x: Input tensor. + axis: Axes over which to compute the norm. + keepdims: Whether to keep the original number of dimensions. + eps: Epsilon value to avoid NaN gradients at zero. + + Returns: + The safe norm of x. + """ + sq = jnp.sum(jnp.square(x), axis=axis, keepdims=keepdims) + return jnp.sqrt(jnp.maximum(sq, eps * eps)) def transpose(x: DataArray) -> DataArray: diff --git a/hackable_diffusion/lib/multimodal.py b/hackable_diffusion/lib/multimodal.py index e223559..5fdd686 100644 --- a/hackable_diffusion/lib/multimodal.py +++ b/hackable_diffusion/lib/multimodal.py @@ -342,14 +342,14 @@ def __call__( time: TimeTree, ) -> LossOutputTree: return jax.tree.map( - lambda loss, target, pred, t: loss( + lambda loss, pred, target, t: loss( preds=pred, targets=target, time=t, ), self.losses, - targets, preds, + targets, time, ) diff --git a/hackable_diffusion/lib/sampling/base.py b/hackable_diffusion/lib/sampling/base.py index 415bed4..5ae001f 100644 --- a/hackable_diffusion/lib/sampling/base.py +++ b/hackable_diffusion/lib/sampling/base.py @@ -61,13 +61,9 @@ is called to produce the final clean output sample. """ -import dataclasses from typing import Protocol import flax.struct from hackable_diffusion.lib import hd_typing -import jax -import kauldron.ktyping as kt - ################################################################################# # MARK: Type Aliases @@ -116,7 +112,6 @@ class DiffusionStep: Attributes: xt: The noisy data at the current step. - conditioning: The conditioning data from the prediction model. step_info: The `StepInfo` used to compute the current step. aux: Additional data computed by the sampler. """ diff --git a/hackable_diffusion/lib/sampling/discrete_step_sampler.py b/hackable_diffusion/lib/sampling/discrete_step_sampler.py index 4e1f3c6..38551fe 100644 --- a/hackable_diffusion/lib/sampling/discrete_step_sampler.py +++ b/hackable_diffusion/lib/sampling/discrete_step_sampler.py @@ -162,7 +162,7 @@ def __post_init__(self): ) schedule: DiscreteSchedule - rescale_factor: float = 1.0 + rescale_factor: float switch_min: float = 0.0 switch_max: float = 1.0 diff --git a/hackable_diffusion/lib/sampling/gaussian_step_sampler.py b/hackable_diffusion/lib/sampling/gaussian_step_sampler.py index 2ac1a42..17d7e13 100644 --- a/hackable_diffusion/lib/sampling/gaussian_step_sampler.py +++ b/hackable_diffusion/lib/sampling/gaussian_step_sampler.py @@ -574,24 +574,21 @@ def first_step( )["velocity"] new_xt = xt - dt * velocity - aux = current_step.aux - # Update the internal counter and current_velocity_step_one. + # Create a new aux dict to avoid mutating the frozen dataclass's dict. internal_counter = jnp.mod( - aux["internal_counter"] + 1, self.num_internal_steps + current_step.aux["internal_counter"] + 1, self.num_internal_steps ) - current_velocity_step_one = velocity - aux.update( - dict( - internal_counter=internal_counter, - current_velocity_step_one=current_velocity_step_one, - ) + new_aux = dict( + internal_counter=internal_counter, + current_update=current_step.aux["current_update"], + current_velocity_step_one=velocity, ) # Note that we output next_next_step_info and not next_step_info. return DiffusionStep( xt=new_xt, step_info=next_next_step_info, - aux=aux, + aux=new_aux, ) @kt.typechecked @@ -601,7 +598,7 @@ def second_step( current_step: DiffusionStep, next_step_info: StepInfo, ) -> DiffusionStep: - """First step of the Heun sampler. + """Second step of the Heun sampler. This is the second internal step. It should be called when the internal counter is 1. @@ -614,16 +611,15 @@ def second_step( Returns: The next step. """ - aux = current_step.aux xt = current_step.xt - current_update = aux["current_update"] - old_time = current_update.step_info.time + prev_update = current_step.aux["current_update"] + old_time = prev_update.step_info.time next_time = next_step_info.time old_time = utils.bcast_right(old_time, xt.ndim) next_time = utils.bcast_right(next_time, xt.ndim) - old_velocity = aux["current_velocity_step_one"] + old_velocity = current_step.aux["current_velocity_step_one"] intermediate_velocity = self.corruption_process.convert_predictions( prediction=prediction, xt=xt, @@ -634,30 +630,28 @@ def second_step( # Perform the final step. - old_xt = current_update.xt + old_xt = prev_update.xt dt = old_time - next_time new_xt = old_xt - dt * (old_velocity + intermediate_velocity) / 2 - # Update the internal counter and the current_update + # Create a new aux dict to avoid mutating the frozen dataclass's dict. internal_counter = jnp.mod( - aux["internal_counter"] + 1, self.num_internal_steps - ) - current_update = DiffusionStep( - xt=new_xt, - step_info=next_step_info, - aux=dict(), - ) - aux.update( - dict( - internal_counter=internal_counter, - current_update=current_update, - ) + current_step.aux["internal_counter"] + 1, self.num_internal_steps + ) + new_aux = dict( + internal_counter=internal_counter, + current_update=DiffusionStep( + xt=new_xt, + step_info=next_step_info, + aux=dict(), + ), + current_velocity_step_one=current_step.aux["current_velocity_step_one"], ) return DiffusionStep( xt=new_xt, step_info=next_step_info, - aux=aux, + aux=new_aux, ) @kt.typechecked diff --git a/hackable_diffusion/lib/sampling/sampling.py b/hackable_diffusion/lib/sampling/sampling.py index de46bc5..59f60ab 100644 --- a/hackable_diffusion/lib/sampling/sampling.py +++ b/hackable_diffusion/lib/sampling/sampling.py @@ -62,7 +62,7 @@ def __call__( rng: PRNGKey, initial_noise: DataTree, conditioning: Conditioning, - ) -> tuple[DiffusionStepTree, DiffusionStepTree]: + ) -> tuple[DiffusionStepTree, DiffusionStepTree | None]: ... @@ -130,11 +130,16 @@ class DiffusionSampler(SampleFn): time_schedule: Defines the sequence of time steps for the process. stepper: The sampling algorithm (e.g., DDIM) that updates the state. num_steps: The total number of denoising steps. + store_trajectory: If True (default), returns the full trajectory of all + intermediate steps. If False, only returns the final step and None for the + trajectory, significantly reducing memory usage for high-resolution or + many-step sampling. """ time_schedule: TimeSchedule stepper: SamplerStep num_steps: int + store_trajectory: bool = True @kt.typechecked def __call__( @@ -143,7 +148,7 @@ def __call__( rng: PRNGKey, initial_noise: DataTree, conditioning: Conditioning | None = None, - ) -> tuple[DiffusionStepTree, DiffusionStepTree]: + ) -> tuple[DiffusionStepTree, DiffusionStepTree | None]: """Performs a full reverse diffusion sampling loop for a single sample. This function orchestrates the denoising process, starting from an initial @@ -159,7 +164,7 @@ def __call__( A tuple containing: - The final `DiffusionStepTree` of the sampling process. - A `DiffusionStepTree` PyTree containing the full trajectory of all - steps. + steps, or None if `store_trajectory` is False. """ if self.num_steps < 2: raise ValueError( @@ -179,7 +184,7 @@ def __call__( first_step_info, ) - def scan_body(step_carry: DiffusionStepTree, next_step_info: StepInfoTree): + def _step(step_carry: DiffusionStepTree, next_step_info: StepInfoTree): xt, time = _get_input_inference_fn(step_carry) prediction = inference_fn( xt=xt, @@ -191,11 +196,27 @@ def scan_body(step_carry: DiffusionStepTree, next_step_info: StepInfoTree): step_carry, next_step_info, ) - return next_step, next_step # ('carryover', 'accumulated') + return next_step - before_last_step, intermediate_steps = jax.lax.scan( - scan_body, first_step, next_step_infos - ) + intermediate_steps = None + if self.store_trajectory: + + def scan_body(step_carry, next_step_info): + next_step = _step(step_carry, next_step_info) + return next_step, next_step # ('carryover', 'accumulated') + + before_last_step, intermediate_steps = jax.lax.scan( + scan_body, first_step, next_step_infos + ) + else: + + def scan_body_no_accum(step_carry, next_step_info): + next_step = _step(step_carry, next_step_info) + return next_step, None # no accumulation + + before_last_step, _ = jax.lax.scan( + scan_body_no_accum, first_step, next_step_infos + ) xt, time = _get_input_inference_fn(before_last_step) last_prediction = inference_fn( @@ -210,5 +231,8 @@ def scan_body(step_carry: DiffusionStepTree, next_step_info: StepInfoTree): last_step_info, ) - all_steps = _concat_pytree(first_step, intermediate_steps, last_step) + if self.store_trajectory: + all_steps = _concat_pytree(first_step, intermediate_steps, last_step) + else: + all_steps = None return last_step, all_steps diff --git a/hackable_diffusion/lib/sampling/time_scheduling.py b/hackable_diffusion/lib/sampling/time_scheduling.py index 82934d5..8ed2ffa 100644 --- a/hackable_diffusion/lib/sampling/time_scheduling.py +++ b/hackable_diffusion/lib/sampling/time_scheduling.py @@ -147,6 +147,13 @@ class EDMTimeSchedule(TimeScheduleBaseClass): rho: float = 1.0 + def __post_init__(self): + if self.rho <= 0: + raise ValueError( + f"rho must be positive, got {self.rho}. rho=0 causes division by" + " zero in the schedule computation." + ) + @kt.typechecked def all_step_infos( self, rng: PRNGKey, num_steps: int, data_spec: DataArray diff --git a/hackable_diffusion/lib/sampling/time_scheduling_test.py b/hackable_diffusion/lib/sampling/time_scheduling_test.py index d058719..7c966c3 100644 --- a/hackable_diffusion/lib/sampling/time_scheduling_test.py +++ b/hackable_diffusion/lib/sampling/time_scheduling_test.py @@ -21,13 +21,14 @@ import jax.numpy as jnp from absl.testing import absltest +from absl.testing import parameterized ################################################################################ # MARK: Tests ################################################################################ -class TimeScheduleTest(absltest.TestCase): +class TimeScheduleTest(parameterized.TestCase, absltest.TestCase): # MARK: UniformTimeSchedule tests @@ -136,6 +137,11 @@ def test_edm_all_step_infos_with_rho_one_is_uniform(self): ).time chex.assert_trees_all_close(uniform_steps, edm_steps) + @parameterized.parameters(0.0, -1.0) + def test_edm_invalid_rho(self, rho): + with self.assertRaisesRegex(ValueError, "rho must be positive"): + time_scheduling.EDMTimeSchedule(span=utils.SafeSpan(), rho=rho) + if __name__ == "__main__": absltest.main() diff --git a/hackable_diffusion/lib/training/gaussian_loss.py b/hackable_diffusion/lib/training/gaussian_loss.py index 73f2e81..5a46bc5 100644 --- a/hackable_diffusion/lib/training/gaussian_loss.py +++ b/hackable_diffusion/lib/training/gaussian_loss.py @@ -106,9 +106,9 @@ def compute_continuous_diffusion_loss( Returns: The batched loss, i.e., a tensor of shape [B,] where B is the batch size. To - get the scalar loss use `jnp.mean(loss)`. The loss is returned before the - averaging to allow for other operations such as masking of loss values - afterwards. + get the scalar loss use `jnp.mean(loss)`. Note that all non-batch dimensions + are averaged (mean-reduced) internally, so the returned loss is a per-sample + scalar and cannot be used for post-hoc spatial masking. """ if convert_to_logsnr_schedule or weight_fn: diff --git a/hackable_diffusion/lib/training/time_sampling.py b/hackable_diffusion/lib/training/time_sampling.py index d5e085e..c9bdb78 100644 --- a/hackable_diffusion/lib/training/time_sampling.py +++ b/hackable_diffusion/lib/training/time_sampling.py @@ -192,6 +192,12 @@ class UniformStratifiedTimeSampler(TimeSampler): _minval=0.0, _maxval=1.0, safety_epsilon=0.0 ) + def __post_init__(self): + if 0 not in self.axes: + raise ValueError( + "axes must include 0. Broadcasting over the batch is not supported." + ) + @kt.typechecked def __call__(self, key: PRNGKey, data_spec: DataArray) -> TimeArray: shape = utils.get_broadcastable_shape(data_spec.shape, self.axes) @@ -240,15 +246,22 @@ def __call__(self, key: PRNGKey, data_spec: DataTree) -> TimeTree: shape1 = utils.get_broadcastable_shape(data_spec[self.key1].shape, (0,)) shape2 = utils.get_broadcastable_shape(data_spec[self.key2].shape, (0,)) - key1, key2, switch_key = jax.random.split(key, 3) + random_key1, random_key2, switch_key = jax.random.split(key, 3) - z1 = jax.random.normal(key1, shape=shape1) + z1 = jax.random.normal(random_key1, shape=shape1) f = jax.nn.sigmoid(z1) * self.s1 / (1 + (self.s1 - 1) * jax.nn.sigmoid(z1)) - z2 = jax.random.normal(key2, shape=shape2) + z2 = jax.random.normal(random_key2, shape=shape2) g = jax.nn.sigmoid(z2) * self.s2 / (1 + (self.s2 - 1) * jax.nn.sigmoid(z2)) # With probability p_equal, set g = 1 - f. - equal_mask = jax.random.bernoulli(switch_key, p=self.p_equal, shape=shape1) - g = jax.lax.select(equal_mask, 1 - f, g) + # Use batch-only shape for the mask to avoid broadcasting issues when + # shape1 and shape2 differ in spatial dimensions. + batch_shape = (data_spec[self.key1].shape[0],) + equal_mask = jax.random.bernoulli( + switch_key, p=self.p_equal, shape=batch_shape + ) + equal_mask = utils.bcast_right(equal_mask, len(shape2)) + f_for_g = utils.bcast_right(f.reshape(batch_shape), len(shape2)) + g = jax.lax.select(equal_mask, 1 - f_for_g, g) return {self.key1: f, self.key2: g} diff --git a/hackable_diffusion/lib/training/time_sampling_test.py b/hackable_diffusion/lib/training/time_sampling_test.py index 222f60e..a2fc659 100644 --- a/hackable_diffusion/lib/training/time_sampling_test.py +++ b/hackable_diffusion/lib/training/time_sampling_test.py @@ -87,6 +87,32 @@ def test_from_safety_epsilon(self, sampler_cls): self.assertGreaterEqual(jnp.min(time), 0.4) self.assertLessEqual(jnp.max(time), 0.6) + @parameterized.named_parameters( + dict( + testcase_name="uniform_no_batch", + sampler_cls=time_sampling.UniformTimeSampler, + axes=(1,), + ), + dict( + testcase_name="logit_normal_no_batch", + sampler_cls=time_sampling.LogitNormalTimeSampler, + axes=(1,), + ), + dict( + testcase_name="uniform_stratified_no_batch", + sampler_cls=time_sampling.UniformStratifiedTimeSampler, + axes=(1,), + ), + dict( + testcase_name="uniform_stratified_no_batch_multi", + sampler_cls=time_sampling.UniformStratifiedTimeSampler, + axes=(1, 2), + ), + ) + def test_invalid_axes_missing_batch_dim(self, sampler_cls, axes): + with self.assertRaisesRegex(ValueError, "axes must include 0"): + sampler_cls(axes=axes) + if __name__ == "__main__": absltest.main() diff --git a/hackable_diffusion/lib/utils.py b/hackable_diffusion/lib/utils.py index bbb0249..82d7347 100644 --- a/hackable_diffusion/lib/utils.py +++ b/hackable_diffusion/lib/utils.py @@ -145,6 +145,8 @@ def lenient_map( KeyError: If the structures of `tree` and `rest` do not match. """ path_vals, struct = jax.tree.flatten_with_path(tree, is_leaf=is_leaf) + if not path_vals: + return tree # Return empty tree unchanged. paths, _ = zip(*path_vals) restructured_rest = [] for r in rest: