sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def unalias_group(expression: exp.Expression) -> exp.Expression: 13 """ 14 Replace references to select aliases in GROUP BY clauses. 15 16 Example: 17 >>> import sqlglot 18 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 19 'SELECT a AS b FROM x GROUP BY 1' 20 21 Args: 22 expression: the expression that will be transformed. 23 24 Returns: 25 The transformed expression. 26 """ 27 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 28 aliased_selects = { 29 e.alias: i 30 for i, e in enumerate(expression.parent.expressions, start=1) 31 if isinstance(e, exp.Alias) 32 } 33 34 for group_by in expression.expressions: 35 if ( 36 isinstance(group_by, exp.Column) 37 and not group_by.table 38 and group_by.name in aliased_selects 39 ): 40 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 41 42 return expression 43 44 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 46 """ 47 Convert SELECT DISTINCT ON statements to a subquery with a window function. 48 49 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 50 51 Args: 52 expression: the expression that will be transformed. 53 54 Returns: 55 The transformed expression. 56 """ 57 if ( 58 isinstance(expression, exp.Select) 59 and expression.args.get("distinct") 60 and expression.args["distinct"].args.get("on") 61 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 62 ): 63 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 64 outer_selects = expression.selects 65 row_number = find_new_name(expression.named_selects, "_row_number") 66 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 67 order = expression.args.get("order") 68 69 if order: 70 window.set("order", order.pop().copy()) 71 else: 72 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 73 74 window = exp.alias_(window, row_number) 75 expression.select(window, copy=False) 76 77 return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') 78 79 return expression 80 81 82def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 83 """ 84 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 85 86 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 87 https://docs.snowflake.com/en/sql-reference/constructs/qualify 88 89 Some dialects don't support window functions in the WHERE clause, so we need to include them as 90 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 91 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 92 otherwise we won't be able to refer to it in the outer query's WHERE clause. 93 """ 94 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 95 taken = set(expression.named_selects) 96 for select in expression.selects: 97 if not select.alias_or_name: 98 alias = find_new_name(taken, "_c") 99 select.replace(exp.alias_(select, alias)) 100 taken.add(alias) 101 102 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 103 qualify_filters = expression.args["qualify"].pop().this 104 105 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 106 for expr in qualify_filters.find_all(select_candidates): 107 if isinstance(expr, exp.Window): 108 alias = find_new_name(expression.named_selects, "_w") 109 expression.select(exp.alias_(expr, alias), copy=False) 110 column = exp.column(alias) 111 112 if isinstance(expr.parent, exp.Qualify): 113 qualify_filters = column 114 else: 115 expr.replace(column) 116 elif expr.name not in expression.named_selects: 117 expression.select(expr.copy(), copy=False) 118 119 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 120 121 return expression 122 123 124def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 125 """ 126 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 127 other expressions. This transforms removes the precision from parameterized types in expressions. 128 """ 129 for node in expression.find_all(exp.DataType): 130 node.set( 131 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 132 ) 133 134 return expression 135 136 137def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 138 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 139 if isinstance(expression, exp.Select): 140 for join in expression.args.get("joins") or []: 141 unnest = join.this 142 143 if isinstance(unnest, exp.Unnest): 144 alias = unnest.args.get("alias") 145 udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode 146 147 expression.args["joins"].remove(join) 148 149 for e, column in zip(unnest.expressions, alias.columns if alias else []): 150 expression.append( 151 "laterals", 152 exp.Lateral( 153 this=udtf(this=e), 154 view=True, 155 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 156 ), 157 ) 158 159 return expression 160 161 162def explode_to_unnest(expression: exp.Expression) -> exp.Expression: 163 """Convert explode/posexplode into unnest (used in hive -> presto).""" 164 if isinstance(expression, exp.Select): 165 from sqlglot.optimizer.scope import Scope 166 167 taken_select_names = set(expression.named_selects) 168 taken_source_names = {name for name, _ in Scope(expression).references} 169 170 for select in expression.selects: 171 to_replace = select 172 173 pos_alias = "" 174 explode_alias = "" 175 176 if isinstance(select, exp.Alias): 177 explode_alias = select.alias 178 select = select.this 179 elif isinstance(select, exp.Aliases): 180 pos_alias = select.aliases[0].name 181 explode_alias = select.aliases[1].name 182 select = select.this 183 184 if isinstance(select, (exp.Explode, exp.Posexplode)): 185 is_posexplode = isinstance(select, exp.Posexplode) 186 187 explode_arg = select.this 188 unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) 189 190 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 191 if isinstance(explode_arg, exp.Column): 192 taken_select_names.add(explode_arg.output_name) 193 194 unnest_source_alias = find_new_name(taken_source_names, "_u") 195 taken_source_names.add(unnest_source_alias) 196 197 if not explode_alias: 198 explode_alias = find_new_name(taken_select_names, "col") 199 taken_select_names.add(explode_alias) 200 201 if is_posexplode: 202 pos_alias = find_new_name(taken_select_names, "pos") 203 taken_select_names.add(pos_alias) 204 205 if is_posexplode: 206 column_names = [explode_alias, pos_alias] 207 to_replace.pop() 208 expression.select(pos_alias, explode_alias, copy=False) 209 else: 210 column_names = [explode_alias] 211 to_replace.replace(exp.column(explode_alias)) 212 213 unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) 214 215 if not expression.args.get("from"): 216 expression.from_(unnest, copy=False) 217 else: 218 expression.join(unnest, join_type="CROSS", copy=False) 219 220 return expression 221 222 223def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 224 if ( 225 isinstance(expression, exp.WithinGroup) 226 and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc)) 227 and isinstance(expression.expression, exp.Order) 228 ): 229 quantile = expression.this.this 230 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 231 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 232 233 return expression 234 235 236def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 237 if isinstance(expression, exp.With) and expression.recursive: 238 next_name = name_sequence("_c_") 239 240 for cte in expression.expressions: 241 if not cte.args["alias"].columns: 242 query = cte.this 243 if isinstance(query, exp.Union): 244 query = query.this 245 246 cte.args["alias"].set( 247 "columns", 248 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 249 ) 250 251 return expression 252 253 254def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 255 if ( 256 isinstance(expression, (exp.Cast, exp.TryCast)) 257 and expression.name.lower() == "epoch" 258 and expression.to.this in exp.DataType.TEMPORAL_TYPES 259 ): 260 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 261 262 return expression 263 264 265def preprocess( 266 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 267) -> t.Callable[[Generator, exp.Expression], str]: 268 """ 269 Creates a new transform by chaining a sequence of transformations and converts the resulting 270 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 271 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 272 273 Args: 274 transforms: sequence of transform functions. These will be called in order. 275 276 Returns: 277 Function that can be used as a generator transform. 278 """ 279 280 def _to_sql(self, expression: exp.Expression) -> str: 281 expression_type = type(expression) 282 283 expression = transforms[0](expression.copy()) 284 for t in transforms[1:]: 285 expression = t(expression) 286 287 _sql_handler = getattr(self, expression.key + "_sql", None) 288 if _sql_handler: 289 return _sql_handler(expression) 290 291 transforms_handler = self.TRANSFORMS.get(type(expression)) 292 if transforms_handler: 293 # Ensures we don't enter an infinite loop. This can happen when the original expression 294 # has the same type as the final expression and there's no _sql method available for it, 295 # because then it'd re-enter _to_sql. 296 if expression_type is type(expression): 297 raise ValueError( 298 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 299 ) 300 301 return transforms_handler(self, expression) 302 303 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 304 305 return _to_sql
13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop().copy()) 72 else: 73 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 74 75 window = exp.alias_(window, row_number) 76 expression.select(window, copy=False) 77 78 return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') 79 80 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
83def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 84 """ 85 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 86 87 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 88 https://docs.snowflake.com/en/sql-reference/constructs/qualify 89 90 Some dialects don't support window functions in the WHERE clause, so we need to include them as 91 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 92 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 93 otherwise we won't be able to refer to it in the outer query's WHERE clause. 94 """ 95 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 96 taken = set(expression.named_selects) 97 for select in expression.selects: 98 if not select.alias_or_name: 99 alias = find_new_name(taken, "_c") 100 select.replace(exp.alias_(select, alias)) 101 taken.add(alias) 102 103 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 104 qualify_filters = expression.args["qualify"].pop().this 105 106 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 107 for expr in qualify_filters.find_all(select_candidates): 108 if isinstance(expr, exp.Window): 109 alias = find_new_name(expression.named_selects, "_w") 110 expression.select(exp.alias_(expr, alias), copy=False) 111 column = exp.column(alias) 112 113 if isinstance(expr.parent, exp.Qualify): 114 qualify_filters = column 115 else: 116 expr.replace(column) 117 elif expr.name not in expression.named_selects: 118 expression.select(expr.copy(), copy=False) 119 120 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 121 122 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.
125def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 126 """ 127 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 128 other expressions. This transforms removes the precision from parameterized types in expressions. 129 """ 130 for node in expression.find_all(exp.DataType): 131 node.set( 132 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 133 ) 134 135 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
138def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 139 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 140 if isinstance(expression, exp.Select): 141 for join in expression.args.get("joins") or []: 142 unnest = join.this 143 144 if isinstance(unnest, exp.Unnest): 145 alias = unnest.args.get("alias") 146 udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode 147 148 expression.args["joins"].remove(join) 149 150 for e, column in zip(unnest.expressions, alias.columns if alias else []): 151 expression.append( 152 "laterals", 153 exp.Lateral( 154 this=udtf(this=e), 155 view=True, 156 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 157 ), 158 ) 159 160 return expression
Convert cross join unnest into lateral view explode (used in presto -> hive).
163def explode_to_unnest(expression: exp.Expression) -> exp.Expression: 164 """Convert explode/posexplode into unnest (used in hive -> presto).""" 165 if isinstance(expression, exp.Select): 166 from sqlglot.optimizer.scope import Scope 167 168 taken_select_names = set(expression.named_selects) 169 taken_source_names = {name for name, _ in Scope(expression).references} 170 171 for select in expression.selects: 172 to_replace = select 173 174 pos_alias = "" 175 explode_alias = "" 176 177 if isinstance(select, exp.Alias): 178 explode_alias = select.alias 179 select = select.this 180 elif isinstance(select, exp.Aliases): 181 pos_alias = select.aliases[0].name 182 explode_alias = select.aliases[1].name 183 select = select.this 184 185 if isinstance(select, (exp.Explode, exp.Posexplode)): 186 is_posexplode = isinstance(select, exp.Posexplode) 187 188 explode_arg = select.this 189 unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) 190 191 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 192 if isinstance(explode_arg, exp.Column): 193 taken_select_names.add(explode_arg.output_name) 194 195 unnest_source_alias = find_new_name(taken_source_names, "_u") 196 taken_source_names.add(unnest_source_alias) 197 198 if not explode_alias: 199 explode_alias = find_new_name(taken_select_names, "col") 200 taken_select_names.add(explode_alias) 201 202 if is_posexplode: 203 pos_alias = find_new_name(taken_select_names, "pos") 204 taken_select_names.add(pos_alias) 205 206 if is_posexplode: 207 column_names = [explode_alias, pos_alias] 208 to_replace.pop() 209 expression.select(pos_alias, explode_alias, copy=False) 210 else: 211 column_names = [explode_alias] 212 to_replace.replace(exp.column(explode_alias)) 213 214 unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) 215 216 if not expression.args.get("from"): 217 expression.from_(unnest, copy=False) 218 else: 219 expression.join(unnest, join_type="CROSS", copy=False) 220 221 return expression
Convert explode/posexplode into unnest (used in hive -> presto).
224def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 225 if ( 226 isinstance(expression, exp.WithinGroup) 227 and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc)) 228 and isinstance(expression.expression, exp.Order) 229 ): 230 quantile = expression.this.this 231 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 232 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 233 234 return expression
237def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 238 if isinstance(expression, exp.With) and expression.recursive: 239 next_name = name_sequence("_c_") 240 241 for cte in expression.expressions: 242 if not cte.args["alias"].columns: 243 query = cte.this 244 if isinstance(query, exp.Union): 245 query = query.this 246 247 cte.args["alias"].set( 248 "columns", 249 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 250 ) 251 252 return expression
255def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 256 if ( 257 isinstance(expression, (exp.Cast, exp.TryCast)) 258 and expression.name.lower() == "epoch" 259 and expression.to.this in exp.DataType.TEMPORAL_TYPES 260 ): 261 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 262 263 return expression
266def preprocess( 267 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 268) -> t.Callable[[Generator, exp.Expression], str]: 269 """ 270 Creates a new transform by chaining a sequence of transformations and converts the resulting 271 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 272 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 273 274 Args: 275 transforms: sequence of transform functions. These will be called in order. 276 277 Returns: 278 Function that can be used as a generator transform. 279 """ 280 281 def _to_sql(self, expression: exp.Expression) -> str: 282 expression_type = type(expression) 283 284 expression = transforms[0](expression.copy()) 285 for t in transforms[1:]: 286 expression = t(expression) 287 288 _sql_handler = getattr(self, expression.key + "_sql", None) 289 if _sql_handler: 290 return _sql_handler(expression) 291 292 transforms_handler = self.TRANSFORMS.get(type(expression)) 293 if transforms_handler: 294 # Ensures we don't enter an infinite loop. This can happen when the original expression 295 # has the same type as the final expression and there's no _sql method available for it, 296 # because then it'd re-enter _to_sql. 297 if expression_type is type(expression): 298 raise ValueError( 299 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 300 ) 301 302 return transforms_handler(self, expression) 303 304 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 305 306 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.