Skip to content

Commit 26b5ed3

Browse files
justinchubyCopilotCopilot
authored
Add OutputFixPass to fix invalid graph outputs (#269)
This pull request introduces a new pass, `OutputFixPass`, to the ONNX IR transformation pipeline. The main purpose of this pass is to ensure that the model's outputs conform to ONNX specifications by inserting Identity nodes where necessary. This helps to fix cases where graph inputs are directly used as outputs or where the same value is used multiple times as an output, both of which are not allowed by ONNX. The most important changes include: ### New Pass Implementation * Added a new file, `output_fix.py`, implementing `OutputFixPass`. This pass automatically inserts Identity nodes in two scenarios: when a graph input is directly used as an output, and when a value is used multiple times as an output. The pass applies to both the main graph and all subgraphs/functions. ### AI use I used copilot to create the tests and some of the logic. --------- Signed-off-by: Justin Chu <[email protected]> Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: justinchuby <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 7283e95 commit 26b5ed3

File tree

3 files changed

+1007
-3
lines changed

3 files changed

+1007
-3
lines changed

src/onnx_ir/passes/common/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"DeduplicateHashedInitializersPass",
1010
"DeduplicateInitializersPass",
1111
"IdentityEliminationPass",
12+
"OutputFixPass",
1213
"InlinePass",
1314
"LiftConstantsToInitializersPass",
1415
"LiftSubgraphInitializersToMainGraphPass",
@@ -33,16 +34,15 @@
3334
LiftSubgraphInitializersToMainGraphPass,
3435
RemoveInitializersFromInputsPass,
3536
)
36-
from onnx_ir.passes.common.identity_elimination import (
37-
IdentityEliminationPass,
38-
)
37+
from onnx_ir.passes.common.identity_elimination import IdentityEliminationPass
3938
from onnx_ir.passes.common.initializer_deduplication import (
4039
DeduplicateHashedInitializersPass,
4140
DeduplicateInitializersPass,
4241
)
4342
from onnx_ir.passes.common.inliner import InlinePass
4443
from onnx_ir.passes.common.naming import NameFixPass
4544
from onnx_ir.passes.common.onnx_checker import CheckerPass
45+
from onnx_ir.passes.common.output_fix import OutputFixPass
4646
from onnx_ir.passes.common.shape_inference import ShapeInferencePass
4747
from onnx_ir.passes.common.topological_sort import TopologicalSortPass
4848
from onnx_ir.passes.common.unused_removal import (
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright (c) ONNX Project Contributors
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Output fix pass for adding Identity nodes.
4+
5+
- Graph inputs are directly used as outputs (without any intermediate nodes).
6+
- A value is used multiple times as a graph output (ensuring each output is unique).
7+
8+
This ensures compliance with the ONNX specification for valid output configurations.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
__all__ = [
14+
"OutputFixPass",
15+
]
16+
17+
import logging
18+
19+
import onnx_ir as ir
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class OutputFixPass(ir.passes.InPlacePass):
25+
"""Pass for adding Identity nodes to fix invalid output configurations.
26+
27+
This pass adds Identity nodes according to the following rules:
28+
29+
- If a graph input is directly used as a graph output (without any intermediate nodes),
30+
insert an Identity node between them. The ONNX specification does not allow a graph
31+
input to be directly used as a graph output without any processing nodes in between.
32+
- If a value is used multiple times as graph outputs, insert Identity nodes for each
33+
duplicate usage (keeping the first usage unchanged). This ensures each output value
34+
is unique, as required by the ONNX specification.
35+
36+
This pass processes both the main graph and all subgraphs (e.g., in control flow operators).
37+
38+
Example transformations:
39+
Direct input-to-output:
40+
Before: input -> (direct connection) -> output
41+
After: input -> Identity -> output
42+
43+
Duplicate outputs:
44+
Before: value -> [output1, output2]
45+
After: value -> output1, value -> Identity -> output2
46+
"""
47+
48+
def call(self, model: ir.Model) -> ir.passes.PassResult:
49+
"""Main entry point for the output fix pass."""
50+
modified = False
51+
52+
# Process the main graph
53+
if _alias_multi_used_outputs(model.graph):
54+
modified = True
55+
if _alias_direct_outputs(model.graph):
56+
modified = True
57+
58+
# Process functions
59+
for function in model.functions.values():
60+
if _alias_multi_used_outputs(function):
61+
modified = True
62+
if _alias_direct_outputs(function):
63+
modified = True
64+
65+
return ir.passes.PassResult(model, modified=modified)
66+
67+
68+
def _alias_multi_used_outputs(graph_like: ir.Graph | ir.Function) -> bool:
69+
"""Insert Identity nodes for values that appear in the graph output list multiple times."""
70+
modified = False
71+
72+
for graph in (graph_like, *graph_like.subgraphs()):
73+
# Count usage of each output
74+
seen: set[ir.Value] = set()
75+
76+
# Add Identity nodes for outputs used multiple times
77+
for i, output in enumerate(graph.outputs):
78+
if output not in seen:
79+
# Skip the first occurrence
80+
seen.add(output)
81+
continue
82+
83+
# Create an Identity node
84+
identity_node = ir.node("Identity", inputs=[output])
85+
identity_output = identity_node.outputs[0]
86+
87+
# Copy metadata from the original output
88+
# TODO: Use a better unique naming strategy if needed
89+
identity_output.name = f"{output.name}_alias_{i}"
90+
identity_output.shape = output.shape
91+
identity_output.type = output.type
92+
identity_output.metadata_props.update(output.metadata_props)
93+
identity_output.doc_string = output.doc_string
94+
95+
# Add the node to the graph
96+
graph.append(identity_node)
97+
graph.outputs[i] = identity_output
98+
logger.debug(
99+
"Added Identity node for graph output '%s' used multiple times", output
100+
)
101+
modified = True
102+
return modified
103+
104+
105+
def _alias_direct_outputs(graph_like: ir.Graph | ir.Function) -> bool:
106+
"""Insert Identity nodes for graph inputs used directly as outputs."""
107+
modified = False
108+
109+
for graph in (graph_like, *graph_like.subgraphs()):
110+
# Check each output to see if it's directly a graph input
111+
outputs_to_fix: list[tuple[ir.Value, int]] = []
112+
for i, output in enumerate(graph.outputs):
113+
if output.is_graph_input():
114+
outputs_to_fix.append((output, i))
115+
116+
# Add Identity nodes for each output that needs fixing
117+
for output, index in outputs_to_fix:
118+
# Create an Identity node
119+
identity_node = ir.node("Identity", inputs=[output])
120+
identity_output = identity_node.outputs[0]
121+
122+
# Copy metadata from the original output
123+
# Preserve the original output name
124+
identity_output.name = output.name
125+
identity_output.shape = output.shape
126+
identity_output.type = output.type
127+
identity_output.metadata_props.update(output.metadata_props)
128+
identity_output.doc_string = output.doc_string
129+
130+
# Create a new name for the old output
131+
# TODO: Use a better unique naming strategy if needed
132+
output.name = f"{output.name}_orig"
133+
134+
# Add the node to the graph
135+
graph.append(identity_node)
136+
graph.outputs[index] = identity_output
137+
138+
logger.debug("Added Identity node for graph input '%s' used as output", output)
139+
modified = True
140+
141+
return modified

0 commit comments

Comments
 (0)