@@ -4352,13 +4352,13 @@ def duckdb_regex_literal_sql(delimiters: str) -> str:
43524352 with self .subTest (f"DuckDB rewrite for { dialect or 'default' } default delimiters" ):
43534353 escaped_literal = duckdb_regex_literal_sql (default_delimiters )
43544354 expected = (
4355- "CASE WHEN col IS NULL THEN NULL ELSE ARRAY_TO_STRING("
4355+ "ARRAY_TO_STRING("
43564356 f"CASE WHEN REGEXP_MATCHES(LEFT(col, 1), '[' || { escaped_literal } || ']') "
43574357 f"THEN LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || { escaped_literal } || ']+|[^' || { escaped_literal } || ']+)'), "
43584358 f"(seg, idx) -> CASE WHEN idx % 2 = 0 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
43594359 f"ELSE LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || { escaped_literal } || ']+|[^' || { escaped_literal } || ']+)'), "
43604360 f"(seg, idx) -> CASE WHEN idx % 2 = 1 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
4361- "END, '') END "
4361+ "END, '')"
43624362 )
43634363 self .assertEqual (parse_one ("INITCAP(col)" , read = dialect ).sql ("duckdb" ), expected )
43644364
@@ -4368,20 +4368,20 @@ def duckdb_regex_literal_sql(delimiters: str) -> str:
43684368 with self .subTest (f"Testing DuckDB generation for { query } from { dialect } " ):
43694369 self .assertEqual (
43704370 parse_one (query , read = dialect ).sql ("duckdb" ),
4371- "CASE WHEN col IS NULL THEN NULL ELSE UPPER(LEFT(col, 1)) || LOWER(SUBSTR(col, 2)) END " ,
4371+ "UPPER(LEFT(col, 1)) || LOWER(SUBSTR(col, 2))" ,
43724372 )
43734373
43744374 query = "INITCAP(col, NULL)"
43754375 with self .subTest (f"DuckDB generation for { query } from { dialect } " ):
43764376 self .assertEqual (
43774377 parse_one (query , read = dialect ).sql ("duckdb" ),
4378- "CASE WHEN col IS NULL OR NULL IS NULL THEN NULL ELSE ARRAY_TO_STRING("
4378+ "ARRAY_TO_STRING("
43794379 "CASE WHEN REGEXP_MATCHES(LEFT(col, 1), NULL) "
43804380 "THEN LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, NULL), "
43814381 "(seg, idx) -> CASE WHEN idx % 2 = 0 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
43824382 "ELSE LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, NULL), "
43834383 "(seg, idx) -> CASE WHEN idx % 2 = 1 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
4384- "END, '') END " ,
4384+ "END, '')" ,
43854385 )
43864386
43874387 for custom_delimiter in (" " , "@" , " _@" , r"\\" ):
@@ -4394,24 +4394,18 @@ def duckdb_regex_literal_sql(delimiters: str) -> str:
43944394 escaped_custom_delimiter = duckdb_regex_literal_sql (custom_delimiter )
43954395 self .assertEqual (
43964396 duckdb_sql ,
4397- "CASE WHEN col IS NULL THEN NULL ELSE ARRAY_TO_STRING("
4397+ "ARRAY_TO_STRING("
43984398 f"CASE WHEN REGEXP_MATCHES(LEFT(col, 1), '[' || { escaped_custom_delimiter } || ']') "
43994399 f"THEN LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || { escaped_custom_delimiter } || ']+|[^' || { escaped_custom_delimiter } || ']+)'), "
44004400 f"(seg, idx) -> CASE WHEN idx % 2 = 0 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
44014401 f"ELSE LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || { escaped_custom_delimiter } || ']+|[^' || { escaped_custom_delimiter } || ']+)'), "
44024402 f"(seg, idx) -> CASE WHEN idx % 2 = 1 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
4403- "END, '') END " ,
4403+ "END, '')" ,
44044404 )
44054405
44064406 def escape_expression_sql (sql : str ) -> str :
44074407 escaped_sql = sql
4408- for raw , escaped in (
4409- ("\\ " , "\\ \\ " ),
4410- ("-" , r"\-" ),
4411- ("^" , r"\^" ),
4412- ("[" , r"\[" ),
4413- ("]" , r"\]" ),
4414- ):
4408+ for raw , escaped in REGEX_LITERAL_ESCAPES .items ():
44154409 raw_sql = exp .Literal .string (raw ).sql ()
44164410 escaped_literal_sql = exp .Literal .string (escaped ).sql ()
44174411 escaped_sql = f"REPLACE({ escaped_sql } , { raw_sql } , { escaped_literal_sql } )"
@@ -4426,7 +4420,7 @@ def escape_expression_sql(sql: str) -> str:
44264420 parse_one (
44274421 "INITCAP(col, (SELECT delimiter FROM settings LIMIT 1))" , read = dialect
44284422 ).sql ("duckdb" ),
4429- "CASE WHEN col IS NULL OR (SELECT delimiter FROM settings LIMIT 1) IS NULL THEN NULL ELSE ARRAY_TO_STRING("
4423+ "ARRAY_TO_STRING("
44304424 + f"CASE WHEN REGEXP_MATCHES(LEFT(col, 1), '[' || { escaped_subquery } || ']') "
44314425 "THEN LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || "
44324426 + escaped_subquery
@@ -4440,7 +4434,7 @@ def escape_expression_sql(sql: str) -> str:
44404434 + escaped_subquery
44414435 + " || ']+)'), "
44424436 "(seg, idx) -> CASE WHEN idx % 2 = 1 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
4443- "END, '') END " ,
4437+ "END, '')" ,
44444438 )
44454439
44464440 def test_initcap_custom_delimiter_warning (self ):
0 commit comments