Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5from functools import reduce
  6
  7from sqlglot import exp
  8from sqlglot._typing import E
  9from sqlglot.errors import ParseError
 10from sqlglot.generator import Generator
 11from sqlglot.helper import flatten, seq_get
 12from sqlglot.parser import Parser
 13from sqlglot.time import format_time
 14from sqlglot.tokens import Token, Tokenizer, TokenType
 15from sqlglot.trie import new_trie
 16
 17B = t.TypeVar("B", bound=exp.Binary)
 18
 19
 20class Dialects(str, Enum):
 21    DIALECT = ""
 22
 23    BIGQUERY = "bigquery"
 24    CLICKHOUSE = "clickhouse"
 25    DATABRICKS = "databricks"
 26    DRILL = "drill"
 27    DUCKDB = "duckdb"
 28    HIVE = "hive"
 29    MYSQL = "mysql"
 30    ORACLE = "oracle"
 31    POSTGRES = "postgres"
 32    PRESTO = "presto"
 33    REDSHIFT = "redshift"
 34    SNOWFLAKE = "snowflake"
 35    SPARK = "spark"
 36    SPARK2 = "spark2"
 37    SQLITE = "sqlite"
 38    STARROCKS = "starrocks"
 39    TABLEAU = "tableau"
 40    TERADATA = "teradata"
 41    TRINO = "trino"
 42    TSQL = "tsql"
 43    Doris = "doris"
 44
 45
 46class _Dialect(type):
 47    classes: t.Dict[str, t.Type[Dialect]] = {}
 48
 49    def __eq__(cls, other: t.Any) -> bool:
 50        if cls is other:
 51            return True
 52        if isinstance(other, str):
 53            return cls is cls.get(other)
 54        if isinstance(other, Dialect):
 55            return cls is type(other)
 56
 57        return False
 58
 59    def __hash__(cls) -> int:
 60        return hash(cls.__name__.lower())
 61
 62    @classmethod
 63    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 64        return cls.classes[key]
 65
 66    @classmethod
 67    def get(
 68        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 69    ) -> t.Optional[t.Type[Dialect]]:
 70        return cls.classes.get(key, default)
 71
 72    def __new__(cls, clsname, bases, attrs):
 73        klass = super().__new__(cls, clsname, bases, attrs)
 74        enum = Dialects.__members__.get(clsname.upper())
 75        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 76
 77        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 78        klass.FORMAT_TRIE = (
 79            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 80        )
 81        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 82        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 83
 84        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 85        klass.parser_class = getattr(klass, "Parser", Parser)
 86        klass.generator_class = getattr(klass, "Generator", Generator)
 87
 88        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 89        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 90            klass.tokenizer_class._IDENTIFIERS.items()
 91        )[0]
 92
 93        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 94            return next(
 95                (
 96                    (s, e)
 97                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 98                    if t == token_type
 99                ),
100                (None, None),
101            )
102
103        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
104        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
105        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
106
107        dialect_properties = {
108            **{
109                k: v
110                for k, v in vars(klass).items()
111                if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
112            },
113            "TOKENIZER_CLASS": klass.tokenizer_class,
114        }
115
116        if enum not in ("", "bigquery"):
117            dialect_properties["SELECT_KINDS"] = ()
118
119        # Pass required dialect properties to the tokenizer, parser and generator classes
120        for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
121            for name, value in dialect_properties.items():
122                if hasattr(subclass, name):
123                    setattr(subclass, name, value)
124
125        if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
126            klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
127
128        klass.generator_class.can_identify = klass.can_identify
129
130        return klass
131
132
133class Dialect(metaclass=_Dialect):
134    # Determines the base index offset for arrays
135    INDEX_OFFSET = 0
136
137    # If true unnest table aliases are considered only as column aliases
138    UNNEST_COLUMN_ONLY = False
139
140    # Determines whether or not the table alias comes after tablesample
141    ALIAS_POST_TABLESAMPLE = False
142
143    # Determines whether or not unquoted identifiers are resolved as uppercase
144    # When set to None, it means that the dialect treats all identifiers as case-insensitive
145    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
146
147    # Determines whether or not an unquoted identifier can start with a digit
148    IDENTIFIERS_CAN_START_WITH_DIGIT = False
149
150    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
151    DPIPE_IS_STRING_CONCAT = True
152
153    # Determines whether or not CONCAT's arguments must be strings
154    STRICT_STRING_CONCAT = False
155
156    # Determines whether or not user-defined data types are supported
157    SUPPORTS_USER_DEFINED_TYPES = True
158
159    # Determines how function names are going to be normalized
160    NORMALIZE_FUNCTIONS: bool | str = "upper"
161
162    # Indicates the default null ordering method to use if not explicitly set
163    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
164    NULL_ORDERING = "nulls_are_small"
165
166    DATE_FORMAT = "'%Y-%m-%d'"
167    DATEINT_FORMAT = "'%Y%m%d'"
168    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
169
170    # Custom time mappings in which the key represents dialect time format
171    # and the value represents a python time format
172    TIME_MAPPING: t.Dict[str, str] = {}
173
174    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
175    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
176    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
177    FORMAT_MAPPING: t.Dict[str, str] = {}
178
179    # Columns that are auto-generated by the engine corresponding to this dialect
180    # Such columns may be excluded from SELECT * queries, for example
181    PSEUDOCOLUMNS: t.Set[str] = set()
182
183    # Autofilled
184    tokenizer_class = Tokenizer
185    parser_class = Parser
186    generator_class = Generator
187
188    # A trie of the time_mapping keys
189    TIME_TRIE: t.Dict = {}
190    FORMAT_TRIE: t.Dict = {}
191
192    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
193    INVERSE_TIME_TRIE: t.Dict = {}
194
195    def __eq__(self, other: t.Any) -> bool:
196        return type(self) == other
197
198    def __hash__(self) -> int:
199        return hash(type(self))
200
201    @classmethod
202    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
203        if not dialect:
204            return cls
205        if isinstance(dialect, _Dialect):
206            return dialect
207        if isinstance(dialect, Dialect):
208            return dialect.__class__
209
210        result = cls.get(dialect)
211        if not result:
212            raise ValueError(f"Unknown dialect '{dialect}'")
213
214        return result
215
216    @classmethod
217    def format_time(
218        cls, expression: t.Optional[str | exp.Expression]
219    ) -> t.Optional[exp.Expression]:
220        if isinstance(expression, str):
221            return exp.Literal.string(
222                # the time formats are quoted
223                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
224            )
225
226        if expression and expression.is_string:
227            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
228
229        return expression
230
231    @classmethod
232    def normalize_identifier(cls, expression: E) -> E:
233        """
234        Normalizes an unquoted identifier to either lower or upper case, thus essentially
235        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
236        they will be normalized regardless of being quoted or not.
237        """
238        if isinstance(expression, exp.Identifier) and (
239            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
240        ):
241            expression.set(
242                "this",
243                expression.this.upper()
244                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
245                else expression.this.lower(),
246            )
247
248        return expression
249
250    @classmethod
251    def case_sensitive(cls, text: str) -> bool:
252        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
253        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
254            return False
255
256        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
257        return any(unsafe(char) for char in text)
258
259    @classmethod
260    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
261        """Checks if text can be identified given an identify option.
262
263        Args:
264            text: The text to check.
265            identify:
266                "always" or `True`: Always returns true.
267                "safe": True if the identifier is case-insensitive.
268
269        Returns:
270            Whether or not the given text can be identified.
271        """
272        if identify is True or identify == "always":
273            return True
274
275        if identify == "safe":
276            return not cls.case_sensitive(text)
277
278        return False
279
280    @classmethod
281    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
282        if isinstance(expression, exp.Identifier):
283            name = expression.this
284            expression.set(
285                "quoted",
286                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
287            )
288
289        return expression
290
291    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
292        return self.parser(**opts).parse(self.tokenize(sql), sql)
293
294    def parse_into(
295        self, expression_type: exp.IntoType, sql: str, **opts
296    ) -> t.List[t.Optional[exp.Expression]]:
297        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
298
299    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
300        return self.generator(**opts).generate(expression)
301
302    def transpile(self, sql: str, **opts) -> t.List[str]:
303        return [self.generate(expression, **opts) for expression in self.parse(sql)]
304
305    def tokenize(self, sql: str) -> t.List[Token]:
306        return self.tokenizer.tokenize(sql)
307
308    @property
309    def tokenizer(self) -> Tokenizer:
310        if not hasattr(self, "_tokenizer"):
311            self._tokenizer = self.tokenizer_class()
312        return self._tokenizer
313
314    def parser(self, **opts) -> Parser:
315        return self.parser_class(**opts)
316
317    def generator(self, **opts) -> Generator:
318        return self.generator_class(**opts)
319
320
321DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
322
323
324def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
325    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
326
327
328def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
329    if expression.args.get("accuracy"):
330        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
331    return self.func("APPROX_COUNT_DISTINCT", expression.this)
332
333
334def if_sql(self: Generator, expression: exp.If) -> str:
335    return self.func(
336        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
337    )
338
339
340def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
341    return self.binary(expression, "->")
342
343
344def arrow_json_extract_scalar_sql(
345    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
346) -> str:
347    return self.binary(expression, "->>")
348
349
350def inline_array_sql(self: Generator, expression: exp.Array) -> str:
351    return f"[{self.expressions(expression, flat=True)}]"
352
353
354def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
355    return self.like_sql(
356        exp.Like(
357            this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
358        )
359    )
360
361
362def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
363    zone = self.sql(expression, "this")
364    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
365
366
367def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
368    if expression.args.get("recursive"):
369        self.unsupported("Recursive CTEs are unsupported")
370        expression.args["recursive"] = False
371    return self.with_sql(expression)
372
373
374def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
375    n = self.sql(expression, "this")
376    d = self.sql(expression, "expression")
377    return f"IF({d} <> 0, {n} / {d}, NULL)"
378
379
380def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
381    self.unsupported("TABLESAMPLE unsupported")
382    return self.sql(expression.this)
383
384
385def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
386    self.unsupported("PIVOT unsupported")
387    return ""
388
389
390def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
391    return self.cast_sql(expression)
392
393
394def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
395    self.unsupported("Properties unsupported")
396    return ""
397
398
399def no_comment_column_constraint_sql(
400    self: Generator, expression: exp.CommentColumnConstraint
401) -> str:
402    self.unsupported("CommentColumnConstraint unsupported")
403    return ""
404
405
406def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
407    self.unsupported("MAP_FROM_ENTRIES unsupported")
408    return ""
409
410
411def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
412    this = self.sql(expression, "this")
413    substr = self.sql(expression, "substr")
414    position = self.sql(expression, "position")
415    if position:
416        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
417    return f"STRPOS({this}, {substr})"
418
419
420def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
421    return (
422        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
423    )
424
425
426def var_map_sql(
427    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
428) -> str:
429    keys = expression.args["keys"]
430    values = expression.args["values"]
431
432    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
433        self.unsupported("Cannot convert array columns into map.")
434        return self.func(map_func_name, keys, values)
435
436    args = []
437    for key, value in zip(keys.expressions, values.expressions):
438        args.append(self.sql(key))
439        args.append(self.sql(value))
440
441    return self.func(map_func_name, *args)
442
443
444def format_time_lambda(
445    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
446) -> t.Callable[[t.List], E]:
447    """Helper used for time expressions.
448
449    Args:
450        exp_class: the expression class to instantiate.
451        dialect: target sql dialect.
452        default: the default format, True being time.
453
454    Returns:
455        A callable that can be used to return the appropriately formatted time expression.
456    """
457
458    def _format_time(args: t.List):
459        return exp_class(
460            this=seq_get(args, 0),
461            format=Dialect[dialect].format_time(
462                seq_get(args, 1)
463                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
464            ),
465        )
466
467    return _format_time
468
469
470def time_format(
471    dialect: DialectType = None,
472) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
473    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
474        """
475        Returns the time format for a given expression, unless it's equivalent
476        to the default time format of the dialect of interest.
477        """
478        time_format = self.format_time(expression)
479        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
480
481    return _time_format
482
483
484def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
485    """
486    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
487    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
488    columns are removed from the create statement.
489    """
490    has_schema = isinstance(expression.this, exp.Schema)
491    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
492
493    if has_schema and is_partitionable:
494        expression = expression.copy()
495        prop = expression.find(exp.PartitionedByProperty)
496        if prop and prop.this and not isinstance(prop.this, exp.Schema):
497            schema = expression.this
498            columns = {v.name.upper() for v in prop.this.expressions}
499            partitions = [col for col in schema.expressions if col.name.upper() in columns]
500            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
501            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
502            expression.set("this", schema)
503
504    return self.create_sql(expression)
505
506
507def parse_date_delta(
508    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
509) -> t.Callable[[t.List], E]:
510    def inner_func(args: t.List) -> E:
511        unit_based = len(args) == 3
512        this = args[2] if unit_based else seq_get(args, 0)
513        unit = args[0] if unit_based else exp.Literal.string("DAY")
514        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
515        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
516
517    return inner_func
518
519
520def parse_date_delta_with_interval(
521    expression_class: t.Type[E],
522) -> t.Callable[[t.List], t.Optional[E]]:
523    def func(args: t.List) -> t.Optional[E]:
524        if len(args) < 2:
525            return None
526
527        interval = args[1]
528
529        if not isinstance(interval, exp.Interval):
530            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
531
532        expression = interval.this
533        if expression and expression.is_string:
534            expression = exp.Literal.number(expression.this)
535
536        return expression_class(
537            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
538        )
539
540    return func
541
542
543def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
544    unit = seq_get(args, 0)
545    this = seq_get(args, 1)
546
547    if isinstance(this, exp.Cast) and this.is_type("date"):
548        return exp.DateTrunc(unit=unit, this=this)
549    return exp.TimestampTrunc(this=this, unit=unit)
550
551
552def date_add_interval_sql(
553    data_type: str, kind: str
554) -> t.Callable[[Generator, exp.Expression], str]:
555    def func(self: Generator, expression: exp.Expression) -> str:
556        this = self.sql(expression, "this")
557        unit = expression.args.get("unit")
558        unit = exp.var(unit.name.upper() if unit else "DAY")
559        interval = exp.Interval(this=expression.expression.copy(), unit=unit)
560        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
561
562    return func
563
564
565def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
566    return self.func(
567        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
568    )
569
570
571def locate_to_strposition(args: t.List) -> exp.Expression:
572    return exp.StrPosition(
573        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
574    )
575
576
577def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
578    return self.func(
579        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
580    )
581
582
583def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
584    expression = expression.copy()
585    return self.sql(
586        exp.Substring(
587            this=expression.this, start=exp.Literal.number(1), length=expression.expression
588        )
589    )
590
591
592def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
593    expression = expression.copy()
594    return self.sql(
595        exp.Substring(
596            this=expression.this,
597            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
598        )
599    )
600
601
602def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
603    return self.sql(exp.cast(expression.this, "timestamp"))
604
605
606def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
607    return self.sql(exp.cast(expression.this, "date"))
608
609
610# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
611def encode_decode_sql(
612    self: Generator, expression: exp.Expression, name: str, replace: bool = True
613) -> str:
614    charset = expression.args.get("charset")
615    if charset and charset.name.lower() != "utf-8":
616        self.unsupported(f"Expected utf-8 character set, got {charset}.")
617
618    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
619
620
621def min_or_least(self: Generator, expression: exp.Min) -> str:
622    name = "LEAST" if expression.expressions else "MIN"
623    return rename_func(name)(self, expression)
624
625
626def max_or_greatest(self: Generator, expression: exp.Max) -> str:
627    name = "GREATEST" if expression.expressions else "MAX"
628    return rename_func(name)(self, expression)
629
630
631def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
632    cond = expression.this
633
634    if isinstance(expression.this, exp.Distinct):
635        cond = expression.this.expressions[0]
636        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
637
638    return self.func("sum", exp.func("if", cond.copy(), 1, 0))
639
640
641def trim_sql(self: Generator, expression: exp.Trim) -> str:
642    target = self.sql(expression, "this")
643    trim_type = self.sql(expression, "position")
644    remove_chars = self.sql(expression, "expression")
645    collation = self.sql(expression, "collation")
646
647    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
648    if not remove_chars and not collation:
649        return self.trim_sql(expression)
650
651    trim_type = f"{trim_type} " if trim_type else ""
652    remove_chars = f"{remove_chars} " if remove_chars else ""
653    from_part = "FROM " if trim_type or remove_chars else ""
654    collation = f" COLLATE {collation}" if collation else ""
655    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
656
657
658def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
659    return self.func("STRPTIME", expression.this, self.format_time(expression))
660
661
662def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
663    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
664        _dialect = Dialect.get_or_raise(dialect)
665        time_format = self.format_time(expression)
666        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
667            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
668
669        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
670
671    return _ts_or_ds_to_date_sql
672
673
674def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
675    expression = expression.copy()
676    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
677
678
679def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
680    expression = expression.copy()
681    delim, *rest_args = expression.expressions
682    return self.sql(
683        reduce(
684            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
685            rest_args,
686        )
687    )
688
689
690def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
691    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
692    if bad_args:
693        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
694
695    return self.func(
696        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
697    )
698
699
700def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
701    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
702    if bad_args:
703        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
704
705    return self.func(
706        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
707    )
708
709
710def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
711    names = []
712    for agg in aggregations:
713        if isinstance(agg, exp.Alias):
714            names.append(agg.alias)
715        else:
716            """
717            This case corresponds to aggregations without aliases being used as suffixes
718            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
719            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
720            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
721            """
722            agg_all_unquoted = agg.transform(
723                lambda node: exp.Identifier(this=node.name, quoted=False)
724                if isinstance(node, exp.Identifier)
725                else node
726            )
727            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
728
729    return names
730
731
732def simplify_literal(expression: E) -> E:
733    if not isinstance(expression.expression, exp.Literal):
734        from sqlglot.optimizer.simplify import simplify
735
736        simplify(expression.expression)
737
738    return expression
739
740
741def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
742    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
743
744
745# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
746def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
747    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
748
749
750def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
751    return self.func("MAX", expression.this)
752
753
754# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
755def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
756    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
757
758
759def is_parse_json(expression: exp.Expression) -> bool:
760    return isinstance(expression, exp.ParseJSON) or (
761        isinstance(expression, exp.Cast) and expression.is_type("json")
762    )
763
764
765def isnull_to_is_null(args: t.List) -> exp.Expression:
766    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
class Dialects(builtins.str, enum.Enum):
21class Dialects(str, Enum):
22    DIALECT = ""
23
24    BIGQUERY = "bigquery"
25    CLICKHOUSE = "clickhouse"
26    DATABRICKS = "databricks"
27    DRILL = "drill"
28    DUCKDB = "duckdb"
29    HIVE = "hive"
30    MYSQL = "mysql"
31    ORACLE = "oracle"
32    POSTGRES = "postgres"
33    PRESTO = "presto"
34    REDSHIFT = "redshift"
35    SNOWFLAKE = "snowflake"
36    SPARK = "spark"
37    SPARK2 = "spark2"
38    SQLITE = "sqlite"
39    STARROCKS = "starrocks"
40    TABLEAU = "tableau"
41    TERADATA = "teradata"
42    TRINO = "trino"
43    TSQL = "tsql"
44    Doris = "doris"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Doris = <Dialects.Doris: 'doris'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
134class Dialect(metaclass=_Dialect):
135    # Determines the base index offset for arrays
136    INDEX_OFFSET = 0
137
138    # If true unnest table aliases are considered only as column aliases
139    UNNEST_COLUMN_ONLY = False
140
141    # Determines whether or not the table alias comes after tablesample
142    ALIAS_POST_TABLESAMPLE = False
143
144    # Determines whether or not unquoted identifiers are resolved as uppercase
145    # When set to None, it means that the dialect treats all identifiers as case-insensitive
146    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
147
148    # Determines whether or not an unquoted identifier can start with a digit
149    IDENTIFIERS_CAN_START_WITH_DIGIT = False
150
151    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
152    DPIPE_IS_STRING_CONCAT = True
153
154    # Determines whether or not CONCAT's arguments must be strings
155    STRICT_STRING_CONCAT = False
156
157    # Determines whether or not user-defined data types are supported
158    SUPPORTS_USER_DEFINED_TYPES = True
159
160    # Determines how function names are going to be normalized
161    NORMALIZE_FUNCTIONS: bool | str = "upper"
162
163    # Indicates the default null ordering method to use if not explicitly set
164    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
165    NULL_ORDERING = "nulls_are_small"
166
167    DATE_FORMAT = "'%Y-%m-%d'"
168    DATEINT_FORMAT = "'%Y%m%d'"
169    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
170
171    # Custom time mappings in which the key represents dialect time format
172    # and the value represents a python time format
173    TIME_MAPPING: t.Dict[str, str] = {}
174
175    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
176    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
177    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
178    FORMAT_MAPPING: t.Dict[str, str] = {}
179
180    # Columns that are auto-generated by the engine corresponding to this dialect
181    # Such columns may be excluded from SELECT * queries, for example
182    PSEUDOCOLUMNS: t.Set[str] = set()
183
184    # Autofilled
185    tokenizer_class = Tokenizer
186    parser_class = Parser
187    generator_class = Generator
188
189    # A trie of the time_mapping keys
190    TIME_TRIE: t.Dict = {}
191    FORMAT_TRIE: t.Dict = {}
192
193    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
194    INVERSE_TIME_TRIE: t.Dict = {}
195
196    def __eq__(self, other: t.Any) -> bool:
197        return type(self) == other
198
199    def __hash__(self) -> int:
200        return hash(type(self))
201
202    @classmethod
203    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
204        if not dialect:
205            return cls
206        if isinstance(dialect, _Dialect):
207            return dialect
208        if isinstance(dialect, Dialect):
209            return dialect.__class__
210
211        result = cls.get(dialect)
212        if not result:
213            raise ValueError(f"Unknown dialect '{dialect}'")
214
215        return result
216
217    @classmethod
218    def format_time(
219        cls, expression: t.Optional[str | exp.Expression]
220    ) -> t.Optional[exp.Expression]:
221        if isinstance(expression, str):
222            return exp.Literal.string(
223                # the time formats are quoted
224                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
225            )
226
227        if expression and expression.is_string:
228            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
229
230        return expression
231
232    @classmethod
233    def normalize_identifier(cls, expression: E) -> E:
234        """
235        Normalizes an unquoted identifier to either lower or upper case, thus essentially
236        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
237        they will be normalized regardless of being quoted or not.
238        """
239        if isinstance(expression, exp.Identifier) and (
240            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
241        ):
242            expression.set(
243                "this",
244                expression.this.upper()
245                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
246                else expression.this.lower(),
247            )
248
249        return expression
250
251    @classmethod
252    def case_sensitive(cls, text: str) -> bool:
253        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
254        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
255            return False
256
257        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
258        return any(unsafe(char) for char in text)
259
260    @classmethod
261    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
262        """Checks if text can be identified given an identify option.
263
264        Args:
265            text: The text to check.
266            identify:
267                "always" or `True`: Always returns true.
268                "safe": True if the identifier is case-insensitive.
269
270        Returns:
271            Whether or not the given text can be identified.
272        """
273        if identify is True or identify == "always":
274            return True
275
276        if identify == "safe":
277            return not cls.case_sensitive(text)
278
279        return False
280
281    @classmethod
282    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
283        if isinstance(expression, exp.Identifier):
284            name = expression.this
285            expression.set(
286                "quoted",
287                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
288            )
289
290        return expression
291
292    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
293        return self.parser(**opts).parse(self.tokenize(sql), sql)
294
295    def parse_into(
296        self, expression_type: exp.IntoType, sql: str, **opts
297    ) -> t.List[t.Optional[exp.Expression]]:
298        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
299
300    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
301        return self.generator(**opts).generate(expression)
302
303    def transpile(self, sql: str, **opts) -> t.List[str]:
304        return [self.generate(expression, **opts) for expression in self.parse(sql)]
305
306    def tokenize(self, sql: str) -> t.List[Token]:
307        return self.tokenizer.tokenize(sql)
308
309    @property
310    def tokenizer(self) -> Tokenizer:
311        if not hasattr(self, "_tokenizer"):
312            self._tokenizer = self.tokenizer_class()
313        return self._tokenizer
314
315    def parser(self, **opts) -> Parser:
316        return self.parser_class(**opts)
317
318    def generator(self, **opts) -> Generator:
319        return self.generator_class(**opts)
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
RESOLVES_IDENTIFIERS_AS_UPPERCASE: Optional[bool] = False
IDENTIFIERS_CAN_START_WITH_DIGIT = False
DPIPE_IS_STRING_CONCAT = True
STRICT_STRING_CONCAT = False
SUPPORTS_USER_DEFINED_TYPES = True
NORMALIZE_FUNCTIONS: bool | str = 'upper'
NULL_ORDERING = 'nulls_are_small'
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}
FORMAT_MAPPING: Dict[str, str] = {}
PSEUDOCOLUMNS: Set[str] = set()
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Type[Dialect]:
202    @classmethod
203    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
204        if not dialect:
205            return cls
206        if isinstance(dialect, _Dialect):
207            return dialect
208        if isinstance(dialect, Dialect):
209            return dialect.__class__
210
211        result = cls.get(dialect)
212        if not result:
213            raise ValueError(f"Unknown dialect '{dialect}'")
214
215        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
217    @classmethod
218    def format_time(
219        cls, expression: t.Optional[str | exp.Expression]
220    ) -> t.Optional[exp.Expression]:
221        if isinstance(expression, str):
222            return exp.Literal.string(
223                # the time formats are quoted
224                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
225            )
226
227        if expression and expression.is_string:
228            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
229
230        return expression
@classmethod
def normalize_identifier(cls, expression: ~E) -> ~E:
232    @classmethod
233    def normalize_identifier(cls, expression: E) -> E:
234        """
235        Normalizes an unquoted identifier to either lower or upper case, thus essentially
236        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
237        they will be normalized regardless of being quoted or not.
238        """
239        if isinstance(expression, exp.Identifier) and (
240            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
241        ):
242            expression.set(
243                "this",
244                expression.this.upper()
245                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
246                else expression.this.lower(),
247            )
248
249        return expression

Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized regardless of being quoted or not.

@classmethod
def case_sensitive(cls, text: str) -> bool:
251    @classmethod
252    def case_sensitive(cls, text: str) -> bool:
253        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
254        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
255            return False
256
257        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
258        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

@classmethod
def can_identify(cls, text: str, identify: str | bool = 'safe') -> bool:
260    @classmethod
261    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
262        """Checks if text can be identified given an identify option.
263
264        Args:
265            text: The text to check.
266            identify:
267                "always" or `True`: Always returns true.
268                "safe": True if the identifier is case-insensitive.
269
270        Returns:
271            Whether or not the given text can be identified.
272        """
273        if identify is True or identify == "always":
274            return True
275
276        if identify == "safe":
277            return not cls.case_sensitive(text)
278
279        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

@classmethod
def quote_identifier(cls, expression: ~E, identify: bool = True) -> ~E:
281    @classmethod
282    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
283        if isinstance(expression, exp.Identifier):
284            name = expression.this
285            expression.set(
286                "quoted",
287                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
288            )
289
290        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
292    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
293        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
295    def parse_into(
296        self, expression_type: exp.IntoType, sql: str, **opts
297    ) -> t.List[t.Optional[exp.Expression]]:
298        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
300    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
301        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
303    def transpile(self, sql: str, **opts) -> t.List[str]:
304        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
306    def tokenize(self, sql: str) -> t.List[Token]:
307        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
315    def parser(self, **opts) -> Parser:
316        return self.parser_class(**opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
318    def generator(self, **opts) -> Generator:
319        return self.generator_class(**opts)
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START = None
BIT_END = None
HEX_START = None
HEX_END = None
BYTE_START = None
BYTE_END = None
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
325def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
326    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
329def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
330    if expression.args.get("accuracy"):
331        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
332    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.If) -> str:
335def if_sql(self: Generator, expression: exp.If) -> str:
336    return self.func(
337        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
338    )
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
341def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
342    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
345def arrow_json_extract_scalar_sql(
346    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
347) -> str:
348    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
351def inline_array_sql(self: Generator, expression: exp.Array) -> str:
352    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
355def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
356    return self.like_sql(
357        exp.Like(
358            this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
359        )
360    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
363def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
364    zone = self.sql(expression, "this")
365    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
368def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
369    if expression.args.get("recursive"):
370        self.unsupported("Recursive CTEs are unsupported")
371        expression.args["recursive"] = False
372    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
375def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
376    n = self.sql(expression, "this")
377    d = self.sql(expression, "expression")
378    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
381def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
382    self.unsupported("TABLESAMPLE unsupported")
383    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
386def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
387    self.unsupported("PIVOT unsupported")
388    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
391def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
392    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
395def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
396    self.unsupported("Properties unsupported")
397    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
400def no_comment_column_constraint_sql(
401    self: Generator, expression: exp.CommentColumnConstraint
402) -> str:
403    self.unsupported("CommentColumnConstraint unsupported")
404    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
407def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
408    self.unsupported("MAP_FROM_ENTRIES unsupported")
409    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
412def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
413    this = self.sql(expression, "this")
414    substr = self.sql(expression, "substr")
415    position = self.sql(expression, "position")
416    if position:
417        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
418    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
421def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
422    return (
423        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
424    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
427def var_map_sql(
428    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
429) -> str:
430    keys = expression.args["keys"]
431    values = expression.args["values"]
432
433    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
434        self.unsupported("Cannot convert array columns into map.")
435        return self.func(map_func_name, keys, values)
436
437    args = []
438    for key, value in zip(keys.expressions, values.expressions):
439        args.append(self.sql(key))
440        args.append(self.sql(value))
441
442    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
445def format_time_lambda(
446    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
447) -> t.Callable[[t.List], E]:
448    """Helper used for time expressions.
449
450    Args:
451        exp_class: the expression class to instantiate.
452        dialect: target sql dialect.
453        default: the default format, True being time.
454
455    Returns:
456        A callable that can be used to return the appropriately formatted time expression.
457    """
458
459    def _format_time(args: t.List):
460        return exp_class(
461            this=seq_get(args, 0),
462            format=Dialect[dialect].format_time(
463                seq_get(args, 1)
464                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
465            ),
466        )
467
468    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
471def time_format(
472    dialect: DialectType = None,
473) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
474    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
475        """
476        Returns the time format for a given expression, unless it's equivalent
477        to the default time format of the dialect of interest.
478        """
479        time_format = self.format_time(expression)
480        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
481
482    return _time_format
def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
485def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
486    """
487    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
488    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
489    columns are removed from the create statement.
490    """
491    has_schema = isinstance(expression.this, exp.Schema)
492    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
493
494    if has_schema and is_partitionable:
495        expression = expression.copy()
496        prop = expression.find(exp.PartitionedByProperty)
497        if prop and prop.this and not isinstance(prop.this, exp.Schema):
498            schema = expression.this
499            columns = {v.name.upper() for v in prop.this.expressions}
500            partitions = [col for col in schema.expressions if col.name.upper() in columns]
501            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
502            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
503            expression.set("this", schema)
504
505    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
508def parse_date_delta(
509    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
510) -> t.Callable[[t.List], E]:
511    def inner_func(args: t.List) -> E:
512        unit_based = len(args) == 3
513        this = args[2] if unit_based else seq_get(args, 0)
514        unit = args[0] if unit_based else exp.Literal.string("DAY")
515        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
516        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
517
518    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
521def parse_date_delta_with_interval(
522    expression_class: t.Type[E],
523) -> t.Callable[[t.List], t.Optional[E]]:
524    def func(args: t.List) -> t.Optional[E]:
525        if len(args) < 2:
526            return None
527
528        interval = args[1]
529
530        if not isinstance(interval, exp.Interval):
531            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
532
533        expression = interval.this
534        if expression and expression.is_string:
535            expression = exp.Literal.number(expression.this)
536
537        return expression_class(
538            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
539        )
540
541    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
544def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
545    unit = seq_get(args, 0)
546    this = seq_get(args, 1)
547
548    if isinstance(this, exp.Cast) and this.is_type("date"):
549        return exp.DateTrunc(unit=unit, this=this)
550    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
553def date_add_interval_sql(
554    data_type: str, kind: str
555) -> t.Callable[[Generator, exp.Expression], str]:
556    def func(self: Generator, expression: exp.Expression) -> str:
557        this = self.sql(expression, "this")
558        unit = expression.args.get("unit")
559        unit = exp.var(unit.name.upper() if unit else "DAY")
560        interval = exp.Interval(this=expression.expression.copy(), unit=unit)
561        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
562
563    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
566def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
567    return self.func(
568        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
569    )
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
572def locate_to_strposition(args: t.List) -> exp.Expression:
573    return exp.StrPosition(
574        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
575    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
578def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
579    return self.func(
580        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
581    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
584def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
585    expression = expression.copy()
586    return self.sql(
587        exp.Substring(
588            this=expression.this, start=exp.Literal.number(1), length=expression.expression
589        )
590    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
593def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
594    expression = expression.copy()
595    return self.sql(
596        exp.Substring(
597            this=expression.this,
598            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
599        )
600    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
603def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
604    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
607def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
608    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
612def encode_decode_sql(
613    self: Generator, expression: exp.Expression, name: str, replace: bool = True
614) -> str:
615    charset = expression.args.get("charset")
616    if charset and charset.name.lower() != "utf-8":
617        self.unsupported(f"Expected utf-8 character set, got {charset}.")
618
619    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
622def min_or_least(self: Generator, expression: exp.Min) -> str:
623    name = "LEAST" if expression.expressions else "MIN"
624    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
627def max_or_greatest(self: Generator, expression: exp.Max) -> str:
628    name = "GREATEST" if expression.expressions else "MAX"
629    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
632def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
633    cond = expression.this
634
635    if isinstance(expression.this, exp.Distinct):
636        cond = expression.this.expressions[0]
637        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
638
639    return self.func("sum", exp.func("if", cond.copy(), 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
642def trim_sql(self: Generator, expression: exp.Trim) -> str:
643    target = self.sql(expression, "this")
644    trim_type = self.sql(expression, "position")
645    remove_chars = self.sql(expression, "expression")
646    collation = self.sql(expression, "collation")
647
648    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
649    if not remove_chars and not collation:
650        return self.trim_sql(expression)
651
652    trim_type = f"{trim_type} " if trim_type else ""
653    remove_chars = f"{remove_chars} " if remove_chars else ""
654    from_part = "FROM " if trim_type or remove_chars else ""
655    collation = f" COLLATE {collation}" if collation else ""
656    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
659def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
660    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
663def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
664    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
665        _dialect = Dialect.get_or_raise(dialect)
666        time_format = self.format_time(expression)
667        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
668            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
669
670        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
671
672    return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat | sqlglot.expressions.SafeConcat) -> str:
675def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
676    expression = expression.copy()
677    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
680def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
681    expression = expression.copy()
682    delim, *rest_args = expression.expressions
683    return self.sql(
684        reduce(
685            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
686            rest_args,
687        )
688    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
691def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
692    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
693    if bad_args:
694        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
695
696    return self.func(
697        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
698    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
701def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
702    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
703    if bad_args:
704        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
705
706    return self.func(
707        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
708    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
711def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
712    names = []
713    for agg in aggregations:
714        if isinstance(agg, exp.Alias):
715            names.append(agg.alias)
716        else:
717            """
718            This case corresponds to aggregations without aliases being used as suffixes
719            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
720            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
721            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
722            """
723            agg_all_unquoted = agg.transform(
724                lambda node: exp.Identifier(this=node.name, quoted=False)
725                if isinstance(node, exp.Identifier)
726                else node
727            )
728            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
729
730    return names
def simplify_literal(expression: ~E) -> ~E:
733def simplify_literal(expression: E) -> E:
734    if not isinstance(expression.expression, exp.Literal):
735        from sqlglot.optimizer.simplify import simplify
736
737        simplify(expression.expression)
738
739    return expression
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
742def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
743    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def parse_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
747def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
748    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
751def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
752    return self.func("MAX", expression.this)
def json_keyvalue_comma_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONKeyValue) -> str:
756def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
757    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
760def is_parse_json(expression: exp.Expression) -> bool:
761    return isinstance(expression, exp.ParseJSON) or (
762        isinstance(expression, exp.Cast) and expression.is_type("json")
763    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
766def isnull_to_is_null(args: t.List) -> exp.Expression:
767    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))