Skip to content

[Bug] FSDP FULL_SHARD incorrectly rejects timm models with features_only=True (FeatureListNet) due to overly-strict nn.ModuleDict inheritance check #2609

@wenwwww

Description

@wenwwww

Describe the bug
When timm.create_model is called with the features_only=True argument, it returns a FeatureListNet module. This module cannot be correctly wrapped by torch.distributed.fsdp.FullyShardedDataParallel when using the FULL_SHARD strategy.

UserWarning: FSDP will not all-gather parameters for containers that do not implement forward: FeatureListNet(
  (stem_0): Conv2d(3, 128, kernel_size=(4, 4), ...)

FSDP incorrectly identifies FeatureListNet as a container that does not implement forward, even though it does. This results in a ValueError

To Reproduce
Steps to reproduce the behavior:
1.Minimal Reproduction Code (test1.py):

import timm
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy

def main():
    dist.init_process_group(backend="nccl")
    local_rank = dist.get_rank()
    torch.cuda.set_device(local_rank)
    encoder = timm.create_model("convnextv2_base",pretrained=False,features_only=True)
    # This returns a FeatureListNet, which inherits from nn.ModuleDict
    print(type(encoder))

    model = FSDP(
        encoder,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        device_id=torch.cuda.current_device(),
    )# This will fail

    print(f"Rank {local_rank}: FSDP wrapping SUCCEEDED .")
if __name__ == '__main__':
    try:
        main()
    except Exception as e:
        print(f"Rank {dist.get_rank()} FAILED with error: {e}")
    finally:
        if dist.is_initialized():
            rank = dist.get_rank()
            dist.destroy_process_group()
            print(f"Rank {rank} cleaned up process group.")
    

2.Run Command

torchrun --nproc_per_node=2  test1.py

Expected behavior
The encoder module (FeatureListNet) should be successfully wrapped by FSDP FULL_SHARD without error, as it is a valid nn.Module that implements its own forward method.

Desktop:

  • OS: Ubuntu 22.04
  • timm-1.0.22-py3-none-any.whl
  • pytorch 2.6.0+cu124

Additional context
The root cause is a validation check in Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py:

if isinstance(module, (nn.ModuleList, nn.ModuleDict)):
        raise ValueError(
            f"fully_shard does not support containers that do not implement forward: {module}"
        )

In the timm library (Lib/site-packages/timm/models/_features.py), the FeatureListNet and FeatureDictNet classes inherit from nn.ModuleDict.

Crucially, both FeatureListNet and FeatureDictNet implement their own forward methods.

The FSDP validation is too strict. It only checks if the module isinstance of nn.ModuleDict and immediately raises the ValueError, without first checking if a forward method has been implemented by the inheriting class.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions