Skip to content

Commit 754dd00

Browse files
committed
PR feedback, clean up
1 parent 9584a13 commit 754dd00

File tree

5 files changed

+39
-66
lines changed

5 files changed

+39
-66
lines changed

sqlglot/dialects/dialect.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,9 @@ def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[
255255

256256
klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS
257257

258+
if enum not in ("", "bigquery", "snowflake"):
259+
klass.INITCAP_SUPPORTS_CUSTOM_DELIMITERS = False
260+
258261
if enum not in ("", "bigquery"):
259262
klass.generator_class.SELECT_KINDS = ()
260263

@@ -539,7 +542,7 @@ class Dialect(metaclass=_Dialect):
539542

540543
# Whether the INITCAP function supports custom delimiter characters as the second argument
541544
# Default delimiter characters for INITCAP function: whitespace and non-alphanumeric characters
542-
INITCAP_SUPPORTS_CUSTOM_DELIMITERS = False
545+
INITCAP_SUPPORTS_CUSTOM_DELIMITERS = True
543546
INITCAP_DEFAULT_DELIMITER_CHARS = " \t\n\r\f\v!\"#$%&'()*+,\\-./:;<=>?@\\[\\]^_`{|}~"
544547

545548
BYTE_STRING_IS_BYTES_TYPE: bool = False

sqlglot/dialects/duckdb.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@
4949
# The pattern matches timezone offsets that appear after the time portion
5050
TIMEZONE_PATTERN = re.compile(r":\d{2}.*?[+\-]\d{2}(?::\d{2})?")
5151

52+
# Characters that must be escaped when building regex expressions in INITCAP
53+
REGEX_ESCAPE_REPLACEMENTS = {
54+
"\\": "\\\\",
55+
"-": r"\-",
56+
"^": r"\^",
57+
"[": r"\[",
58+
"]": r"\]",
59+
}
60+
5261

5362
# BigQuery -> DuckDB conversion for the DATE function
5463
def _date_sql(self: DuckDB.Generator, expression: exp.Date) -> str:
@@ -296,29 +305,13 @@ def _escape_regex_metachars(
296305
if not delimiters:
297306
return delimiters_sql
298307

299-
REGEX_LITERAL_ESCAPES = {
300-
"\\": "\\\\", # literals need two slashes inside []
301-
"-": "\\-",
302-
"^": "\\^",
303-
"[": "\\[",
304-
"]": "\\]",
305-
}
306-
307308
if delimiters.is_string:
308309
literal_value = delimiters.this
309-
escaped_literal = "".join(REGEX_LITERAL_ESCAPES.get(ch, ch) for ch in literal_value)
310+
escaped_literal = "".join(REGEX_ESCAPE_REPLACEMENTS.get(ch, ch) for ch in literal_value)
310311
return self.sql(exp.Literal.string(escaped_literal))
311312

312-
REGEX_ESCAPE_REPLACEMENTS = (
313-
("\\", "\\\\"),
314-
("-", r"\-"),
315-
("^", r"\^"),
316-
("[", r"\["),
317-
("]", r"\]"),
318-
)
319-
320313
escaped_sql = delimiters_sql
321-
for raw, escaped in REGEX_ESCAPE_REPLACEMENTS:
314+
for raw, escaped in REGEX_ESCAPE_REPLACEMENTS.items():
322315
escaped_sql = self.func(
323316
"REPLACE",
324317
escaped_sql,
@@ -382,20 +375,13 @@ def _initcap_sql(self: DuckDB.Generator, expression: exp.Initcap) -> str:
382375
else delimiters_sql
383376
)
384377

385-
if delimiters and (isinstance(delimiters, exp.Literal) and delimiters.is_string):
386-
return (
387-
f"CASE WHEN {this_sql} IS NULL THEN NULL ELSE "
388-
f"{_build_capitalization_sql(self, this_sql, delimiters_sql, escaped_delimiters_sql)} END"
389-
)
390-
391-
capitalize_sql = _build_capitalization_sql(
378+
return _build_capitalization_sql(
392379
self,
393380
this_sql,
394381
delimiters_sql,
395382
escaped_delimiters_sql,
396383
convert_delim_to_regex=not isinstance(delimiters, exp.Null),
397384
)
398-
return f"CASE WHEN {this_sql} IS NULL OR {delimiters_sql} IS NULL THEN NULL ELSE {capitalize_sql} END"
399385

400386

401387
class DuckDB(Dialect):

sqlglot/generator.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5412,26 +5412,16 @@ def uuid_sql(self, expression: exp.Uuid) -> str:
54125412

54135413
def initcap_sql(self, expression: exp.Initcap) -> str:
54145414
delimiters = expression.expression
5415-
delimiters_sql = None
54165415

5417-
# do not generate delimiters arg if we are round-tripping from default delimiters
5418-
if (
5419-
delimiters
5420-
and delimiters.is_string
5421-
and delimiters.this == self.dialect.INITCAP_DEFAULT_DELIMITER_CHARS
5422-
):
5423-
delimiters_sql = ""
5424-
5425-
if (
5426-
delimiters
5427-
and delimiters_sql is None
5428-
and not self.dialect.INITCAP_SUPPORTS_CUSTOM_DELIMITERS
5429-
):
5430-
delimiters_sql = ""
5431-
self.unsupported("INITCAP does not support custom delimiters")
5432-
5433-
delimiters_sql = (
5434-
f", {self.sql(delimiters)}" if delimiters and delimiters_sql is None else delimiters_sql
5435-
)
5416+
if delimiters:
5417+
# do not generate delimiters arg if we are round-tripping from default delimiters
5418+
if (
5419+
delimiters.is_string
5420+
and delimiters.this == self.dialect.INITCAP_DEFAULT_DELIMITER_CHARS
5421+
):
5422+
delimiters = None
5423+
elif not self.dialect.INITCAP_SUPPORTS_CUSTOM_DELIMITERS:
5424+
self.unsupported("INITCAP does not support custom delimiters")
5425+
delimiters = None
54365426

5437-
return f"INITCAP({self.sql(expression, 'this')}{delimiters_sql})"
5427+
return self.func("INITCAP", expression.this, delimiters)

tests/dialects/test_dialect.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4352,13 +4352,13 @@ def duckdb_regex_literal_sql(delimiters: str) -> str:
43524352
with self.subTest(f"DuckDB rewrite for {dialect or 'default'} default delimiters"):
43534353
escaped_literal = duckdb_regex_literal_sql(default_delimiters)
43544354
expected = (
4355-
"CASE WHEN col IS NULL THEN NULL ELSE ARRAY_TO_STRING("
4355+
"ARRAY_TO_STRING("
43564356
f"CASE WHEN REGEXP_MATCHES(LEFT(col, 1), '[' || {escaped_literal} || ']') "
43574357
f"THEN LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || {escaped_literal} || ']+|[^' || {escaped_literal} || ']+)'), "
43584358
f"(seg, idx) -> CASE WHEN idx % 2 = 0 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
43594359
f"ELSE LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || {escaped_literal} || ']+|[^' || {escaped_literal} || ']+)'), "
43604360
f"(seg, idx) -> CASE WHEN idx % 2 = 1 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
4361-
"END, '') END"
4361+
"END, '')"
43624362
)
43634363
self.assertEqual(parse_one("INITCAP(col)", read=dialect).sql("duckdb"), expected)
43644364

@@ -4368,20 +4368,20 @@ def duckdb_regex_literal_sql(delimiters: str) -> str:
43684368
with self.subTest(f"Testing DuckDB generation for {query} from {dialect}"):
43694369
self.assertEqual(
43704370
parse_one(query, read=dialect).sql("duckdb"),
4371-
"CASE WHEN col IS NULL THEN NULL ELSE UPPER(LEFT(col, 1)) || LOWER(SUBSTR(col, 2)) END",
4371+
"UPPER(LEFT(col, 1)) || LOWER(SUBSTR(col, 2))",
43724372
)
43734373

43744374
query = "INITCAP(col, NULL)"
43754375
with self.subTest(f"DuckDB generation for {query} from {dialect}"):
43764376
self.assertEqual(
43774377
parse_one(query, read=dialect).sql("duckdb"),
4378-
"CASE WHEN col IS NULL OR NULL IS NULL THEN NULL ELSE ARRAY_TO_STRING("
4378+
"ARRAY_TO_STRING("
43794379
"CASE WHEN REGEXP_MATCHES(LEFT(col, 1), NULL) "
43804380
"THEN LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, NULL), "
43814381
"(seg, idx) -> CASE WHEN idx % 2 = 0 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
43824382
"ELSE LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, NULL), "
43834383
"(seg, idx) -> CASE WHEN idx % 2 = 1 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
4384-
"END, '') END",
4384+
"END, '')",
43854385
)
43864386

43874387
for custom_delimiter in (" ", "@", " _@", r"\\"):
@@ -4394,24 +4394,18 @@ def duckdb_regex_literal_sql(delimiters: str) -> str:
43944394
escaped_custom_delimiter = duckdb_regex_literal_sql(custom_delimiter)
43954395
self.assertEqual(
43964396
duckdb_sql,
4397-
"CASE WHEN col IS NULL THEN NULL ELSE ARRAY_TO_STRING("
4397+
"ARRAY_TO_STRING("
43984398
f"CASE WHEN REGEXP_MATCHES(LEFT(col, 1), '[' || {escaped_custom_delimiter} || ']') "
43994399
f"THEN LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || {escaped_custom_delimiter} || ']+|[^' || {escaped_custom_delimiter} || ']+)'), "
44004400
f"(seg, idx) -> CASE WHEN idx % 2 = 0 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
44014401
f"ELSE LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || {escaped_custom_delimiter} || ']+|[^' || {escaped_custom_delimiter} || ']+)'), "
44024402
f"(seg, idx) -> CASE WHEN idx % 2 = 1 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
4403-
"END, '') END",
4403+
"END, '')",
44044404
)
44054405

44064406
def escape_expression_sql(sql: str) -> str:
44074407
escaped_sql = sql
4408-
for raw, escaped in (
4409-
("\\", "\\\\"),
4410-
("-", r"\-"),
4411-
("^", r"\^"),
4412-
("[", r"\["),
4413-
("]", r"\]"),
4414-
):
4408+
for raw, escaped in REGEX_LITERAL_ESCAPES.items():
44154409
raw_sql = exp.Literal.string(raw).sql()
44164410
escaped_literal_sql = exp.Literal.string(escaped).sql()
44174411
escaped_sql = f"REPLACE({escaped_sql}, {raw_sql}, {escaped_literal_sql})"
@@ -4426,7 +4420,7 @@ def escape_expression_sql(sql: str) -> str:
44264420
parse_one(
44274421
"INITCAP(col, (SELECT delimiter FROM settings LIMIT 1))", read=dialect
44284422
).sql("duckdb"),
4429-
"CASE WHEN col IS NULL OR (SELECT delimiter FROM settings LIMIT 1) IS NULL THEN NULL ELSE ARRAY_TO_STRING("
4423+
"ARRAY_TO_STRING("
44304424
+ f"CASE WHEN REGEXP_MATCHES(LEFT(col, 1), '[' || {escaped_subquery} || ']') "
44314425
"THEN LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || "
44324426
+ escaped_subquery
@@ -4440,7 +4434,7 @@ def escape_expression_sql(sql: str) -> str:
44404434
+ escaped_subquery
44414435
+ " || ']+)'), "
44424436
"(seg, idx) -> CASE WHEN idx % 2 = 1 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
4443-
"END, '') END",
4437+
"END, '')",
44444438
)
44454439

44464440
def test_initcap_custom_delimiter_warning(self):

tests/dialects/test_hive.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -703,13 +703,13 @@ def duckdb_regex_literal_sql(delimiters: str) -> str:
703703
"INITCAP('new york')",
704704
write={
705705
"duckdb": (
706-
"CASE WHEN 'new york' IS NULL THEN NULL ELSE ARRAY_TO_STRING("
706+
"ARRAY_TO_STRING("
707707
f"CASE WHEN REGEXP_MATCHES(LEFT('new york', 1), '[' || {hive_escaped_delimiters} || ']') "
708708
f"THEN LIST_TRANSFORM(REGEXP_EXTRACT_ALL('new york', '([' || {hive_escaped_delimiters} || ']+|[^' || {hive_escaped_delimiters} || ']+)'), "
709709
"(seg, idx) -> CASE WHEN idx % 2 = 0 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
710710
f"ELSE LIST_TRANSFORM(REGEXP_EXTRACT_ALL('new york', '([' || {hive_escaped_delimiters} || ']+|[^' || {hive_escaped_delimiters} || ']+)'), "
711711
"(seg, idx) -> CASE WHEN idx % 2 = 1 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
712-
"END, '') END"
712+
"END, '')"
713713
),
714714
"hive": "INITCAP('new york')",
715715
"spark": "INITCAP('new york')",

0 commit comments

Comments
 (0)