Skip to content

Commit c69c96c

Browse files
committed
Handle escaping when converting delimiters to regex expr
1 parent d7a54f6 commit c69c96c

File tree

7 files changed

+175
-60
lines changed

7 files changed

+175
-60
lines changed

sqlglot/dialects/bigquery.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ class BigQuery(Dialect):
360360

361361
# https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#initcap
362362
INITCAP_SUPPORTS_CUSTOM_DELIMITERS = True
363-
INITCAP_DEFAULT_DELIMITER_CHARS = r' \t\n\r\f\v\[\](){}/|\<>!?@"^#$&~_,.:;*%+\-'
363+
INITCAP_DEFAULT_DELIMITER_CHARS = ' \t\n\r\f\v\\[\\](){}/|\<>!?@"^#$&~_,.:;*%+\\-'
364364

365365
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
366366
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE

sqlglot/dialects/dialect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ class Dialect(metaclass=_Dialect):
540540
# Whether the INITCAP function supports custom delimiter characters as the second argument
541541
# Default delimiter characters for INITCAP function: whitespace and non-alphanumeric characters
542542
INITCAP_SUPPORTS_CUSTOM_DELIMITERS = False
543-
INITCAP_DEFAULT_DELIMITER_CHARS = r" \t\n\r\f\v!\"#$%&'()*+,\-./:;<=>?@\[\\]^_`{|}~"
543+
INITCAP_DEFAULT_DELIMITER_CHARS = " \t\n\r\f\v!\"#$%&'()*+,\\-./:;<=>?@\\[\\]^_`{|}~"
544544

545545
BYTE_STRING_IS_BYTES_TYPE: bool = False
546546
"""

sqlglot/dialects/duckdb.py

Lines changed: 95 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -268,54 +268,110 @@ def _json_extract_value_array_sql(
268268
return self.sql(exp.cast(json_extract, to=exp.DataType.build(data_type)))
269269

270270

271-
def _initcap_sql(self: DuckDB.Generator, expression: exp.Initcap) -> str:
272-
def build_capitalize_sql(
273-
value_to_split: str, delimiters_sql: str, convert_delim_to_regex: bool = True
274-
) -> str:
275-
# empty string delimiter --> treat value as one word, no need to split
276-
if delimiters_sql == "''":
277-
return f"UPPER(LEFT({value_to_split}, 1)) || LOWER(SUBSTR({value_to_split}, 2))"
278-
279-
delim_regex_sql = delimiters_sql
280-
split_regex_sql = delimiters_sql
281-
if convert_delim_to_regex:
282-
delim_regex_sql = f"CONCAT('[', {delimiters_sql}, ']')"
283-
split_regex_sql = f"CONCAT('([', {delimiters_sql}, ']+|[^', {delimiters_sql}, ']+)')"
284-
285-
# REGEXP_EXTRACT_ALL produces a list of string segments, alternating between delimiter and non-delimiter segments.
286-
# We do not know whether the first segment is a delimiter or not, so we check the first character of the string
287-
# with REGEXP_MATCHES. If the first char is a delimiter, we capitalize even list indexes, otherwise capitalize odd.
288-
return self.func(
289-
"ARRAY_TO_STRING",
290-
exp.case()
291-
.when(
292-
f"REGEXP_MATCHES(LEFT({value_to_split}, 1), {delim_regex_sql})",
293-
self.func(
294-
"LIST_TRANSFORM",
295-
self.func("REGEXP_EXTRACT_ALL", value_to_split, split_regex_sql),
296-
"(seg, idx) -> CASE WHEN idx % 2 = 0 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTR(seg, 2)) ELSE seg END",
297-
),
298-
)
299-
.else_(
300-
self.func(
301-
"LIST_TRANSFORM",
302-
self.func("REGEXP_EXTRACT_ALL", value_to_split, split_regex_sql),
303-
"(seg, idx) -> CASE WHEN idx % 2 = 1 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTR(seg, 2)) ELSE seg END",
304-
),
271+
def _escape_regex_metachars(
272+
self: DuckDB.Generator, delimiters: t.Optional[exp.Expression], delimiters_sql: str
273+
) -> str:
274+
if not delimiters:
275+
return delimiters_sql
276+
277+
REGEX_LITERAL_ESCAPES = {
278+
"\\": "\\\\", # literals need two slashes inside []
279+
"-": "\\-",
280+
"^": "\\^",
281+
"[": "\\[",
282+
"]": "\\]",
283+
}
284+
285+
if isinstance(delimiters, exp.Literal) and delimiters.is_string:
286+
literal_value = delimiters.this
287+
escaped_literal = "".join(REGEX_LITERAL_ESCAPES.get(ch, ch) for ch in literal_value)
288+
return self.sql(exp.Literal.string(escaped_literal))
289+
290+
REGEX_ESCAPE_REPLACEMENTS = (
291+
("\\", "\\\\"),
292+
("-", r"\-"),
293+
("^", r"\^"),
294+
("[", r"\["),
295+
("]", r"\]"),
296+
)
297+
298+
escaped_sql = delimiters_sql
299+
for raw, escaped in REGEX_ESCAPE_REPLACEMENTS:
300+
escaped_sql = self.func(
301+
"REPLACE",
302+
escaped_sql,
303+
self.sql(exp.Literal.string(raw)),
304+
self.sql(exp.Literal.string(escaped)),
305+
)
306+
307+
return escaped_sql
308+
309+
310+
def _build_capitalization_sql(
311+
self: DuckDB.Generator,
312+
value_to_split: str,
313+
raw_delimiters_sql: str,
314+
escaped_delimiters_sql: t.Optional[str] = None,
315+
convert_delim_to_regex: bool = True,
316+
) -> str:
317+
# empty string delimiter --> treat value as one word, no need to split
318+
if raw_delimiters_sql == "''":
319+
return f"UPPER(LEFT({value_to_split}, 1)) || LOWER(SUBSTR({value_to_split}, 2))"
320+
321+
regex_ready_sql = escaped_delimiters_sql or raw_delimiters_sql
322+
delim_regex_sql = regex_ready_sql
323+
split_regex_sql = regex_ready_sql
324+
if convert_delim_to_regex:
325+
delim_regex_sql = f"CONCAT('[', {regex_ready_sql}, ']')"
326+
split_regex_sql = f"CONCAT('([', {regex_ready_sql}, ']+|[^', {regex_ready_sql}, ']+)')"
327+
328+
# REGEXP_EXTRACT_ALL produces a list of string segments, alternating between delimiter and non-delimiter segments.
329+
# We do not know whether the first segment is a delimiter or not, so we check the first character of the string
330+
# with REGEXP_MATCHES. If the first char is a delimiter, we capitalize even list indexes, otherwise capitalize odd.
331+
return self.func(
332+
"ARRAY_TO_STRING",
333+
exp.case()
334+
.when(
335+
f"REGEXP_MATCHES(LEFT({value_to_split}, 1), {delim_regex_sql})",
336+
self.func(
337+
"LIST_TRANSFORM",
338+
self.func("REGEXP_EXTRACT_ALL", value_to_split, split_regex_sql),
339+
"(seg, idx) -> CASE WHEN idx % 2 = 0 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTR(seg, 2)) ELSE seg END",
305340
),
306-
"''",
307341
)
342+
.else_(
343+
self.func(
344+
"LIST_TRANSFORM",
345+
self.func("REGEXP_EXTRACT_ALL", value_to_split, split_regex_sql),
346+
"(seg, idx) -> CASE WHEN idx % 2 = 1 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTR(seg, 2)) ELSE seg END",
347+
),
348+
),
349+
"''",
350+
)
308351

352+
353+
def _initcap_sql(self: DuckDB.Generator, expression: exp.Initcap) -> str:
309354
this_sql = self.sql(expression, "this")
310355
delimiters = expression.args.get("expression")
311356
delimiters_sql = self.sql(delimiters)
357+
escaped_delimiters_sql = (
358+
_escape_regex_metachars(self, delimiters, delimiters_sql)
359+
if not isinstance(delimiters, exp.Null)
360+
else delimiters_sql
361+
)
312362

313363
if delimiters and (isinstance(delimiters, exp.Literal) and delimiters.is_string):
314-
return f"CASE WHEN {this_sql} IS NULL THEN NULL ELSE {build_capitalize_sql(this_sql, delimiters_sql)} END"
364+
return (
365+
f"CASE WHEN {this_sql} IS NULL THEN NULL ELSE "
366+
f"{_build_capitalization_sql(self, this_sql, delimiters_sql, escaped_delimiters_sql)} END"
367+
)
315368

316-
# delimiters arg is SQL expression or NULL
317-
capitalize_sql = build_capitalize_sql(
318-
this_sql, delimiters_sql, convert_delim_to_regex=not isinstance(delimiters, exp.Null)
369+
capitalize_sql = _build_capitalization_sql(
370+
self,
371+
this_sql,
372+
delimiters_sql,
373+
escaped_delimiters_sql,
374+
convert_delim_to_regex=not isinstance(delimiters, exp.Null),
319375
)
320376
return f"CASE WHEN {this_sql} IS NULL OR {delimiters_sql} IS NULL THEN NULL ELSE {capitalize_sql} END"
321377

sqlglot/dialects/snowflake.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ class Snowflake(Dialect):
558558

559559
# https://docs.snowflake.com/en/en/sql-reference/functions/initcap
560560
INITCAP_SUPPORTS_CUSTOM_DELIMITERS = True
561-
INITCAP_DEFAULT_DELIMITER_CHARS = r' \t\n\r\f\v!?@"^#$&~_,.:;+\-*%/|\[\](){}<>'
561+
INITCAP_DEFAULT_DELIMITER_CHARS = ' \t\n\r\f\v!?@"^#$&~_,.:;+\\-*%/|\\[\\](){}<>'
562562

563563
TIME_MAPPING = {
564564
"YYYY": "%Y",

sqlglot/dialects/spark2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class Spark2(Hive):
121121

122122
# https://spark.apache.org/docs/latest/api/sql/index.html#initcap
123123
# https://docs.databricks.com/aws/en/sql/language-manual/functions/initcap
124-
INITCAP_DEFAULT_DELIMITER_CHARS = r" \t\n\r\f\v"
124+
INITCAP_DEFAULT_DELIMITER_CHARS = " \t\n\r\f\v"
125125

126126
class Tokenizer(Hive.Tokenizer):
127127
HEX_STRINGS = [("X'", "'"), ("x'", "'")]

tests/dialects/test_dialect.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4313,6 +4313,18 @@ def test_initcap(self):
43134313
"spark": Spark2.INITCAP_DEFAULT_DELIMITER_CHARS,
43144314
}
43154315

4316+
REGEX_LITERAL_ESCAPES = {
4317+
"\\": "\\\\",
4318+
"-": "\\-",
4319+
"^": "\\^",
4320+
"[": "\\[",
4321+
"]": "\\]",
4322+
}
4323+
4324+
def duckdb_regex_literal_sql(delimiters: str) -> str:
4325+
escaped_literal = "".join(REGEX_LITERAL_ESCAPES.get(ch, ch) for ch in delimiters)
4326+
return exp.Literal.string(escaped_literal).sql("duckdb")
4327+
43164328
# default delimiters not present in roundtrip
43174329
for dialect in delimiter_chars.keys():
43184330
with self.subTest(
@@ -4338,13 +4350,13 @@ def test_initcap(self):
43384350

43394351
for dialect, default_delimiters in delimiter_chars.items():
43404352
with self.subTest(f"DuckDB rewrite for {dialect or 'default'} default delimiters"):
4341-
literal = exp.Literal.string(default_delimiters).sql()
4353+
escaped_literal = duckdb_regex_literal_sql(default_delimiters)
43424354
expected = (
43434355
"CASE WHEN col IS NULL THEN NULL ELSE ARRAY_TO_STRING("
4344-
f"CASE WHEN REGEXP_MATCHES(LEFT(col, 1), '[' || {literal} || ']') "
4345-
f"THEN LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || {literal} || ']+|[^' || {literal} || ']+)'), "
4356+
f"CASE WHEN REGEXP_MATCHES(LEFT(col, 1), '[' || {escaped_literal} || ']') "
4357+
f"THEN LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || {escaped_literal} || ']+|[^' || {escaped_literal} || ']+)'), "
43464358
f"(seg, idx) -> CASE WHEN idx % 2 = 0 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
4347-
f"ELSE LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || {literal} || ']+|[^' || {literal} || ']+)'), "
4359+
f"ELSE LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || {escaped_literal} || ']+|[^' || {escaped_literal} || ']+)'), "
43484360
f"(seg, idx) -> CASE WHEN idx % 2 = 1 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
43494361
"END, '') END"
43504362
)
@@ -4372,35 +4384,61 @@ def test_initcap(self):
43724384
"END, '') END",
43734385
)
43744386

4375-
for custom_delimiter in (" ", "@", " _@"):
4387+
for custom_delimiter in (" ", "@", " _@", r"\\"):
43764388
with self.subTest(
43774389
f"DuckDB generation for INITCAP(col, {custom_delimiter}) from {dialect}"
43784390
):
4391+
literal_sql = exp.Literal.string(custom_delimiter).sql(dialect)
4392+
expression = parse_one(f"INITCAP(col, {literal_sql})", read=dialect)
4393+
duckdb_sql = expression.sql("duckdb")
4394+
escaped_custom_delimiter = duckdb_regex_literal_sql(custom_delimiter)
43794395
self.assertEqual(
4380-
parse_one(f"INITCAP(col, '{custom_delimiter}')", read=dialect).sql(
4381-
"duckdb"
4382-
),
4396+
duckdb_sql,
43834397
"CASE WHEN col IS NULL THEN NULL ELSE ARRAY_TO_STRING("
4384-
f"CASE WHEN REGEXP_MATCHES(LEFT(col, 1), '[' || '{custom_delimiter}' || ']') "
4385-
f"THEN LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || '{custom_delimiter}' || ']+|[^' || '{custom_delimiter}' || ']+)'), "
4398+
f"CASE WHEN REGEXP_MATCHES(LEFT(col, 1), '[' || {escaped_custom_delimiter} || ']') "
4399+
f"THEN LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || {escaped_custom_delimiter} || ']+|[^' || {escaped_custom_delimiter} || ']+)'), "
43864400
f"(seg, idx) -> CASE WHEN idx % 2 = 0 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
4387-
f"ELSE LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || '{custom_delimiter}' || ']+|[^' || '{custom_delimiter}' || ']+)'), "
4401+
f"ELSE LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || {escaped_custom_delimiter} || ']+|[^' || {escaped_custom_delimiter} || ']+)'), "
43884402
f"(seg, idx) -> CASE WHEN idx % 2 = 1 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
43894403
"END, '') END",
43904404
)
43914405

4406+
def escape_expression_sql(sql: str) -> str:
4407+
escaped_sql = sql
4408+
for raw, escaped in (
4409+
("\\", "\\\\"),
4410+
("-", r"\-"),
4411+
("^", r"\^"),
4412+
("[", r"\["),
4413+
("]", r"\]"),
4414+
):
4415+
raw_sql = exp.Literal.string(raw).sql()
4416+
escaped_literal_sql = exp.Literal.string(escaped).sql()
4417+
escaped_sql = f"REPLACE({escaped_sql}, {raw_sql}, {escaped_literal_sql})"
4418+
4419+
return escaped_sql
4420+
43924421
with self.subTest(
43934422
f"DuckDB generation for INITCAP subquery as custom delimiter arg from {dialect}"
43944423
):
4424+
escaped_subquery = escape_expression_sql("(SELECT delimiter FROM settings LIMIT 1)")
43954425
self.assertEqual(
43964426
parse_one(
43974427
"INITCAP(col, (SELECT delimiter FROM settings LIMIT 1))", read=dialect
43984428
).sql("duckdb"),
43994429
"CASE WHEN col IS NULL OR (SELECT delimiter FROM settings LIMIT 1) IS NULL THEN NULL ELSE ARRAY_TO_STRING("
4400-
+ "CASE WHEN REGEXP_MATCHES(LEFT(col, 1), '[' || (SELECT delimiter FROM settings LIMIT 1) || ']') "
4401-
"THEN LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || (SELECT delimiter FROM settings LIMIT 1) || ']+|[^' || (SELECT delimiter FROM settings LIMIT 1) || ']+)'), "
4430+
+ f"CASE WHEN REGEXP_MATCHES(LEFT(col, 1), '[' || {escaped_subquery} || ']') "
4431+
"THEN LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || "
4432+
+ escaped_subquery
4433+
+ " || ']+|[^' || "
4434+
+ escaped_subquery
4435+
+ " || ']+)'), "
44024436
"(seg, idx) -> CASE WHEN idx % 2 = 0 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
4403-
"ELSE LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || (SELECT delimiter FROM settings LIMIT 1) || ']+|[^' || (SELECT delimiter FROM settings LIMIT 1) || ']+)'), "
4437+
"ELSE LIST_TRANSFORM(REGEXP_EXTRACT_ALL(col, '([' || "
4438+
+ escaped_subquery
4439+
+ " || ']+|[^' || "
4440+
+ escaped_subquery
4441+
+ " || ']+)'), "
44044442
"(seg, idx) -> CASE WHEN idx % 2 = 1 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
44054443
"END, '') END",
44064444
)

tests/dialects/test_hive.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from tests.dialects.test_dialect import Validator
2-
32
from sqlglot import exp
3+
from sqlglot.dialects import Hive
44

55

66
class TestHive(Validator):
@@ -685,11 +685,32 @@ def test_hive(self):
685685
"spark": "LOCATE('a', x, 3)",
686686
},
687687
)
688+
689+
REGEX_LITERAL_ESCAPES = {
690+
"\\": "\\\\",
691+
"-": "\\-",
692+
"^": "\\^",
693+
"[": "\\[",
694+
"]": "\\]",
695+
}
696+
697+
def duckdb_regex_literal_sql(delimiters: str) -> str:
698+
escaped_literal = "".join(REGEX_LITERAL_ESCAPES.get(ch, ch) for ch in delimiters)
699+
return exp.Literal.string(escaped_literal).sql("duckdb")
700+
701+
hive_escaped_delimiters = duckdb_regex_literal_sql(Hive.INITCAP_DEFAULT_DELIMITER_CHARS)
688702
self.validate_all(
689703
"INITCAP('new york')",
690704
write={
691-
"duckdb": r"CASE WHEN 'new york' IS NULL THEN NULL ELSE ARRAY_TO_STRING(CASE WHEN REGEXP_MATCHES(LEFT('new york', 1), '[' || ' \t\n\r\f\v!\"#$%&''()*+,\-./:;<=>?@\[\\]^_`{|}~' || ']') THEN LIST_TRANSFORM(REGEXP_EXTRACT_ALL('new york', '([' || ' \t\n\r\f\v!\"#$%&''()*+,\-./:;<=>?@\[\\]^_`{|}~' || ']+|[^' || ' \t\n\r\f\v!\"#$%&''()*+,\-./:;<=>?@\[\\]^_`{|}~' || ']+)'), (seg, idx) -> CASE WHEN idx % 2 = 0 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) ELSE LIST_TRANSFORM(REGEXP_EXTRACT_ALL('new york', '([' || ' \t\n\r\f\v!\"#$%&''()*+,\-./:;<=>?@\[\\]^_`{|}~' || ']+|[^' || ' \t\n\r\f\v!\"#$%&''()*+,\-./:;<=>?@\[\\]^_`{|}~' || ']+)'), (seg, idx) -> CASE WHEN idx % 2 = 1 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) END, '') END",
692-
"presto": r"REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))",
705+
"duckdb": (
706+
"CASE WHEN 'new york' IS NULL THEN NULL ELSE ARRAY_TO_STRING("
707+
f"CASE WHEN REGEXP_MATCHES(LEFT('new york', 1), '[' || {hive_escaped_delimiters} || ']') "
708+
f"THEN LIST_TRANSFORM(REGEXP_EXTRACT_ALL('new york', '([' || {hive_escaped_delimiters} || ']+|[^' || {hive_escaped_delimiters} || ']+)'), "
709+
"(seg, idx) -> CASE WHEN idx % 2 = 0 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
710+
f"ELSE LIST_TRANSFORM(REGEXP_EXTRACT_ALL('new york', '([' || {hive_escaped_delimiters} || ']+|[^' || {hive_escaped_delimiters} || ']+)'), "
711+
"(seg, idx) -> CASE WHEN idx % 2 = 1 THEN UPPER(LEFT(seg, 1)) || LOWER(SUBSTRING(seg, 2)) ELSE seg END) "
712+
"END, '') END"
713+
),
693714
"hive": "INITCAP('new york')",
694715
"spark": "INITCAP('new york')",
695716
},

0 commit comments

Comments
 (0)