diff --git a/hackable_diffusion/lib/sampling/__init__.py b/hackable_diffusion/lib/sampling/__init__.py index 1b38266..df96441 100644 --- a/hackable_diffusion/lib/sampling/__init__.py +++ b/hackable_diffusion/lib/sampling/__init__.py @@ -21,15 +21,20 @@ from hackable_diffusion.lib.sampling.base import StepInfo from hackable_diffusion.lib.sampling.base import StepInfoTree from hackable_diffusion.lib.sampling.discrete_step_sampler import AllCorruptedMaskFn +from hackable_diffusion.lib.sampling.discrete_step_sampler import CLEAN from hackable_diffusion.lib.sampling.discrete_step_sampler import CorruptedMaskFn from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteDDIMStep from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteFlowMatchingStep from hackable_diffusion.lib.sampling.discrete_step_sampler import IntegratedDiscreteDDIMStep from hackable_diffusion.lib.sampling.discrete_step_sampler import MaskValueCorruptedMaskFn from hackable_diffusion.lib.sampling.discrete_step_sampler import MaxCappedRemaskingFn +from hackable_diffusion.lib.sampling.discrete_step_sampler import NOISE from hackable_diffusion.lib.sampling.discrete_step_sampler import NoRemaskingFn from hackable_diffusion.lib.sampling.discrete_step_sampler import RemaskingFn from hackable_diffusion.lib.sampling.discrete_step_sampler import RescaledRemaskingFn +from hackable_diffusion.lib.sampling.discrete_step_sampler import RoutingProbPlanner +from hackable_diffusion.lib.sampling.discrete_step_sampler import RoutingProbs +from hackable_diffusion.lib.sampling.discrete_step_sampler import STAY from hackable_diffusion.lib.sampling.discrete_step_sampler import UnMaskingStep from hackable_diffusion.lib.sampling.gaussian_step_sampler import AdjustedDDIMStep from hackable_diffusion.lib.sampling.gaussian_step_sampler import DDIMStep diff --git a/hackable_diffusion/lib/sampling/discrete_step_sampler.py b/hackable_diffusion/lib/sampling/discrete_step_sampler.py index 8a701ca..ec09d7a 100644 --- a/hackable_diffusion/lib/sampling/discrete_step_sampler.py +++ b/hackable_diffusion/lib/sampling/discrete_step_sampler.py @@ -52,6 +52,7 @@ DataArray = hd_typing.DataArray TargetInfo = hd_typing.TargetInfo TimeArray = hd_typing.TimeArray +PRNGKey = hd_typing.PRNGKey DiffusionStep = base.DiffusionStep StepInfo = base.StepInfo @@ -225,6 +226,104 @@ def __call__(self, xt: DataArray) -> DataArray: return xt == mask_value +################################################################################ +# MARK: Routing +################################################################################ + +# Almost all discrete samplers compute a 3-way routing for each token position: +# 0 = stay at current token (xt) +# 1 = sample from invariant distribution (noise) +# 2 = use predicted clean token (x0) +# +# The routing probabilities (p_stay, p_noise, p_clean) are computed by each +# sampler and applied via the shared `_apply_routing` helper. +# IntegratedDiscreteDDIMStep is an exception as it integrates the routing +# probabilities into the update rule. + +STAY = 0 +NOISE = 1 +CLEAN = 2 + + +@dataclasses.dataclass(frozen=True) +class RoutingProbs: + stay: Float['...'] + noise: Float['...'] + clean: Float['...'] + + +def _apply_routing( + routing_probs: RoutingProbs, + xt: DataArray, + x0: DataArray, + x_noise: DataArray, + key: PRNGKey, +) -> DataArray: + """Apply 3-way routing to construct the next state. + + Args: + routing_probs: Routing probabilities. + xt: Current state. Shape ``(*, 1)``. + x0: Predicted clean state. Shape ``(*, 1)``. + x_noise: Sample from invariant distribution. Shape ``(*, 1)``. + key: Random key for categorical sampling. + + Returns: + The new state ``new_xt``. Shape ``(*, 1)``. + """ + probs = jnp.concatenate( + [routing_probs.stay, routing_probs.noise, routing_probs.clean], axis=-1 + ) + action = jax.random.categorical( + key=key, logits=jnp.log(jnp.maximum(probs, 1e-12)) + ) + new_xt = jnp.where( + action[..., None] == CLEAN, + x0, + jnp.where(action[..., None] == NOISE, x_noise, xt), + ) + return new_xt + + +class RoutingProbPlanner(Protocol): + """Protocol for transforming routing probabilities. + + A planner takes the routing probabilities computed by a sampler and + optionally transforms them before they are applied via ``_apply_routing``. + This allows injecting different selection strategies (e.g. greedy top-k) + without modifying the sampler logic. + + When no planner is used (``planner=None``), the routing probabilities are + applied as-is via stochastic categorical sampling. + """ + + def __call__( + self, + routing_probs: RoutingProbs, + logits: Float['... M'], + x0: DataArray, + xt: DataArray, + time: TimeArray, + next_time: TimeArray, + key: PRNGKey, + ) -> RoutingProbs: + """Transforms routing probabilities. + + Args: + routing_probs: Per-position routing probabilities. + logits: Model logits ``(*, M)``. + x0: Sampled clean token ``(*, 1)``. + xt: Current state ``(*, 1)``. + time: Current diffusion time. + next_time: Next diffusion time. + key: Random key. + + Returns: + Transformed routing probabilities. + """ + ... + + ################################################################################ # MARK: UnMasking Step ################################################################################ @@ -234,6 +333,17 @@ def __call__(self, xt: DataArray) -> DataArray: class UnMaskingStep(SamplerStep): """Unmasking step following https://arxiv.org/abs/2406.04329. + This sampler uses the 3-way routing representation. For each token position + we compute the probabilities of three actions: + - STAY (0): keep the current token. + - NOISE (1): sample from the invariant distribution (remasking). + - CLEAN (2): use the predicted clean token x0. + + For masked tokens: + p_clean = prob_up, p_noise = prob_down, p_stay = 1 - prob_up - prob_down. + For unmasked tokens: + p_clean = 0, p_noise = prob_down, p_stay = 1 - prob_down. + Attributes: corruption_process: The corruption process to use. remasking_fn: The remasking function to use, see @@ -246,6 +356,7 @@ class UnMaskingStep(SamplerStep): """ corruption_process: CategoricalProcess + planner: RoutingProbPlanner | None = None remasking_fn: RemaskingFn = NoRemaskingFn() corruption_mask_fn: CorruptedMaskFn = AllCorruptedMaskFn() temperature: float = 1.0 @@ -258,18 +369,6 @@ def __post_init__(self): if not self.corruption_process.is_masking: raise ValueError('UnMaskingStep only supports masking processes.') - @property - def mask_value(self) -> int: - return self.corruption_process.num_categories - 1 - - @property - def unused_token(self) -> int: - return self.corruption_process.unused_token - - @property - def post_corruption_fn(self) -> discrete.PostCorruptionFn: - return self.corruption_process.post_corruption_fn - @kt.typechecked def initialize( self, @@ -303,65 +402,81 @@ def update( current_step_info = current_step.step_info xt = current_step.xt - unused_mask = xt == self.unused_token + unused_mask = xt == self.corruption_process.unused_token # The mask is True if the token is unused. time = current_step_info.time next_time = next_step_info.time - time = utils.bcast_right(time, xt.ndim) - next_time = utils.bcast_right(next_time, xt.ndim) + time_bcast = utils.bcast_right(time, xt.ndim) + next_time_bcast = utils.bcast_right(next_time, xt.ndim) key = next_step_info.rng - # Sample from p_{0|t} - + # Get model predictions logits = self.corruption_process.convert_predictions( prediction, xt, - time, + time_bcast, )['logits'] logits = logits / self.temperature - key, subkey = jax.random.split(key) - sample = jax.random.categorical(key=subkey, logits=logits)[..., None] - # (bsz, *seq_len, 1) + _, x0_key, noise_key, plan_key, route_key = jax.random.split(key, 5) - # Split xt into masked and unmasked regions - - currently_masked = self.corruption_mask_fn(xt) - currently_unmasked = jnp.invert(currently_masked) + # Sample candidates + x0 = jax.random.categorical(key=x0_key, logits=logits)[..., None] + x_noise = self.corruption_process.sample_from_invariant( + noise_key, data_spec=xt + ) - # Denoising + currently_masked = self.corruption_mask_fn(xt) # (bsz, seq_len, 1) - alpha_s = self.corruption_process.schedule.alpha(next_time) - alpha_t = self.corruption_process.schedule.alpha(time) + # Denoising rates + alpha_s = self.corruption_process.schedule.alpha(next_time_bcast) + alpha_t = self.corruption_process.schedule.alpha(time_bcast) - p_st = self.remasking_fn(s=next_time, t=time) + p_st = self.remasking_fn(s=next_time_bcast, t=time_bcast) - prob = (alpha_s - (1.0 - p_st) * alpha_t) / (1.0 - alpha_t) + prob_up = (alpha_s - (1.0 - p_st) * alpha_t) / (1.0 - alpha_t) + prob_down = p_st # Denoising probability following https://arxiv.org/abs/2503.00307v1 # If no remasking, p_st = 0, so prob = (alpha_s - alpha_t) / (1.0 - alpha_t) - prob = jnp.broadcast_to(prob, currently_masked.shape) - key, subkey = jax.random.split(key) - to_unmask = currently_masked * jax.random.bernoulli(subkey, prob) - - new_xt = jnp.where(to_unmask, sample, xt) - - # Renoising following https://arxiv.org/abs/2503.00307 - key_noise, key_remask = jax.random.split(key) - noise_sample = self.corruption_process.sample_from_invariant( - key=key_noise, - data_spec=xt, - ) - - p_st = jnp.broadcast_to(p_st, currently_unmasked.shape) - to_remask = currently_unmasked * jax.random.bernoulli(key_remask, p_st) + # Routing probabilities for masked tokens: + # p_stay = 1 - prob_up - prob_down + # p_noise = prob_down + # p_clean = prob_up + p_stay_masked = 1.0 - prob_up - prob_down + p_noise_masked = prob_down + p_clean_masked = prob_up + + # Routing probabilities for unmasked tokens: + # p_stay = 1 - prob_down + # p_noise = prob_down + # p_clean = 0 + p_stay_unmasked = 1.0 - prob_down + p_noise_unmasked = prob_down + p_clean_unmasked = jnp.zeros_like(prob_down) + + # Combine based on masking state + p_stay = jnp.where(currently_masked, p_stay_masked, p_stay_unmasked) + p_noise = jnp.where(currently_masked, p_noise_masked, p_noise_unmasked) + p_clean = jnp.where(currently_masked, p_clean_masked, p_clean_unmasked) + + routing_probs = RoutingProbs(stay=p_stay, noise=p_noise, clean=p_clean) + # (bsz, seq_len, 3) + + # Apply planner transformation (if any) + if self.planner is not None: + routing_probs = self.planner( + routing_probs, logits, x0, xt, time, next_time, plan_key + ) - new_xt = jnp.where(to_remask, noise_sample, new_xt) - new_xt = self.post_corruption_fn(new_xt) + new_xt = _apply_routing(routing_probs, xt, x0, x_noise, route_key) + new_xt = self.corruption_process.post_corruption_fn(new_xt) # Replace the unused tokens with the unused_token. - new_xt = jnp.where(unused_mask, self.unused_token, new_xt) + new_xt = jnp.where( + unused_mask, self.corruption_process.unused_token, new_xt + ) return DiffusionStep( xt=new_xt, @@ -398,21 +513,46 @@ class DiscreteDDIMStep(SamplerStep): Diffusion Models in Discrete State-Spaces" (known as D3PM, see https://arxiv.org/abs/2107.03006). - Given the forward process with density p(x_t|x_0) it computes the reverse - process by first sampling from p(x_0|x_t) to obtain x_0. + This sampler uses the 3-way routing representation. Given the forward process + with density p(x_t|x_0), we decompose the reverse posterior + p(x_s|x_t, x_0) into three components: - Then it samples x_s (for s < t) using the following formula: + p(x_s|x_t,x_0) = p_stay * δ_{x_t}(x_s) + p_noise * π(x_s) + + p_clean * δ_{x_0}(x_s) - p(x_s|x_t,x_0) ∝ p(x_s|x_0) * p(x_t|x_s) (1) + where: + - p_stay: probability of staying at x_t + - p_noise: probability of jumping to invariant noise + - p_clean: probability of jumping to the predicted x_0 - In order to compute (1) we recall that for any s, t such that s < t we have: + **Derivation.** Recall that for the forward process: - p(x_t|x_s) = (α_t/α_s) * δ_{x_s}(x_t) + (1 - α_t/α_s) * π(x_t) + p(x_t|x_s) = r * δ_{x_t}(x_s) + (1 - r) * π(x_t) + p(x_s|x_0) = α_s * δ_{x_0}(x_s) + (1 - α_s) * π(x_s) - The computation of the probability happens in the logits space. + where r = α_t/α_s. The posterior is proportional to their product: + + p(x_s|x_t,x_0) ∝ p(x_t|x_s) * p(x_s|x_0) + + Expanding gives four cross-terms: + + (T1) r * α_s * δ_{x_t}(x_s) * δ_{x_0}(x_s) + (T2) r * (1-α_s) * δ_{x_t}(x_s) * π(x_s) = r*(1-α_s)*π(x_t) * δ_{x_t} + (T3) (1-r) * α_s * π(x_t) * δ_{x_0}(x_s) + (T4) (1-r) * (1-α_s) * π(x_t) * π(x_s) + + Collecting by routing outcome: + - p_stay ∝ r * (1-α_s) * π(x_t) [T2: δ_{x_t} · π gives π(x_t)*δ_{x_t}] + - p_noise ∝ (1-r) * (1-α_s) * π(x_t) [T4: π(x_t) · π(x_s)] + - p_clean ∝ (1-r) * α_s * π(x_t) [T3: π(x_t) · δ_{x_0}] + + **The fourth cross-term (T1):** δ_{x_t}(x_s) * δ_{x_0}(x_s) is non-zero + only when x_s = x_t AND x_s = x_0, i.e., when x_0 = x_t. In that case it + contributes r * α_s to the δ_{x_t} (stay) weight. """ corruption_process: CategoricalProcess + planner: RoutingProbPlanner | None = None temperature: float = 1.0 def __post_init__(self): @@ -428,25 +568,149 @@ def __post_init__(self): ' with 0.0 probability mass for any element.' ) - @property - def mask_value(self) -> int: - return self.corruption_process.num_categories - 1 + @kt.typechecked + def initialize( + self, + initial_noise: DataArray, + initial_step_info: StepInfo, + ) -> DiffusionStep: + + init_logits = jnp.repeat( + initial_noise, self.corruption_process.num_categories, axis=-1 + ) + init_logits = jnp.zeros_like(init_logits, dtype=jnp.float32) - jnp.inf + + return DiffusionStep( + xt=initial_noise, + step_info=initial_step_info, + aux={'logits': init_logits}, + ) + # `logits` need to be passed in `aux` dictionary to a performance + # bug when using TPU. Needs to be investigated. + + @kt.typechecked + def update( + self, + prediction: TargetInfo, + current_step: DiffusionStep, + next_step_info: StepInfo, + ) -> DiffusionStep: + + current_step_info = current_step.step_info + xt = current_step.xt + + unused_mask = xt == self.corruption_process.unused_token + + time = current_step_info.time + next_time = next_step_info.time + time_bcast = utils.bcast_right(time, xt.ndim) + next_time_bcast = utils.bcast_right(next_time, xt.ndim) + key = next_step_info.rng + + # Get model predictions + logits = self.corruption_process.convert_predictions( + prediction, + xt, + time_bcast, + )['logits'] + logits = logits / self.temperature + + _, x0_key, noise_key, plan_key, route_key = jax.random.split(key, 5) + + # Sample candidates + x0 = jax.random.categorical(key=x0_key, logits=logits)[..., None] + x_noise = self.corruption_process.sample_from_invariant( + noise_key, data_spec=xt + ) + + # Schedule + alpha_s = self.corruption_process.schedule.alpha(next_time_bcast) + alpha_t = self.corruption_process.schedule.alpha(time_bcast) + ratio = alpha_t / alpha_s + # (bsz, *seq_len, 1) + + # Routing probabilities (unnormalized). + # See the class docstring for the full derivation of terms T1–T4. + pi_xt = self.corruption_process.invariant_probs_vec[xt[..., 0]][..., None] + + # T2 → stay, T4 → noise, T3 → clean + p_stay = ratio * (1.0 - alpha_s) * pi_xt + p_noise = (1.0 - ratio) * (1.0 - alpha_s) * pi_xt + p_clean = (1.0 - ratio) * alpha_s * pi_xt + + # T1 → adds r * α_s to stay, only when x_0 = x_t (see docstring). + x0_eq_xt = (x0 == xt).astype(jnp.float32) + p_stay = p_stay + x0_eq_xt * ratio * alpha_s + + routing_probs = RoutingProbs(stay=p_stay, noise=p_noise, clean=p_clean) + # (bsz, *seq_len, 3) + + # Apply planner transformation (if any) + if self.planner is not None: + routing_probs = self.planner( + routing_probs, logits, x0, xt, time, next_time, plan_key + ) + + new_xt = _apply_routing(routing_probs, xt, x0, x_noise, route_key) + new_xt = self.corruption_process.post_corruption_fn(new_xt) - @property - def unused_token(self) -> int: - return self.corruption_process.unused_token + # Replace the unused tokens with the unused_token. + new_xt = jnp.where( + unused_mask, self.corruption_process.unused_token, new_xt + ) - @property - def post_corruption_fn(self) -> discrete.PostCorruptionFn: - return self.corruption_process.post_corruption_fn + return DiffusionStep( + xt=new_xt, + step_info=next_step_info, + aux={'logits': logits}, + ) + # `logits` need to be passed in `aux` dictionary to a performance + # bug when using TPU. Needs to be investigated. - @property - def invariant_probs_vec(self) -> Float['M']: - return self.corruption_process.invariant_probs_vec + @kt.typechecked + def finalize( + self, + prediction: TargetInfo, + current_step: DiffusionStep, + last_step_info: StepInfo, + ) -> DiffusionStep: + return self.update( + prediction, + current_step, + last_step_info, + ) + + +################################################################################ +# MARK: Discrete Flow Matching Step +################################################################################ - @property - def process_num_categories(self) -> int: - return self.corruption_process.process_num_categories + +@dataclasses.dataclass(frozen=True, kw_only=True) +class DiscreteFlowMatchingStep(SamplerStep): + """Discrete Flow Matching step following https://arxiv.org/abs/2407.15595. + + This sampler uses the 3-way routing representation. The update rule + decomposes naturally into: + + p(x_s) = p_stay * δ_{x_t} + p_up * p_x0 + p_down * π + + where: + - p_stay = 1 - p_up - p_down + - p_up = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff) + - p_down = (α_s - α_t) / α_t * stoch_coeff + + Attributes: + corruption_process: The corruption process to use. + temperature: The temperature to use. + stoch_coeff: The stochasticity coefficient (default 0.0). Higher values + introduce more noise during the denoising process. + """ + + corruption_process: CategoricalProcess + planner: RoutingProbPlanner | None = None + temperature: float = 1.0 + stoch_coeff: float = 0.0 @kt.typechecked def initialize( @@ -465,8 +729,6 @@ def initialize( step_info=initial_step_info, aux={'logits': init_logits}, ) - # `logits` need to be passed in `aux` dictionary to a performance - # bug when using TPU. Needs to be investigated. @kt.typechecked def update( @@ -479,72 +741,75 @@ def update( current_step_info = current_step.step_info xt = current_step.xt - unused_mask = xt == self.unused_token - # The mask is True if the token is unused. + unused_mask = xt == self.corruption_process.unused_token time = current_step_info.time next_time = next_step_info.time - time = utils.bcast_right(time, xt.ndim) - next_time = utils.bcast_right(next_time, xt.ndim) + time_bcast = utils.bcast_right(time, xt.ndim) + next_time_bcast = utils.bcast_right(next_time, xt.ndim) key = next_step_info.rng - # Sample from p_{0|t} + # Get model predictions logits = self.corruption_process.convert_predictions( prediction, xt, - time, + time_bcast, )['logits'] logits = logits / self.temperature - x0 = jax.random.categorical(key=key, logits=logits)[..., None] - # (bsz, *seq_len, 1) - key, _ = jax.random.split(key) - - # Compute the probability vector + _, x0_key, noise_key, plan_key, route_key = jax.random.split(key, 5) - xt_oh = jax.nn.one_hot(xt[..., 0], num_classes=self.process_num_categories) - x0_oh = jax.nn.one_hot(x0[..., 0], num_classes=self.process_num_categories) - # (bsz, *seq_len, M) + # Sample candidates + x0 = jax.random.categorical(key=x0_key, logits=logits)[..., None] + x_noise = self.corruption_process.sample_from_invariant( + noise_key, data_spec=xt + ) - alpha_s = self.corruption_process.schedule.alpha(next_time) - alpha_t = self.corruption_process.schedule.alpha(time) - alpha_s = jnp.broadcast_to(alpha_s, x0_oh.shape) - alpha_t = jnp.broadcast_to(alpha_t, x0_oh.shape) - ratio = alpha_t / alpha_s - # (bsz, *seq_len, M) + # Denoising rates + alpha_s = self.corruption_process.schedule.alpha(next_time_bcast) + alpha_t = self.corruption_process.schedule.alpha(time_bcast) - first_logit = jnp.log( - ratio * xt_oh + (1.0 - ratio) * self.invariant_probs_vec[xt] + prob_up = ( + (alpha_s - alpha_t) + / jnp.maximum(1.0 - alpha_t, 1e-12) + * (1.0 + self.stoch_coeff) ) - second_logit = jnp.log( - alpha_s * x0_oh + (1.0 - alpha_s) * self.invariant_probs_vec + prob_down = ( + (alpha_s - alpha_t) / jnp.maximum(alpha_t, 1e-12) * self.stoch_coeff ) - total_logit = first_logit + second_logit - # Do not use this sampler for masking. - # What could happen is xt is unmasked (assume at first position) so the - # first logits (first_logit) is [value, -inf, ..., -inf]. Then assume - # that the predictionfor x0 is different than xt - # (can never happen in unmasking),assume that the second position is the - # one chosen by the x0 predictor. Then we have for the second logits - # (second_logit): [-inf, value, -inf, ..., -inf, value_mask]. - # So when we add them together we get [-inf, ..., -inf]. - # jax.random.categorical will then return the first position. - # This is not what we want and this behavior should not be accepted. - - # Sample from the distribution defined by logits - new_xt = jax.random.categorical(key=key, logits=total_logit)[..., None] - new_xt = self.post_corruption_fn(new_xt) + + # Clip and rescale to ensure valid probabilities + raw_p_up = jnp.maximum(prob_up, 0.0) + raw_p_down = jnp.maximum(prob_down, 0.0) + sum_jumps = raw_p_up + raw_p_down + scale_factor = jnp.maximum(1.0, sum_jumps) + + p_clean = raw_p_up / scale_factor + p_noise = raw_p_down / scale_factor + p_stay = 1.0 - p_clean - p_noise + + routing_probs = RoutingProbs(stay=p_stay, noise=p_noise, clean=p_clean) + # (bsz, *seq_len, 3) + + # Apply planner transformation (if any) + if self.planner is not None: + routing_probs = self.planner( + routing_probs, logits, x0, xt, time, next_time, plan_key + ) + + new_xt = _apply_routing(routing_probs, xt, x0, x_noise, route_key) + new_xt = self.corruption_process.post_corruption_fn(new_xt) # Replace the unused tokens with the unused_token. - new_xt = jnp.where(unused_mask, self.unused_token, new_xt) + new_xt = jnp.where( + unused_mask, self.corruption_process.unused_token, new_xt + ) return DiffusionStep( xt=new_xt, step_info=next_step_info, aux={'logits': logits}, ) - # `logits` need to be passed in `aux` dictionary to a performance - # bug when using TPU. Needs to be investigated. @kt.typechecked def finalize( @@ -563,6 +828,10 @@ def finalize( ################################################################################ # MARK: Integrated DDIM Step ################################################################################ +# Note: IntegratedDiscreteDDIMStep does NOT fit the 3-way routing scheme +# because it marginalizes over x_0 rather than sampling a single x_0. +# It is kept as-is with direct categorical sampling. +################################################################################ @dataclasses.dataclass(frozen=True, kw_only=True) @@ -593,7 +862,8 @@ class IntegratedDiscreteDDIMStep(SamplerStep): In particular, we use the following formula: - p(x_s|x_t) = p(x_t|x_s) * sum_{x_0} (p(x_0|x_t) / p(x_t|x_0)) p(x_s|x_0) (2) + p(x_s|x_t) = p(x_t|x_s) * sum_{x_0} (p(x_0|x_t) / p(x_t|x_0)) p(x_s|x_0) + (2) Denoting w(x_0, x_t) = p(x_0|x_t) / p(x_t|x_0) and W(x_t) = sum_{x_0} w(x_0, x_t) we have: @@ -619,26 +889,6 @@ def __post_init__(self): ' with 0.0 probability mass for any element.' ) - @property - def mask_value(self) -> int: - return self.corruption_process.num_categories - 1 - - @property - def unused_token(self) -> int: - return self.corruption_process.unused_token - - @property - def post_corruption_fn(self) -> discrete.PostCorruptionFn: - return self.corruption_process.post_corruption_fn - - @property - def invariant_probs_vec(self) -> Float['M']: - return self.corruption_process.invariant_probs_vec - - @property - def process_num_categories(self) -> int: - return self.corruption_process.process_num_categories - @kt.typechecked def initialize( self, @@ -668,7 +918,7 @@ def update( ) -> DiffusionStep: xt = current_step.xt - unused_mask = xt == self.unused_token + unused_mask = xt == self.corruption_process.unused_token time = utils.bcast_right(current_step.step_info.time, xt.ndim) next_time = utils.bcast_right(next_step_info.time, xt.ndim) @@ -683,7 +933,9 @@ def update( # (bsz, *seq_len, M) # One-hot encoding for the current state - xt_oh = jax.nn.one_hot(xt[..., 0], num_classes=self.process_num_categories) + xt_oh = jax.nn.one_hot( + xt[..., 0], num_classes=self.corruption_process.process_num_categories + ) # (bsz, *seq_len, M) # Calculate schedule alphas. @@ -695,7 +947,7 @@ def update( # (bsz, *seq_len, M) # Extract invariant probabilities. - pi = self.invariant_probs_vec + pi = self.corruption_process.invariant_probs_vec pi_xt = pi[xt[..., 0]][..., None] # The prior prob of the current token # (bsz, *seq_len, 1) @@ -721,15 +973,15 @@ def update( p_xs = q_xt_given_xs * expected_xs_given_x0 # (bsz, *seq_len, M) - # Convert back to logits for safe categorical sampling + # Convert to logits and sample. total_logit = jnp.log(jnp.clip(p_xs, min=1e-12)) - - # Sample and format the new state new_xt = jax.random.categorical(key=key, logits=total_logit)[..., None] - new_xt = self.post_corruption_fn(new_xt) + new_xt = self.corruption_process.post_corruption_fn(new_xt) # Replace the unused tokens with the unused_token. - new_xt = jnp.where(unused_mask, self.unused_token, new_xt) + new_xt = jnp.where( + unused_mask, self.corruption_process.unused_token, new_xt + ) return DiffusionStep( xt=new_xt, @@ -751,166 +1003,3 @@ def finalize( current_step, last_step_info, ) - - -################################################################################ -# MARK: Discrete Flow Matching Step -################################################################################ - - -@dataclasses.dataclass(frozen=True, kw_only=True) -class DiscreteFlowMatchingStep(SamplerStep): - """Discrete Flow Matching step following https://arxiv.org/abs/2407.15595. - - This sampler is the simplest variant of Algorithm 1 in Discrete Flow Matching, - Gat et. al., 2024, https://arxiv.org/abs/2407.15595. It implements the - update rule based on the probability velocity derived for the probability - path family in (9). - - The update rule is: - x_{t-dt} ~ (1 - prob_jump) * delta_{x_t} + prob_jump * prediction - - where prob_jump = (alpha_s - alpha_t) / (1 - alpha_t). Note that alpha(t) in - this codebase is the probability of keeping the original value, which - corresponds to 1 - kappa(t) in the paper if the time is reversed. - - Attributes: - corruption_process: The corruption process to use. - temperature: The temperature to use. - gamma: The corrector term (default 0.0). Higher values introduce more noise - during the denoising process, which can improve sample quality. - """ - - corruption_process: CategoricalProcess - temperature: float = 1.0 - gamma: float = 0.0 - - @property - def unused_token(self) -> int: - return self.corruption_process.unused_token - - @property - def post_corruption_fn(self) -> discrete.PostCorruptionFn: - return self.corruption_process.post_corruption_fn - - @kt.typechecked - def initialize( - self, - initial_noise: DataArray, - initial_step_info: StepInfo, - ) -> DiffusionStep: - - init_logits = jnp.repeat( - initial_noise, self.corruption_process.num_categories, axis=-1 - ) - init_logits = jnp.zeros_like(init_logits, dtype=jnp.float32) - jnp.inf - - return DiffusionStep( - xt=initial_noise, - step_info=initial_step_info, - aux={'logits': init_logits}, - ) - - @kt.typechecked - def update( - self, - prediction: TargetInfo, - current_step: DiffusionStep, - next_step_info: StepInfo, - ) -> DiffusionStep: - - current_step_info = current_step.step_info - xt = current_step.xt - - unused_mask = xt == self.unused_token - - time = current_step_info.time - next_time = next_step_info.time - time_bcast = utils.bcast_right(time, xt.ndim) - next_time_bcast = utils.bcast_right(next_time, xt.ndim) - key = next_step_info.rng - - # Sample from p_{0|t} - logits = self.corruption_process.convert_predictions( - prediction, - xt, - time_bcast, - )['logits'] - logits = logits / self.temperature - - _, sample_key, noise_key, jump_key = jax.random.split(key, 4) - sample = jax.random.categorical(key=sample_key, logits=logits)[..., None] - noise_sample = self.corruption_process.sample_from_invariant( - noise_key, data_spec=xt - ) - - # Denoising - alpha_s = self.corruption_process.schedule.alpha(next_time_bcast) - alpha_t = self.corruption_process.schedule.alpha(time_bcast) - - # prob_up is the probability of switching from the current state to the - # predicted data state. Following the paper's formula (24): - # u_fwd = (dot_kappa / (1 - kappa)) * (p_data - delta_xt) - # prob_down is the probability of switching back to noise (corrector logic): - # u_bwd = (dot_kappa / kappa) * (delta_xt - p_noise) - # Following the paper's formula (26), the combined velocity is: - # u_bar = (1 + gamma) * u_fwd - gamma * u_bwd. - # Note that since u_bwd (u^(0) in the paper) involves (delta_xt - p_noise), - # it has negative jump rates back to noise. Subtracting it (-gamma * u_bwd) - # results in positive jump probabilities in the discretization. - - # We discretize this as a jump process where each token has probability - # prob_up of jumping to data and prob_down of jumping to noise. - - prob_up = ( - (alpha_s - alpha_t) - / jnp.maximum(1.0 - alpha_t, 1e-12) - * (1.0 + self.gamma) - ) - prob_down = (alpha_s - alpha_t) / jnp.maximum(alpha_t, 1e-12) * self.gamma - - # Calculate raw, unclipped probabilities - raw_p_up = jnp.maximum(prob_up, 0.0) - raw_p_down = jnp.maximum(prob_down, 0.0) - sum_jumps = raw_p_up + raw_p_down - - # If the sum exceeds 1.0, scale them down proportionally to maintain their - # ratio - scale_factor = jnp.maximum(1.0, sum_jumps) - - p_up = raw_p_up / scale_factor - p_down = raw_p_down / scale_factor - p_stay = 1.0 - p_up - p_down - - probs = jnp.stack([p_stay, p_up, p_down], axis=-1) - probs = jnp.broadcast_to(probs, xt.shape + (3,)) - jump_type = jax.random.categorical( - jump_key, logits=jnp.log(jnp.maximum(probs, 1e-12)) - ) - - # 0: stay, 1: jump to data, 2: jump to noise - new_xt = jnp.where(jump_type == 1, sample, xt) - new_xt = jnp.where(jump_type == 2, noise_sample, new_xt) - new_xt = self.post_corruption_fn(new_xt) - - # Replace the unused tokens with the unused_token. - new_xt = jnp.where(unused_mask, self.unused_token, new_xt) - - return DiffusionStep( - xt=new_xt, - step_info=next_step_info, - aux={'logits': logits}, - ) - - @kt.typechecked - def finalize( - self, - prediction: TargetInfo, - current_step: DiffusionStep, - last_step_info: StepInfo, - ) -> DiffusionStep: - return self.update( - prediction, - current_step, - last_step_info, - ) diff --git a/hackable_diffusion/lib/sampling/discrete_step_sampler_test.py b/hackable_diffusion/lib/sampling/discrete_step_sampler_test.py index 1b0522b..ac335a0 100644 --- a/hackable_diffusion/lib/sampling/discrete_step_sampler_test.py +++ b/hackable_diffusion/lib/sampling/discrete_step_sampler_test.py @@ -649,7 +649,7 @@ def test_update_with_gamma(self): # Use gamma that won't clip. dfm_step_gamma = discrete_step_sampler.DiscreteFlowMatchingStep( - corruption_process=self.process, gamma=1.0 + corruption_process=self.process, stoch_coeff=1.0 ) next_step_info = StepInfo( @@ -667,5 +667,153 @@ def test_update_with_gamma(self): self.assertTrue(jnp.any(next_step.xt != 1)) +class DDIMRoutingEquivalenceTest(absltest.TestCase): + """Verify routing-based DDIM matches the original logit-space computation. + + The original DiscreteDDIMStep computed the reverse posterior in full + M-dimensional logit space: + + first_logit[k] = log(r * 1[k=xt] + (1-r) * π(xt)) + second_logit[k] = log(αs * 1[k=x0] + (1-αs) * π(k)) + total_logit = first_logit + second_logit + + The routing reformulation decomposes this into 3-way routing weights. + This test checks that both produce exactly the same distribution over + output tokens, including the edge case where x0 == xt. + """ + + def _posterior_distribution( + self, xt, x0, alpha_s, alpha_t, invariant_probs_vec + ): + """Compute the exact posterior in probability space. + + p(x_s | x_t, x_0) ∝ p(x_t | x_s) * p(x_s | x_0) + + Evaluated for every x_s in {0, ..., M-1}. + + Args: + xt: Current token. + x0: Predicted clean token. + alpha_s: Diffusion schedule value at time s. + alpha_t: Diffusion schedule value at time t. + invariant_probs_vec: Invariant distribution. + + Returns: + The M-dimensional posterior distribution. + """ + voc_size = int(invariant_probs_vec.shape[0]) + ratio = alpha_t / alpha_s + + # Build unnormalized weight for each x_s value. + weights = [] + for xs in range(voc_size): + # p(x_t | x_s) = r * 1[xs=xt] + (1-r) * π(xt) + p_xt_given_xs = ratio * float(xs == xt) + (1.0 - ratio) * float( + invariant_probs_vec[xt] + ) + # p(x_s | x_0) = α_s * 1[xs=x0] + (1-α_s) * π(xs) + p_xs_given_x0 = alpha_s * float(xs == x0) + (1.0 - alpha_s) * float( + invariant_probs_vec[xs] + ) + weights.append(p_xt_given_xs * p_xs_given_x0) + + weights = jnp.array(weights) + return weights / jnp.sum(weights) + + def _routing_distribution( + self, xt, x0, alpha_s, alpha_t, invariant_probs_vec + ): + """Compute the routing-based posterior distribution. + + Mirrors the actual code in DiscreteDDIMStep.update. + + Args: + xt: Current token. + x0: Predicted clean token. + alpha_s: Diffusion schedule value at time s. + alpha_t: Diffusion schedule value at time t. + invariant_probs_vec: Invariant distribution. + + Returns: + The M-dimensional posterior distribution. + """ + ratio = alpha_t / alpha_s + pi_xt = float(invariant_probs_vec[xt]) + + # T2 → stay, T4 → noise, T3 → clean + p_stay = ratio * (1.0 - alpha_s) * pi_xt + p_noise = (1.0 - ratio) * (1.0 - alpha_s) * pi_xt + p_clean = (1.0 - ratio) * alpha_s * pi_xt + + # T1: fourth cross-term, only when x0 == xt. + if x0 == xt: + p_stay = p_stay + ratio * alpha_s + + total = p_stay + p_noise + p_clean + p_stay_norm = p_stay / total + p_noise_norm = p_noise / total + p_clean_norm = p_clean / total + + # Build the M-dimensional output distribution by marginalizing + # over the routing action: + # P(output=k) = P(STAY)*1[k=xt] + P(NOISE)*π(k) + P(CLEAN)*1[k=x0] + inv_probs = [float(p) for p in invariant_probs_vec] + dist = [p_noise_norm * inv_probs[k] for k in range(len(inv_probs))] + dist[xt] += p_stay_norm + dist[x0] += p_clean_norm + return jnp.array(dist) + + def test_equivalence_x0_neq_xt(self): + """Test routing matches posterior when x0 != xt.""" + voc_size = 5 + invariant_probs = jnp.array([0.1, 0.3, 0.2, 0.25, 0.15]) + + for xt_val in range(voc_size): + for x0_val in range(voc_size): + if x0_val == xt_val: + continue + for alpha_s_val in [0.2, 0.5, 0.8]: + alpha_t = 0.05 + p_exact = self._posterior_distribution( + xt_val, x0_val, alpha_s_val, alpha_t, invariant_probs + ) + p_route = self._routing_distribution( + xt_val, x0_val, alpha_s_val, alpha_t, invariant_probs + ) + chex.assert_trees_all_close(p_exact, p_route, atol=1e-6) + + def test_equivalence_x0_eq_xt(self): + """Test routing matches posterior when x0 == xt (the T1 cross-term).""" + voc_size = 5 + invariant_probs = jnp.array([0.1, 0.3, 0.2, 0.25, 0.15]) + + for xt_val in range(voc_size): + x0_val = xt_val + for alpha_s_val in [0.2, 0.5, 0.8]: + alpha_t = 0.05 + p_exact = self._posterior_distribution( + xt_val, x0_val, alpha_s_val, alpha_t, invariant_probs + ) + p_route = self._routing_distribution( + xt_val, x0_val, alpha_s_val, alpha_t, invariant_probs + ) + chex.assert_trees_all_close(p_exact, p_route, atol=1e-6) + + def test_equivalence_nonuniform_invariant(self): + """Test with a highly non-uniform invariant distribution.""" + voc_size = 3 + invariant_probs = jnp.array([0.01, 0.01, 0.98]) + + for xt_val in range(voc_size): + for x0_val in range(voc_size): + p_exact = self._posterior_distribution( + xt_val, x0_val, 0.3, 0.7, invariant_probs + ) + p_route = self._routing_distribution( + xt_val, x0_val, 0.3, 0.7, invariant_probs + ) + chex.assert_trees_all_close(p_exact, p_route, atol=1e-6) + + if __name__ == '__main__': absltest.main() diff --git a/hackable_diffusion/lib/sampling/planner_test.py b/hackable_diffusion/lib/sampling/planner_test.py new file mode 100644 index 0000000..9521d49 --- /dev/null +++ b/hackable_diffusion/lib/sampling/planner_test.py @@ -0,0 +1,153 @@ +# Copyright 2026 Hackable Diffusion Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for routing in discrete samplers.""" + +import chex +from hackable_diffusion.lib.sampling import discrete_step_sampler +import jax +import jax.numpy as jnp + +from absl.testing import absltest + + +class ApplyRoutingTest(absltest.TestCase): + """Tests for the _apply_routing helper.""" + + def test_deterministic_stay(self): + # routing_probs = [1, 0, 0] means stay. + routing_probs = discrete_step_sampler.RoutingProbs( + stay=jnp.array([[[1.0], [1.0]]]), + noise=jnp.array([[[0.0], [0.0]]]), + clean=jnp.array([[[0.0], [0.0]]]), + ) + xt = jnp.array([[[3], [5]]]) + x0 = jnp.array([[[0], [1]]]) + x_noise = jnp.array([[[2], [2]]]) + key = jax.random.PRNGKey(0) + + new_xt = discrete_step_sampler._apply_routing( + routing_probs, xt, x0, x_noise, key + ) + chex.assert_trees_all_equal(new_xt, xt) + + def test_deterministic_clean(self): + # routing_probs = [0, 0, 1] means jump to x0. + routing_probs = discrete_step_sampler.RoutingProbs( + stay=jnp.array([[[0.0], [0.0]]]), + noise=jnp.array([[[0.0], [0.0]]]), + clean=jnp.array([[[1.0], [1.0]]]), + ) + xt = jnp.array([[[3], [5]]]) + x0 = jnp.array([[[0], [1]]]) + x_noise = jnp.array([[[2], [2]]]) + key = jax.random.PRNGKey(0) + + new_xt = discrete_step_sampler._apply_routing( + routing_probs, xt, x0, x_noise, key + ) + chex.assert_trees_all_equal(new_xt, x0) + + def test_deterministic_noise(self): + # routing_probs = [0, 1, 0] means jump to noise. + routing_probs = discrete_step_sampler.RoutingProbs( + stay=jnp.array([[[0.0], [0.0]]]), + noise=jnp.array([[[1.0], [1.0]]]), + clean=jnp.array([[[0.0], [0.0]]]), + ) + xt = jnp.array([[[3], [5]]]) + x0 = jnp.array([[[0], [1]]]) + x_noise = jnp.array([[[2], [2]]]) + key = jax.random.PRNGKey(0) + + new_xt = discrete_step_sampler._apply_routing( + routing_probs, xt, x0, x_noise, key + ) + chex.assert_trees_all_equal(new_xt, x_noise) + + def test_mixed_routing(self): + # Position 0: deterministic stay, Position 1: deterministic clean. + routing_probs = discrete_step_sampler.RoutingProbs( + stay=jnp.array([[[1.0], [0.0]]]), + noise=jnp.array([[[0.0], [0.0]]]), + clean=jnp.array([[[0.0], [1.0]]]), + ) + xt = jnp.array([[[3], [5]]]) + x0 = jnp.array([[[0], [1]]]) + x_noise = jnp.array([[[2], [2]]]) + key = jax.random.PRNGKey(0) + + new_xt = discrete_step_sampler._apply_routing( + routing_probs, xt, x0, x_noise, key + ) + expected = jnp.array([[[3], [1]]]) + chex.assert_trees_all_equal(new_xt, expected) + + def test_stochastic_routing(self): + # 50/50 stay vs clean — results should vary across seeds. + routing_probs = discrete_step_sampler.RoutingProbs( + stay=jnp.array([[[0.5]]]), + noise=jnp.array([[[0.0]]]), + clean=jnp.array([[[0.5]]]), + ) + xt = jnp.array([[[3]]]) + x0 = jnp.array([[[0]]]) + x_noise = jnp.array([[[2]]]) + + results = set() + for seed in range(50): + new_xt = discrete_step_sampler._apply_routing( + routing_probs, xt, x0, x_noise, jax.random.PRNGKey(seed) + ) + results.add(int(new_xt[0, 0, 0])) + + # Should see both stay (3) and clean (0). + self.assertIn(3, results) + self.assertIn(0, results) + + def test_routing_constants(self): + self.assertEqual(discrete_step_sampler.STAY, 0) + self.assertEqual(discrete_step_sampler.NOISE, 1) + self.assertEqual(discrete_step_sampler.CLEAN, 2) + + +class PlannerProtocolTest(absltest.TestCase): + + def test_identity_planner(self): + + class IdentityPlanner: + + def __call__(self, routing_probs, logits, x0, xt, time, next_time, key): + return routing_probs + + planner = IdentityPlanner() + routing_probs = discrete_step_sampler.RoutingProbs( + stay=jnp.array([[[0.2]]]), + noise=jnp.array([[[0.3]]]), + clean=jnp.array([[[0.5]]]), + ) + # dummy args + logits = jnp.zeros((1, 1, 5)) + x0 = jnp.zeros((1, 1, 1)) + xt = jnp.zeros((1, 1, 1)) + time = jnp.array([1.0]) + next_time = jnp.array([0.5]) + key = jax.random.PRNGKey(0) + + out = planner(routing_probs, logits, x0, xt, time, next_time, key) + chex.assert_trees_all_equal(out, routing_probs) + + +if __name__ == '__main__': + absltest.main()