/*------------------------------------------------------------------------- * * multi_logical_planner.c * * Routines for constructing a logical plan tree from the given Query tree * structure. This new logical plan is based on multi-relational algebra rules. * * Copyright (c) Citus Data, Inc. * * $Id$ * *------------------------------------------------------------------------- */ #include "postgres.h" #include "distributed/pg_version_constants.h" #include "access/heapam.h" #include "access/nbtree.h" #include "catalog/pg_am.h" #include "catalog/pg_class.h" #include "commands/defrem.h" #include "distributed/citus_clauses.h" #include "distributed/colocation_utils.h" #include "distributed/metadata_cache.h" #include "distributed/insert_select_planner.h" #include "distributed/listutils.h" #include "distributed/multi_logical_optimizer.h" #include "distributed/multi_logical_planner.h" #include "distributed/multi_physical_planner.h" #include "distributed/reference_table_utils.h" #include "distributed/relation_restriction_equivalence.h" #include "distributed/query_pushdown_planning.h" #include "distributed/query_utils.h" #include "distributed/multi_router_planner.h" #include "distributed/worker_protocol.h" #include "distributed/version_compat.h" #include "nodes/makefuncs.h" #include "nodes/nodeFuncs.h" #if PG_VERSION_NUM >= PG_VERSION_12 #include "nodes/pathnodes.h" #include "optimizer/optimizer.h" #else #include "nodes/relation.h" #include "optimizer/var.h" #endif #include "optimizer/clauses.h" #include "optimizer/prep.h" #include "optimizer/tlist.h" #include "parser/parsetree.h" #include "utils/builtins.h" #include "utils/datum.h" #include "utils/lsyscache.h" #include "utils/syscache.h" #include "utils/rel.h" #include "utils/relcache.h" /* Struct to differentiate different qualifier types in an expression tree walker */ typedef struct QualifierWalkerContext { List *baseQualifierList; List *outerJoinQualifierList; } QualifierWalkerContext; /* Function pointer type definition for apply join rule functions */ typedef MultiNode *(*RuleApplyFunction) (MultiNode *leftNode, MultiNode *rightNode, List *partitionColumnList, JoinType joinType, List *joinClauses); typedef bool (*CheckNodeFunc)(Node *); static RuleApplyFunction RuleApplyFunctionArray[JOIN_RULE_LAST] = { 0 }; /* join rules */ /* Local functions forward declarations */ static FieldSelect * CompositeFieldRecursive(Expr *expression, Query *query); static Oid NodeTryGetRteRelid(Node *node); static bool FullCompositeFieldList(List *compositeFieldList); static bool HasUnsupportedJoinWalker(Node *node, void *context); static bool ErrorHintRequired(const char *errorHint, Query *queryTree); static bool HasTablesample(Query *queryTree); static bool HasComplexRangeTableType(Query *queryTree); static bool IsReadIntermediateResultFunction(Node *node); static bool IsReadIntermediateResultArrayFunction(Node *node); static bool IsCitusExtraDataContainerFunc(Node *node); static bool IsFunctionWithOid(Node *node, Oid funcOid); static bool IsGroupingFunc(Node *node); static bool ExtractFromExpressionWalker(Node *node, QualifierWalkerContext *walkerContext); static List * MultiTableNodeList(List *tableEntryList, List *rangeTableList); static List * AddMultiCollectNodes(List *tableNodeList); static MultiNode * MultiJoinTree(List *joinOrderList, List *collectTableList, List *joinClauseList); static MultiCollect * CollectNodeForTable(List *collectTableList, uint32 rangeTableId); static MultiSelect * MultiSelectNode(List *whereClauseList); static bool IsSelectClause(Node *clause); /* Local functions forward declarations for applying joins */ static MultiNode * ApplyJoinRule(MultiNode *leftNode, MultiNode *rightNode, JoinRuleType ruleType, List *partitionColumnList, JoinType joinType, List *joinClauseList); static RuleApplyFunction JoinRuleApplyFunction(JoinRuleType ruleType); static MultiNode * ApplyReferenceJoin(MultiNode *leftNode, MultiNode *rightNode, List *partitionColumnList, JoinType joinType, List *joinClauses); static MultiNode * ApplyLocalJoin(MultiNode *leftNode, MultiNode *rightNode, List *partitionColumnList, JoinType joinType, List *joinClauses); static MultiNode * ApplySingleRangePartitionJoin(MultiNode *leftNode, MultiNode *rightNode, List *partitionColumnList, JoinType joinType, List *applicableJoinClauses); static MultiNode * ApplySingleHashPartitionJoin(MultiNode *leftNode, MultiNode *rightNode, List *partitionColumnList, JoinType joinType, List *applicableJoinClauses); static MultiJoin * ApplySinglePartitionJoin(MultiNode *leftNode, MultiNode *rightNode, List *partitionColumnList, JoinType joinType, List *joinClauses); static MultiNode * ApplyDualPartitionJoin(MultiNode *leftNode, MultiNode *rightNode, List *partitionColumnList, JoinType joinType, List *joinClauses); static MultiNode * ApplyCartesianProductReferenceJoin(MultiNode *leftNode, MultiNode *rightNode, List *partitionColumnList, JoinType joinType, List *joinClauses); static MultiNode * ApplyCartesianProduct(MultiNode *leftNode, MultiNode *rightNode, List *partitionColumnList, JoinType joinType, List *joinClauses); /* * MultiLogicalPlanCreate takes in both the original query and its corresponding modified * query tree yield by the standard planner. It uses helper functions to create logical * plan and adds a root node to top of it. The original query is only used for subquery * pushdown planning. * * We also pass queryTree and plannerRestrictionContext to the planner. They * are primarily used to decide whether the subquery is safe to pushdown. * If not, it helps to produce meaningful error messages for subquery * pushdown planning. */ MultiTreeRoot * MultiLogicalPlanCreate(Query *originalQuery, Query *queryTree, PlannerRestrictionContext *plannerRestrictionContext) { MultiNode *multiQueryNode = NULL; if (ShouldUseSubqueryPushDown(originalQuery, queryTree, plannerRestrictionContext)) { multiQueryNode = SubqueryMultiNodeTree(originalQuery, queryTree, plannerRestrictionContext); } else { multiQueryNode = MultiNodeTree(queryTree); } /* add a root node to serve as the permanent handle to the tree */ MultiTreeRoot *rootNode = CitusMakeNode(MultiTreeRoot); SetChild((MultiUnaryNode *) rootNode, multiQueryNode); return rootNode; } /* * FindNodeMatchingCheckFunction finds a node for which the checker function returns true. * * To call this function directly with an RTE, use: * range_table_walker(rte, FindNodeMatchingCheckFunction, checker, QTW_EXAMINE_RTES_BEFORE) */ bool FindNodeMatchingCheckFunction(Node *node, CheckNodeFunc checker) { if (node == NULL) { return false; } if (checker(node)) { return true; } if (IsA(node, RangeTblEntry)) { /* query_tree_walker descends into RTEs */ return false; } else if (IsA(node, Query)) { return query_tree_walker((Query *) node, FindNodeMatchingCheckFunction, checker, QTW_EXAMINE_RTES_BEFORE); } return expression_tree_walker(node, FindNodeMatchingCheckFunction, checker); } /* * TargetListOnPartitionColumn checks if at least one target list entry is on * partition column. */ bool TargetListOnPartitionColumn(Query *query, List *targetEntryList) { bool targetListOnPartitionColumn = false; List *compositeFieldList = NIL; ListCell *targetEntryCell = NULL; foreach(targetEntryCell, targetEntryList) { TargetEntry *targetEntry = (TargetEntry *) lfirst(targetEntryCell); Expr *targetExpression = targetEntry->expr; bool isPartitionColumn = IsPartitionColumn(targetExpression, query); Oid relationId = InvalidOid; Var *column = NULL; FindReferencedTableColumn(targetExpression, NIL, query, &relationId, &column); /* * If the expression belongs to a non-distributed table continue searching for * other partition keys. */ if (IsCitusTableType(relationId, CITUS_TABLE_WITH_NO_DIST_KEY)) { continue; } if (isPartitionColumn) { FieldSelect *compositeField = CompositeFieldRecursive(targetExpression, query); if (compositeField) { compositeFieldList = lappend(compositeFieldList, compositeField); } else { targetListOnPartitionColumn = true; break; } } } /* check composite fields */ if (!targetListOnPartitionColumn) { bool fullCompositeFieldList = FullCompositeFieldList(compositeFieldList); if (fullCompositeFieldList) { targetListOnPartitionColumn = true; } } /* * We could still behave as if the target list is on partition column if * range table entries don't contain a distributed table. */ if (!targetListOnPartitionColumn) { if (!FindNodeMatchingCheckFunctionInRangeTableList(query->rtable, IsDistributedTableRTE)) { targetListOnPartitionColumn = true; } } return targetListOnPartitionColumn; } /* * FindNodeMatchingCheckFunctionInRangeTableList finds a node for which the checker * function returns true. * * FindNodeMatchingCheckFunctionInRangeTableList relies on * FindNodeMatchingCheckFunction() but only considers the range table entries. */ bool FindNodeMatchingCheckFunctionInRangeTableList(List *rtable, CheckNodeFunc checker) { return range_table_walker(rtable, FindNodeMatchingCheckFunction, checker, QTW_EXAMINE_RTES_BEFORE); } /* * NodeTryGetRteRelid returns the relid of the given RTE_RELATION RangeTableEntry. * Returns InvalidOid if any of these assumptions fail for given node. */ static Oid NodeTryGetRteRelid(Node *node) { if (node == NULL) { return InvalidOid; } if (!IsA(node, RangeTblEntry)) { return InvalidOid; } RangeTblEntry *rangeTableEntry = (RangeTblEntry *) node; if (rangeTableEntry->rtekind != RTE_RELATION) { return InvalidOid; } return rangeTableEntry->relid; } /* * IsCitusTableRTE gets a node and returns true if the node is a * range table relation entry that points to a distributed relation. */ bool IsCitusTableRTE(Node *node) { Oid relationId = NodeTryGetRteRelid(node); return relationId != InvalidOid && IsCitusTable(relationId); } /* * IsDistributedTableRTE gets a node and returns true if the node * is a range table relation entry that points to a distributed relation, * returning false still if the relation is a reference table. */ bool IsDistributedTableRTE(Node *node) { Oid relationId = NodeTryGetRteRelid(node); return relationId != InvalidOid && IsCitusTableType(relationId, DISTRIBUTED_TABLE); } /* * IsReferenceTableRTE gets a node and returns true if the node * is a range table relation entry that points to a reference table. */ bool IsReferenceTableRTE(Node *node) { Oid relationId = NodeTryGetRteRelid(node); return relationId != InvalidOid && IsCitusTableType(relationId, REFERENCE_TABLE); } /* * FullCompositeFieldList gets a composite field list, and checks if all fields * of composite type are used in the list. */ static bool FullCompositeFieldList(List *compositeFieldList) { bool fullCompositeFieldList = true; bool *compositeFieldArray = NULL; uint32 compositeFieldCount = 0; ListCell *fieldSelectCell = NULL; foreach(fieldSelectCell, compositeFieldList) { FieldSelect *fieldSelect = (FieldSelect *) lfirst(fieldSelectCell); Expr *fieldExpression = fieldSelect->arg; if (!IsA(fieldExpression, Var)) { continue; } if (compositeFieldArray == NULL) { Var *compositeColumn = (Var *) fieldExpression; Oid compositeTypeId = compositeColumn->vartype; Oid compositeRelationId = get_typ_typrelid(compositeTypeId); /* get composite type attribute count */ Relation relation = relation_open(compositeRelationId, AccessShareLock); compositeFieldCount = relation->rd_att->natts; compositeFieldArray = palloc0(compositeFieldCount * sizeof(bool)); relation_close(relation, AccessShareLock); for (uint32 compositeFieldIndex = 0; compositeFieldIndex < compositeFieldCount; compositeFieldIndex++) { compositeFieldArray[compositeFieldIndex] = false; } } uint32 compositeFieldIndex = fieldSelect->fieldnum - 1; compositeFieldArray[compositeFieldIndex] = true; } for (uint32 fieldIndex = 0; fieldIndex < compositeFieldCount; fieldIndex++) { if (!compositeFieldArray[fieldIndex]) { fullCompositeFieldList = false; } } if (compositeFieldCount == 0) { fullCompositeFieldList = false; } return fullCompositeFieldList; } /* * CompositeFieldRecursive recursively finds composite field in the query tree * referred by given expression. If expression does not refer to a composite * field, then it returns NULL. * * If expression is a field select we directly return composite field. If it is * a column is referenced from a subquery, then we recursively check that subquery * until we reach the source of that column, and find composite field. If this * column is referenced from join range table entry, then we resolve which join * column it refers and recursively use this column with the same query. */ static FieldSelect * CompositeFieldRecursive(Expr *expression, Query *query) { FieldSelect *compositeField = NULL; List *rangetableList = query->rtable; Var *candidateColumn = NULL; if (IsA(expression, FieldSelect)) { compositeField = (FieldSelect *) expression; return compositeField; } if (IsA(expression, Var)) { candidateColumn = (Var *) expression; } else { return NULL; } Index rangeTableEntryIndex = candidateColumn->varno - 1; RangeTblEntry *rangeTableEntry = list_nth(rangetableList, rangeTableEntryIndex); if (rangeTableEntry->rtekind == RTE_SUBQUERY) { Query *subquery = rangeTableEntry->subquery; List *targetEntryList = subquery->targetList; AttrNumber targetEntryIndex = candidateColumn->varattno - 1; TargetEntry *subqueryTargetEntry = list_nth(targetEntryList, targetEntryIndex); Expr *subqueryExpression = subqueryTargetEntry->expr; compositeField = CompositeFieldRecursive(subqueryExpression, subquery); } else if (rangeTableEntry->rtekind == RTE_JOIN) { List *joinColumnList = rangeTableEntry->joinaliasvars; AttrNumber joinColumnIndex = candidateColumn->varattno - 1; Expr *joinColumn = list_nth(joinColumnList, joinColumnIndex); compositeField = CompositeFieldRecursive(joinColumn, query); } return compositeField; } /* * SubqueryEntryList finds the subquery nodes in the range table entry list, and * builds a list of subquery range table entries from these subquery nodes. Range * table entry list also includes subqueries which are pulled up. We don't want * to add pulled up subqueries to list, so we walk over join tree indexes and * check range table entries referenced in the join tree. */ List * SubqueryEntryList(Query *queryTree) { List *rangeTableList = queryTree->rtable; List *subqueryEntryList = NIL; List *joinTreeTableIndexList = NIL; ListCell *joinTreeTableIndexCell = NULL; /* * Extract all range table indexes from the join tree. Note that here we * only walk over range table entries at this level and do not recurse into * subqueries. */ ExtractRangeTableIndexWalker((Node *) queryTree->jointree, &joinTreeTableIndexList); foreach(joinTreeTableIndexCell, joinTreeTableIndexList) { /* * Join tree's range table index starts from 1 in the query tree. But, * list indexes start from 0. */ int joinTreeTableIndex = lfirst_int(joinTreeTableIndexCell); int rangeTableListIndex = joinTreeTableIndex - 1; RangeTblEntry *rangeTableEntry = (RangeTblEntry *) list_nth(rangeTableList, rangeTableListIndex); if (rangeTableEntry->rtekind == RTE_SUBQUERY) { subqueryEntryList = lappend(subqueryEntryList, rangeTableEntry); } } return subqueryEntryList; } /* * MultiNodeTree takes in a parsed query tree and uses that tree to construct a * logical plan. This plan is based on multi-relational algebra. This function * creates the logical plan in several steps. * * First, the function checks if there is a subquery. If there is a subquery * it recursively creates nested multi trees. If this query has a subquery, the * function does not create any join trees and jumps to last step. * * If there is no subquery, the function calculates the join order using tables * in the query and join clauses between the tables. Second, the function * starts building the logical plan from the bottom-up, and begins with the table * and collect nodes. Third, the function builds the join tree using the join * order information and table nodes. * * In the last step, the function adds the select, project, aggregate, sort, * group, and limit nodes if they appear in the original query tree. */ MultiNode * MultiNodeTree(Query *queryTree) { List *rangeTableList = queryTree->rtable; List *targetEntryList = queryTree->targetList; List *joinClauseList = NIL; List *joinOrderList = NIL; List *tableEntryList = NIL; List *tableNodeList = NIL; List *collectTableList = NIL; MultiNode *joinTreeNode = NULL; MultiNode *currentTopNode = NULL; /* verify we can perform distributed planning on this query */ DeferredErrorMessage *unsupportedQueryError = DeferErrorIfQueryNotSupported( queryTree); if (unsupportedQueryError != NULL) { RaiseDeferredError(unsupportedQueryError, ERROR); } /* extract where clause qualifiers and verify we can plan for them */ List *whereClauseList = WhereClauseList(queryTree->jointree); unsupportedQueryError = DeferErrorIfUnsupportedClause(whereClauseList); if (unsupportedQueryError) { RaiseDeferredErrorInternal(unsupportedQueryError, ERROR); } /* * If we have a subquery, build a multi table node for the subquery and * add a collect node on top of the multi table node. */ List *subqueryEntryList = SubqueryEntryList(queryTree); if (subqueryEntryList != NIL) { MultiCollect *subqueryCollectNode = CitusMakeNode(MultiCollect); ListCell *columnCell = NULL; /* we only support single subquery in the entry list */ Assert(list_length(subqueryEntryList) == 1); RangeTblEntry *subqueryRangeTableEntry = (RangeTblEntry *) linitial( subqueryEntryList); Query *subqueryTree = subqueryRangeTableEntry->subquery; /* ensure if subquery satisfies preconditions */ Assert(DeferErrorIfUnsupportedSubqueryRepartition(subqueryTree) == NULL); MultiTable *subqueryNode = CitusMakeNode(MultiTable); subqueryNode->relationId = SUBQUERY_RELATION_ID; subqueryNode->rangeTableId = SUBQUERY_RANGE_TABLE_ID; subqueryNode->partitionColumn = NULL; subqueryNode->alias = NULL; subqueryNode->referenceNames = NULL; /* * We disregard pulled subqueries. This changes order of range table list. * We do not allow subquery joins, so we will have only one range table * entry in range table list after dropping pulled subquery. For this * reason, here we are updating columns in the most outer query for where * clause list and target list accordingly. */ Assert(list_length(subqueryEntryList) == 1); List *whereClauseColumnList = pull_var_clause_default((Node *) whereClauseList); List *targetListColumnList = pull_var_clause_default((Node *) targetEntryList); List *columnList = list_concat(whereClauseColumnList, targetListColumnList); foreach(columnCell, columnList) { Var *column = (Var *) lfirst(columnCell); column->varno = 1; } /* recursively create child nested multitree */ MultiNode *subqueryExtendedNode = MultiNodeTree(subqueryTree); SetChild((MultiUnaryNode *) subqueryCollectNode, (MultiNode *) subqueryNode); SetChild((MultiUnaryNode *) subqueryNode, subqueryExtendedNode); currentTopNode = (MultiNode *) subqueryCollectNode; } else { /* * We calculate the join order using the list of tables in the query and * the join clauses between them. Note that this function owns the table * entry list's memory, and JoinOrderList() shallow copies the list's * elements. */ joinClauseList = JoinClauseList(whereClauseList); tableEntryList = UsedTableEntryList(queryTree); /* build the list of multi table nodes */ tableNodeList = MultiTableNodeList(tableEntryList, rangeTableList); /* add collect nodes on top of the multi table nodes */ collectTableList = AddMultiCollectNodes(tableNodeList); /* find best join order for commutative inner joins */ joinOrderList = JoinOrderList(tableEntryList, joinClauseList); /* build join tree using the join order and collected tables */ joinTreeNode = MultiJoinTree(joinOrderList, collectTableList, joinClauseList); currentTopNode = joinTreeNode; } Assert(currentTopNode != NULL); /* build select node if the query has selection criteria */ MultiSelect *selectNode = MultiSelectNode(whereClauseList); if (selectNode != NULL) { SetChild((MultiUnaryNode *) selectNode, currentTopNode); currentTopNode = (MultiNode *) selectNode; } /* build project node for the columns to project */ MultiProject *projectNode = MultiProjectNode(targetEntryList); SetChild((MultiUnaryNode *) projectNode, currentTopNode); currentTopNode = (MultiNode *) projectNode; /* * We build the extended operator node to capture aggregate functions, group * clauses, sort clauses, limit/offset clauses, and expressions. We need to * distinguish between aggregates and expressions; and we address this later * in the logical optimizer. */ MultiExtendedOp *extendedOpNode = MultiExtendedOpNode(queryTree, queryTree); SetChild((MultiUnaryNode *) extendedOpNode, currentTopNode); currentTopNode = (MultiNode *) extendedOpNode; return currentTopNode; } /* * ContainsReadIntermediateResultFunction determines whether an expresion tree contains * a call to the read_intermediate_result function. */ bool ContainsReadIntermediateResultFunction(Node *node) { return FindNodeMatchingCheckFunction(node, IsReadIntermediateResultFunction); } /* * ContainsReadIntermediateResultArrayFunction determines whether an expresion * tree contains a call to the read_intermediate_results(result_ids, format) * function. */ bool ContainsReadIntermediateResultArrayFunction(Node *node) { return FindNodeMatchingCheckFunction(node, IsReadIntermediateResultArrayFunction); } /* * IsReadIntermediateResultFunction determines whether a given node is a function call * to the read_intermediate_result function. */ static bool IsReadIntermediateResultFunction(Node *node) { return IsFunctionWithOid(node, CitusReadIntermediateResultFuncId()); } /* * IsReadIntermediateResultArrayFunction determines whether a given node is a * function call to the read_intermediate_results(result_ids, format) function. */ static bool IsReadIntermediateResultArrayFunction(Node *node) { return IsFunctionWithOid(node, CitusReadIntermediateResultArrayFuncId()); } /* * IsCitusExtraDataContainerRelation determines whether a range table entry contains a * call to the citus_extradata_container function. */ bool IsCitusExtraDataContainerRelation(RangeTblEntry *rte) { if (rte->rtekind != RTE_FUNCTION || list_length(rte->functions) != 1) { /* avoid more expensive checks below for non-functions */ return false; } if (!CitusHasBeenLoaded() || !CheckCitusVersion(DEBUG5)) { return false; } return FindNodeMatchingCheckFunction((Node *) rte->functions, IsCitusExtraDataContainerFunc); } /* * IsCitusExtraDataContainerFunc determines whether a given node is a function call * to the citus_extradata_container function. */ static bool IsCitusExtraDataContainerFunc(Node *node) { return IsFunctionWithOid(node, CitusExtraDataContainerFuncId()); } /* * IsFunctionWithOid determines whether a given node is a function call * to the read_intermediate_result function. */ static bool IsFunctionWithOid(Node *node, Oid funcOid) { if (IsA(node, FuncExpr)) { FuncExpr *funcExpr = (FuncExpr *) node; if (funcExpr->funcid == funcOid) { return true; } } return false; } /* * IsGroupingFunc returns whether node is a GroupingFunc. */ static bool IsGroupingFunc(Node *node) { return IsA(node, GroupingFunc); } /* * FindIntermediateResultIdIfExists extracts the id of the intermediate result * if the given RTE contains a read_intermediate_results function, NULL otherwise */ char * FindIntermediateResultIdIfExists(RangeTblEntry *rte) { char *resultId = NULL; Assert(rte->rtekind == RTE_FUNCTION); List *functionList = rte->functions; RangeTblFunction *rangeTblfunction = (RangeTblFunction *) linitial(functionList); FuncExpr *funcExpr = (FuncExpr *) rangeTblfunction->funcexpr; if (IsReadIntermediateResultFunction((Node *) funcExpr)) { Const *resultIdConst = linitial(funcExpr->args); if (!resultIdConst->constisnull) { resultId = TextDatumGetCString(resultIdConst->constvalue); } } return resultId; } /* * ErrorIfQueryNotSupported checks that we can perform distributed planning for * the given query. The checks in this function will be removed as we support * more functionality in our distributed planning. */ DeferredErrorMessage * DeferErrorIfQueryNotSupported(Query *queryTree) { char *errorMessage = NULL; bool preconditionsSatisfied = true; const char *errorHint = NULL; const char *joinHint = "Consider joining tables on partition column and have " "equal filter on joining columns."; const char *filterHint = "Consider using an equality filter on the distributed " "table's partition column."; /* * There could be Sublinks in the target list as well. To produce better * error messages we're checking if that's the case. */ if (queryTree->hasSubLinks && TargetListContainsSubquery(queryTree)) { preconditionsSatisfied = false; errorMessage = "could not run distributed query with subquery outside the " "FROM, WHERE and HAVING clauses"; errorHint = filterHint; } if (queryTree->setOperations) { preconditionsSatisfied = false; errorMessage = "could not run distributed query with UNION, INTERSECT, or " "EXCEPT"; errorHint = filterHint; } if (queryTree->hasRecursive) { preconditionsSatisfied = false; errorMessage = "could not run distributed query with RECURSIVE"; errorHint = filterHint; } if (queryTree->cteList) { preconditionsSatisfied = false; errorMessage = "could not run distributed query with common table expressions"; errorHint = filterHint; } if (queryTree->hasForUpdate) { preconditionsSatisfied = false; errorMessage = "could not run distributed query with FOR UPDATE/SHARE commands"; errorHint = filterHint; } if (queryTree->groupingSets) { preconditionsSatisfied = false; errorMessage = "could not run distributed query with GROUPING SETS, CUBE, " "or ROLLUP"; errorHint = filterHint; } if (FindNodeMatchingCheckFunction((Node *) queryTree, IsGroupingFunc)) { preconditionsSatisfied = false; errorMessage = "could not run distributed query with GROUPING"; errorHint = filterHint; } bool hasTablesample = HasTablesample(queryTree); if (hasTablesample) { preconditionsSatisfied = false; errorMessage = "could not run distributed query which use TABLESAMPLE"; errorHint = filterHint; } bool hasUnsupportedJoin = HasUnsupportedJoinWalker((Node *) queryTree->jointree, NULL); if (hasUnsupportedJoin) { preconditionsSatisfied = false; errorMessage = "could not run distributed query with join types other than " "INNER or OUTER JOINS"; errorHint = joinHint; } bool hasComplexRangeTableType = HasComplexRangeTableType(queryTree); if (hasComplexRangeTableType) { preconditionsSatisfied = false; errorMessage = "could not run distributed query with complex table expressions"; errorHint = filterHint; } if (FindNodeMatchingCheckFunction((Node *) queryTree->limitCount, IsNodeSubquery)) { preconditionsSatisfied = false; errorMessage = "subquery in LIMIT is not supported in multi-shard queries"; } if (FindNodeMatchingCheckFunction((Node *) queryTree->limitOffset, IsNodeSubquery)) { preconditionsSatisfied = false; errorMessage = "subquery in OFFSET is not supported in multi-shard queries"; } RTEListProperties *queryRteListProperties = GetRTEListPropertiesForQuery(queryTree); if (queryRteListProperties->hasCitusLocalTable || queryRteListProperties->hasPostgresLocalTable) { preconditionsSatisfied = false; errorMessage = "direct joins between distributed and local tables are " "not supported"; errorHint = LOCAL_TABLE_SUBQUERY_CTE_HINT; } /* finally check and error out if not satisfied */ if (!preconditionsSatisfied) { bool showHint = ErrorHintRequired(errorHint, queryTree); return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED, errorMessage, NULL, showHint ? errorHint : NULL); } return NULL; } /* HasTablesample returns tree if the query contains tablesample */ static bool HasTablesample(Query *queryTree) { List *rangeTableList = queryTree->rtable; ListCell *rangeTableEntryCell = NULL; bool hasTablesample = false; foreach(rangeTableEntryCell, rangeTableList) { RangeTblEntry *rangeTableEntry = lfirst(rangeTableEntryCell); if (rangeTableEntry->tablesample) { hasTablesample = true; break; } } return hasTablesample; } /* * HasUnsupportedJoinWalker returns tree if the query contains an unsupported * join type. We currently support inner, left, right, full and anti joins. * Semi joins are not supported. A full description of these join types is * included in nodes/nodes.h. */ static bool HasUnsupportedJoinWalker(Node *node, void *context) { bool hasUnsupportedJoin = false; if (node == NULL) { return false; } if (IsA(node, JoinExpr)) { JoinExpr *joinExpr = (JoinExpr *) node; JoinType joinType = joinExpr->jointype; bool outerJoin = IS_OUTER_JOIN(joinType); if (!outerJoin && joinType != JOIN_INNER && joinType != JOIN_SEMI) { hasUnsupportedJoin = true; } } if (!hasUnsupportedJoin) { hasUnsupportedJoin = expression_tree_walker(node, HasUnsupportedJoinWalker, NULL); } return hasUnsupportedJoin; } /* * ErrorHintRequired returns true if error hint shold be displayed with the * query error message. Error hint is valid only for queries involving reference * and hash partitioned tables. If more than one hash distributed table is * present we display the hint only if the tables are colocated. If the query * only has reference table(s), then it is handled by router planner. */ static bool ErrorHintRequired(const char *errorHint, Query *queryTree) { List *distributedRelationIdList = DistributedRelationIdList(queryTree); ListCell *relationIdCell = NULL; List *colocationIdList = NIL; if (errorHint == NULL) { return false; } foreach(relationIdCell, distributedRelationIdList) { Oid relationId = lfirst_oid(relationIdCell); if (IsCitusTableType(relationId, REFERENCE_TABLE)) { continue; } else if (IsCitusTableType(relationId, HASH_DISTRIBUTED)) { int colocationId = TableColocationId(relationId); colocationIdList = list_append_unique_int(colocationIdList, colocationId); } else { return false; } } /* do not display the hint if there are more than one colocation group */ if (list_length(colocationIdList) > 1) { return false; } return true; } /* * DeferErrorIfUnsupportedSubqueryRepartition checks that we can perform distributed planning for * the given subquery. If not, a deferred error is returned. The function recursively * does this check to all lower levels of the subquery. */ DeferredErrorMessage * DeferErrorIfUnsupportedSubqueryRepartition(Query *subqueryTree) { char *errorDetail = NULL; bool preconditionsSatisfied = true; List *joinTreeTableIndexList = NIL; if (!subqueryTree->hasAggs) { preconditionsSatisfied = false; errorDetail = "Subqueries without aggregates are not supported yet"; } if (subqueryTree->groupClause == NIL) { preconditionsSatisfied = false; errorDetail = "Subqueries without group by clause are not supported yet"; } if (subqueryTree->sortClause != NULL) { preconditionsSatisfied = false; errorDetail = "Subqueries with order by clause are not supported yet"; } if (subqueryTree->limitCount != NULL) { preconditionsSatisfied = false; errorDetail = "Subqueries with limit are not supported yet"; } if (subqueryTree->limitOffset != NULL) { preconditionsSatisfied = false; errorDetail = "Subqueries with offset are not supported yet"; } if (subqueryTree->hasSubLinks) { preconditionsSatisfied = false; errorDetail = "Subqueries other than from-clause subqueries are unsupported"; } /* finally check and return error if conditions are not satisfied */ if (!preconditionsSatisfied) { return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED, "cannot perform distributed planning on this query", errorDetail, NULL); } /* * Extract all range table indexes from the join tree. Note that sub-queries * that get pulled up by PostgreSQL don't appear in this join tree. */ ExtractRangeTableIndexWalker((Node *) subqueryTree->jointree, &joinTreeTableIndexList); Assert(list_length(joinTreeTableIndexList) == 1); /* continue with the inner subquery */ int rangeTableIndex = linitial_int(joinTreeTableIndexList); RangeTblEntry *rangeTableEntry = rt_fetch(rangeTableIndex, subqueryTree->rtable); if (rangeTableEntry->rtekind == RTE_RELATION) { return NULL; } Assert(rangeTableEntry->rtekind == RTE_SUBQUERY); Query *innerSubquery = rangeTableEntry->subquery; /* recursively continue to the inner subqueries */ return DeferErrorIfUnsupportedSubqueryRepartition(innerSubquery); } /* * HasComplexRangeTableType checks if the given query tree contains any complex * range table types. For this, the function walks over all range tables in the * join tree, and checks if they correspond to simple relations or subqueries. * If they don't, the function assumes the query has complex range tables. */ static bool HasComplexRangeTableType(Query *queryTree) { List *rangeTableList = queryTree->rtable; List *joinTreeTableIndexList = NIL; ListCell *joinTreeTableIndexCell = NULL; bool hasComplexRangeTableType = false; /* * Extract all range table indexes from the join tree. Note that sub-queries * that get pulled up by PostgreSQL don't appear in this join tree. */ ExtractRangeTableIndexWalker((Node *) queryTree->jointree, &joinTreeTableIndexList); foreach(joinTreeTableIndexCell, joinTreeTableIndexList) { /* * Join tree's range table index starts from 1 in the query tree. But, * list indexes start from 0. */ int joinTreeTableIndex = lfirst_int(joinTreeTableIndexCell); int rangeTableListIndex = joinTreeTableIndex - 1; RangeTblEntry *rangeTableEntry = (RangeTblEntry *) list_nth(rangeTableList, rangeTableListIndex); /* * Check if the range table in the join tree is a simple relation or a * subquery or a function. Note that RTE_FUNCTIONs are handled via (sub)query * pushdown. */ if (rangeTableEntry->rtekind != RTE_RELATION && rangeTableEntry->rtekind != RTE_SUBQUERY && rangeTableEntry->rtekind != RTE_FUNCTION) { hasComplexRangeTableType = true; } /* * Check if the subquery range table entry includes children inheritance. * * Note that PostgreSQL flattens out simple union all queries into an * append relation, sets "inh" field of RangeTblEntry to true and deletes * set operations. Here we check this for subqueries. */ if (rangeTableEntry->rtekind == RTE_SUBQUERY && rangeTableEntry->inh) { hasComplexRangeTableType = true; } } return hasComplexRangeTableType; } /* * WhereClauseList walks over the FROM expression in the query tree, and builds * a list of all clauses from the expression tree. The function checks for both * implicitly and explicitly defined clauses, but only selects INNER join * explicit clauses, and skips any outer-join clauses. Explicit clauses are * expressed as "SELECT ... FROM R1 INNER JOIN R2 ON R1.A = R2.A". Implicit * joins differ in that they live in the WHERE clause, and are expressed as * "SELECT ... FROM ... WHERE R1.a = R2.a". */ List * WhereClauseList(FromExpr *fromExpr) { FromExpr *fromExprCopy = copyObject(fromExpr); QualifierWalkerContext *walkerContext = palloc0(sizeof(QualifierWalkerContext)); ExtractFromExpressionWalker((Node *) fromExprCopy, walkerContext); List *whereClauseList = walkerContext->baseQualifierList; return whereClauseList; } /* * QualifierList walks over the FROM expression in the query tree, and builds * a list of all qualifiers from the expression tree. The function checks for * both implicitly and explicitly defined qualifiers. Note that this function * is very similar to WhereClauseList(), but QualifierList() also includes * outer-join clauses. */ List * QualifierList(FromExpr *fromExpr) { FromExpr *fromExprCopy = copyObject(fromExpr); QualifierWalkerContext *walkerContext = palloc0(sizeof(QualifierWalkerContext)); List *qualifierList = NIL; ExtractFromExpressionWalker((Node *) fromExprCopy, walkerContext); qualifierList = list_concat(qualifierList, walkerContext->baseQualifierList); qualifierList = list_concat(qualifierList, walkerContext->outerJoinQualifierList); return qualifierList; } /* * DeferErrorIfUnsupportedClause walks over the given list of clauses, and * checks that we can recognize all the clauses. This function ensures that * we do not drop an unsupported clause type on the floor, and thus prevents * erroneous results. * * Returns a deferred error, caller is responsible for raising the error. */ DeferredErrorMessage * DeferErrorIfUnsupportedClause(List *clauseList) { ListCell *clauseCell = NULL; foreach(clauseCell, clauseList) { Node *clause = (Node *) lfirst(clauseCell); if (!(IsSelectClause(clause) || IsJoinClause(clause) || or_clause(clause))) { return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED, "unsupported clause type", NULL, NULL); } } return NULL; } /* * JoinClauseList finds the join clauses from the given where clause expression * list, and returns them. The function does not iterate into nested OR clauses * and relies on find_duplicate_ors() in the optimizer to pull up factorizable * OR clauses. */ List * JoinClauseList(List *whereClauseList) { List *joinClauseList = NIL; ListCell *whereClauseCell = NULL; foreach(whereClauseCell, whereClauseList) { Node *whereClause = (Node *) lfirst(whereClauseCell); if (IsJoinClause(whereClause)) { joinClauseList = lappend(joinClauseList, whereClause); } } return joinClauseList; } /* * ExtractFromExpressionWalker walks over a FROM expression, and finds all * implicit and explicit qualifiers in the expression. The function looks at * join and from expression nodes to find qualifiers, and returns these * qualifiers. * * Note that we don't want outer join clauses in regular outer join planning, * but we need outer join clauses in subquery pushdown prerequisite checks. * Therefore, outer join qualifiers are returned in a different list than other * qualifiers inside the given walker context. For this reason, we return two * qualifier lists. * * Note that we check if the qualifier node in join and from expression nodes * is a list node. If it is not a list node which is the case for subqueries, * then we run eval_const_expressions(), canonicalize_qual() and make_ands_implicit() * on the qualifier node and get a list of flattened implicitly AND'ed qualifier * list. Actually in the planer phase of PostgreSQL these functions also run on * subqueries but differently from the outermost query, they are run on a copy * of parse tree and changes do not get persisted as modifications to the original * query tree. * * Also this function adds SubLinks to the baseQualifierList when they appear on * the query's WHERE clause. The callers of the function should consider processing * Sublinks as well. */ static bool ExtractFromExpressionWalker(Node *node, QualifierWalkerContext *walkerContext) { if (node == NULL) { return false; } /* * Get qualifier lists of join and from expression nodes. Note that in the * case of subqueries, PostgreSQL can skip simplifying, flattening and * making ANDs implicit. If qualifiers node is not a list, then we run these * preprocess routines on qualifiers node. */ if (IsA(node, JoinExpr)) { List *joinQualifierList = NIL; JoinExpr *joinExpression = (JoinExpr *) node; Node *joinQualifiersNode = joinExpression->quals; JoinType joinType = joinExpression->jointype; if (joinQualifiersNode != NULL) { if (IsA(joinQualifiersNode, List)) { joinQualifierList = (List *) joinQualifiersNode; } else { /* this part of code only run for subqueries */ Node *joinClause = eval_const_expressions(NULL, joinQualifiersNode); joinClause = (Node *) canonicalize_qual((Expr *) joinClause, false); joinQualifierList = make_ands_implicit((Expr *) joinClause); } } /* return outer join clauses in a separate list */ if (joinType == JOIN_INNER || joinType == JOIN_SEMI) { walkerContext->baseQualifierList = list_concat(walkerContext->baseQualifierList, joinQualifierList); } else if (IS_OUTER_JOIN(joinType)) { walkerContext->outerJoinQualifierList = list_concat(walkerContext->outerJoinQualifierList, joinQualifierList); } } else if (IsA(node, FromExpr)) { List *fromQualifierList = NIL; FromExpr *fromExpression = (FromExpr *) node; Node *fromQualifiersNode = fromExpression->quals; if (fromQualifiersNode != NULL) { if (IsA(fromQualifiersNode, List)) { fromQualifierList = (List *) fromQualifiersNode; } else { /* this part of code only run for subqueries */ Node *fromClause = eval_const_expressions(NULL, fromQualifiersNode); fromClause = (Node *) canonicalize_qual((Expr *) fromClause, false); fromQualifierList = make_ands_implicit((Expr *) fromClause); } walkerContext->baseQualifierList = list_concat(walkerContext->baseQualifierList, fromQualifierList); } } bool walkerResult = expression_tree_walker(node, ExtractFromExpressionWalker, (void *) walkerContext); return walkerResult; } /* * IsJoinClause determines if the given node is a join clause according to our * criteria. Our criteria defines a join clause as an equi join operator between * two columns that belong to two different tables. */ bool IsJoinClause(Node *clause) { Var *var = NULL; /* * take all column references from the clause, if we find 2 column references from a * different relation we assume this is a join clause */ List *varList = pull_var_clause_default(clause); if (list_length(varList) <= 0) { /* no column references in query, not describing a join */ return false; } Var *initialVar = castNode(Var, linitial(varList)); foreach_ptr(var, varList) { if (var->varno != initialVar->varno) { /* * this column reference comes from a different relation, hence describing a * join */ return true; } } /* all column references were to the same relation, no join */ return false; } /* * TableEntryList finds the regular relation nodes in the range table entry * list, and builds a list of table entries from these regular relation nodes. */ List * TableEntryList(List *rangeTableList) { List *tableEntryList = NIL; ListCell *rangeTableCell = NULL; uint32 tableId = 1; /* range table indices start at 1 */ foreach(rangeTableCell, rangeTableList) { RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell); if (rangeTableEntry->rtekind == RTE_RELATION) { TableEntry *tableEntry = (TableEntry *) palloc0(sizeof(TableEntry)); tableEntry->relationId = rangeTableEntry->relid; tableEntry->rangeTableId = tableId; tableEntryList = lappend(tableEntryList, tableEntry); } /* * Increment tableId regardless so that table entry's tableId remains * congruent with column's range table reference (varno). */ tableId++; } return tableEntryList; } /* * UsedTableEntryList returns list of relation range table entries * that are referenced within the query. Unused entries due to query * flattening or re-rewriting are ignored. */ List * UsedTableEntryList(Query *query) { List *tableEntryList = NIL; List *rangeTableList = query->rtable; List *joinTreeTableIndexList = NIL; ListCell *joinTreeTableIndexCell = NULL; ExtractRangeTableIndexWalker((Node *) query->jointree, &joinTreeTableIndexList); foreach(joinTreeTableIndexCell, joinTreeTableIndexList) { int joinTreeTableIndex = lfirst_int(joinTreeTableIndexCell); RangeTblEntry *rangeTableEntry = rt_fetch(joinTreeTableIndex, rangeTableList); if (rangeTableEntry->rtekind == RTE_RELATION) { TableEntry *tableEntry = (TableEntry *) palloc0(sizeof(TableEntry)); tableEntry->relationId = rangeTableEntry->relid; tableEntry->rangeTableId = joinTreeTableIndex; tableEntryList = lappend(tableEntryList, tableEntry); } } return tableEntryList; } /* * MultiTableNodeList builds a list of MultiTable nodes from the given table * entry list. A multi table node represents one entry from the range table * list. These entries may belong to the same physical relation in the case of * self-joins. */ static List * MultiTableNodeList(List *tableEntryList, List *rangeTableList) { List *tableNodeList = NIL; ListCell *tableEntryCell = NULL; foreach(tableEntryCell, tableEntryList) { TableEntry *tableEntry = (TableEntry *) lfirst(tableEntryCell); Oid relationId = tableEntry->relationId; uint32 rangeTableId = tableEntry->rangeTableId; Var *partitionColumn = PartitionColumn(relationId, rangeTableId); RangeTblEntry *rangeTableEntry = rt_fetch(rangeTableId, rangeTableList); MultiTable *tableNode = CitusMakeNode(MultiTable); tableNode->subquery = NULL; tableNode->relationId = relationId; tableNode->rangeTableId = rangeTableId; tableNode->partitionColumn = partitionColumn; tableNode->alias = rangeTableEntry->alias; tableNode->referenceNames = rangeTableEntry->eref; tableNodeList = lappend(tableNodeList, tableNode); } return tableNodeList; } /* Adds a MultiCollect node on top of each MultiTable node in the given list. */ static List * AddMultiCollectNodes(List *tableNodeList) { List *collectTableList = NIL; ListCell *tableNodeCell = NULL; foreach(tableNodeCell, tableNodeList) { MultiTable *tableNode = (MultiTable *) lfirst(tableNodeCell); MultiCollect *collectNode = CitusMakeNode(MultiCollect); SetChild((MultiUnaryNode *) collectNode, (MultiNode *) tableNode); collectTableList = lappend(collectTableList, collectNode); } return collectTableList; } /* * MultiJoinTree takes in the join order information and the list of tables, and * builds a join tree by applying the corresponding join rules. The function * builds a left deep tree, as expressed by the join order list. * * The function starts by setting the first table as the top node in the join * tree. Then, the function iterates over the list of tables, and builds a new * join node between the top of the join tree and the next table in the list. * At each iteration, the function sets the top of the join tree to the newly * built list. This results in a left deep join tree, and the function returns * this tree after every table in the list has been joined. */ static MultiNode * MultiJoinTree(List *joinOrderList, List *collectTableList, List *joinWhereClauseList) { MultiNode *currentTopNode = NULL; ListCell *joinOrderCell = NULL; bool firstJoinNode = true; foreach(joinOrderCell, joinOrderList) { JoinOrderNode *joinOrderNode = (JoinOrderNode *) lfirst(joinOrderCell); uint32 joinTableId = joinOrderNode->tableEntry->rangeTableId; MultiCollect *collectNode = CollectNodeForTable(collectTableList, joinTableId); if (firstJoinNode) { currentTopNode = (MultiNode *) collectNode; firstJoinNode = false; } else { JoinRuleType joinRuleType = joinOrderNode->joinRuleType; JoinType joinType = joinOrderNode->joinType; List *partitionColumnList = joinOrderNode->partitionColumnList; List *joinClauseList = joinOrderNode->joinClauseList; /* * Build a join node between the top of our join tree and the next * table in the join order. */ MultiNode *newJoinNode = ApplyJoinRule(currentTopNode, (MultiNode *) collectNode, joinRuleType, partitionColumnList, joinType, joinClauseList); /* the new join node becomes the top of our join tree */ currentTopNode = newJoinNode; } } /* current top node points to the entire left deep join tree */ return currentTopNode; } /* * CollectNodeForTable finds the MultiCollect node whose MultiTable node has the * given range table identifier. Note that this function expects each collect * node in the given list to have one table node as its child. */ static MultiCollect * CollectNodeForTable(List *collectTableList, uint32 rangeTableId) { MultiCollect *collectNodeForTable = NULL; ListCell *collectTableCell = NULL; foreach(collectTableCell, collectTableList) { MultiCollect *collectNode = (MultiCollect *) lfirst(collectTableCell); List *tableIdList = OutputTableIdList((MultiNode *) collectNode); uint32 tableId = (uint32) linitial_int(tableIdList); Assert(list_length(tableIdList) == 1); if (tableId == rangeTableId) { collectNodeForTable = collectNode; break; } } Assert(collectNodeForTable != NULL); return collectNodeForTable; } /* * MultiSelectNode extracts the select clauses from the given where clause list, * and builds a MultiSelect node from these clauses. If the expression tree does * not have any select clauses, the function return null. */ static MultiSelect * MultiSelectNode(List *whereClauseList) { List *selectClauseList = NIL; MultiSelect *selectNode = NULL; ListCell *whereClauseCell = NULL; foreach(whereClauseCell, whereClauseList) { Node *whereClause = (Node *) lfirst(whereClauseCell); if (IsSelectClause(whereClause)) { selectClauseList = lappend(selectClauseList, whereClause); } } if (list_length(selectClauseList) > 0) { selectNode = CitusMakeNode(MultiSelect); selectNode->selectClauseList = selectClauseList; } return selectNode; } /* * IsSelectClause determines if the given node is a select clause according to * our criteria. Our criteria defines a select clause as an expression that has * zero or more columns belonging to only one table. The function assumes that * no sublinks exists in the clause. */ static bool IsSelectClause(Node *clause) { ListCell *columnCell = NULL; bool isSelectClause = true; /* extract columns from the clause */ List *columnList = pull_var_clause_default(clause); if (list_length(columnList) == 0) { return true; } /* get first column's tableId */ Var *firstColumn = (Var *) linitial(columnList); Index firstColumnTableId = firstColumn->varno; /* check if all columns are from the same table */ foreach(columnCell, columnList) { Var *column = (Var *) lfirst(columnCell); if (column->varno != firstColumnTableId) { isSelectClause = false; } } return isSelectClause; } /* * MultiProjectNode builds the project node using the target entry information * from the query tree. The project node only encapsulates projected columns, * and does not include aggregates, group clauses, or project expressions. */ MultiProject * MultiProjectNode(List *targetEntryList) { List *uniqueColumnList = NIL; ListCell *columnCell = NULL; /* extract the list of columns and remove any duplicates */ List *columnList = pull_var_clause_default((Node *) targetEntryList); foreach(columnCell, columnList) { Var *column = (Var *) lfirst(columnCell); uniqueColumnList = list_append_unique(uniqueColumnList, column); } /* create project node with list of columns to project */ MultiProject *projectNode = CitusMakeNode(MultiProject); projectNode->columnList = uniqueColumnList; return projectNode; } /* Builds the extended operator node using fields from the given query tree. */ MultiExtendedOp * MultiExtendedOpNode(Query *queryTree, Query *originalQuery) { MultiExtendedOp *extendedOpNode = CitusMakeNode(MultiExtendedOp); extendedOpNode->targetList = queryTree->targetList; extendedOpNode->groupClauseList = queryTree->groupClause; extendedOpNode->sortClauseList = queryTree->sortClause; extendedOpNode->limitCount = queryTree->limitCount; extendedOpNode->limitOffset = queryTree->limitOffset; #if PG_VERSION_NUM >= PG_VERSION_13 extendedOpNode->limitOption = queryTree->limitOption; #endif extendedOpNode->havingQual = queryTree->havingQual; extendedOpNode->distinctClause = queryTree->distinctClause; extendedOpNode->hasDistinctOn = queryTree->hasDistinctOn; extendedOpNode->hasWindowFuncs = queryTree->hasWindowFuncs; extendedOpNode->windowClause = queryTree->windowClause; extendedOpNode->onlyPushableWindowFunctions = !queryTree->hasWindowFuncs || SafeToPushdownWindowFunction(originalQuery, NULL); return extendedOpNode; } /* Helper function to return the parent node of the given node. */ MultiNode * ParentNode(MultiNode *multiNode) { MultiNode *parentNode = multiNode->parentNode; return parentNode; } /* Helper function to return the child of the given unary node. */ MultiNode * ChildNode(MultiUnaryNode *multiNode) { MultiNode *childNode = multiNode->childNode; return childNode; } /* Helper function to return the grand child of the given unary node. */ MultiNode * GrandChildNode(MultiUnaryNode *multiNode) { MultiNode *childNode = ChildNode(multiNode); MultiNode *grandChildNode = ChildNode((MultiUnaryNode *) childNode); return grandChildNode; } /* Sets the given child node as a child of the given unary parent node. */ void SetChild(MultiUnaryNode *parent, MultiNode *child) { parent->childNode = child; child->parentNode = (MultiNode *) parent; } /* Sets the given child node as a left child of the given parent node. */ void SetLeftChild(MultiBinaryNode *parent, MultiNode *leftChild) { parent->leftChildNode = leftChild; leftChild->parentNode = (MultiNode *) parent; } /* Sets the given child node as a right child of the given parent node. */ void SetRightChild(MultiBinaryNode *parent, MultiNode *rightChild) { parent->rightChildNode = rightChild; rightChild->parentNode = (MultiNode *) parent; } /* Returns true if the given node is a unary operator. */ bool UnaryOperator(MultiNode *node) { bool unaryOperator = false; if (CitusIsA(node, MultiTreeRoot) || CitusIsA(node, MultiTable) || CitusIsA(node, MultiCollect) || CitusIsA(node, MultiSelect) || CitusIsA(node, MultiProject) || CitusIsA(node, MultiPartition) || CitusIsA(node, MultiExtendedOp)) { unaryOperator = true; } return unaryOperator; } /* Returns true if the given node is a binary operator. */ bool BinaryOperator(MultiNode *node) { bool binaryOperator = false; if (CitusIsA(node, MultiJoin) || CitusIsA(node, MultiCartesianProduct)) { binaryOperator = true; } return binaryOperator; } /* * OutputTableIdList finds all table identifiers that are output by the given * multi node, and returns these identifiers in a new list. */ List * OutputTableIdList(MultiNode *multiNode) { List *tableIdList = NIL; List *tableNodeList = FindNodesOfType(multiNode, T_MultiTable); ListCell *tableNodeCell = NULL; foreach(tableNodeCell, tableNodeList) { MultiTable *tableNode = (MultiTable *) lfirst(tableNodeCell); int tableId = (int) tableNode->rangeTableId; if (tableId != SUBQUERY_RANGE_TABLE_ID) { tableIdList = lappend_int(tableIdList, tableId); } } return tableIdList; } /* * FindNodesOfType takes in a given logical plan tree, and recursively traverses * the tree in preorder. The function finds all nodes of requested type during * the traversal, and returns them in a list. */ List * FindNodesOfType(MultiNode *node, int type) { List *nodeList = NIL; /* terminal condition for recursion */ if (node == NULL) { return NIL; } /* current node has expected node type */ int nodeType = CitusNodeTag(node); if (nodeType == type) { nodeList = lappend(nodeList, node); } if (UnaryOperator(node)) { MultiNode *childNode = ((MultiUnaryNode *) node)->childNode; List *childNodeList = FindNodesOfType(childNode, type); nodeList = list_concat(nodeList, childNodeList); } else if (BinaryOperator(node)) { MultiNode *leftChildNode = ((MultiBinaryNode *) node)->leftChildNode; MultiNode *rightChildNode = ((MultiBinaryNode *) node)->rightChildNode; List *leftChildNodeList = FindNodesOfType(leftChildNode, type); List *rightChildNodeList = FindNodesOfType(rightChildNode, type); nodeList = list_concat(nodeList, leftChildNodeList); nodeList = list_concat(nodeList, rightChildNodeList); } return nodeList; } /* * pull_var_clause_default calls pull_var_clause with the most commonly used * arguments for distributed planning. */ List * pull_var_clause_default(Node *node) { /* * PVC_REJECT_PLACEHOLDERS is implicit if PVC_INCLUDE_PLACEHOLDERS * isn't specified. */ List *columnList = pull_var_clause(node, PVC_RECURSE_AGGREGATES | PVC_RECURSE_WINDOWFUNCS); return columnList; } /* * ApplyJoinRule finds the join rule application function that corresponds to * the given join rule, and calls this function to create a new join node that * joins the left and right nodes together. */ static MultiNode * ApplyJoinRule(MultiNode *leftNode, MultiNode *rightNode, JoinRuleType ruleType, List *partitionColumnList, JoinType joinType, List *joinClauseList) { List *leftTableIdList = OutputTableIdList(leftNode); List *rightTableIdList = OutputTableIdList(rightNode); int rightTableIdCount PG_USED_FOR_ASSERTS_ONLY = 0; rightTableIdCount = list_length(rightTableIdList); Assert(rightTableIdCount == 1); /* find applicable join clauses between the left and right data sources */ uint32 rightTableId = (uint32) linitial_int(rightTableIdList); List *applicableJoinClauses = ApplicableJoinClauses(leftTableIdList, rightTableId, joinClauseList); /* call the join rule application function to create the new join node */ RuleApplyFunction ruleApplyFunction = JoinRuleApplyFunction(ruleType); MultiNode *multiNode = (*ruleApplyFunction)(leftNode, rightNode, partitionColumnList, joinType, applicableJoinClauses); if (joinType != JOIN_INNER && CitusIsA(multiNode, MultiJoin)) { MultiJoin *joinNode = (MultiJoin *) multiNode; /* preserve non-join clauses for OUTER joins */ joinNode->joinClauseList = list_copy(joinClauseList); } return multiNode; } /* * JoinRuleApplyFunction returns a function pointer for the rule application * function; this rule application function corresponds to the given rule type. * This function also initializes the rule application function array in a * static code block, if the array has not been initialized. */ static RuleApplyFunction JoinRuleApplyFunction(JoinRuleType ruleType) { static bool ruleApplyFunctionInitialized = false; if (!ruleApplyFunctionInitialized) { RuleApplyFunctionArray[REFERENCE_JOIN] = &ApplyReferenceJoin; RuleApplyFunctionArray[LOCAL_PARTITION_JOIN] = &ApplyLocalJoin; RuleApplyFunctionArray[SINGLE_HASH_PARTITION_JOIN] = &ApplySingleHashPartitionJoin; RuleApplyFunctionArray[SINGLE_RANGE_PARTITION_JOIN] = &ApplySingleRangePartitionJoin; RuleApplyFunctionArray[DUAL_PARTITION_JOIN] = &ApplyDualPartitionJoin; RuleApplyFunctionArray[CARTESIAN_PRODUCT_REFERENCE_JOIN] = &ApplyCartesianProductReferenceJoin; RuleApplyFunctionArray[CARTESIAN_PRODUCT] = &ApplyCartesianProduct; ruleApplyFunctionInitialized = true; } RuleApplyFunction ruleApplyFunction = RuleApplyFunctionArray[ruleType]; Assert(ruleApplyFunction != NULL); return ruleApplyFunction; } /* * ApplyBroadcastJoin creates a new MultiJoin node that joins the left and the * right node. The new node uses the broadcast join rule to perform the join. */ static MultiNode * ApplyReferenceJoin(MultiNode *leftNode, MultiNode *rightNode, List *partitionColumnList, JoinType joinType, List *applicableJoinClauses) { MultiJoin *joinNode = CitusMakeNode(MultiJoin); joinNode->joinRuleType = REFERENCE_JOIN; joinNode->joinType = joinType; joinNode->joinClauseList = applicableJoinClauses; SetLeftChild((MultiBinaryNode *) joinNode, leftNode); SetRightChild((MultiBinaryNode *) joinNode, rightNode); return (MultiNode *) joinNode; } /* * ApplyCartesianProductReferenceJoin creates a new MultiJoin node that joins * the left and the right node. The new node uses the broadcast join rule to * perform the join. */ static MultiNode * ApplyCartesianProductReferenceJoin(MultiNode *leftNode, MultiNode *rightNode, List *partitionColumnList, JoinType joinType, List *applicableJoinClauses) { MultiJoin *joinNode = CitusMakeNode(MultiJoin); joinNode->joinRuleType = CARTESIAN_PRODUCT_REFERENCE_JOIN; joinNode->joinType = joinType; joinNode->joinClauseList = applicableJoinClauses; SetLeftChild((MultiBinaryNode *) joinNode, leftNode); SetRightChild((MultiBinaryNode *) joinNode, rightNode); return (MultiNode *) joinNode; } /* * ApplyLocalJoin creates a new MultiJoin node that joins the left and the right * node. The new node uses the local join rule to perform the join. */ static MultiNode * ApplyLocalJoin(MultiNode *leftNode, MultiNode *rightNode, List *partitionColumnList, JoinType joinType, List *applicableJoinClauses) { MultiJoin *joinNode = CitusMakeNode(MultiJoin); joinNode->joinRuleType = LOCAL_PARTITION_JOIN; joinNode->joinType = joinType; joinNode->joinClauseList = applicableJoinClauses; SetLeftChild((MultiBinaryNode *) joinNode, leftNode); SetRightChild((MultiBinaryNode *) joinNode, rightNode); return (MultiNode *) joinNode; } /* * ApplySingleRangePartitionJoin is a wrapper around ApplySinglePartitionJoin() * which sets the joinRuleType properly. */ static MultiNode * ApplySingleRangePartitionJoin(MultiNode *leftNode, MultiNode *rightNode, List *partitionColumnList, JoinType joinType, List *applicableJoinClauses) { MultiJoin *joinNode = ApplySinglePartitionJoin(leftNode, rightNode, partitionColumnList, joinType, applicableJoinClauses); joinNode->joinRuleType = SINGLE_RANGE_PARTITION_JOIN; return (MultiNode *) joinNode; } /* * ApplySingleHashPartitionJoin is a wrapper around ApplySinglePartitionJoin() * which sets the joinRuleType properly. */ static MultiNode * ApplySingleHashPartitionJoin(MultiNode *leftNode, MultiNode *rightNode, List *partitionColumnList, JoinType joinType, List *applicableJoinClauses) { MultiJoin *joinNode = ApplySinglePartitionJoin(leftNode, rightNode, partitionColumnList, joinType, applicableJoinClauses); joinNode->joinRuleType = SINGLE_HASH_PARTITION_JOIN; return (MultiNode *) joinNode; } /* * ApplySinglePartitionJoin creates a new MultiJoin node that joins the left and * right node. The function also adds a MultiPartition node on top of the node * (left or right) that is not partitioned on the join column. */ static MultiJoin * ApplySinglePartitionJoin(MultiNode *leftNode, MultiNode *rightNode, List *partitionColumnList, JoinType joinType, List *applicableJoinClauses) { Var *partitionColumn = linitial(partitionColumnList); uint32 partitionTableId = partitionColumn->varno; /* create all operator structures up front */ MultiJoin *joinNode = CitusMakeNode(MultiJoin); MultiCollect *collectNode = CitusMakeNode(MultiCollect); MultiPartition *partitionNode = CitusMakeNode(MultiPartition); /* * We first find the appropriate join clause. Then, we compare the partition * column against the join clause's columns. If one of the columns matches, * we introduce a (re-)partition operator for the other column. */ OpExpr *joinClause = SinglePartitionJoinClause(partitionColumnList, applicableJoinClauses); Assert(joinClause != NULL); /* both are verified in SinglePartitionJoinClause to not be NULL, assert is to guard */ Var *leftColumn = LeftColumnOrNULL(joinClause); Var *rightColumn = RightColumnOrNULL(joinClause); Assert(leftColumn != NULL); Assert(rightColumn != NULL); if (equal(partitionColumn, leftColumn)) { partitionNode->partitionColumn = rightColumn; partitionNode->splitPointTableId = partitionTableId; } else if (equal(partitionColumn, rightColumn)) { partitionNode->partitionColumn = leftColumn; partitionNode->splitPointTableId = partitionTableId; } /* determine the node the partition operator goes on top of */ List *rightTableIdList = OutputTableIdList(rightNode); uint32 rightTableId = (uint32) linitial_int(rightTableIdList); Assert(list_length(rightTableIdList) == 1); /* * If the right child node is partitioned on the partition key column, we * add the partition operator on the left child node; and vice versa. Then, * we add a collect operator on top of the partition operator, and always * make sure that we have at most one relation on the right-hand side. */ if (partitionTableId == rightTableId) { SetChild((MultiUnaryNode *) partitionNode, leftNode); SetChild((MultiUnaryNode *) collectNode, (MultiNode *) partitionNode); SetLeftChild((MultiBinaryNode *) joinNode, (MultiNode *) collectNode); SetRightChild((MultiBinaryNode *) joinNode, rightNode); } else { SetChild((MultiUnaryNode *) partitionNode, rightNode); SetChild((MultiUnaryNode *) collectNode, (MultiNode *) partitionNode); SetLeftChild((MultiBinaryNode *) joinNode, leftNode); SetRightChild((MultiBinaryNode *) joinNode, (MultiNode *) collectNode); } /* finally set join operator fields */ joinNode->joinType = joinType; joinNode->joinClauseList = applicableJoinClauses; return joinNode; } /* * ApplyDualPartitionJoin creates a new MultiJoin node that joins the left and * right node. The function also adds two MultiPartition operators on top of * both nodes to repartition these nodes' data on the join clause columns. */ static MultiNode * ApplyDualPartitionJoin(MultiNode *leftNode, MultiNode *rightNode, List *partitionColumnList, JoinType joinType, List *applicableJoinClauses) { /* find the appropriate join clause */ OpExpr *joinClause = DualPartitionJoinClause(applicableJoinClauses); Assert(joinClause != NULL); /* both are verified in DualPartitionJoinClause to not be NULL, assert is to guard */ Var *leftColumn = LeftColumnOrNULL(joinClause); Var *rightColumn = RightColumnOrNULL(joinClause); Assert(leftColumn != NULL); Assert(rightColumn != NULL); List *rightTableIdList = OutputTableIdList(rightNode); uint32 rightTableId = (uint32) linitial_int(rightTableIdList); Assert(list_length(rightTableIdList) == 1); MultiPartition *leftPartitionNode = CitusMakeNode(MultiPartition); MultiPartition *rightPartitionNode = CitusMakeNode(MultiPartition); /* find the partition node each join clause column belongs to */ if (leftColumn->varno == rightTableId) { leftPartitionNode->partitionColumn = rightColumn; rightPartitionNode->partitionColumn = leftColumn; } else { leftPartitionNode->partitionColumn = leftColumn; rightPartitionNode->partitionColumn = rightColumn; } /* add partition operators on top of left and right nodes */ SetChild((MultiUnaryNode *) leftPartitionNode, leftNode); SetChild((MultiUnaryNode *) rightPartitionNode, rightNode); /* add collect operators on top of the two partition operators */ MultiCollect *leftCollectNode = CitusMakeNode(MultiCollect); MultiCollect *rightCollectNode = CitusMakeNode(MultiCollect); SetChild((MultiUnaryNode *) leftCollectNode, (MultiNode *) leftPartitionNode); SetChild((MultiUnaryNode *) rightCollectNode, (MultiNode *) rightPartitionNode); /* add join operator on top of the two collect operators */ MultiJoin *joinNode = CitusMakeNode(MultiJoin); joinNode->joinRuleType = DUAL_PARTITION_JOIN; joinNode->joinType = joinType; joinNode->joinClauseList = applicableJoinClauses; SetLeftChild((MultiBinaryNode *) joinNode, (MultiNode *) leftCollectNode); SetRightChild((MultiBinaryNode *) joinNode, (MultiNode *) rightCollectNode); return (MultiNode *) joinNode; } /* Creates a cartesian product node that joins the left and the right node. */ static MultiNode * ApplyCartesianProduct(MultiNode *leftNode, MultiNode *rightNode, List *partitionColumnList, JoinType joinType, List *applicableJoinClauses) { MultiCartesianProduct *cartesianNode = CitusMakeNode(MultiCartesianProduct); SetLeftChild((MultiBinaryNode *) cartesianNode, leftNode); SetRightChild((MultiBinaryNode *) cartesianNode, rightNode); return (MultiNode *) cartesianNode; } /* * OperatorImplementsEquality returns true if the given opno represents an * equality operator. The function retrieves btree interpretation list for this * opno and check if BTEqualStrategyNumber strategy is present. */ bool OperatorImplementsEquality(Oid opno) { bool equalityOperator = false; List *btreeIntepretationList = get_op_btree_interpretation(opno); ListCell *btreeInterpretationCell = NULL; foreach(btreeInterpretationCell, btreeIntepretationList) { OpBtreeInterpretation *btreeIntepretation = (OpBtreeInterpretation *) lfirst(btreeInterpretationCell); if (btreeIntepretation->strategy == BTEqualStrategyNumber) { equalityOperator = true; break; } } return equalityOperator; }