diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index e5347797..b3f1bc7a 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -489,12 +489,14 @@ def get_precision(config): retval = jax.lax.Precision.HIGHEST return retval + def value_or_none(flash_block_sizes, key): if key in flash_block_sizes: return flash_block_sizes[key] else: return None + def get_flash_block_sizes(config): """Create custom flash attention BlockSizes.""" flash_block_sizes = None @@ -508,7 +510,7 @@ def get_flash_block_sizes(config): block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"], block_q_dq=value_or_none(config.flash_block_sizes, "block_q_dq"), block_kv_dq=value_or_none(config.flash_block_sizes, "block_kv_dq"), - use_fused_bwd_kernel=value_or_none(config.flash_block_sizes, "use_fused_bwd_kernel") + use_fused_bwd_kernel=value_or_none(config.flash_block_sizes, "use_fused_bwd_kernel"), ) return flash_block_sizes @@ -528,6 +530,20 @@ def get_memory_allocations(): ) +def get_live_arrays(): + + backend = jax.extend.backend.get_backend() + live_arrays = backend.live_arrays() + + max_logging.log(f"Total live arrays: {len(live_arrays)}\n") + + for i, arr in enumerate(live_arrays): + max_logging.log(f"Array {i}:") + max_logging.log(f" Shape: {arr.shape}") + max_logging.log(f" Dtype: {arr.dtype}") + max_logging.log(f" Devices: {arr.devices()}") + + # Taking inspiration from flax's https://flax.readthedocs.io/en/v0.5.3/_modules/flax/linen/summary.html#tabulate # to retrieve layer parameters and calculate def calculate_model_tflops(module: module_lib.Module, rngs: Union[PRNGKey, RNGSequences], train, **kwargs): diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index b9892c5f..510d044b 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -215,8 +215,14 @@ def _tpu_flash_attention( def wrap_flash_attention(query, key, value): uses_fused_kernel = block_sizes.use_fused_bwd_kernel - block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv,) - block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv,) + block_q_sizes = ( + block_sizes.block_q, + block_sizes.block_q_dkv, + ) + block_kv_sizes = ( + block_sizes.block_kv, + block_sizes.block_kv_dkv, + ) if uses_fused_kernel: block_q_sizes += (block_sizes.block_q_dkv,) block_kv_sizes += (block_sizes.block_kv_dkv,) @@ -455,7 +461,16 @@ def _apply_attention( ) elif attention_kernel == "flash": return _tpu_flash_attention( - query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, residual_checkpoint_name=residual_checkpoint_name + query, + key * scale, + value, + heads, + mesh, + axis_names_q, + axis_names_kv, + flash_block_sizes, + dtype, + residual_checkpoint_name=residual_checkpoint_name, ) elif attention_kernel == "ring": return _tpu_flash_attention( diff --git a/src/maxdiffusion/models/gradient_checkpoint.py b/src/maxdiffusion/models/gradient_checkpoint.py index 99c514df..9162fbcb 100644 --- a/src/maxdiffusion/models/gradient_checkpoint.py +++ b/src/maxdiffusion/models/gradient_checkpoint.py @@ -80,7 +80,7 @@ def to_jax_policy(self, names_which_can_be_saved: list = [], names_which_can_be_ case GradientCheckpointType.HIDDEN_STATE_WITH_OFFLOAD: return jax.checkpoint_policies.save_and_offload_only_these_names( names_which_can_be_saved=[], - names_which_can_be_offloaded=["hidden_states","self_attn","cross_attn"], + names_which_can_be_offloaded=["hidden_states", "self_attn", "cross_attn"], offload_src="device", offload_dst="pinned_host", ) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 46dd7ca5..876fdb04 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -283,7 +283,7 @@ def __init__( precision=precision, attention_kernel=attention, dropout=dropout, - residual_checkpoint_name='self_attn', + residual_checkpoint_name="self_attn", ) # 1. Cross-attention @@ -302,7 +302,7 @@ def __init__( precision=precision, attention_kernel=attention, dropout=dropout, - residual_checkpoint_name='cross_attn', + residual_checkpoint_name="cross_attn", ) assert cross_attn_norm is True self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 3e7ce7bf..55981be0 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -131,9 +131,9 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): # This helps with loading sharded weights directly into the accelerators without fist copying them # all to one device and then distributing them, thus using low HBM memory. if restored_checkpoint: - if "params" in restored_checkpoint["wan_state"]: # if checkpointed with optimizer + if "params" in restored_checkpoint["wan_state"]: # if checkpointed with optimizer params = restored_checkpoint["wan_state"]["params"] - else: # if not checkpointed with optimizer + else: # if not checkpointed with optimizer params = restored_checkpoint["wan_state"] else: params = load_wan_transformer( diff --git a/src/maxdiffusion/tests/wan_checkpointer_test.py b/src/maxdiffusion/tests/wan_checkpointer_test.py index a588aa3d..ab5b5ca3 100644 --- a/src/maxdiffusion/tests/wan_checkpointer_test.py +++ b/src/maxdiffusion/tests/wan_checkpointer_test.py @@ -16,107 +16,106 @@ from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer, WAN_CHECKPOINT + class WanCheckpointerTest(unittest.TestCase): - def setUp(self): - self.config = MagicMock() - self.config.checkpoint_dir = "/tmp/wan_checkpoint_test" - self.config.dataset_type = "test_dataset" - - @patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager') - @patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline') - def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): - mock_manager = MagicMock() - mock_manager.latest_step.return_value = None - mock_create_manager.return_value = mock_manager - - mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance - - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) - pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) - - mock_manager.latest_step.assert_called_once() - mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config) - self.assertEqual(pipeline, mock_pipeline_instance) - self.assertIsNone(opt_state) - self.assertIsNone(step) - - @patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager') - @patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline') - def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): - mock_manager = MagicMock() - mock_manager.latest_step.return_value = 1 - metadata_mock = MagicMock() - metadata_mock.wan_state = {} - mock_manager.item_metadata.return_value = metadata_mock - - restored_mock = MagicMock() - restored_mock.wan_state = {'params': {}} - restored_mock.wan_config = {} - restored_mock.keys.return_value = ['wan_state', 'wan_config'] - def getitem_side_effect(key): - if key == 'wan_state': - return restored_mock.wan_state - raise KeyError(key) - restored_mock.__getitem__.side_effect = getitem_side_effect - mock_manager.restore.return_value = restored_mock - - mock_create_manager.return_value = mock_manager - - mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) - pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) - - mock_manager.restore.assert_called_once_with( - directory=unittest.mock.ANY, - step=1, - args=unittest.mock.ANY - ) - mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) - self.assertEqual(pipeline, mock_pipeline_instance) - self.assertIsNone(opt_state) - self.assertEqual(step, 1) - - @patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager') - @patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline') - def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager): - mock_manager = MagicMock() - mock_manager.latest_step.return_value = 1 - metadata_mock = MagicMock() - metadata_mock.wan_state = {} - mock_manager.item_metadata.return_value = metadata_mock - - restored_mock = MagicMock() - restored_mock.wan_state = {'params': {}, 'opt_state': {'learning_rate': 0.001}} - restored_mock.wan_config = {} - restored_mock.keys.return_value = ['wan_state', 'wan_config'] - def getitem_side_effect(key): - if key == 'wan_state': - return restored_mock.wan_state - raise KeyError(key) - restored_mock.__getitem__.side_effect = getitem_side_effect - mock_manager.restore.return_value = restored_mock - - mock_create_manager.return_value = mock_manager - - mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) - pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) - - mock_manager.restore.assert_called_once_with( - directory=unittest.mock.ANY, - step=1, - args=unittest.mock.ANY - ) - mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) - self.assertEqual(pipeline, mock_pipeline_instance) - self.assertIsNotNone(opt_state) - self.assertEqual(opt_state['learning_rate'], 0.001) - self.assertEqual(step, 1) + + def setUp(self): + self.config = MagicMock() + self.config.checkpoint_dir = "/tmp/wan_checkpoint_test" + self.config.dataset_type = "test_dataset" + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline") + def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = None + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) + + mock_manager.latest_step.assert_called_once() + mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertIsNone(step) + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline") + def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.wan_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.wan_state = {"params": {}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["wan_state", "wan_config"] + + def getitem_side_effect(key): + if key == "wan_state": + return restored_mock.wan_state + raise KeyError(key) + + restored_mock.__getitem__.side_effect = getitem_side_effect + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertEqual(step, 1) + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline") + def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.wan_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.wan_state = {"params": {}, "opt_state": {"learning_rate": 0.001}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["wan_state", "wan_config"] + + def getitem_side_effect(key): + if key == "wan_state": + return restored_mock.wan_state + raise KeyError(key) + + restored_mock.__getitem__.side_effect = getitem_side_effect + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNotNone(opt_state) + self.assertEqual(opt_state["learning_rate"], 0.001) + self.assertEqual(step, 1) + if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 89981f1a..53743b93 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -213,7 +213,7 @@ def start_training(self): pipeline, opt_state, step = self.load_checkpoint() restore_args = {} if opt_state and step: - restore_args = {"opt_state": opt_state, "step":step} + restore_args = {"opt_state": opt_state, "step": step} del opt_state if self.config.enable_ssim: # Generate a sample before training to compare against generated sample after training. @@ -285,17 +285,18 @@ def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, wr if writer: writer.add_scalar("learning/eval_loss", final_eval_loss, step) - def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args:dict={}): + def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args: dict = {}): mesh = pipeline.mesh graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...) with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): state = TrainState.create( - apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state) + apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state + ) if restore_args: step = restore_args.get("step", 0) max_logging.log(f"Restoring optimizer and resuming from step {step}") - state.replace(opt_state=restore_args.get("opt_state"), step = restore_args.get("step", 0)) + state.replace(opt_state=restore_args.get("opt_state"), step=restore_args.get("step", 0)) del restore_args["opt_state"] del optimizer state = jax.tree.map(_to_array, state) @@ -303,10 +304,11 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data state = jax.lax.with_sharding_constraint(state, state_spec) state_shardings = nnx.get_named_sharding(state, mesh) if jax.process_index() == 0 and restore_args: - max_logging.log("--- Optimizer State Sharding Spec (opt_state) ---") - pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60) - max_logging.log(pretty_string) - max_logging.log("------------------------------------------------") + max_logging.log("--- Optimizer State Sharding Spec (opt_state) ---") + pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60) + max_logging.log(pretty_string) + max_logging.log("------------------------------------------------") + max_utils.delete_pytree(params) data_shardings = self.get_data_shardings(mesh) eval_data_shardings = self.get_eval_data_shardings(mesh) @@ -349,9 +351,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data last_profiling_step = np.clip( first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1 ) - if restore_args.get("step",0): - max_logging.log(f"Resuming training from step {step}") - start_step = restore_args.get("step",0) + if restore_args.get("step", 0): + max_logging.log(f"Resuming training from step {step}") + start_step = restore_args.get("step", 0) per_device_tflops, _, _ = WanTrainer.calculate_tflops(pipeline) scheduler_state = pipeline.scheduler_state example_batch = load_next_batch(train_data_iterator, None, self.config)