Skip to content

Commit a9041a1

Browse files
committed
Update SDK code ai_we_0
1 parent 585ff75 commit a9041a1

File tree

7 files changed

+453
-221
lines changed

7 files changed

+453
-221
lines changed

sdk/ai/azure-ai-agents/azure/ai/agents/_utils/model_base.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pylint: disable=line-too-long,useless-suppression,too-many-lines
12
# coding=utf-8
23
# --------------------------------------------------------------------------
34
# Copyright (c) Microsoft Corporation. All rights reserved.
@@ -637,6 +638,10 @@ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self:
637638
if not rf._rest_name_input:
638639
rf._rest_name_input = attr
639640
cls._attr_to_rest_field: dict[str, _RestField] = dict(attr_to_rest_field.items())
641+
cls._backcompat_attr_to_rest_field: dict[str, _RestField] = {
642+
Model._get_backcompat_attribute_name(cls._attr_to_rest_field, attr): rf
643+
for attr, rf in cls._attr_to_rest_field.items()
644+
}
640645
cls._calculated.add(f"{cls.__module__}.{cls.__qualname__}")
641646

642647
return super().__new__(cls)
@@ -646,6 +651,16 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None:
646651
if hasattr(base, "__mapping__"):
647652
base.__mapping__[discriminator or cls.__name__] = cls # type: ignore
648653

654+
@classmethod
655+
def _get_backcompat_attribute_name(cls, attr_to_rest_field: dict[str, "_RestField"], attr_name: str) -> str:
656+
rest_field_obj = attr_to_rest_field.get(attr_name) # pylint: disable=protected-access
657+
if rest_field_obj is None:
658+
return attr_name
659+
original_tsp_name = getattr(rest_field_obj, "_original_tsp_name", None) # pylint: disable=protected-access
660+
if original_tsp_name:
661+
return original_tsp_name
662+
return attr_name
663+
649664
@classmethod
650665
def _get_discriminator(cls, exist_discriminators) -> typing.Optional["_RestField"]:
651666
for v in cls.__dict__.values():
@@ -971,6 +986,7 @@ def _failsafe_deserialize_xml(
971986
return None
972987

973988

989+
# pylint: disable=too-many-instance-attributes
974990
class _RestField:
975991
def __init__(
976992
self,
@@ -983,6 +999,7 @@ def __init__(
983999
format: typing.Optional[str] = None,
9841000
is_multipart_file_input: bool = False,
9851001
xml: typing.Optional[dict[str, typing.Any]] = None,
1002+
original_tsp_name: typing.Optional[str] = None,
9861003
):
9871004
self._type = type
9881005
self._rest_name_input = name
@@ -994,10 +1011,15 @@ def __init__(
9941011
self._format = format
9951012
self._is_multipart_file_input = is_multipart_file_input
9961013
self._xml = xml if xml is not None else {}
1014+
self._original_tsp_name = original_tsp_name
9971015

9981016
@property
9991017
def _class_type(self) -> typing.Any:
1000-
return getattr(self._type, "args", [None])[0]
1018+
result = getattr(self._type, "args", [None])[0]
1019+
# type may be wrapped by nested functools.partial so we need to check for that
1020+
if isinstance(result, functools.partial):
1021+
return getattr(result, "args", [None])[0]
1022+
return result
10011023

10021024
@property
10031025
def _rest_name(self) -> str:
@@ -1045,6 +1067,7 @@ def rest_field(
10451067
format: typing.Optional[str] = None,
10461068
is_multipart_file_input: bool = False,
10471069
xml: typing.Optional[dict[str, typing.Any]] = None,
1070+
original_tsp_name: typing.Optional[str] = None,
10481071
) -> typing.Any:
10491072
return _RestField(
10501073
name=name,
@@ -1054,6 +1077,7 @@ def rest_field(
10541077
format=format,
10551078
is_multipart_file_input=is_multipart_file_input,
10561079
xml=xml,
1080+
original_tsp_name=original_tsp_name,
10571081
)
10581082

10591083

0 commit comments

Comments
 (0)