Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/onnx_ir/passes/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"DeduplicateHashedInitializersPass",
"DeduplicateInitializersPass",
"IdentityEliminationPass",
"OutputFixPass",
"InlinePass",
"LiftConstantsToInitializersPass",
"LiftSubgraphInitializersToMainGraphPass",
Expand Down Expand Up @@ -43,6 +44,9 @@
from onnx_ir.passes.common.inliner import InlinePass
from onnx_ir.passes.common.naming import NameFixPass
from onnx_ir.passes.common.onnx_checker import CheckerPass
from onnx_ir.passes.common.output_fix import (
OutputFixPass,
)
from onnx_ir.passes.common.shape_inference import ShapeInferencePass
from onnx_ir.passes.common.topological_sort import TopologicalSortPass
from onnx_ir.passes.common.unused_removal import (
Expand Down
84 changes: 84 additions & 0 deletions src/onnx_ir/passes/common/output_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Identity fix pass for adding Identity nodes when graph inputs are directly used as outputs."""

from __future__ import annotations

__all__ = [
"OutputFixPass",
]

import logging

import onnx_ir as ir

logger = logging.getLogger(__name__)


class OutputFixPass(ir.passes.InPlacePass):
"""Pass for adding Identity nodes when graph inputs are directly used as outputs.

This pass adds Identity nodes according to the following rule:

If a graph input is directly used as a graph output (without any intermediate nodes),
insert an Identity node between them. This turns an invalid ONNX graph into a valid one.

Example transformation:
Before: input -> (direct connection) -> output
After: input -> Identity -> output

This is required because ONNX specification does not allow a graph input to be
directly used as a graph output without any processing nodes in between.
"""

def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Main entry point for the identity fix pass."""
modified = False

# Process the main graph
if self._process_graph(model.graph):
modified = True

# Process functions
for function in model.functions.values():
if self._process_graph(function):
modified = True

return ir.passes.PassResult(model, modified=modified)

def _process_graph(self, graph_like: ir.Graph | ir.Function) -> bool:
"""Process a single graph or function, returning True if modified."""
modified = False

for graph in (graph_like, *graph_like.subgraphs()):
# Check each output to see if it's directly a graph input
outputs_to_fix: list[tuple[ir.Value, int]] = []
for i, output in enumerate(graph.outputs):
if output.is_graph_input():
outputs_to_fix.append((output, i))

# Add Identity nodes for each output that needs fixing
for output, index in outputs_to_fix:
# Create an Identity node
identity_node = ir.node("Identity", inputs=[output])
identity_output = identity_node.outputs[0]

# Copy metadata from the original output
identity_output.name = output.name
identity_output.shape = output.shape
identity_output.type = output.type
identity_output.metadata_props.update(output.metadata_props)
identity_output.doc_string = output.doc_string

# Create a unique name for the old output to avoid name conflicts
# TODO: Use a better unique naming strategy if needed
output.name = f"{output.name}_orig"

# Add the node to the graph
graph.append(identity_node)
graph.outputs[index] = identity_output

logger.debug("Added Identity node for graph input '%s' used as output", output)
modified = True

return modified
Loading