@@ -840,72 +840,44 @@ def import_program(
840840 # Even though import_stateless_graph is deprecated as an entrypoint mechanism,
841841 # HOP operator graphs are stateless graphs with no mutation, and it is correct
842842 # to import them as stateless graphs.
843- self ._import_child_graph_modules (
843+ self ._import_all_child_modules (
844844 prog ,
845845 func_name ,
846846 import_symbolic_shape_expressions
847847 )
848848
849849 return func_op
850850
851- def _import_child_graph_modules (
851+ def _import_all_child_modules (
852852 self ,
853853 prog : torch .export .ExportedProgram ,
854854 parent_name : str ,
855855 import_symbolic_shape_expressions : bool = False
856856 ):
857- """Scan graph for get_attr nodes and import referenced submodules as stateless graphs.
857+ """Recursively import all child modules that have graphs.
858858
859- This is used to import HOP subgraphs. When we encounter a get_attr node, it references
860- a submodule that needs to be imported as a private function that can be called
861- by the HOP operation.
859+ This simple approach imports all submodules recursively, which is sufficient
860+ for HOP operations since they only reference existing submodules.
862861 """
863- # Collect all get_attr nodes by recursively searching through all graphs
864- get_attr_nodes = []
865- self ._collect_get_attr_nodes (prog .graph , get_attr_nodes , set ())
866- # Import each referenced module as a stateless graph
867- for node in get_attr_nodes :
868- attr_name = node .target
869-
870- # Find the module that contains this attribute
871- # The node's graph should have an owning_module that contains the submodule
872- if hasattr (node .graph , 'owning_module' ):
873- parent_module = node .graph .owning_module
874-
875- # Check if the parent module has the referenced attribute
876- if hasattr (parent_module , attr_name ):
877- child_module = getattr (parent_module , attr_name )
878-
879- # Only import if it's a GraphModule with a graph
880- if isinstance (child_module , GraphModule ) and hasattr (child_module , 'graph' ):
881- # Generate function name: parent_childname_id for uniqueness
882- child_func_name = f"{ parent_name } _{ attr_name } _{ id (child_module )} "
883-
884- # Import the child as a stateless graph (private function)
885- self .import_stateless_graph (
886- child_module .graph ,
887- func_name = child_func_name ,
888- func_visibility = "private" ,
889- import_symbolic_shape_expressions = import_symbolic_shape_expressions ,
890- )
891-
892- def _collect_get_attr_nodes (self , graph : Graph , get_attr_nodes : List [Node ], visited : Set [int ] = None ):
893- """Recursively collect all get_attr nodes from a graph and its subgraphs."""
894- if visited is None :
895- visited = set ()
896-
897- graph_id = id (graph )
898- if graph_id in visited :
899- return # Already processed this graph
900- visited .add (graph_id )
901-
902- for node in graph .nodes :
903- if node .op == "get_attr" :
904- get_attr_nodes .append (node )
862+ for child_name , child_module in prog .graph .owning_module .named_children ():
863+ if isinstance (child_module , GraphModule ) and hasattr (child_module , 'graph' ):
864+ # Generate function name: parent_childname
865+ child_func_name = f"{ parent_name } _{ child_name } _{ id (child_module )} "
866+
867+ # Import the child as a stateless graph (private function)
868+ self .import_stateless_graph (
869+ child_module .graph ,
870+ func_name = child_func_name ,
871+ func_visibility = "private" ,
872+ import_symbolic_shape_expressions = import_symbolic_shape_expressions ,
873+ )
905874
906- # If this node has a subgraph (like HOP operations), search it too
907- if hasattr (node , 'graph' ) and node .graph is not None :
908- self ._collect_get_attr_nodes (node .graph , get_attr_nodes , visited )
875+ # Recursively import its children
876+ self ._import_all_child_modules (
877+ child_module ,
878+ child_func_name ,
879+ import_symbolic_shape_expressions
880+ )
909881
910882 def import_frozen_program (
911883 self ,
0 commit comments