-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Description
Describe the bug
When timm.create_model is called with the features_only=True argument, it returns a FeatureListNet module. This module cannot be correctly wrapped by torch.distributed.fsdp.FullyShardedDataParallel when using the FULL_SHARD strategy.
UserWarning: FSDP will not all-gather parameters for containers that do not implement forward: FeatureListNet(
(stem_0): Conv2d(3, 128, kernel_size=(4, 4), ...)FSDP incorrectly identifies FeatureListNet as a container that does not implement forward, even though it does. This results in a ValueError
To Reproduce
Steps to reproduce the behavior:
1.Minimal Reproduction Code (test1.py):
import timm
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
def main():
dist.init_process_group(backend="nccl")
local_rank = dist.get_rank()
torch.cuda.set_device(local_rank)
encoder = timm.create_model("convnextv2_base",pretrained=False,features_only=True)
# This returns a FeatureListNet, which inherits from nn.ModuleDict
print(type(encoder))
model = FSDP(
encoder,
sharding_strategy=ShardingStrategy.FULL_SHARD,
device_id=torch.cuda.current_device(),
)# This will fail
print(f"Rank {local_rank}: FSDP wrapping SUCCEEDED .")
if __name__ == '__main__':
try:
main()
except Exception as e:
print(f"Rank {dist.get_rank()} FAILED with error: {e}")
finally:
if dist.is_initialized():
rank = dist.get_rank()
dist.destroy_process_group()
print(f"Rank {rank} cleaned up process group.")
2.Run Command
torchrun --nproc_per_node=2 test1.pyExpected behavior
The encoder module (FeatureListNet) should be successfully wrapped by FSDP FULL_SHARD without error, as it is a valid nn.Module that implements its own forward method.
Desktop:
- OS: Ubuntu 22.04
- timm-1.0.22-py3-none-any.whl
- pytorch 2.6.0+cu124
Additional context
The root cause is a validation check in Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py:
if isinstance(module, (nn.ModuleList, nn.ModuleDict)):
raise ValueError(
f"fully_shard does not support containers that do not implement forward: {module}"
)In the timm library (Lib/site-packages/timm/models/_features.py), the FeatureListNet and FeatureDictNet classes inherit from nn.ModuleDict.
Crucially, both FeatureListNet and FeatureDictNet implement their own forward methods.
The FSDP validation is too strict. It only checks if the module isinstance of nn.ModuleDict and immediately raises the ValueError, without first checking if a forward method has been implemented by the inheriting class.