diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f122afaf2..b539105792 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,13 +35,13 @@ - `try_hex_decode_string` - `unicode` - `uuid_string` - + - Conditional expressions: - `booland_agg` - `boolxor_agg` - `regr_valy` - `zeroifnull` - + - Numeric expressions: - `cot` - `mod` @@ -65,6 +65,7 @@ #### Bug Fixes +- Fixed with a bug when sql generation when joining two `DataFrame`s created using `DataFrame.alias` and CTE optimization is enabled. - Fixed a bug in `XMLReader` where finding the start position of a row tag could return an incorrect file position. ### Snowpark pandas API Updates @@ -127,13 +128,13 @@ - `str.pad` - `str.len` - `str.ljust` - - `str.rjust` - - `str.split` - - `str.replace` - - `str.strip` - - `str.lstrip` - - `str.rstrip` - - `str.translate` + - `str.rjust` + - `str.split` + - `str.replace` + - `str.strip` + - `str.lstrip` + - `str.rstrip` + - `str.translate` - `dt.tz_localize` - `dt.tz_convert` - `dt.ceil` @@ -142,11 +143,11 @@ - `dt.normalize` - `dt.month_name` - `dt.day_name` - - `dt.strftime` - - `dt.dayofweek` - - `dt.weekday` - - `dt.dayofyear` - - `dt.isocalendar` + - `dt.strftime` + - `dt.dayofweek` + - `dt.weekday` + - `dt.dayofyear` + - `dt.isocalendar` - `rolling.min` - `rolling.max` - `rolling.count` diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 6446fc2da0..fe939a5b14 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -972,7 +972,8 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: for c in logical_plan.children: # post-order traversal of the tree resolved = self.resolve(c) - df_aliased_col_name_to_real_col_name.update(resolved.df_aliased_col_name_to_real_col_name) # type: ignore + for alias, dict_ in resolved.df_aliased_col_name_to_real_col_name.items(): + df_aliased_col_name_to_real_col_name[alias].update(dict_) resolved_children[c] = resolved if isinstance(logical_plan, Selectable): @@ -1004,9 +1005,8 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: res = self.do_resolve_with_resolved_children( logical_plan, resolved_children, df_aliased_col_name_to_real_col_name ) - res.df_aliased_col_name_to_real_col_name.update( - df_aliased_col_name_to_real_col_name - ) + for alias, dict_ in df_aliased_col_name_to_real_col_name.items(): + res.df_aliased_col_name_to_real_col_name[alias].update(dict_) return res def do_resolve_with_resolved_children( diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 723277a31d..fed49f976b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -880,7 +880,7 @@ def __init__( self._projection_in_str = None self._query_params = None self.expr_to_alias.update(self.from_.expr_to_alias) - self.df_aliased_col_name_to_real_col_name.update( + self.df_aliased_col_name_to_real_col_name = deepcopy( self.from_.df_aliased_col_name_to_real_col_name ) self.api_calls = ( diff --git a/src/snowflake/snowpark/_internal/compiler/utils.py b/src/snowflake/snowpark/_internal/compiler/utils.py index 898b37d907..b62e986285 100644 --- a/src/snowflake/snowpark/_internal/compiler/utils.py +++ b/src/snowflake/snowpark/_internal/compiler/utils.py @@ -252,9 +252,10 @@ def update_resolvable_node( # df_aliased_col_name_to_real_col_name is updated at the frontend api # layer when alias is called, not produced during code generation. Should # always retain the original value of the map. - node.df_aliased_col_name_to_real_col_name.update( + node.df_aliased_col_name_to_real_col_name = copy.deepcopy( node.from_.df_aliased_col_name_to_real_col_name ) + # projection_in_str for SelectStatement runs a analyzer.analyze() which # needs the correct expr_to_alias map setup. This map is setup during # snowflake plan generation and cached for later use. Calling snowflake_plan diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 99328eeca1..90c6cecca0 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -25,6 +25,13 @@ when_matched, to_timestamp, ) +from snowflake.snowpark.types import ( + StructType, + StructField, + IntegerType, + StringType, + TimestampType, +) from tests.integ.scala.test_dataframe_reader_suite import get_reader from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker from tests.utils import IS_IN_STORED_PROC_LOCALFS, TestFiles, Utils @@ -272,6 +279,55 @@ def test_join_with_alias_dataframe(session): assert last_query.count(WITH) == 1 +def test_join_with_alias_dataframe_2(session): + # Reproduced from issue SNOW-2257191 + schema1 = StructType( + [ + StructField("DST_Year", IntegerType(), True), + StructField("DST_Start", TimestampType(), True), + StructField("DST_End", TimestampType(), True), + ] + ) + + schema2 = StructType( + [ + StructField("MATTRANSID", StringType(), True), + StructField("LOADSTARTTIME", TimestampType(), True), + StructField("LOADENDTIME", TimestampType(), True), + StructField("DUMPENDTIME", TimestampType(), True), + StructField("__CURRENT", StringType(), True), + StructField("__DELETED", StringType(), True), + ] + ) + + schema3 = StructType( + [ + StructField("MATTRANSID", StringType(), True), + StructField("DUMPENDTIME", TimestampType(), True), + StructField("LOADENDTIME", TimestampType(), True), + StructField("__CURRENT", StringType(), True), + StructField("__DELETED", StringType(), True), + ] + ) + + df1 = session.create_dataframe([], schema=schema1).cache_result() + df2 = session.create_dataframe([], schema=schema2).cache_result() + df3 = session.create_dataframe([], schema=schema3).cache_result() + + df4 = df2.alias("d2").join( + df1, col("d2", "LoadStartTime").between(df1.DST_Start, df1.DST_End), "left" + ) + + df5 = df3.alias("d3").join( + df1, col("d3", "LoadEndTime").between(df1.DST_Start, df1.DST_End), "left" + ) + + df6 = df5.join(df4, (df5.MatTransId == df4.MatTransId), "left") + + # Assert that the generated sql compiles + df6.collect() + + def test_join_with_set_operation(session): df1 = session.create_dataframe([[1, 2, 3], [4, 5, 6]], "a: int, b: int, c: int") df2 = session.create_dataframe([[1, 1], [4, 5]], "a: int, b: int")