sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 11from sqlglot.optimizer.simplify import simplify_parens 12from sqlglot.schema import Schema, ensure_schema 13 14if t.TYPE_CHECKING: 15 from sqlglot._typing import E 16 17 18def qualify_columns( 19 expression: exp.Expression, 20 schema: t.Dict | Schema, 21 expand_alias_refs: bool = True, 22 expand_stars: bool = True, 23 infer_schema: t.Optional[bool] = None, 24) -> exp.Expression: 25 """ 26 Rewrite sqlglot AST to have fully qualified columns. 27 28 Example: 29 >>> import sqlglot 30 >>> schema = {"tbl": {"col": "INT"}} 31 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 32 >>> qualify_columns(expression, schema).sql() 33 'SELECT tbl.col AS col FROM tbl' 34 35 Args: 36 expression: Expression to qualify. 37 schema: Database schema. 38 expand_alias_refs: Whether to expand references to aliases. 39 expand_stars: Whether to expand star queries. This is a necessary step 40 for most of the optimizer's rules to work; do not set to False unless you 41 know what you're doing! 42 infer_schema: Whether to infer the schema if missing. 43 44 Returns: 45 The qualified expression. 46 47 Notes: 48 - Currently only handles a single PIVOT or UNPIVOT operator 49 """ 50 schema = ensure_schema(schema) 51 infer_schema = schema.empty if infer_schema is None else infer_schema 52 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 53 54 for scope in traverse_scope(expression): 55 resolver = Resolver(scope, schema, infer_schema=infer_schema) 56 _pop_table_column_aliases(scope.ctes) 57 _pop_table_column_aliases(scope.derived_tables) 58 using_column_tables = _expand_using(scope, resolver) 59 60 if schema.empty and expand_alias_refs: 61 _expand_alias_refs(scope, resolver) 62 63 _qualify_columns(scope, resolver) 64 65 if not schema.empty and expand_alias_refs: 66 _expand_alias_refs(scope, resolver) 67 68 if not isinstance(scope.expression, exp.UDTF): 69 if expand_stars: 70 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 71 qualify_outputs(scope) 72 73 _expand_group_by(scope) 74 _expand_order_by(scope, resolver) 75 76 return expression 77 78 79def validate_qualify_columns(expression: E) -> E: 80 """Raise an `OptimizeError` if any columns aren't qualified""" 81 all_unqualified_columns = [] 82 for scope in traverse_scope(expression): 83 if isinstance(scope.expression, exp.Select): 84 unqualified_columns = scope.unqualified_columns 85 86 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 87 column = scope.external_columns[0] 88 for_table = f" for table: '{column.table}'" if column.table else "" 89 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 90 91 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 92 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 93 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 94 # this list here to ensure those in the former category will be excluded. 95 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 96 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 97 98 all_unqualified_columns.extend(unqualified_columns) 99 100 if all_unqualified_columns: 101 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 102 103 return expression 104 105 106def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 107 name_column = [] 108 field = unpivot.args.get("field") 109 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 110 name_column.append(field.this) 111 112 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 113 return itertools.chain(name_column, value_columns) 114 115 116def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 117 """ 118 Remove table column aliases. 119 120 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 121 """ 122 for derived_table in derived_tables: 123 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 124 continue 125 table_alias = derived_table.args.get("alias") 126 if table_alias: 127 table_alias.args.pop("columns", None) 128 129 130def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 131 joins = list(scope.find_all(exp.Join)) 132 names = {join.alias_or_name for join in joins} 133 ordered = [key for key in scope.selected_sources if key not in names] 134 135 # Mapping of automatically joined column names to an ordered set of source names (dict). 136 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 137 138 for join in joins: 139 using = join.args.get("using") 140 141 if not using: 142 continue 143 144 join_table = join.alias_or_name 145 146 columns = {} 147 148 for source_name in scope.selected_sources: 149 if source_name in ordered: 150 for column_name in resolver.get_source_columns(source_name): 151 if column_name not in columns: 152 columns[column_name] = source_name 153 154 source_table = ordered[-1] 155 ordered.append(join_table) 156 join_columns = resolver.get_source_columns(join_table) 157 conditions = [] 158 159 for identifier in using: 160 identifier = identifier.name 161 table = columns.get(identifier) 162 163 if not table or identifier not in join_columns: 164 if (columns and "*" not in columns) and join_columns: 165 raise OptimizeError(f"Cannot automatically join: {identifier}") 166 167 table = table or source_table 168 conditions.append( 169 exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table)) 170 ) 171 172 # Set all values in the dict to None, because we only care about the key ordering 173 tables = column_tables.setdefault(identifier, {}) 174 if table not in tables: 175 tables[table] = None 176 if join_table not in tables: 177 tables[join_table] = None 178 179 join.args.pop("using") 180 join.set("on", exp.and_(*conditions, copy=False)) 181 182 if column_tables: 183 for column in scope.columns: 184 if not column.table and column.name in column_tables: 185 tables = column_tables[column.name] 186 coalesce = [exp.column(column.name, table=table) for table in tables] 187 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 188 189 # Ensure selects keep their output name 190 if isinstance(column.parent, exp.Select): 191 replacement = alias(replacement, alias=column.name, copy=False) 192 193 scope.replace(column, replacement) 194 195 return column_tables 196 197 198def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 199 expression = scope.expression 200 201 if not isinstance(expression, exp.Select): 202 return 203 204 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 205 206 def replace_columns( 207 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 208 ) -> None: 209 if not node: 210 return 211 212 for column in walk_in_scope(node, prune=lambda node: node.is_star): 213 if not isinstance(column, exp.Column): 214 continue 215 216 table = resolver.get_table(column.name) if resolve_table and not column.table else None 217 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 218 double_agg = ( 219 ( 220 alias_expr.find(exp.AggFunc) 221 and ( 222 column.find_ancestor(exp.AggFunc) 223 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 224 ) 225 ) 226 if alias_expr 227 else False 228 ) 229 230 if table and (not alias_expr or double_agg): 231 column.set("table", table) 232 elif not column.table and alias_expr and not double_agg: 233 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 234 if literal_index: 235 column.replace(exp.Literal.number(i)) 236 else: 237 column = column.replace(exp.paren(alias_expr)) 238 simplified = simplify_parens(column) 239 if simplified is not column: 240 column.replace(simplified) 241 242 for i, projection in enumerate(scope.expression.selects): 243 replace_columns(projection) 244 245 if isinstance(projection, exp.Alias): 246 alias_to_expression[projection.alias] = (projection.this, i + 1) 247 248 replace_columns(expression.args.get("where")) 249 replace_columns(expression.args.get("group"), literal_index=True) 250 replace_columns(expression.args.get("having"), resolve_table=True) 251 replace_columns(expression.args.get("qualify"), resolve_table=True) 252 253 scope.clear_cache() 254 255 256def _expand_group_by(scope: Scope) -> None: 257 expression = scope.expression 258 group = expression.args.get("group") 259 if not group: 260 return 261 262 group.set("expressions", _expand_positional_references(scope, group.expressions)) 263 expression.set("group", group) 264 265 266def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 267 order = scope.expression.args.get("order") 268 if not order: 269 return 270 271 ordereds = order.expressions 272 for ordered, new_expression in zip( 273 ordereds, 274 _expand_positional_references(scope, (o.this for o in ordereds), alias=True), 275 ): 276 for agg in ordered.find_all(exp.AggFunc): 277 for col in agg.find_all(exp.Column): 278 if not col.table: 279 col.set("table", resolver.get_table(col.name)) 280 281 ordered.set("this", new_expression) 282 283 if scope.expression.args.get("group"): 284 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 285 286 for ordered in ordereds: 287 ordered = ordered.this 288 289 ordered.replace( 290 exp.to_identifier(_select_by_pos(scope, ordered).alias) 291 if ordered.is_int 292 else selects.get(ordered, ordered) 293 ) 294 295 296def _expand_positional_references( 297 scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False 298) -> t.List[exp.Expression]: 299 new_nodes: t.List[exp.Expression] = [] 300 for node in expressions: 301 if node.is_int: 302 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 303 304 if alias: 305 new_nodes.append(exp.column(select.args["alias"].copy())) 306 else: 307 select = select.this 308 309 if isinstance(select, exp.CONSTANTS) or select.find(exp.Explode, exp.Unnest): 310 new_nodes.append(node) 311 else: 312 new_nodes.append(select.copy()) 313 else: 314 new_nodes.append(node) 315 316 return new_nodes 317 318 319def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 320 try: 321 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 322 except IndexError: 323 raise OptimizeError(f"Unknown output column: {node.name}") 324 325 326def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 327 """Disambiguate columns, ensuring each column specifies a source""" 328 for column in scope.columns: 329 column_table = column.table 330 column_name = column.name 331 332 if column_table and column_table in scope.sources: 333 source_columns = resolver.get_source_columns(column_table) 334 if source_columns and column_name not in source_columns and "*" not in source_columns: 335 raise OptimizeError(f"Unknown column: {column_name}") 336 337 if not column_table: 338 if scope.pivots and not column.find_ancestor(exp.Pivot): 339 # If the column is under the Pivot expression, we need to qualify it 340 # using the name of the pivoted source instead of the pivot's alias 341 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 342 continue 343 344 column_table = resolver.get_table(column_name) 345 346 # column_table can be a '' because bigquery unnest has no table alias 347 if column_table: 348 column.set("table", column_table) 349 elif column_table not in scope.sources and ( 350 not scope.parent 351 or column_table not in scope.parent.sources 352 or not scope.is_correlated_subquery 353 ): 354 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 355 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 356 357 root, *parts = column.parts 358 359 if root.name in scope.sources: 360 # struct is already qualified, but we still need to change the AST representation 361 column_table = root 362 root, *parts = parts 363 else: 364 column_table = resolver.get_table(root.name) 365 366 if column_table: 367 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 368 369 for pivot in scope.pivots: 370 for column in pivot.find_all(exp.Column): 371 if not column.table and column.name in resolver.all_columns: 372 column_table = resolver.get_table(column.name) 373 if column_table: 374 column.set("table", column_table) 375 376 377def _expand_stars( 378 scope: Scope, 379 resolver: Resolver, 380 using_column_tables: t.Dict[str, t.Any], 381 pseudocolumns: t.Set[str], 382) -> None: 383 """Expand stars to lists of column selections""" 384 385 new_selections = [] 386 except_columns: t.Dict[int, t.Set[str]] = {} 387 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 388 coalesced_columns = set() 389 390 pivot_output_columns = None 391 pivot_exclude_columns = None 392 393 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 394 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 395 if pivot.unpivot: 396 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 397 398 field = pivot.args.get("field") 399 if isinstance(field, exp.In): 400 pivot_exclude_columns = { 401 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 402 } 403 else: 404 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 405 406 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 407 if not pivot_output_columns: 408 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 409 410 for expression in scope.expression.selects: 411 if isinstance(expression, exp.Star): 412 tables = list(scope.selected_sources) 413 _add_except_columns(expression, tables, except_columns) 414 _add_replace_columns(expression, tables, replace_columns) 415 elif expression.is_star and not isinstance(expression, exp.Dot): 416 tables = [expression.table] 417 _add_except_columns(expression.this, tables, except_columns) 418 _add_replace_columns(expression.this, tables, replace_columns) 419 else: 420 new_selections.append(expression) 421 continue 422 423 for table in tables: 424 if table not in scope.sources: 425 raise OptimizeError(f"Unknown table: {table}") 426 427 columns = resolver.get_source_columns(table, only_visible=True) 428 columns = columns or scope.outer_columns 429 430 if pseudocolumns: 431 columns = [name for name in columns if name.upper() not in pseudocolumns] 432 433 if not columns or "*" in columns: 434 return 435 436 table_id = id(table) 437 columns_to_exclude = except_columns.get(table_id) or set() 438 439 if pivot: 440 if pivot_output_columns and pivot_exclude_columns: 441 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 442 pivot_columns.extend(pivot_output_columns) 443 else: 444 pivot_columns = pivot.alias_column_names 445 446 if pivot_columns: 447 new_selections.extend( 448 alias(exp.column(name, table=pivot.alias), name, copy=False) 449 for name in pivot_columns 450 if name not in columns_to_exclude 451 ) 452 continue 453 454 for name in columns: 455 if name in columns_to_exclude or name in coalesced_columns: 456 continue 457 if name in using_column_tables and table in using_column_tables[name]: 458 coalesced_columns.add(name) 459 tables = using_column_tables[name] 460 coalesce = [exp.column(name, table=table) for table in tables] 461 462 new_selections.append( 463 alias( 464 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 465 alias=name, 466 copy=False, 467 ) 468 ) 469 else: 470 alias_ = replace_columns.get(table_id, {}).get(name, name) 471 column = exp.column(name, table=table) 472 new_selections.append( 473 alias(column, alias_, copy=False) if alias_ != name else column 474 ) 475 476 # Ensures we don't overwrite the initial selections with an empty list 477 if new_selections and isinstance(scope.expression, exp.Select): 478 scope.expression.set("expressions", new_selections) 479 480 481def _add_except_columns( 482 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 483) -> None: 484 except_ = expression.args.get("except") 485 486 if not except_: 487 return 488 489 columns = {e.name for e in except_} 490 491 for table in tables: 492 except_columns[id(table)] = columns 493 494 495def _add_replace_columns( 496 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 497) -> None: 498 replace = expression.args.get("replace") 499 500 if not replace: 501 return 502 503 columns = {e.this.name: e.alias for e in replace} 504 505 for table in tables: 506 replace_columns[id(table)] = columns 507 508 509def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 510 """Ensure all output columns are aliased""" 511 if isinstance(scope_or_expression, exp.Expression): 512 scope = build_scope(scope_or_expression) 513 if not isinstance(scope, Scope): 514 return 515 else: 516 scope = scope_or_expression 517 518 new_selections = [] 519 for i, (selection, aliased_column) in enumerate( 520 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 521 ): 522 if selection is None: 523 break 524 525 if isinstance(selection, exp.Subquery): 526 if not selection.output_name: 527 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 528 elif not isinstance(selection, exp.Alias) and not selection.is_star: 529 selection = alias( 530 selection, 531 alias=selection.output_name or f"_col_{i}", 532 copy=False, 533 ) 534 if aliased_column: 535 selection.set("alias", exp.to_identifier(aliased_column)) 536 537 new_selections.append(selection) 538 539 if isinstance(scope.expression, exp.Select): 540 scope.expression.set("expressions", new_selections) 541 542 543def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 544 """Makes sure all identifiers that need to be quoted are quoted.""" 545 return expression.transform( 546 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 547 ) # type: ignore 548 549 550def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 551 """ 552 Pushes down the CTE alias columns into the projection, 553 554 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 555 556 Example: 557 >>> import sqlglot 558 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 559 >>> pushdown_cte_alias_columns(expression).sql() 560 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 561 562 Args: 563 expression: Expression to pushdown. 564 565 Returns: 566 The expression with the CTE aliases pushed down into the projection. 567 """ 568 for cte in expression.find_all(exp.CTE): 569 if cte.alias_column_names: 570 new_expressions = [] 571 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 572 if isinstance(projection, exp.Alias): 573 projection.set("alias", _alias) 574 else: 575 projection = alias(projection, alias=_alias) 576 new_expressions.append(projection) 577 cte.this.set("expressions", new_expressions) 578 579 return expression 580 581 582class Resolver: 583 """ 584 Helper for resolving columns. 585 586 This is a class so we can lazily load some things and easily share them across functions. 587 """ 588 589 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 590 self.scope = scope 591 self.schema = schema 592 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 593 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 594 self._all_columns: t.Optional[t.Set[str]] = None 595 self._infer_schema = infer_schema 596 597 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 598 """ 599 Get the table for a column name. 600 601 Args: 602 column_name: The column name to find the table for. 603 Returns: 604 The table name if it can be found/inferred. 605 """ 606 if self._unambiguous_columns is None: 607 self._unambiguous_columns = self._get_unambiguous_columns( 608 self._get_all_source_columns() 609 ) 610 611 table_name = self._unambiguous_columns.get(column_name) 612 613 if not table_name and self._infer_schema: 614 sources_without_schema = tuple( 615 source 616 for source, columns in self._get_all_source_columns().items() 617 if not columns or "*" in columns 618 ) 619 if len(sources_without_schema) == 1: 620 table_name = sources_without_schema[0] 621 622 if table_name not in self.scope.selected_sources: 623 return exp.to_identifier(table_name) 624 625 node, _ = self.scope.selected_sources.get(table_name) 626 627 if isinstance(node, exp.Query): 628 while node and node.alias != table_name: 629 node = node.parent 630 631 node_alias = node.args.get("alias") 632 if node_alias: 633 return exp.to_identifier(node_alias.this) 634 635 return exp.to_identifier(table_name) 636 637 @property 638 def all_columns(self) -> t.Set[str]: 639 """All available columns of all sources in this scope""" 640 if self._all_columns is None: 641 self._all_columns = { 642 column for columns in self._get_all_source_columns().values() for column in columns 643 } 644 return self._all_columns 645 646 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 647 """Resolve the source columns for a given source `name`.""" 648 if name not in self.scope.sources: 649 raise OptimizeError(f"Unknown table: {name}") 650 651 source = self.scope.sources[name] 652 653 if isinstance(source, exp.Table): 654 columns = self.schema.column_names(source, only_visible) 655 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 656 columns = source.expression.alias_column_names 657 else: 658 columns = source.expression.named_selects 659 660 node, _ = self.scope.selected_sources.get(name) or (None, None) 661 if isinstance(node, Scope): 662 column_aliases = node.expression.alias_column_names 663 elif isinstance(node, exp.Expression): 664 column_aliases = node.alias_column_names 665 else: 666 column_aliases = [] 667 668 if column_aliases: 669 # If the source's columns are aliased, their aliases shadow the corresponding column names. 670 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 671 return [ 672 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 673 ] 674 return columns 675 676 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 677 if self._source_columns is None: 678 self._source_columns = { 679 source_name: self.get_source_columns(source_name) 680 for source_name, source in itertools.chain( 681 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 682 ) 683 } 684 return self._source_columns 685 686 def _get_unambiguous_columns( 687 self, source_columns: t.Dict[str, t.Sequence[str]] 688 ) -> t.Mapping[str, str]: 689 """ 690 Find all the unambiguous columns in sources. 691 692 Args: 693 source_columns: Mapping of names to source columns. 694 695 Returns: 696 Mapping of column name to source name. 697 """ 698 if not source_columns: 699 return {} 700 701 source_columns_pairs = list(source_columns.items()) 702 703 first_table, first_columns = source_columns_pairs[0] 704 705 if len(source_columns_pairs) == 1: 706 # Performance optimization - avoid copying first_columns if there is only one table. 707 return SingleValuedMapping(first_columns, first_table) 708 709 unambiguous_columns = {col: first_table for col in first_columns} 710 all_columns = set(unambiguous_columns) 711 712 for table, columns in source_columns_pairs[1:]: 713 unique = set(columns) 714 ambiguous = all_columns.intersection(unique) 715 all_columns.update(columns) 716 717 for column in ambiguous: 718 unambiguous_columns.pop(column, None) 719 for column in unique.difference(ambiguous): 720 unambiguous_columns[column] = table 721 722 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25) -> exp.Expression: 26 """ 27 Rewrite sqlglot AST to have fully qualified columns. 28 29 Example: 30 >>> import sqlglot 31 >>> schema = {"tbl": {"col": "INT"}} 32 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 33 >>> qualify_columns(expression, schema).sql() 34 'SELECT tbl.col AS col FROM tbl' 35 36 Args: 37 expression: Expression to qualify. 38 schema: Database schema. 39 expand_alias_refs: Whether to expand references to aliases. 40 expand_stars: Whether to expand star queries. This is a necessary step 41 for most of the optimizer's rules to work; do not set to False unless you 42 know what you're doing! 43 infer_schema: Whether to infer the schema if missing. 44 45 Returns: 46 The qualified expression. 47 48 Notes: 49 - Currently only handles a single PIVOT or UNPIVOT operator 50 """ 51 schema = ensure_schema(schema) 52 infer_schema = schema.empty if infer_schema is None else infer_schema 53 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 54 55 for scope in traverse_scope(expression): 56 resolver = Resolver(scope, schema, infer_schema=infer_schema) 57 _pop_table_column_aliases(scope.ctes) 58 _pop_table_column_aliases(scope.derived_tables) 59 using_column_tables = _expand_using(scope, resolver) 60 61 if schema.empty and expand_alias_refs: 62 _expand_alias_refs(scope, resolver) 63 64 _qualify_columns(scope, resolver) 65 66 if not schema.empty and expand_alias_refs: 67 _expand_alias_refs(scope, resolver) 68 69 if not isinstance(scope.expression, exp.UDTF): 70 if expand_stars: 71 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 72 qualify_outputs(scope) 73 74 _expand_group_by(scope) 75 _expand_order_by(scope, resolver) 76 77 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
80def validate_qualify_columns(expression: E) -> E: 81 """Raise an `OptimizeError` if any columns aren't qualified""" 82 all_unqualified_columns = [] 83 for scope in traverse_scope(expression): 84 if isinstance(scope.expression, exp.Select): 85 unqualified_columns = scope.unqualified_columns 86 87 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 88 column = scope.external_columns[0] 89 for_table = f" for table: '{column.table}'" if column.table else "" 90 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 91 92 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 93 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 94 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 95 # this list here to ensure those in the former category will be excluded. 96 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 97 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 98 99 all_unqualified_columns.extend(unqualified_columns) 100 101 if all_unqualified_columns: 102 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 103 104 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
510def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 511 """Ensure all output columns are aliased""" 512 if isinstance(scope_or_expression, exp.Expression): 513 scope = build_scope(scope_or_expression) 514 if not isinstance(scope, Scope): 515 return 516 else: 517 scope = scope_or_expression 518 519 new_selections = [] 520 for i, (selection, aliased_column) in enumerate( 521 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 522 ): 523 if selection is None: 524 break 525 526 if isinstance(selection, exp.Subquery): 527 if not selection.output_name: 528 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 529 elif not isinstance(selection, exp.Alias) and not selection.is_star: 530 selection = alias( 531 selection, 532 alias=selection.output_name or f"_col_{i}", 533 copy=False, 534 ) 535 if aliased_column: 536 selection.set("alias", exp.to_identifier(aliased_column)) 537 538 new_selections.append(selection) 539 540 if isinstance(scope.expression, exp.Select): 541 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
544def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 545 """Makes sure all identifiers that need to be quoted are quoted.""" 546 return expression.transform( 547 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 548 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
551def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 552 """ 553 Pushes down the CTE alias columns into the projection, 554 555 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 556 557 Example: 558 >>> import sqlglot 559 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 560 >>> pushdown_cte_alias_columns(expression).sql() 561 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 562 563 Args: 564 expression: Expression to pushdown. 565 566 Returns: 567 The expression with the CTE aliases pushed down into the projection. 568 """ 569 for cte in expression.find_all(exp.CTE): 570 if cte.alias_column_names: 571 new_expressions = [] 572 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 573 if isinstance(projection, exp.Alias): 574 projection.set("alias", _alias) 575 else: 576 projection = alias(projection, alias=_alias) 577 new_expressions.append(projection) 578 cte.this.set("expressions", new_expressions) 579 580 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
583class Resolver: 584 """ 585 Helper for resolving columns. 586 587 This is a class so we can lazily load some things and easily share them across functions. 588 """ 589 590 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 591 self.scope = scope 592 self.schema = schema 593 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 594 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 595 self._all_columns: t.Optional[t.Set[str]] = None 596 self._infer_schema = infer_schema 597 598 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 599 """ 600 Get the table for a column name. 601 602 Args: 603 column_name: The column name to find the table for. 604 Returns: 605 The table name if it can be found/inferred. 606 """ 607 if self._unambiguous_columns is None: 608 self._unambiguous_columns = self._get_unambiguous_columns( 609 self._get_all_source_columns() 610 ) 611 612 table_name = self._unambiguous_columns.get(column_name) 613 614 if not table_name and self._infer_schema: 615 sources_without_schema = tuple( 616 source 617 for source, columns in self._get_all_source_columns().items() 618 if not columns or "*" in columns 619 ) 620 if len(sources_without_schema) == 1: 621 table_name = sources_without_schema[0] 622 623 if table_name not in self.scope.selected_sources: 624 return exp.to_identifier(table_name) 625 626 node, _ = self.scope.selected_sources.get(table_name) 627 628 if isinstance(node, exp.Query): 629 while node and node.alias != table_name: 630 node = node.parent 631 632 node_alias = node.args.get("alias") 633 if node_alias: 634 return exp.to_identifier(node_alias.this) 635 636 return exp.to_identifier(table_name) 637 638 @property 639 def all_columns(self) -> t.Set[str]: 640 """All available columns of all sources in this scope""" 641 if self._all_columns is None: 642 self._all_columns = { 643 column for columns in self._get_all_source_columns().values() for column in columns 644 } 645 return self._all_columns 646 647 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 648 """Resolve the source columns for a given source `name`.""" 649 if name not in self.scope.sources: 650 raise OptimizeError(f"Unknown table: {name}") 651 652 source = self.scope.sources[name] 653 654 if isinstance(source, exp.Table): 655 columns = self.schema.column_names(source, only_visible) 656 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 657 columns = source.expression.alias_column_names 658 else: 659 columns = source.expression.named_selects 660 661 node, _ = self.scope.selected_sources.get(name) or (None, None) 662 if isinstance(node, Scope): 663 column_aliases = node.expression.alias_column_names 664 elif isinstance(node, exp.Expression): 665 column_aliases = node.alias_column_names 666 else: 667 column_aliases = [] 668 669 if column_aliases: 670 # If the source's columns are aliased, their aliases shadow the corresponding column names. 671 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 672 return [ 673 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 674 ] 675 return columns 676 677 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 678 if self._source_columns is None: 679 self._source_columns = { 680 source_name: self.get_source_columns(source_name) 681 for source_name, source in itertools.chain( 682 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 683 ) 684 } 685 return self._source_columns 686 687 def _get_unambiguous_columns( 688 self, source_columns: t.Dict[str, t.Sequence[str]] 689 ) -> t.Mapping[str, str]: 690 """ 691 Find all the unambiguous columns in sources. 692 693 Args: 694 source_columns: Mapping of names to source columns. 695 696 Returns: 697 Mapping of column name to source name. 698 """ 699 if not source_columns: 700 return {} 701 702 source_columns_pairs = list(source_columns.items()) 703 704 first_table, first_columns = source_columns_pairs[0] 705 706 if len(source_columns_pairs) == 1: 707 # Performance optimization - avoid copying first_columns if there is only one table. 708 return SingleValuedMapping(first_columns, first_table) 709 710 unambiguous_columns = {col: first_table for col in first_columns} 711 all_columns = set(unambiguous_columns) 712 713 for table, columns in source_columns_pairs[1:]: 714 unique = set(columns) 715 ambiguous = all_columns.intersection(unique) 716 all_columns.update(columns) 717 718 for column in ambiguous: 719 unambiguous_columns.pop(column, None) 720 for column in unique.difference(ambiguous): 721 unambiguous_columns[column] = table 722 723 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
590 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 591 self.scope = scope 592 self.schema = schema 593 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 594 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 595 self._all_columns: t.Optional[t.Set[str]] = None 596 self._infer_schema = infer_schema
598 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 599 """ 600 Get the table for a column name. 601 602 Args: 603 column_name: The column name to find the table for. 604 Returns: 605 The table name if it can be found/inferred. 606 """ 607 if self._unambiguous_columns is None: 608 self._unambiguous_columns = self._get_unambiguous_columns( 609 self._get_all_source_columns() 610 ) 611 612 table_name = self._unambiguous_columns.get(column_name) 613 614 if not table_name and self._infer_schema: 615 sources_without_schema = tuple( 616 source 617 for source, columns in self._get_all_source_columns().items() 618 if not columns or "*" in columns 619 ) 620 if len(sources_without_schema) == 1: 621 table_name = sources_without_schema[0] 622 623 if table_name not in self.scope.selected_sources: 624 return exp.to_identifier(table_name) 625 626 node, _ = self.scope.selected_sources.get(table_name) 627 628 if isinstance(node, exp.Query): 629 while node and node.alias != table_name: 630 node = node.parent 631 632 node_alias = node.args.get("alias") 633 if node_alias: 634 return exp.to_identifier(node_alias.this) 635 636 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
638 @property 639 def all_columns(self) -> t.Set[str]: 640 """All available columns of all sources in this scope""" 641 if self._all_columns is None: 642 self._all_columns = { 643 column for columns in self._get_all_source_columns().values() for column in columns 644 } 645 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
647 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 648 """Resolve the source columns for a given source `name`.""" 649 if name not in self.scope.sources: 650 raise OptimizeError(f"Unknown table: {name}") 651 652 source = self.scope.sources[name] 653 654 if isinstance(source, exp.Table): 655 columns = self.schema.column_names(source, only_visible) 656 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 657 columns = source.expression.alias_column_names 658 else: 659 columns = source.expression.named_selects 660 661 node, _ = self.scope.selected_sources.get(name) or (None, None) 662 if isinstance(node, Scope): 663 column_aliases = node.expression.alias_column_names 664 elif isinstance(node, exp.Expression): 665 column_aliases = node.alias_column_names 666 else: 667 column_aliases = [] 668 669 if column_aliases: 670 # If the source's columns are aliased, their aliases shadow the corresponding column names. 671 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 672 return [ 673 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 674 ] 675 return columns
Resolve the source columns for a given source name
.