@@ -114,6 +114,7 @@ def parameterize_inputs(inputs, prefix=""):
114114 inputs = tree .map_structure (make_tf_tensor_spec , input_signature )
115115 decorated_fn = get_concrete_fn (model , inputs , ** kwargs )
116116 ov_model = ov .convert_model (decorated_fn )
117+ set_names (ov_model , inputs )
117118 elif backend .backend () == "torch" :
118119 import torch
119120
@@ -128,6 +129,7 @@ def parameterize_inputs(inputs, prefix=""):
128129 warnings .filterwarnings ("ignore" , category = torch .jit .TracerWarning )
129130 traced = torch .jit .trace (model , sample_inputs )
130131 ov_model = ov .convert_model (traced )
132+ set_names (ov_model , sample_inputs )
131133 else :
132134 raise NotImplementedError (
133135 "`export_openvino` is only compatible with OpenVINO, "
@@ -140,6 +142,30 @@ def parameterize_inputs(inputs, prefix=""):
140142 io_utils .print_msg (f"Saved OpenVINO IR at '{ filepath } '." )
141143
142144
145+ def collect_names (structure ):
146+ if isinstance (structure , dict ):
147+ for k , v in structure .items ():
148+ if isinstance (v , (dict , list , tuple )):
149+ yield from collect_names (v )
150+ else :
151+ yield k
152+ elif isinstance (structure , (list , tuple )):
153+ for v in structure :
154+ yield from collect_names (v )
155+ else :
156+ if hasattr (structure , "name" ) and structure .name :
157+ yield structure .name
158+ else :
159+ yield "input"
160+
161+
162+ def set_names (model , inputs ):
163+ names = list (collect_names (inputs ))
164+ for ov_input , name in zip (model .inputs , names ):
165+ ov_input .get_node ().set_friendly_name (name )
166+ ov_input .tensor .set_names ({name })
167+
168+
143169def _check_jax_kwargs (kwargs ):
144170 kwargs = kwargs .copy ()
145171 if "is_static" not in kwargs :
0 commit comments