Skip to content

Commit db1e7e9

Browse files
Addressed comments
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1 parent b250583 commit db1e7e9

File tree

1 file changed

+75
-54
lines changed

1 file changed

+75
-54
lines changed

python/torch_mlir/extras/fx_importer.py

Lines changed: 75 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,9 @@ def __init__(
566566
self._hooks = hooks or FxImporterHooks()
567567
self.symbol_table = SymbolTable(self._m.operation)
568568
self._hooks.prepare_module(self._m.operation)
569+
# Used specifically in HOPs to map module IDs to function names
569570
self._graph_module_to_func_name: Dict[int, str] = {}
571+
# Handles collision of function names in the same module
570572
self._func_name_counter: int = 0
571573

572574
def _config_check(self):
@@ -834,7 +836,7 @@ def import_program(
834836
# HOP operator graphs are stateless graphs with no mutation, and it is correct
835837
# to import them as stateless graphs.
836838
self._import_all_child_modules(
837-
prog, func_name, import_symbolic_shape_expressions
839+
prog.graph.owning_module, func_name, import_symbolic_shape_expressions
838840
)
839841

840842
# Import all nodes and return.
@@ -853,49 +855,27 @@ def import_program(
853855

854856
def _import_all_child_modules(
855857
self,
856-
prog_or_module: Union[torch.export.ExportedProgram, GraphModule],
858+
module: GraphModule,
857859
parent_name: str,
858860
import_symbolic_shape_expressions: bool = False,
859861
):
860-
"""Recursively import all child modules that have graphs.
862+
"""Import all child modules by delegating to import_graph_module.
861863
862-
This simple approach imports all submodules recursively, which is sufficient
863-
for HOP operations since they only reference existing submodules.
864+
This is a thin wrapper that extracts the owning module and delegates to
865+
import_graph_module for each child.
866+
867+
Note: This only imports children, not the parent module itself.
864868
"""
865-
# Get the owning module from either ExportedProgram or GraphModule
866-
if isinstance(prog_or_module, GraphModule):
867-
owning_module = prog_or_module
868-
else:
869-
owning_module = prog_or_module.graph.owning_module
870869

871-
for child_name, child_module in owning_module.named_children():
870+
for child_name, child_module in module.named_children():
872871
if isinstance(child_module, GraphModule) and hasattr(child_module, "graph"):
873-
# Check if we've already assigned a name to this module
874-
module_id = id(child_module)
875-
# Module already imported, skip it
876-
if module_id in self._graph_module_to_func_name:
877-
continue
878-
# Use the child_name directly - PyTorch already provides unique names
879-
child_func_name = child_name
880-
# Handle collision by adding counter suffix if name already exists
881-
if child_func_name in self._graph_module_to_func_name.values():
882-
child_func_name = f"{child_name}_{self._func_name_counter}"
883-
self._func_name_counter += 1
884-
# Store the mapping for future lookups
885-
self._graph_module_to_func_name[module_id] = child_func_name
886-
# Import the child as a stateless graph (private function)
887-
self.import_stateless_graph(
888-
child_module.graph,
889-
func_name=child_func_name,
872+
self.import_graph_module(
873+
child_module,
874+
func_name=child_name,
890875
func_visibility="private",
891876
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
892877
)
893878

894-
# Recursively import its children
895-
self._import_all_child_modules(
896-
child_module, child_func_name, import_symbolic_shape_expressions
897-
)
898-
899879
def import_frozen_program(
900880
self,
901881
prog: torch.export.ExportedProgram,
@@ -993,13 +973,56 @@ def import_frozen_program(
993973
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
994974
)
995975

996-
def import_graph_module(self, gm: GraphModule) -> Operation:
976+
def import_graph_module(
977+
self,
978+
gm: GraphModule,
979+
*,
980+
func_name: str = "main",
981+
func_visibility: Optional[str] = None,
982+
import_symbolic_shape_expressions: bool = False,
983+
) -> Operation:
997984
"""Low-level import of a GraphModule assuming that it has been functionalized.
998985
986+
This method recursively imports all child GraphModules first, then imports
987+
the provided GraphModule itself. This ensures that any higher-order operations
988+
that reference child modules will find them already imported.
989+
999990
TODO: This mechanism is deprecated by the `import_program` entry-point and
1000991
it should be removed when no longer required for backwards compatibility.
992+
993+
Note: This method should only be used for HOPs.
1001994
"""
1002-
return self.import_stateless_graph(gm.graph)
995+
# Store the mapping for this module itself (HOPs will need to look this up)
996+
module_id = id(gm)
997+
if module_id not in self._graph_module_to_func_name:
998+
# Ensure the func_name is unique
999+
final_func_name = func_name
1000+
if func_name in self._graph_module_to_func_name.values():
1001+
final_func_name = f"{func_name}_{self._func_name_counter}"
1002+
self._func_name_counter += 1
1003+
self._graph_module_to_func_name[module_id] = final_func_name
1004+
else:
1005+
# Module already imported, use existing name
1006+
final_func_name = self._graph_module_to_func_name[module_id]
1007+
1008+
# First, recursively import all child modules
1009+
for child_name, child_module in gm.named_children():
1010+
if isinstance(child_module, GraphModule) and hasattr(child_module, "graph"):
1011+
# Recursively import this child (which will handle its own mapping)
1012+
self.import_graph_module(
1013+
child_module,
1014+
func_name=child_name,
1015+
func_visibility="private",
1016+
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
1017+
)
1018+
1019+
# Then import this module's own graph
1020+
return self.import_stateless_graph(
1021+
gm.graph,
1022+
func_name=final_func_name,
1023+
func_visibility=func_visibility,
1024+
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
1025+
)
10031026

10041027
def import_stateless_graph(
10051028
self,
@@ -1033,11 +1056,9 @@ def import_stateless_graph(
10331056
entry_block,
10341057
)
10351058

1036-
# Import child modules (for HOPs) before importing nodes
1037-
if hasattr(g, 'owning_module') and g.owning_module is not None:
1038-
self._import_all_child_modules(
1039-
g.owning_module, func_name, import_symbolic_shape_expressions
1040-
)
1059+
# Note: Child module importing is handled by import_graph_module, which is
1060+
# the recommended entry point. This method is deprecated and should only be
1061+
# used for stateless graphs that truly have no child modules.
10411062

10421063
node_importer.import_nodes(
10431064
g.nodes, import_symbolic_shape_expressions=import_symbolic_shape_expressions
@@ -1068,10 +1089,11 @@ def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]:
10681089
# for single returns it might be a single Node.
10691090
output_arg = node.args[0]
10701091
# Handle both single Node and tuple/list of Nodes
1071-
if isinstance(output_arg, (list, tuple)):
1072-
result_nodes = output_arg
1073-
else:
1074-
result_nodes = [output_arg]
1092+
result_nodes = (
1093+
output_arg
1094+
if isinstance(output_arg, (list, tuple))
1095+
else [output_arg]
1096+
)
10751097

10761098
for result_node in result_nodes:
10771099
if result_node is None:
@@ -1586,11 +1608,11 @@ def import_nodes(
15861608
# results.
15871609
output_arg = node.args[0]
15881610
# Handle both single Node and tuple/list of Nodes
1589-
if isinstance(output_arg, (list, tuple)):
1590-
result_nodes = output_arg
1591-
else:
1592-
result_nodes = [output_arg]
1593-
1611+
result_nodes = (
1612+
output_arg
1613+
if isinstance(output_arg, (list, tuple))
1614+
else [output_arg]
1615+
)
15941616
operands = [self._import_argument(loc, arg) for arg in result_nodes]
15951617
func_dialect.ReturnOp(operands, loc=loc)
15961618

@@ -1722,8 +1744,8 @@ def _import_hop_while_loop(
17221744
body_fn_module = getattr(root_module, body_fn_node.target, None)
17231745

17241746
# Generate function names with module IDs for uniqueness
1725-
cond_fn_name = self.fx_importer._graph_module_to_func_name.get(id(cond_fn_module))
1726-
body_fn_name = self.fx_importer._graph_module_to_func_name.get(id(body_fn_module))
1747+
cond_fn_name = self.fx_importer._graph_module_to_func_name[id(cond_fn_module)]
1748+
body_fn_name = self.fx_importer._graph_module_to_func_name[id(body_fn_module)]
17271749

17281750
# Import the carries (loop state variables)
17291751
carry_values = []
@@ -1765,7 +1787,7 @@ def _import_hop_while_loop(
17651787
with loc:
17661788
max_iter = _make_constant_op(
17671789
"torch.constant.int",
1768-
self._cc.integer_attr(9223372036854775807, 64),
1790+
torch.iinfo(torch.int64).max,
17691791
self._cc.torch_int_type,
17701792
)
17711793

@@ -2153,7 +2175,7 @@ def _import_scalar_as_tensor(self, loc: Location, arg: NodeArgument) -> Value:
21532175

21542176
def _import_tuple_argument(
21552177
self, loc: Location, arg: tuple, expected_jit_type
2156-
) -> List[Value]:
2178+
) -> list[Value]:
21572179
"""Import a tuple argument by importing each element separately."""
21582180
# For tuples in while_loop carries, treat each element as a separate argument
21592181
return [self._import_argument(loc, elem, expected_jit_type) for elem in arg]
@@ -2268,7 +2290,6 @@ def _import_getitem(self, loc: Location, node: torch.fx.Node):
22682290
# NOTE: the length of the list must be knowable at compile time.
22692291
if ref_node not in self._unpack_list_values:
22702292
node_result = self.resolve_node_value(ref_node, 0)
2271-
node_val = ref_node.meta.get("val")
22722293

22732294
if str(node_result.type) in TORCH_LIST_TYPES:
22742295
result_types = [

0 commit comments

Comments
 (0)