Skip to content

Commit 7c1be9f

Browse files
Simplified recursive import_stateless_graph call
1 parent 4c5f940 commit 7c1be9f

File tree

1 file changed

+23
-51
lines changed

1 file changed

+23
-51
lines changed

python/torch_mlir/extras/fx_importer.py

Lines changed: 23 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -840,72 +840,44 @@ def import_program(
840840
# Even though import_stateless_graph is deprecated as an entrypoint mechanism,
841841
# HOP operator graphs are stateless graphs with no mutation, and it is correct
842842
# to import them as stateless graphs.
843-
self._import_child_graph_modules(
843+
self._import_all_child_modules(
844844
prog,
845845
func_name,
846846
import_symbolic_shape_expressions
847847
)
848848

849849
return func_op
850850

851-
def _import_child_graph_modules(
851+
def _import_all_child_modules(
852852
self,
853853
prog: torch.export.ExportedProgram,
854854
parent_name: str,
855855
import_symbolic_shape_expressions: bool = False
856856
):
857-
"""Scan graph for get_attr nodes and import referenced submodules as stateless graphs.
857+
"""Recursively import all child modules that have graphs.
858858
859-
This is used to import HOP subgraphs. When we encounter a get_attr node, it references
860-
a submodule that needs to be imported as a private function that can be called
861-
by the HOP operation.
859+
This simple approach imports all submodules recursively, which is sufficient
860+
for HOP operations since they only reference existing submodules.
862861
"""
863-
# Collect all get_attr nodes by recursively searching through all graphs
864-
get_attr_nodes = []
865-
self._collect_get_attr_nodes(prog.graph, get_attr_nodes, set())
866-
# Import each referenced module as a stateless graph
867-
for node in get_attr_nodes:
868-
attr_name = node.target
869-
870-
# Find the module that contains this attribute
871-
# The node's graph should have an owning_module that contains the submodule
872-
if hasattr(node.graph, 'owning_module'):
873-
parent_module = node.graph.owning_module
874-
875-
# Check if the parent module has the referenced attribute
876-
if hasattr(parent_module, attr_name):
877-
child_module = getattr(parent_module, attr_name)
878-
879-
# Only import if it's a GraphModule with a graph
880-
if isinstance(child_module, GraphModule) and hasattr(child_module, 'graph'):
881-
# Generate function name: parent_childname_id for uniqueness
882-
child_func_name = f"{parent_name}_{attr_name}_{id(child_module)}"
883-
884-
# Import the child as a stateless graph (private function)
885-
self.import_stateless_graph(
886-
child_module.graph,
887-
func_name=child_func_name,
888-
func_visibility="private",
889-
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
890-
)
891-
892-
def _collect_get_attr_nodes(self, graph: Graph, get_attr_nodes: List[Node], visited: Set[int] = None):
893-
"""Recursively collect all get_attr nodes from a graph and its subgraphs."""
894-
if visited is None:
895-
visited = set()
896-
897-
graph_id = id(graph)
898-
if graph_id in visited:
899-
return # Already processed this graph
900-
visited.add(graph_id)
901-
902-
for node in graph.nodes:
903-
if node.op == "get_attr":
904-
get_attr_nodes.append(node)
862+
for child_name, child_module in prog.graph.owning_module.named_children():
863+
if isinstance(child_module, GraphModule) and hasattr(child_module, 'graph'):
864+
# Generate function name: parent_childname
865+
child_func_name = f"{parent_name}_{child_name}_{id(child_module)}"
866+
867+
# Import the child as a stateless graph (private function)
868+
self.import_stateless_graph(
869+
child_module.graph,
870+
func_name=child_func_name,
871+
func_visibility="private",
872+
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
873+
)
905874

906-
# If this node has a subgraph (like HOP operations), search it too
907-
if hasattr(node, 'graph') and node.graph is not None:
908-
self._collect_get_attr_nodes(node.graph, get_attr_nodes, visited)
875+
# Recursively import its children
876+
self._import_all_child_modules(
877+
child_module,
878+
child_func_name,
879+
import_symbolic_shape_expressions
880+
)
909881

910882
def import_frozen_program(
911883
self,

0 commit comments

Comments
 (0)