@@ -566,7 +566,9 @@ def __init__(
566566 self ._hooks = hooks or FxImporterHooks ()
567567 self .symbol_table = SymbolTable (self ._m .operation )
568568 self ._hooks .prepare_module (self ._m .operation )
569+ # Used specifically in HOPs to map module IDs to function names
569570 self ._graph_module_to_func_name : Dict [int , str ] = {}
571+ # Handles collision of function names in the same module
570572 self ._func_name_counter : int = 0
571573
572574 def _config_check (self ):
@@ -834,7 +836,7 @@ def import_program(
834836 # HOP operator graphs are stateless graphs with no mutation, and it is correct
835837 # to import them as stateless graphs.
836838 self ._import_all_child_modules (
837- prog , func_name , import_symbolic_shape_expressions
839+ prog . graph . owning_module , func_name , import_symbolic_shape_expressions
838840 )
839841
840842 # Import all nodes and return.
@@ -853,49 +855,27 @@ def import_program(
853855
854856 def _import_all_child_modules (
855857 self ,
856- prog_or_module : Union [ torch . export . ExportedProgram , GraphModule ] ,
858+ module : GraphModule ,
857859 parent_name : str ,
858860 import_symbolic_shape_expressions : bool = False ,
859861 ):
860- """Recursively import all child modules that have graphs .
862+ """Import all child modules by delegating to import_graph_module .
861863
862- This simple approach imports all submodules recursively, which is sufficient
863- for HOP operations since they only reference existing submodules.
864+ This is a thin wrapper that extracts the owning module and delegates to
865+ import_graph_module for each child.
866+
867+ Note: This only imports children, not the parent module itself.
864868 """
865- # Get the owning module from either ExportedProgram or GraphModule
866- if isinstance (prog_or_module , GraphModule ):
867- owning_module = prog_or_module
868- else :
869- owning_module = prog_or_module .graph .owning_module
870869
871- for child_name , child_module in owning_module .named_children ():
870+ for child_name , child_module in module .named_children ():
872871 if isinstance (child_module , GraphModule ) and hasattr (child_module , "graph" ):
873- # Check if we've already assigned a name to this module
874- module_id = id (child_module )
875- # Module already imported, skip it
876- if module_id in self ._graph_module_to_func_name :
877- continue
878- # Use the child_name directly - PyTorch already provides unique names
879- child_func_name = child_name
880- # Handle collision by adding counter suffix if name already exists
881- if child_func_name in self ._graph_module_to_func_name .values ():
882- child_func_name = f"{ child_name } _{ self ._func_name_counter } "
883- self ._func_name_counter += 1
884- # Store the mapping for future lookups
885- self ._graph_module_to_func_name [module_id ] = child_func_name
886- # Import the child as a stateless graph (private function)
887- self .import_stateless_graph (
888- child_module .graph ,
889- func_name = child_func_name ,
872+ self .import_graph_module (
873+ child_module ,
874+ func_name = child_name ,
890875 func_visibility = "private" ,
891876 import_symbolic_shape_expressions = import_symbolic_shape_expressions ,
892877 )
893878
894- # Recursively import its children
895- self ._import_all_child_modules (
896- child_module , child_func_name , import_symbolic_shape_expressions
897- )
898-
899879 def import_frozen_program (
900880 self ,
901881 prog : torch .export .ExportedProgram ,
@@ -993,13 +973,56 @@ def import_frozen_program(
993973 import_symbolic_shape_expressions = import_symbolic_shape_expressions ,
994974 )
995975
996- def import_graph_module (self , gm : GraphModule ) -> Operation :
976+ def import_graph_module (
977+ self ,
978+ gm : GraphModule ,
979+ * ,
980+ func_name : str = "main" ,
981+ func_visibility : Optional [str ] = None ,
982+ import_symbolic_shape_expressions : bool = False ,
983+ ) -> Operation :
997984 """Low-level import of a GraphModule assuming that it has been functionalized.
998985
986+ This method recursively imports all child GraphModules first, then imports
987+ the provided GraphModule itself. This ensures that any higher-order operations
988+ that reference child modules will find them already imported.
989+
999990 TODO: This mechanism is deprecated by the `import_program` entry-point and
1000991 it should be removed when no longer required for backwards compatibility.
992+
993+ Note: This method should only be used for HOPs.
1001994 """
1002- return self .import_stateless_graph (gm .graph )
995+ # Store the mapping for this module itself (HOPs will need to look this up)
996+ module_id = id (gm )
997+ if module_id not in self ._graph_module_to_func_name :
998+ # Ensure the func_name is unique
999+ final_func_name = func_name
1000+ if func_name in self ._graph_module_to_func_name .values ():
1001+ final_func_name = f"{ func_name } _{ self ._func_name_counter } "
1002+ self ._func_name_counter += 1
1003+ self ._graph_module_to_func_name [module_id ] = final_func_name
1004+ else :
1005+ # Module already imported, use existing name
1006+ final_func_name = self ._graph_module_to_func_name [module_id ]
1007+
1008+ # First, recursively import all child modules
1009+ for child_name , child_module in gm .named_children ():
1010+ if isinstance (child_module , GraphModule ) and hasattr (child_module , "graph" ):
1011+ # Recursively import this child (which will handle its own mapping)
1012+ self .import_graph_module (
1013+ child_module ,
1014+ func_name = child_name ,
1015+ func_visibility = "private" ,
1016+ import_symbolic_shape_expressions = import_symbolic_shape_expressions ,
1017+ )
1018+
1019+ # Then import this module's own graph
1020+ return self .import_stateless_graph (
1021+ gm .graph ,
1022+ func_name = final_func_name ,
1023+ func_visibility = func_visibility ,
1024+ import_symbolic_shape_expressions = import_symbolic_shape_expressions ,
1025+ )
10031026
10041027 def import_stateless_graph (
10051028 self ,
@@ -1033,11 +1056,9 @@ def import_stateless_graph(
10331056 entry_block ,
10341057 )
10351058
1036- # Import child modules (for HOPs) before importing nodes
1037- if hasattr (g , 'owning_module' ) and g .owning_module is not None :
1038- self ._import_all_child_modules (
1039- g .owning_module , func_name , import_symbolic_shape_expressions
1040- )
1059+ # Note: Child module importing is handled by import_graph_module, which is
1060+ # the recommended entry point. This method is deprecated and should only be
1061+ # used for stateless graphs that truly have no child modules.
10411062
10421063 node_importer .import_nodes (
10431064 g .nodes , import_symbolic_shape_expressions = import_symbolic_shape_expressions
@@ -1068,10 +1089,11 @@ def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]:
10681089 # for single returns it might be a single Node.
10691090 output_arg = node .args [0 ]
10701091 # Handle both single Node and tuple/list of Nodes
1071- if isinstance (output_arg , (list , tuple )):
1072- result_nodes = output_arg
1073- else :
1074- result_nodes = [output_arg ]
1092+ result_nodes = (
1093+ output_arg
1094+ if isinstance (output_arg , (list , tuple ))
1095+ else [output_arg ]
1096+ )
10751097
10761098 for result_node in result_nodes :
10771099 if result_node is None :
@@ -1586,11 +1608,11 @@ def import_nodes(
15861608 # results.
15871609 output_arg = node .args [0 ]
15881610 # Handle both single Node and tuple/list of Nodes
1589- if isinstance ( output_arg , ( list , tuple )):
1590- result_nodes = output_arg
1591- else :
1592- result_nodes = [output_arg ]
1593-
1611+ result_nodes = (
1612+ output_arg
1613+ if isinstance ( output_arg , ( list , tuple ))
1614+ else [output_arg ]
1615+ )
15941616 operands = [self ._import_argument (loc , arg ) for arg in result_nodes ]
15951617 func_dialect .ReturnOp (operands , loc = loc )
15961618
@@ -1722,8 +1744,8 @@ def _import_hop_while_loop(
17221744 body_fn_module = getattr (root_module , body_fn_node .target , None )
17231745
17241746 # Generate function names with module IDs for uniqueness
1725- cond_fn_name = self .fx_importer ._graph_module_to_func_name . get ( id (cond_fn_module ))
1726- body_fn_name = self .fx_importer ._graph_module_to_func_name . get ( id (body_fn_module ))
1747+ cond_fn_name = self .fx_importer ._graph_module_to_func_name [ id (cond_fn_module )]
1748+ body_fn_name = self .fx_importer ._graph_module_to_func_name [ id (body_fn_module )]
17271749
17281750 # Import the carries (loop state variables)
17291751 carry_values = []
@@ -1765,7 +1787,7 @@ def _import_hop_while_loop(
17651787 with loc :
17661788 max_iter = _make_constant_op (
17671789 "torch.constant.int" ,
1768- self . _cc . integer_attr ( 9223372036854775807 , 64 ) ,
1790+ torch . iinfo ( torch . int64 ). max ,
17691791 self ._cc .torch_int_type ,
17701792 )
17711793
@@ -2153,7 +2175,7 @@ def _import_scalar_as_tensor(self, loc: Location, arg: NodeArgument) -> Value:
21532175
21542176 def _import_tuple_argument (
21552177 self , loc : Location , arg : tuple , expected_jit_type
2156- ) -> List [Value ]:
2178+ ) -> list [Value ]:
21572179 """Import a tuple argument by importing each element separately."""
21582180 # For tuples in while_loop carries, treat each element as a separate argument
21592181 return [self ._import_argument (loc , elem , expected_jit_type ) for elem in arg ]
@@ -2268,7 +2290,6 @@ def _import_getitem(self, loc: Location, node: torch.fx.Node):
22682290 # NOTE: the length of the list must be knowable at compile time.
22692291 if ref_node not in self ._unpack_list_values :
22702292 node_result = self .resolve_node_value (ref_node , 0 )
2271- node_val = ref_node .meta .get ("val" )
22722293
22732294 if str (node_result .type ) in TORCH_LIST_TYPES :
22742295 result_types = [
0 commit comments