From 5f61585091868c78b61b27c95497ae02949c0845 Mon Sep 17 00:00:00 2001 From: Pedro Miguel Reis Bento Paredes Date: Sun, 25 Mar 2018 22:01:48 -0400 Subject: [PATCH 01/26] Initial new join operator for optimizer --- src/include/common/internal_types.h | 1 + src/include/optimizer/child_stats_deriver.h | 1 + src/include/optimizer/operator_node.h | 1 + src/include/optimizer/operator_visitor.h | 1 + src/include/optimizer/operators.h | 19 ++++++ src/include/optimizer/rule_impls.h | 15 +++++ src/include/optimizer/stats_calculator.h | 1 + src/optimizer/child_stats_deriver.cpp | 11 +++ src/optimizer/operators.cpp | 67 ++++++++++++++++++- .../query_to_operator_transformer.cpp | 2 +- src/optimizer/rule.cpp | 1 + src/optimizer/rule_impls.cpp | 67 ++++++++++++++++++- src/optimizer/stats_calculator.cpp | 59 ++++++++++++++++ 13 files changed, 241 insertions(+), 5 deletions(-) diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 17020512944..132d74aeb75 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1341,6 +1341,7 @@ enum class RuleType : uint32_t { INSERT_SELECT_TO_PHYSICAL, AGGREGATE_TO_HASH_AGGREGATE, AGGREGATE_TO_PLAIN_AGGREGATE, + JOIN_TO_NL_JOIN, INNER_JOIN_TO_NL_JOIN, INNER_JOIN_TO_HASH_JOIN, IMPLEMENT_DISTINCT, diff --git a/src/include/optimizer/child_stats_deriver.h b/src/include/optimizer/child_stats_deriver.h index d0c72f9bf9b..c3513faa832 100644 --- a/src/include/optimizer/child_stats_deriver.h +++ b/src/include/optimizer/child_stats_deriver.h @@ -32,6 +32,7 @@ class ChildStatsDeriver : public OperatorVisitor { ExprSet required_cols, Memo *memo); void Visit(const LogicalQueryDerivedGet *) override; + void Visit(const LogicalJoin *) override; void Visit(const LogicalInnerJoin *) override; void Visit(const LogicalLeftJoin *) override; void Visit(const LogicalRightJoin *) override; diff --git a/src/include/optimizer/operator_node.h b/src/include/optimizer/operator_node.h index cb20c163bbe..78c816750f2 100644 --- a/src/include/optimizer/operator_node.h +++ b/src/include/optimizer/operator_node.h @@ -33,6 +33,7 @@ enum class OpType { LogicalMarkJoin, LogicalDependentJoin, LogicalSingleJoin, + LogicalJoin, InnerJoin, LeftJoin, RightJoin, diff --git a/src/include/optimizer/operator_visitor.h b/src/include/optimizer/operator_visitor.h index 75b0a9f9c67..1bf56b0c285 100644 --- a/src/include/optimizer/operator_visitor.h +++ b/src/include/optimizer/operator_visitor.h @@ -58,6 +58,7 @@ class OperatorVisitor { virtual void Visit(const LogicalMarkJoin *) {} virtual void Visit(const LogicalSingleJoin *) {} virtual void Visit(const LogicalDependentJoin *) {} + virtual void Visit(const LogicalJoin *) {} virtual void Visit(const LogicalInnerJoin *) {} virtual void Visit(const LogicalLeftJoin *) {} virtual void Visit(const LogicalRightJoin *) {} diff --git a/src/include/optimizer/operators.h b/src/include/optimizer/operators.h index a745439251a..2f33ff675e4 100644 --- a/src/include/optimizer/operators.h +++ b/src/include/optimizer/operators.h @@ -162,6 +162,25 @@ class LogicalSingleJoin : public OperatorNode { std::vector join_predicates; }; +//===--------------------------------------------------------------------===// +// Join (Inner + Outer Joins) +//===--------------------------------------------------------------------===// +class LogicalJoin : public OperatorNode { + public: + enum class JoinType { Inner, FullOuter, LeftOuter, RightOuter }; + + static Operator make(JoinType _type); + + static Operator make(JoinType _type, std::vector &conditions); + + bool operator==(const BaseOperatorNode &r) override; + + hash_t Hash() const override; + + std::vector join_predicates; + JoinType type; +}; + //===--------------------------------------------------------------------===// // InnerJoin //===--------------------------------------------------------------------===// diff --git a/src/include/optimizer/rule_impls.h b/src/include/optimizer/rule_impls.h index 2c40e3f3c81..fa6634d55c3 100644 --- a/src/include/optimizer/rule_impls.h +++ b/src/include/optimizer/rule_impls.h @@ -209,6 +209,21 @@ class LogicalAggregateToPhysical : public Rule { OptimizeContext *context) const override; }; +/** + * @brief (Logical Join -> Nested-Loop Join) + */ +class JoinToNLJoin : public Rule { + public: + JoinToNLJoin(); + + bool Check(std::shared_ptr plan, + OptimizeContext *context) const override; + + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + /** * @brief (Logical Inner Join -> Inner Nested-Loop Join) */ diff --git a/src/include/optimizer/stats_calculator.h b/src/include/optimizer/stats_calculator.h index 5aed2902671..3e9d43c7eb0 100644 --- a/src/include/optimizer/stats_calculator.h +++ b/src/include/optimizer/stats_calculator.h @@ -31,6 +31,7 @@ class StatsCalculator : public OperatorVisitor { void Visit(const LogicalGet *) override; void Visit(const LogicalQueryDerivedGet *) override; + void Visit(const LogicalJoin *) override; void Visit(const LogicalInnerJoin *) override; void Visit(const LogicalLeftJoin *) override; void Visit(const LogicalRightJoin *) override; diff --git a/src/optimizer/child_stats_deriver.cpp b/src/optimizer/child_stats_deriver.cpp index 0833d55a0f0..d3730d4e870 100644 --- a/src/optimizer/child_stats_deriver.cpp +++ b/src/optimizer/child_stats_deriver.cpp @@ -33,6 +33,17 @@ vector ChildStatsDeriver::DeriveInputStats(GroupExpression *gexpr, // TODO(boweic): support stats derivation for derivedGet void ChildStatsDeriver::Visit(const LogicalQueryDerivedGet *) {} +void ChildStatsDeriver::Visit(const LogicalJoin *op) { + PassDownRequiredCols(); + for (auto &annotated_expr : op->join_predicates) { + auto predicate = annotated_expr.expr.get(); + ExprSet expr_set; + expression::ExpressionUtil::GetTupleValueExprs(expr_set, predicate); + for (auto &col : expr_set) { + PassDownColumn(col); + } + } +} void ChildStatsDeriver::Visit(const LogicalInnerJoin *op) { PassDownRequiredCols(); for (auto &annotated_expr : op->join_predicates) { diff --git a/src/optimizer/operators.cpp b/src/optimizer/operators.cpp index 78c34d16257..43b70c563c2 100644 --- a/src/optimizer/operators.cpp +++ b/src/optimizer/operators.cpp @@ -230,6 +230,67 @@ bool LogicalSingleJoin::operator==(const BaseOperatorNode &r) { return true; } +//===--------------------------------------------------------------------===// +// Join (Inner + Outer Joins) +//===--------------------------------------------------------------------===// +Operator LogicalJoin::make(JoinType _type) { + LogicalJoin *join = new LogicalJoin; + join->join_predicates = {}; + join->type = _type; + return Operator(join); +} + +Operator LogicalJoin::make(JoinType _type, std::vector &conditions) { + LogicalJoin *join = new LogicalJoin; + join->join_predicates = std::move(conditions); + join->type = _type; + return Operator(join); +} + +hash_t LogicalJoin::Hash() const { + hash_t hash = BaseOperatorNode::Hash(); + for (auto &pred : join_predicates) + hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); + return hash; +} + +bool LogicalJoin::operator==(const BaseOperatorNode &r) { + switch (r.GetType()) { + case OpType::InnerJoin: + if (type != JoinType::Inner) { + return false; + } + break; + case OpType::OuterJoin: + if (type != JoinType::FullOuter) { + return false; + } + break; + case OpType::LeftJoin: + if (type != JoinType::LeftOuter) { + return false; + } + break; + case OpType::RightJoin: + if (type != JoinType::RightOuter) { + return false; + } + break; + default: + return false; + break; + } + + const LogicalJoin &node = *static_cast(&r); + if (join_predicates.size() != node.join_predicates.size()) return false; + for (size_t i = 0; i < join_predicates.size(); i++) { + if (!join_predicates[i].expr->ExactlyEquals( + *node.join_predicates[i].expr.get())) + return false; + } + return true; +} + //===--------------------------------------------------------------------===// // InnerJoin //===--------------------------------------------------------------------===// @@ -295,7 +356,7 @@ Operator LogicalOuterJoin::make(expression::AbstractExpression *condition) { } //===--------------------------------------------------------------------===// -// OuterJoin +// SemiJoin //===--------------------------------------------------------------------===// Operator LogicalSemiJoin::make(expression::AbstractExpression *condition) { LogicalSemiJoin *join = new LogicalSemiJoin; @@ -859,6 +920,8 @@ std::string OperatorNode::name_ = "LogicalSingleJoin"; template <> std::string OperatorNode::name_ = "LogicalDependentJoin"; template <> +std::string OperatorNode::name_ = "LogicalJoin"; +template <> std::string OperatorNode::name_ = "LogicalInnerJoin"; template <> std::string OperatorNode::name_ = "LogicalLeftJoin"; @@ -950,6 +1013,8 @@ OpType OperatorNode::type_ = OpType::LogicalSingleJoin; template <> OpType OperatorNode::type_ = OpType::LogicalDependentJoin; template <> +OpType OperatorNode::type_ = OpType::LogicalJoin; +template <> OpType OperatorNode::type_ = OpType::InnerJoin; template <> OpType OperatorNode::type_ = OpType::LeftJoin; diff --git a/src/optimizer/query_to_operator_transformer.cpp b/src/optimizer/query_to_operator_transformer.cpp index ff75140d5f5..bd8ddda612c 100644 --- a/src/optimizer/query_to_operator_transformer.cpp +++ b/src/optimizer/query_to_operator_transformer.cpp @@ -134,7 +134,7 @@ void QueryToOperatorTransformer::Visit(parser::JoinDefinition *node) { case JoinType::INNER: { predicates_ = CollectPredicates(node->condition.get(), predicates_); join_expr = - std::make_shared(LogicalInnerJoin::make()); + std::make_shared(LogicalJoin::make(LogicalJoin::JoinType::Inner)); break; } case JoinType::OUTER: { diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index 1e81799147d..1a36ebc8358 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -40,6 +40,7 @@ RuleSet::RuleSet() { AddImplementationRule(new GetToSeqScan()); AddImplementationRule(new GetToIndexScan()); AddImplementationRule(new LogicalQueryDerivedGetToPhysical()); + AddImplementationRule(new JoinToNLJoin()); AddImplementationRule(new InnerJoinToInnerNLJoin()); AddImplementationRule(new InnerJoinToInnerHashJoin()); AddImplementationRule(new ImplementDistinct()); diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index e540555c9e3..bb85519aeed 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -611,6 +611,67 @@ void LogicalAggregateToPhysical::Transform( transformed.push_back(result); } +/////////////////////////////////////////////////////////////////////////////// +/// JoinToNLJoin +JoinToNLJoin::JoinToNLJoin() { + type_ = RuleType::JOIN_TO_NL_JOIN; + + // TODO NLJoin currently only support left deep tree + std::shared_ptr left_child(std::make_shared(OpType::Leaf)); + std::shared_ptr right_child(std::make_shared(OpType::Leaf)); + + // Initialize a pattern for optimizer to match + match_pattern = std::make_shared(OpType::LogicalJoin); + + // Add node - we match join relation R and S + match_pattern->AddChild(left_child); + match_pattern->AddChild(right_child); + + return; +} + +bool JoinToNLJoin::Check(std::shared_ptr plan, + OptimizeContext *context) const { + (void)context; + (void)plan; + return true; +} + +void JoinToNLJoin::Transform( + std::shared_ptr input, + std::vector> &transformed, + UNUSED_ATTRIBUTE OptimizeContext *context) const { + // first build an expression representing hash join + const LogicalJoin *inner_join = input->Op().As(); + + auto children = input->Children(); + PL_ASSERT(children.size() == 2); + auto left_group_id = children[0]->Op().As()->origin_group; + auto right_group_id = children[1]->Op().As()->origin_group; + auto &left_group_alias = + context->metadata->memo.GetGroupByID(left_group_id)->GetTableAliases(); + auto &right_group_alias = + context->metadata->memo.GetGroupByID(right_group_id)->GetTableAliases(); + std::vector> left_keys; + std::vector> right_keys; + + util::ExtractEquiJoinKeys(inner_join->join_predicates, left_keys, right_keys, + left_group_alias, right_group_alias); + + PL_ASSERT(right_keys.size() == left_keys.size()); + std::shared_ptr result_plan = + std::make_shared(PhysicalInnerNLJoin::make( + inner_join->join_predicates, left_keys, right_keys)); + + // Then push all children into the child list of the new operator + result_plan->PushChild(children[0]); + result_plan->PushChild(children[1]); + + transformed.push_back(result_plan); + + return; +} + /////////////////////////////////////////////////////////////////////////////// /// InnerJoinToInnerNLJoin InnerJoinToInnerNLJoin::InnerJoinToInnerNLJoin() { @@ -807,7 +868,7 @@ PushFilterThroughJoin::PushFilterThroughJoin() { type_ = RuleType::PUSH_FILTER_THROUGH_JOIN; // Make three node types for pattern matching - std::shared_ptr child(std::make_shared(OpType::InnerJoin)); + std::shared_ptr child(std::make_shared(OpType::LogicalJoin)); child->AddChild(std::make_shared(OpType::Leaf)); child->AddChild(std::make_shared(OpType::Leaf)); @@ -862,12 +923,12 @@ void PushFilterThroughJoin::Transform( // Construct join operator auto pre_join_predicate = - join_op_expr->Op().As()->join_predicates; + join_op_expr->Op().As()->join_predicates; join_predicates.insert(join_predicates.end(), pre_join_predicate.begin(), pre_join_predicate.end()); std::shared_ptr output = std::make_shared( - LogicalInnerJoin::make(join_predicates)); + LogicalJoin::make(join_op_expr->Op().As()->type, join_predicates)); // Construct left filter if any if (!left_predicates.empty()) { diff --git a/src/optimizer/stats_calculator.cpp b/src/optimizer/stats_calculator.cpp index 3cdb34c4d9d..59beb15800f 100644 --- a/src/optimizer/stats_calculator.cpp +++ b/src/optimizer/stats_calculator.cpp @@ -96,6 +96,65 @@ void StatsCalculator::Visit(const LogicalQueryDerivedGet *) { } } +void StatsCalculator::Visit(const LogicalJoin *op) { + // Check if there's join condition + PL_ASSERT(gexpr_->GetChildrenGroupsSize() == 2); + auto left_child_group = memo_->GetGroupByID(gexpr_->GetChildGroupId(0)); + auto right_child_group = memo_->GetGroupByID(gexpr_->GetChildGroupId(1)); + auto root_group = memo_->GetGroupByID(gexpr_->GetGroupID()); + // Calculate output num rows first + if (root_group->GetNumRows() == -1) { + size_t curr_rows = + left_child_group->GetNumRows() * right_child_group->GetNumRows(); + for (auto &annotated_expr : op->join_predicates) { + // See if there are join conditions + if (annotated_expr.expr->GetExpressionType() == + ExpressionType::COMPARE_EQUAL && + annotated_expr.expr->GetChild(0)->GetExpressionType() == + ExpressionType::VALUE_TUPLE && + annotated_expr.expr->GetChild(1)->GetExpressionType() == + ExpressionType::VALUE_TUPLE) { + auto left_child = + reinterpret_cast( + annotated_expr.expr->GetChild(0)); + auto right_child = + reinterpret_cast( + annotated_expr.expr->GetChild(1)); + if ((left_child_group->HasColumnStats(left_child->GetColFullName()) && + right_child_group->HasColumnStats( + right_child->GetColFullName())) || + (left_child_group->HasColumnStats(right_child->GetColFullName()) && + right_child_group->HasColumnStats(left_child->GetColFullName()))) { + curr_rows /= std::max(std::max(left_child_group->GetNumRows(), + right_child_group->GetNumRows()), + 1); + } + } + } + root_group->SetNumRows(curr_rows); + } + size_t num_rows = root_group->GetNumRows(); + for (auto &col : required_cols_) { + PL_ASSERT(col->GetExpressionType() == ExpressionType::VALUE_TUPLE); + auto tv_expr = reinterpret_cast(col); + std::shared_ptr column_stats; + // Make a copy from the child stats + if (left_child_group->HasColumnStats(tv_expr->GetColFullName())) { + column_stats = std::make_shared( + *left_child_group->GetStats(tv_expr->GetColFullName())); + } else { + PL_ASSERT(right_child_group->HasColumnStats(tv_expr->GetColFullName())); + column_stats = std::make_shared( + *right_child_group->GetStats(tv_expr->GetColFullName())); + } + // Reset num_rows + column_stats->num_rows = num_rows; + root_group->AddStats(tv_expr->GetColFullName(), column_stats); + } + // TODO(boweic): calculate stats based on predicates other than join + // conditions +} + void StatsCalculator::Visit(const LogicalInnerJoin *op) { // Check if there's join condition PELOTON_ASSERT(gexpr_->GetChildrenGroupsSize() == 2); From 2c6e6c27f34467a08ac55d641086d48efb396eca Mon Sep 17 00:00:00 2001 From: Pedro Miguel Reis Bento Paredes Date: Tue, 27 Mar 2018 16:48:39 -0400 Subject: [PATCH 02/26] Added full outer join and refactored physical join operators --- src/codegen/query_compiler.cpp | 5 +- src/executor/nested_loop_join_executor.cpp | 2 +- .../optimizer/child_property_deriver.h | 2 + src/include/optimizer/cost_calculator.h | 2 + src/include/optimizer/input_column_deriver.h | 4 + src/include/optimizer/operator_node.h | 2 + src/include/optimizer/operator_visitor.h | 2 + src/include/optimizer/operators.h | 46 +++++- src/include/optimizer/plan_generator.h | 4 + src/include/optimizer/rule_impls.h | 15 ++ src/optimizer/child_property_deriver.cpp | 6 + src/optimizer/cost_calculator.cpp | 16 +++ src/optimizer/input_column_deriver.cpp | 16 ++- src/optimizer/operators.cpp | 136 ++++++++++++++---- src/optimizer/plan_generator.cpp | 82 +++++++++++ .../query_to_operator_transformer.cpp | 7 +- src/optimizer/rule.cpp | 1 + src/optimizer/rule_impls.cpp | 72 +++++++++- 18 files changed, 378 insertions(+), 42 deletions(-) diff --git a/src/codegen/query_compiler.cpp b/src/codegen/query_compiler.cpp index d4698d7007a..734e3039522 100644 --- a/src/codegen/query_compiler.cpp +++ b/src/codegen/query_compiler.cpp @@ -73,9 +73,10 @@ bool QueryCompiler::IsSupported(const planner::AbstractPlan &plan) { case PlanNodeType::HASHJOIN: { const auto &join = static_cast(plan); // Right now, only support inner joins - if (join.GetJoinType() == JoinType::INNER) { - break; + if (join.GetJoinType() != JoinType::INNER) { + return false; } + break; } case PlanNodeType::HASH: { break; diff --git a/src/executor/nested_loop_join_executor.cpp b/src/executor/nested_loop_join_executor.cpp index 6f1bca1eb36..4bec12cdc10 100644 --- a/src/executor/nested_loop_join_executor.cpp +++ b/src/executor/nested_loop_join_executor.cpp @@ -101,7 +101,7 @@ bool NestedLoopJoinExecutor::DExecute() { // If we have already retrieved all left child's results in buffer if (left_child_done_ == true) { LOG_TRACE("Left is done which means all join comparison completes"); - return false; + return BuildOuterJoinOutput(); } // If left tile result is not done, continue the left tuples diff --git a/src/include/optimizer/child_property_deriver.h b/src/include/optimizer/child_property_deriver.h index bd4aeb7b933..152ef60207e 100644 --- a/src/include/optimizer/child_property_deriver.h +++ b/src/include/optimizer/child_property_deriver.h @@ -42,6 +42,8 @@ class ChildPropertyDeriver : public OperatorVisitor { void Visit(const QueryDerivedScan *op) override; void Visit(const PhysicalOrderBy *) override; void Visit(const PhysicalLimit *) override; + void Visit(const PhysicalNLJoin *) override; + void Visit(const PhysicalHashJoin *) override; void Visit(const PhysicalInnerNLJoin *) override; void Visit(const PhysicalLeftNLJoin *) override; void Visit(const PhysicalRightNLJoin *) override; diff --git a/src/include/optimizer/cost_calculator.h b/src/include/optimizer/cost_calculator.h index 442f386fc5f..c3c0a31e6e0 100644 --- a/src/include/optimizer/cost_calculator.h +++ b/src/include/optimizer/cost_calculator.h @@ -30,6 +30,8 @@ class CostCalculator : public OperatorVisitor { void Visit(const QueryDerivedScan *) override; void Visit(const PhysicalOrderBy *) override; void Visit(const PhysicalLimit *) override; + void Visit(const PhysicalNLJoin *) override; + void Visit(const PhysicalHashJoin *) override; void Visit(const PhysicalInnerNLJoin *) override; void Visit(const PhysicalLeftNLJoin *) override; void Visit(const PhysicalRightNLJoin *) override; diff --git a/src/include/optimizer/input_column_deriver.h b/src/include/optimizer/input_column_deriver.h index fa1ec6ca5a1..d105aa06561 100644 --- a/src/include/optimizer/input_column_deriver.h +++ b/src/include/optimizer/input_column_deriver.h @@ -59,6 +59,10 @@ class InputColumnDeriver : public OperatorVisitor { void Visit(const PhysicalLimit *) override; + void Visit(const PhysicalNLJoin *) override; + + void Visit(const PhysicalHashJoin *) override; + void Visit(const PhysicalInnerNLJoin *) override; void Visit(const PhysicalLeftNLJoin *) override; diff --git a/src/include/optimizer/operator_node.h b/src/include/optimizer/operator_node.h index 78c816750f2..a9ff68dca23 100644 --- a/src/include/optimizer/operator_node.h +++ b/src/include/optimizer/operator_node.h @@ -56,6 +56,8 @@ enum class OpType { OrderBy, PhysicalLimit, Distinct, + NLJoin, + HashJoin, InnerNLJoin, LeftNLJoin, RightNLJoin, diff --git a/src/include/optimizer/operator_visitor.h b/src/include/optimizer/operator_visitor.h index 1bf56b0c285..f39bfc88d4c 100644 --- a/src/include/optimizer/operator_visitor.h +++ b/src/include/optimizer/operator_visitor.h @@ -32,6 +32,8 @@ class OperatorVisitor { virtual void Visit(const QueryDerivedScan *) {} virtual void Visit(const PhysicalOrderBy *) {} virtual void Visit(const PhysicalLimit *) {} + virtual void Visit(const PhysicalNLJoin *) {} + virtual void Visit(const PhysicalHashJoin *) {} virtual void Visit(const PhysicalInnerNLJoin *) {} virtual void Visit(const PhysicalLeftNLJoin *) {} virtual void Visit(const PhysicalRightNLJoin *) {} diff --git a/src/include/optimizer/operators.h b/src/include/optimizer/operators.h index 2f33ff675e4..d30eb9b196f 100644 --- a/src/include/optimizer/operators.h +++ b/src/include/optimizer/operators.h @@ -167,8 +167,6 @@ class LogicalSingleJoin : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalJoin : public OperatorNode { public: - enum class JoinType { Inner, FullOuter, LeftOuter, RightOuter }; - static Operator make(JoinType _type); static Operator make(JoinType _type, std::vector &conditions); @@ -426,6 +424,50 @@ class PhysicalLimit : public OperatorNode { int64_t limit; }; +//===--------------------------------------------------------------------===// +// NLJoin (Inner + Outer Joins) +//===--------------------------------------------------------------------===// +class PhysicalNLJoin : public OperatorNode { + public: + static Operator make( + JoinType _type, + std::vector conditions, + std::vector> &left_keys, + std::vector> &right_keys); + + bool operator==(const BaseOperatorNode &r) override; + + hash_t Hash() const override; + + std::vector> left_keys; + std::vector> right_keys; + + std::vector join_predicates; + JoinType type; +}; + +//===--------------------------------------------------------------------===// +// HashJoin (Inner + Outer Joins) +//===--------------------------------------------------------------------===// +class PhysicalHashJoin : public OperatorNode { + public: + static Operator make( + JoinType _type, + std::vector conditions, + std::vector> &left_keys, + std::vector> &right_keys); + + bool operator==(const BaseOperatorNode &r) override; + + hash_t Hash() const override; + + std::vector> left_keys; + std::vector> right_keys; + + std::vector join_predicates; + JoinType type; +}; + //===--------------------------------------------------------------------===// // InnerNLJoin //===--------------------------------------------------------------------===// diff --git a/src/include/optimizer/plan_generator.h b/src/include/optimizer/plan_generator.h index c0a21259bc6..20935377c3e 100644 --- a/src/include/optimizer/plan_generator.h +++ b/src/include/optimizer/plan_generator.h @@ -60,6 +60,10 @@ class PlanGenerator : public OperatorVisitor { void Visit(const PhysicalLimit *) override; + void Visit(const PhysicalNLJoin *) override; + + void Visit(const PhysicalHashJoin *) override; + void Visit(const PhysicalInnerNLJoin *) override; void Visit(const PhysicalLeftNLJoin *) override; diff --git a/src/include/optimizer/rule_impls.h b/src/include/optimizer/rule_impls.h index fa6634d55c3..e0e67db73c6 100644 --- a/src/include/optimizer/rule_impls.h +++ b/src/include/optimizer/rule_impls.h @@ -224,6 +224,21 @@ class JoinToNLJoin : public Rule { OptimizeContext *context) const override; }; +/** + * @brief (Logical Join -> Hash Join) + */ +class JoinToHashJoin : public Rule { + public: + JoinToHashJoin(); + + bool Check(std::shared_ptr plan, + OptimizeContext *context) const override; + + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + /** * @brief (Logical Inner Join -> Inner Nested-Loop Join) */ diff --git a/src/optimizer/child_property_deriver.cpp b/src/optimizer/child_property_deriver.cpp index 1df06b3ea50..a72f64883a7 100644 --- a/src/optimizer/child_property_deriver.cpp +++ b/src/optimizer/child_property_deriver.cpp @@ -144,6 +144,12 @@ void ChildPropertyDeriver::Visit(const PhysicalDistinct *) { output_.push_back(make_pair(requirements_, move(child_input_properties))); } void ChildPropertyDeriver::Visit(const PhysicalOrderBy *) {} +void ChildPropertyDeriver::Visit(const PhysicalNLJoin *) { + DeriveForJoin(); +} +void ChildPropertyDeriver::Visit(const PhysicalHashJoin *) { + DeriveForJoin(); +} void ChildPropertyDeriver::Visit(const PhysicalInnerNLJoin *) { DeriveForJoin(); } diff --git a/src/optimizer/cost_calculator.cpp b/src/optimizer/cost_calculator.cpp index 5dda9e67c8a..1b40743473b 100644 --- a/src/optimizer/cost_calculator.cpp +++ b/src/optimizer/cost_calculator.cpp @@ -72,6 +72,22 @@ void CostCalculator::Visit(const PhysicalLimit *op) { output_cost_ = std::min((size_t)child_num_rows, (size_t)op->limit) * DEFAULT_TUPLE_COST; } +void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalNLJoin *op) { + auto left_child_rows = + memo_->GetGroupByID(gexpr_->GetChildGroupId(0))->GetNumRows(); + auto right_child_rows = + memo_->GetGroupByID(gexpr_->GetChildGroupId(1))->GetNumRows(); + + output_cost_ = left_child_rows * right_child_rows * DEFAULT_TUPLE_COST; +} +void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalHashJoin *op) { + auto left_child_rows = + memo_->GetGroupByID(gexpr_->GetChildGroupId(0))->GetNumRows(); + auto right_child_rows = + memo_->GetGroupByID(gexpr_->GetChildGroupId(1))->GetNumRows(); + + output_cost_ = left_child_rows * right_child_rows * DEFAULT_TUPLE_COST; +} void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalInnerNLJoin *op) { auto left_child_rows = memo_->GetGroupByID(gexpr_->GetChildGroupId(0))->GetNumRows(); diff --git a/src/optimizer/input_column_deriver.cpp b/src/optimizer/input_column_deriver.cpp index 7819f81afb9..c321fcc182f 100644 --- a/src/optimizer/input_column_deriver.cpp +++ b/src/optimizer/input_column_deriver.cpp @@ -123,6 +123,14 @@ void InputColumnDeriver::Visit(const PhysicalAggregate *op) { void InputColumnDeriver::Visit(const PhysicalDistinct *) { Passdown(); } +void InputColumnDeriver::Visit(const PhysicalNLJoin *op) { + JoinHelper(op); +} + +void InputColumnDeriver::Visit(const PhysicalHashJoin *op) { + JoinHelper(op); +} + void InputColumnDeriver::Visit(const PhysicalInnerNLJoin *op) { JoinHelper(op); } @@ -246,13 +254,13 @@ void InputColumnDeriver::JoinHelper(const BaseOperatorNode *op) { const vector> *left_keys = nullptr; const vector> *right_keys = nullptr; - if (op->GetType() == OpType::InnerHashJoin) { - auto join_op = reinterpret_cast(op); + if (op->GetType() == OpType::HashJoin) { + auto join_op = reinterpret_cast(op); join_conds = &(join_op->join_predicates); left_keys = &(join_op->left_keys); right_keys = &(join_op->right_keys); - } else if (op->GetType() == OpType::InnerNLJoin) { - auto join_op = reinterpret_cast(op); + } else if (op->GetType() == OpType::NLJoin) { + auto join_op = reinterpret_cast(op); join_conds = &(join_op->join_predicates); left_keys = &(join_op->left_keys); right_keys = &(join_op->right_keys); diff --git a/src/optimizer/operators.cpp b/src/optimizer/operators.cpp index 43b70c563c2..7ff78610cb1 100644 --- a/src/optimizer/operators.cpp +++ b/src/optimizer/operators.cpp @@ -255,31 +255,7 @@ hash_t LogicalJoin::Hash() const { } bool LogicalJoin::operator==(const BaseOperatorNode &r) { - switch (r.GetType()) { - case OpType::InnerJoin: - if (type != JoinType::Inner) { - return false; - } - break; - case OpType::OuterJoin: - if (type != JoinType::FullOuter) { - return false; - } - break; - case OpType::LeftJoin: - if (type != JoinType::LeftOuter) { - return false; - } - break; - case OpType::RightJoin: - if (type != JoinType::RightOuter) { - return false; - } - break; - default: - return false; - break; - } + if (r.GetType() != OpType::LogicalJoin) return false; const LogicalJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size()) return false; @@ -614,6 +590,108 @@ Operator PhysicalLimit::make(int64_t offset, int64_t limit) { return Operator(limit_op); } +//===--------------------------------------------------------------------===// +// NLJoin (Inner + Outer Joins) +//===--------------------------------------------------------------------===// +Operator PhysicalNLJoin::make( + JoinType _type, + std::vector conditions, + std::vector> &left_keys, + std::vector> &right_keys) { + PhysicalNLJoin *join = new PhysicalNLJoin(); + join->join_predicates = std::move(conditions); + join->left_keys = std::move(left_keys); + join->right_keys = std::move(right_keys); + join->type = _type; + + return Operator(join); +} + +hash_t PhysicalNLJoin::Hash() const { + hash_t hash = BaseOperatorNode::Hash(); + for (auto &expr : left_keys) + hash = HashUtil::CombineHashes(hash, expr->Hash()); + for (auto &expr : right_keys) + hash = HashUtil::CombineHashes(hash, expr->Hash()); + for (auto &pred : join_predicates) + hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); + return hash; +} + +bool PhysicalNLJoin::operator==(const BaseOperatorNode &r) { + if (r.GetType() != OpType::NLJoin) return false; + + const PhysicalNLJoin &node = + *static_cast(&r); + if (join_predicates.size() != node.join_predicates.size() || + left_keys.size() != node.left_keys.size() || + right_keys.size() != node.right_keys.size()) + return false; + + for (size_t i = 0; i < left_keys.size(); i++) { + if (!left_keys[i]->ExactlyEquals(*node.left_keys[i].get())) return false; + } + for (size_t i = 0; i < right_keys.size(); i++) { + if (!right_keys[i]->ExactlyEquals(*node.right_keys[i].get())) return false; + } + for (size_t i = 0; i < join_predicates.size(); i++) { + if (!join_predicates[i].expr->ExactlyEquals( + *node.join_predicates[i].expr.get())) + return false; + } + return true; +} + +//===--------------------------------------------------------------------===// +// HashJoin +//===--------------------------------------------------------------------===// +Operator PhysicalHashJoin::make( + JoinType _type, + std::vector conditions, + std::vector> &left_keys, + std::vector> &right_keys) { + PhysicalHashJoin *join = new PhysicalHashJoin(); + join->join_predicates = std::move(conditions); + join->left_keys = std::move(left_keys); + join->right_keys = std::move(right_keys); + join->type = _type; + return Operator(join); +} + +hash_t PhysicalHashJoin::Hash() const { + hash_t hash = BaseOperatorNode::Hash(); + for (auto &expr : left_keys) + hash = HashUtil::CombineHashes(hash, expr->Hash()); + for (auto &expr : right_keys) + hash = HashUtil::CombineHashes(hash, expr->Hash()); + for (auto &pred : join_predicates) + hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); + return hash; +} + +bool PhysicalHashJoin::operator==(const BaseOperatorNode &r) { + if (r.GetType() != OpType::HashJoin) return false; + + const PhysicalHashJoin &node = + *static_cast(&r); + if (join_predicates.size() != node.join_predicates.size() || + left_keys.size() != node.left_keys.size() || + right_keys.size() != node.right_keys.size()) + return false; + for (size_t i = 0; i < left_keys.size(); i++) { + if (!left_keys[i]->ExactlyEquals(*node.left_keys[i].get())) return false; + } + for (size_t i = 0; i < right_keys.size(); i++) { + if (!right_keys[i]->ExactlyEquals(*node.right_keys[i].get())) return false; + } + for (size_t i = 0; i < join_predicates.size(); i++) { + if (!join_predicates[i].expr->ExactlyEquals( + *node.join_predicates[i].expr.get())) + return false; + } + return true; +} + //===--------------------------------------------------------------------===// // InnerNLJoin //===--------------------------------------------------------------------===// @@ -959,6 +1037,10 @@ std::string OperatorNode::name_ = "PhysicalOrderBy"; template <> std::string OperatorNode::name_ = "PhysicalLimit"; template <> +std::string OperatorNode::name_ = "PhysicalNLJoin"; +template <> +std::string OperatorNode::name_ = "PhysicalHashJoin"; +template <> std::string OperatorNode::name_ = "PhysicalInnerNLJoin"; template <> std::string OperatorNode::name_ = "PhysicalLeftNLJoin"; @@ -1054,6 +1136,10 @@ OpType OperatorNode::type_ = OpType::Distinct; template <> OpType OperatorNode::type_ = OpType::PhysicalLimit; template <> +OpType OperatorNode::type_ = OpType::HashJoin; +template <> +OpType OperatorNode::type_ = OpType::NLJoin; +template <> OpType OperatorNode::type_ = OpType::InnerNLJoin; template <> OpType OperatorNode::type_ = OpType::LeftNLJoin; diff --git a/src/optimizer/plan_generator.cpp b/src/optimizer/plan_generator.cpp index 95bfe48db04..e48ce443d30 100644 --- a/src/optimizer/plan_generator.cpp +++ b/src/optimizer/plan_generator.cpp @@ -174,6 +174,88 @@ void PlanGenerator::Visit(const PhysicalDistinct *) { output_plan_ = move(hash_plan); } +void PlanGenerator::Visit(const PhysicalNLJoin *op) { + std::unique_ptr proj_info; + std::shared_ptr proj_schema; + GenerateProjectionForJoin(proj_info, proj_schema); + + auto join_predicate = + expression::ExpressionUtil::JoinAnnotatedExprs(op->join_predicates); + expression::ExpressionUtil::EvaluateExpression(children_expr_map_, + join_predicate.get()); + expression::ExpressionUtil::ConvertToTvExpr(join_predicate.get(), + children_expr_map_); + + vector left_keys; + vector right_keys; + for (auto &expr : op->left_keys) { + PL_ASSERT(children_expr_map_[0].find(expr.get()) != + children_expr_map_[0].end()); + left_keys.push_back(children_expr_map_[0][expr.get()]); + } + for (auto &expr : op->right_keys) { + PL_ASSERT(children_expr_map_[1].find(expr.get()) != + children_expr_map_[1].end()); + right_keys.emplace_back(children_expr_map_[1][expr.get()]); + } + + unique_ptr join_plan = + unique_ptr(new planner::NestedLoopJoinPlan( + op->type, move(join_predicate), move(proj_info), proj_schema, + left_keys, right_keys)); + + join_plan->AddChild(move(children_plans_[0])); + join_plan->AddChild(move(children_plans_[1])); + output_plan_ = move(join_plan); +} + +void PlanGenerator::Visit(const PhysicalHashJoin *op) { + std::unique_ptr proj_info; + std::shared_ptr proj_schema; + GenerateProjectionForJoin(proj_info, proj_schema); + + auto join_predicate = + expression::ExpressionUtil::JoinAnnotatedExprs(op->join_predicates); + expression::ExpressionUtil::EvaluateExpression(children_expr_map_, + join_predicate.get()); + expression::ExpressionUtil::ConvertToTvExpr(join_predicate.get(), + children_expr_map_); + + vector> left_keys; + vector> right_keys; + vector l_child_map{move(children_expr_map_[0])}; + vector r_child_map{move(children_expr_map_[1])}; + for (auto &expr : op->left_keys) { + auto left_key = expr->Copy(); + expression::ExpressionUtil::EvaluateExpression(l_child_map, left_key); + left_keys.emplace_back(left_key); + } + for (auto &expr : op->right_keys) { + auto right_key = expr->Copy(); + expression::ExpressionUtil::EvaluateExpression(r_child_map, right_key); + right_keys.emplace_back(right_key); + } + // Evaluate Expr for hash plan + vector> hash_keys; + for (auto &expr : op->right_keys) { + auto hash_key = expr->Copy(); + expression::ExpressionUtil::EvaluateExpression(r_child_map, hash_key); + hash_keys.emplace_back(hash_key); + } + + unique_ptr hash_plan(new planner::HashPlan(hash_keys)); + hash_plan->AddChild(move(children_plans_[1])); + + auto join_plan = unique_ptr(new planner::HashJoinPlan( + op->type, move(join_predicate), move(proj_info), proj_schema, + left_keys, right_keys, settings::SettingsManager::GetBool( + settings::SettingId::hash_join_bloom_filter))); + + join_plan->AddChild(move(children_plans_[0])); + join_plan->AddChild(move(hash_plan)); + output_plan_ = move(join_plan); +} + void PlanGenerator::Visit(const PhysicalInnerNLJoin *op) { std::unique_ptr proj_info; std::shared_ptr proj_schema; diff --git a/src/optimizer/query_to_operator_transformer.cpp b/src/optimizer/query_to_operator_transformer.cpp index bd8ddda612c..0bd68f61198 100644 --- a/src/optimizer/query_to_operator_transformer.cpp +++ b/src/optimizer/query_to_operator_transformer.cpp @@ -134,12 +134,13 @@ void QueryToOperatorTransformer::Visit(parser::JoinDefinition *node) { case JoinType::INNER: { predicates_ = CollectPredicates(node->condition.get(), predicates_); join_expr = - std::make_shared(LogicalJoin::make(LogicalJoin::JoinType::Inner)); + std::make_shared(LogicalJoin::make(JoinType::INNER)); break; } case JoinType::OUTER: { - join_expr = std::make_shared( - LogicalOuterJoin::make(node->condition->Copy())); + predicates_ = CollectPredicates(node->condition.get(), predicates_); + join_expr = + std::make_shared(LogicalJoin::make(JoinType::OUTER)); break; } case JoinType::LEFT: { diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index 1a36ebc8358..a5598e0d11f 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -41,6 +41,7 @@ RuleSet::RuleSet() { AddImplementationRule(new GetToIndexScan()); AddImplementationRule(new LogicalQueryDerivedGetToPhysical()); AddImplementationRule(new JoinToNLJoin()); + AddImplementationRule(new JoinToHashJoin()); AddImplementationRule(new InnerJoinToInnerNLJoin()); AddImplementationRule(new InnerJoinToInnerHashJoin()); AddImplementationRule(new ImplementDistinct()); diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index bb85519aeed..5d3e7ac4968 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -642,7 +642,7 @@ void JoinToNLJoin::Transform( std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { // first build an expression representing hash join - const LogicalJoin *inner_join = input->Op().As(); + const LogicalJoin *join = input->Op().As(); auto children = input->Children(); PL_ASSERT(children.size() == 2); @@ -655,13 +655,14 @@ void JoinToNLJoin::Transform( std::vector> left_keys; std::vector> right_keys; - util::ExtractEquiJoinKeys(inner_join->join_predicates, left_keys, right_keys, + util::ExtractEquiJoinKeys(join->join_predicates, left_keys, right_keys, left_group_alias, right_group_alias); PL_ASSERT(right_keys.size() == left_keys.size()); - std::shared_ptr result_plan = - std::make_shared(PhysicalInnerNLJoin::make( - inner_join->join_predicates, left_keys, right_keys)); + std::shared_ptr result_plan; + + result_plan = std::make_shared(PhysicalNLJoin::make( + join->type, join->join_predicates, left_keys, right_keys)); // Then push all children into the child list of the new operator result_plan->PushChild(children[0]); @@ -671,6 +672,67 @@ void JoinToNLJoin::Transform( return; } + + /////////////////////////////////////////////////////////////////////////////// +/// InnerJoinToInnerHashJoin +JoinToHashJoin::JoinToHashJoin() { + type_ = RuleType::INNER_JOIN_TO_HASH_JOIN; + + // Make three node types for pattern matching + std::shared_ptr left_child(std::make_shared(OpType::Leaf)); + std::shared_ptr right_child(std::make_shared(OpType::Leaf)); + + // Initialize a pattern for optimizer to match + match_pattern = std::make_shared(OpType::LogicalJoin); + + // Add node - we match join relation R and S as well as the predicate exp + match_pattern->AddChild(left_child); + match_pattern->AddChild(right_child); + + return; +} + +bool JoinToHashJoin::Check(std::shared_ptr plan, + OptimizeContext *context) const { + (void)context; + (void)plan; + return true; +} + +void JoinToHashJoin::Transform( + std::shared_ptr input, + std::vector> &transformed, + UNUSED_ATTRIBUTE OptimizeContext *context) const { + // first build an expression representing hash join + const LogicalJoin *join = input->Op().As(); + + auto children = input->Children(); + PL_ASSERT(children.size() == 2); + auto left_group_id = children[0]->Op().As()->origin_group; + auto right_group_id = children[1]->Op().As()->origin_group; + auto &left_group_alias = + context->metadata->memo.GetGroupByID(left_group_id)->GetTableAliases(); + auto &right_group_alias = + context->metadata->memo.GetGroupByID(right_group_id)->GetTableAliases(); + std::vector> left_keys; + std::vector> right_keys; + + util::ExtractEquiJoinKeys(join->join_predicates, left_keys, right_keys, + left_group_alias, right_group_alias); + + PL_ASSERT(right_keys.size() == left_keys.size()); + if (!left_keys.empty()) { + auto result_plan = + std::make_shared(PhysicalHashJoin::make( + join->type, join->join_predicates, left_keys, right_keys)); + + // Then push all children into the child list of the new operator + result_plan->PushChild(children[0]); + result_plan->PushChild(children[1]); + + transformed.push_back(result_plan); + } +} /////////////////////////////////////////////////////////////////////////////// /// InnerJoinToInnerNLJoin From 9023039049e56379a675af5c1b868dab20acac24 Mon Sep 17 00:00:00 2001 From: Pedro Miguel Reis Bento Paredes Date: Tue, 27 Mar 2018 21:16:28 -0400 Subject: [PATCH 03/26] Support left and right outer joins --- src/optimizer/query_to_operator_transformer.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/optimizer/query_to_operator_transformer.cpp b/src/optimizer/query_to_operator_transformer.cpp index 0bd68f61198..36722a98c98 100644 --- a/src/optimizer/query_to_operator_transformer.cpp +++ b/src/optimizer/query_to_operator_transformer.cpp @@ -144,13 +144,15 @@ void QueryToOperatorTransformer::Visit(parser::JoinDefinition *node) { break; } case JoinType::LEFT: { - join_expr = std::make_shared( - LogicalLeftJoin::make(node->condition->Copy())); + predicates_ = CollectPredicates(node->condition.get(), predicates_); + join_expr = + std::make_shared(LogicalJoin::make(JoinType::LEFT)); break; } case JoinType::RIGHT: { - join_expr = std::make_shared( - LogicalRightJoin::make(node->condition->Copy())); + predicates_ = CollectPredicates(node->condition.get(), predicates_); + join_expr = + std::make_shared(LogicalJoin::make(JoinType::RIGHT)); break; } case JoinType::SEMI: { From 8a7ab6b06452964f13c341f808cdfcef15160766 Mon Sep 17 00:00:00 2001 From: Pedro Miguel Reis Bento Paredes Date: Wed, 28 Mar 2018 14:43:11 -0400 Subject: [PATCH 04/26] Add commutative rule to new join operator --- src/include/common/internal_types.h | 3 +- src/include/optimizer/rule_impls.h | 15 +++++++++ src/optimizer/rule_impls.cpp | 47 +++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 1 deletion(-) diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 132d74aeb75..412837609d0 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1324,7 +1324,8 @@ std::ostream &operator<<(std::ostream &os, const PropertyType &type); enum class RuleType : uint32_t { // Transformation rules (logical -> logical) - INNER_JOIN_COMMUTE = 0, + JOIN_COMMUTE = 0, + INNER_JOIN_COMMUTE, INNER_JOIN_ASSOCIATE, // Don't move this one diff --git a/src/include/optimizer/rule_impls.h b/src/include/optimizer/rule_impls.h index e0e67db73c6..649422b7962 100644 --- a/src/include/optimizer/rule_impls.h +++ b/src/include/optimizer/rule_impls.h @@ -23,6 +23,21 @@ namespace optimizer { // Transformation rules //===--------------------------------------------------------------------===// +/** + * @brief (A join B) -> (B join A) + */ +class JoinCommutativity : public Rule { + public: + JoinCommutativity(); + + bool Check(std::shared_ptr plan, + OptimizeContext *context) const override; + + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + /** * @brief (A join B) -> (B join A) */ diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index 5d3e7ac4968..168d0c92cf5 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -29,6 +29,53 @@ namespace optimizer { // Transformation rules //===--------------------------------------------------------------------===// +/////////////////////////////////////////////////////////////////////////////// +/// JoinCommutativity +JoinCommutativity::JoinCommutativity() { + type_ = RuleType::JOIN_COMMUTE; + + std::shared_ptr left_child(std::make_shared(OpType::Leaf)); + std::shared_ptr right_child(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared(OpType::LogicalJoin); + match_pattern->AddChild(left_child); + match_pattern->AddChild(right_child); +} + +bool JoinCommutativity::Check(std::shared_ptr expr, + OptimizeContext *context) const { + (void)context; + (void)expr; + return true; +} + +void JoinCommutativity::Transform( + std::shared_ptr input, + std::vector> &transformed, + UNUSED_ATTRIBUTE OptimizeContext *context) const { + auto join_op = input->Op().As(); + auto join_predicates = + std::vector(join_op->join_predicates); + + auto join_type = join_op->type; + if (join_type == JoinType::LEFT) { + join_type = JoinType::RIGHT; + } else if (join_type == JoinType::RIGHT) { + join_type = JoinType::LEFT; + } + + auto result_plan = std::make_shared( + LogicalJoin::make(join_type, join_predicates)); + std::vector> children = input->Children(); + PL_ASSERT(children.size() == 2); + LOG_TRACE( + "Reorder left child with op %s and right child with op %s for inner join", + children[0]->Op().GetName().c_str(), children[1]->Op().GetName().c_str()); + result_plan->PushChild(children[1]); + result_plan->PushChild(children[0]); + + transformed.push_back(result_plan); +} + /////////////////////////////////////////////////////////////////////////////// /// InnerJoinCommutativity InnerJoinCommutativity::InnerJoinCommutativity() { From a8be02ebe55d7b353f170487a57d54bce6a1312c Mon Sep 17 00:00:00 2001 From: Pedro Miguel Reis Bento Paredes Date: Thu, 29 Mar 2018 14:06:49 -0400 Subject: [PATCH 05/26] Fix semi and mark join rules to new inner join operator --- src/include/common/internal_types.h | 1 + src/optimizer/rule.cpp | 1 + src/optimizer/rule_impls.cpp | 12 ++++++------ 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 412837609d0..3f889374689 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1343,6 +1343,7 @@ enum class RuleType : uint32_t { AGGREGATE_TO_HASH_AGGREGATE, AGGREGATE_TO_PLAIN_AGGREGATE, JOIN_TO_NL_JOIN, + JOIN_TO_HASH_JOIN, INNER_JOIN_TO_NL_JOIN, INNER_JOIN_TO_HASH_JOIN, IMPLEMENT_DISTINCT, diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index a5598e0d11f..8945a4c94ba 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -28,6 +28,7 @@ int Rule::Promise(GroupExpression *group_expr, OptimizeContext *context) const { } RuleSet::RuleSet() { + AddTransformationRule(new JoinCommutativity()); AddTransformationRule(new InnerJoinCommutativity()); AddTransformationRule(new InnerJoinAssociativity()); AddImplementationRule(new LogicalDeleteToPhysical()); diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index 168d0c92cf5..a0bba6360c8 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -68,7 +68,7 @@ void JoinCommutativity::Transform( std::vector> children = input->Children(); PL_ASSERT(children.size() == 2); LOG_TRACE( - "Reorder left child with op %s and right child with op %s for inner join", + "Reorder left child with op %s and right child with op %s for join", children[0]->Op().GetName().c_str(), children[1]->Op().GetName().c_str()); result_plan->PushChild(children[1]); result_plan->PushChild(children[0]); @@ -720,10 +720,10 @@ void JoinToNLJoin::Transform( return; } - /////////////////////////////////////////////////////////////////////////////// -/// InnerJoinToInnerHashJoin +/////////////////////////////////////////////////////////////////////////////// +/// JoinToInnerHashJoin JoinToHashJoin::JoinToHashJoin() { - type_ = RuleType::INNER_JOIN_TO_HASH_JOIN; + type_ = RuleType::JOIN_TO_HASH_JOIN; // Make three node types for pattern matching std::shared_ptr left_child(std::make_shared(OpType::Leaf)); @@ -1257,7 +1257,7 @@ void MarkJoinToInnerJoin::Transform( PELOTON_ASSERT(mark_join->join_predicates.empty()); std::shared_ptr output = - std::make_shared(LogicalInnerJoin::make()); + std::make_shared(LogicalJoin::make(JoinType::INNER)); output->PushChild(join_children[0]); output->PushChild(join_children[1]); @@ -1308,7 +1308,7 @@ void SingleJoinToInnerJoin::Transform( PELOTON_ASSERT(single_join->join_predicates.empty()); std::shared_ptr output = - std::make_shared(LogicalInnerJoin::make()); + std::make_shared(LogicalJoin::make(JoinType::INNER)); output->PushChild(join_children[0]); output->PushChild(join_children[1]); From d31a39d01a8e0acfe6ae1d305d07329b0b5de733 Mon Sep 17 00:00:00 2001 From: Guoquan Zhao Date: Thu, 5 Apr 2018 17:15:17 -0400 Subject: [PATCH 06/26] fix assertion macro --- src/optimizer/plan_generator.cpp | 4 ++-- src/optimizer/rule_impls.cpp | 10 +++++----- src/optimizer/stats_calculator.cpp | 6 +++--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/optimizer/plan_generator.cpp b/src/optimizer/plan_generator.cpp index e48ce443d30..c0fe4efe1ea 100644 --- a/src/optimizer/plan_generator.cpp +++ b/src/optimizer/plan_generator.cpp @@ -189,12 +189,12 @@ void PlanGenerator::Visit(const PhysicalNLJoin *op) { vector left_keys; vector right_keys; for (auto &expr : op->left_keys) { - PL_ASSERT(children_expr_map_[0].find(expr.get()) != + PELOTON_ASSERT(children_expr_map_[0].find(expr.get()) != children_expr_map_[0].end()); left_keys.push_back(children_expr_map_[0][expr.get()]); } for (auto &expr : op->right_keys) { - PL_ASSERT(children_expr_map_[1].find(expr.get()) != + PELOTON_ASSERT(children_expr_map_[1].find(expr.get()) != children_expr_map_[1].end()); right_keys.emplace_back(children_expr_map_[1][expr.get()]); } diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index a0bba6360c8..09fe8ad8242 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -66,7 +66,7 @@ void JoinCommutativity::Transform( auto result_plan = std::make_shared( LogicalJoin::make(join_type, join_predicates)); std::vector> children = input->Children(); - PL_ASSERT(children.size() == 2); + PELOTON_ASSERT(children.size() == 2); LOG_TRACE( "Reorder left child with op %s and right child with op %s for join", children[0]->Op().GetName().c_str(), children[1]->Op().GetName().c_str()); @@ -692,7 +692,7 @@ void JoinToNLJoin::Transform( const LogicalJoin *join = input->Op().As(); auto children = input->Children(); - PL_ASSERT(children.size() == 2); + PELOTON_ASSERT(children.size() == 2); auto left_group_id = children[0]->Op().As()->origin_group; auto right_group_id = children[1]->Op().As()->origin_group; auto &left_group_alias = @@ -705,7 +705,7 @@ void JoinToNLJoin::Transform( util::ExtractEquiJoinKeys(join->join_predicates, left_keys, right_keys, left_group_alias, right_group_alias); - PL_ASSERT(right_keys.size() == left_keys.size()); + PELOTON_ASSERT(right_keys.size() == left_keys.size()); std::shared_ptr result_plan; result_plan = std::make_shared(PhysicalNLJoin::make( @@ -754,7 +754,7 @@ void JoinToHashJoin::Transform( const LogicalJoin *join = input->Op().As(); auto children = input->Children(); - PL_ASSERT(children.size() == 2); + PELOTON_ASSERT(children.size() == 2); auto left_group_id = children[0]->Op().As()->origin_group; auto right_group_id = children[1]->Op().As()->origin_group; auto &left_group_alias = @@ -767,7 +767,7 @@ void JoinToHashJoin::Transform( util::ExtractEquiJoinKeys(join->join_predicates, left_keys, right_keys, left_group_alias, right_group_alias); - PL_ASSERT(right_keys.size() == left_keys.size()); + PELOTON_ASSERT(right_keys.size() == left_keys.size()); if (!left_keys.empty()) { auto result_plan = std::make_shared(PhysicalHashJoin::make( diff --git a/src/optimizer/stats_calculator.cpp b/src/optimizer/stats_calculator.cpp index 59beb15800f..94a22a30b7d 100644 --- a/src/optimizer/stats_calculator.cpp +++ b/src/optimizer/stats_calculator.cpp @@ -98,7 +98,7 @@ void StatsCalculator::Visit(const LogicalQueryDerivedGet *) { void StatsCalculator::Visit(const LogicalJoin *op) { // Check if there's join condition - PL_ASSERT(gexpr_->GetChildrenGroupsSize() == 2); + PELOTON_ASSERT(gexpr_->GetChildrenGroupsSize() == 2); auto left_child_group = memo_->GetGroupByID(gexpr_->GetChildGroupId(0)); auto right_child_group = memo_->GetGroupByID(gexpr_->GetChildGroupId(1)); auto root_group = memo_->GetGroupByID(gexpr_->GetGroupID()); @@ -135,7 +135,7 @@ void StatsCalculator::Visit(const LogicalJoin *op) { } size_t num_rows = root_group->GetNumRows(); for (auto &col : required_cols_) { - PL_ASSERT(col->GetExpressionType() == ExpressionType::VALUE_TUPLE); + PELOTON_ASSERT(col->GetExpressionType() == ExpressionType::VALUE_TUPLE); auto tv_expr = reinterpret_cast(col); std::shared_ptr column_stats; // Make a copy from the child stats @@ -143,7 +143,7 @@ void StatsCalculator::Visit(const LogicalJoin *op) { column_stats = std::make_shared( *left_child_group->GetStats(tv_expr->GetColFullName())); } else { - PL_ASSERT(right_child_group->HasColumnStats(tv_expr->GetColFullName())); + PELOTON_ASSERT(right_child_group->HasColumnStats(tv_expr->GetColFullName())); column_stats = std::make_shared( *right_child_group->GetStats(tv_expr->GetColFullName())); } From f653ac49c0bac8283e50658b6cc2d9982d7aae4f Mon Sep 17 00:00:00 2001 From: Irene Qiuwen Kai Date: Sat, 7 Apr 2018 00:49:24 -0400 Subject: [PATCH 07/26] Add naive associativity rule for LogicalJoin. --- src/include/common/internal_types.h | 1 + src/include/optimizer/rule_impls.h | 15 ++++ src/optimizer/rule.cpp | 1 + src/optimizer/rule_impls.cpp | 107 ++++++++++++++++++++++++++++ 4 files changed, 124 insertions(+) diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 3f889374689..58aa75b09cc 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1325,6 +1325,7 @@ std::ostream &operator<<(std::ostream &os, const PropertyType &type); enum class RuleType : uint32_t { // Transformation rules (logical -> logical) JOIN_COMMUTE = 0, + JOIN_ASSOCIATE, INNER_JOIN_COMMUTE, INNER_JOIN_ASSOCIATE, diff --git a/src/include/optimizer/rule_impls.h b/src/include/optimizer/rule_impls.h index 649422b7962..a749675ca2c 100644 --- a/src/include/optimizer/rule_impls.h +++ b/src/include/optimizer/rule_impls.h @@ -38,6 +38,21 @@ class JoinCommutativity : public Rule { OptimizeContext *context) const override; }; +/** + * @brief (A join B) join C -> A join (B join C) + */ +class JoinAssociativity : public Rule { + public: + JoinAssociativity(); + + bool Check(std::shared_ptr plan, + OptimizeContext *context) const override; + + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + /** * @brief (A join B) -> (B join A) */ diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index 8945a4c94ba..36d5a7e27b7 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -29,6 +29,7 @@ int Rule::Promise(GroupExpression *group_expr, OptimizeContext *context) const { RuleSet::RuleSet() { AddTransformationRule(new JoinCommutativity()); + AddTransformationRule(new JoinAssociativity()); AddTransformationRule(new InnerJoinCommutativity()); AddTransformationRule(new InnerJoinAssociativity()); AddImplementationRule(new LogicalDeleteToPhysical()); diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index 09fe8ad8242..85612c82cd0 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -76,6 +76,113 @@ void JoinCommutativity::Transform( transformed.push_back(result_plan); } +/////////////////////////////////////////////////////////////////////////////// +/// JoinAssociativity +JoinAssociativity::JoinAssociativity() { + type_ = RuleType::JOIN_ASSOCIATE; + + // Create left nested join + auto left_child = std::make_shared(OpType::LogicalJoin); + left_child->AddChild(std::make_shared(OpType::Leaf)); + left_child->AddChild(std::make_shared(OpType::Leaf)); + + std::shared_ptr right_child(std::make_shared(OpType::Leaf)); + + match_pattern = std::make_shared(OpType::LogicalJoin); + match_pattern->AddChild(left_child); + match_pattern->AddChild(right_child); +} + +bool JoinAssociativity::Check(std::shared_ptr expr, + OptimizeContext *context) const { + (void)context; + auto parent_join = expr->Op().As(); + std::vector> children = expr->Children(); + auto child_join = children[0]->Op().As(); + return (parent_join->type == child_join->type); +} + +void JoinAssociativity::Transform( + std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + // NOTE: Transforms (left JOIN middle) JOIN right -> left JOIN (middle JOIN + // right) Variables are named accordingly to above transformation + auto parent_join = input->Op().As(); + std::vector> children = input->Children(); + PELOTON_ASSERT(children.size() == 2); + PELOTON_ASSERT(children[0]->Op().GetType() == OpType::LogicalJoin); + PELOTON_ASSERT(children[0]->Children().size() == 2); + auto child_join = children[0]->Op().As(); + auto left = children[0]->Children()[0]; + auto middle = children[0]->Children()[1]; + auto right = children[1]; + + LOG_DEBUG("Reordered join structured: (%s JOIN %s) JOIN %s", + left->Op().GetName().c_str(), middle->Op().GetName().c_str(), + right->Op().GetName().c_str()); + + // Get Alias sets + auto &memo = context->metadata->memo; + auto middle_group_id = middle->Op().As()->origin_group; + auto right_group_id = right->Op().As()->origin_group; + + const auto &middle_group_aliases_set = + memo.GetGroupByID(middle_group_id)->GetTableAliases(); + const auto &right_group_aliases_set = + memo.GetGroupByID(right_group_id)->GetTableAliases(); + + // Union Predicates into single alias set for new child join + std::unordered_set right_join_aliases_set; + right_join_aliases_set.insert(middle_group_aliases_set.begin(), + middle_group_aliases_set.end()); + right_join_aliases_set.insert(right_group_aliases_set.begin(), + right_group_aliases_set.end()); + + // Redistribute predicates + auto parent_join_predicates = + std::vector(parent_join->join_predicates); + auto child_join_predicates = + std::vector(child_join->join_predicates); + + std::vector predicates; + predicates.insert(predicates.end(), parent_join_predicates.begin(), + parent_join_predicates.end()); + predicates.insert(predicates.end(), child_join_predicates.begin(), + child_join_predicates.end()); + + std::vector new_child_join_predicates; + std::vector new_parent_join_predicates; + + for (auto predicate : predicates) { + if (util::IsSubset(right_join_aliases_set, predicate.table_alias_set)) { + new_child_join_predicates.emplace_back(predicate); + } else { + new_parent_join_predicates.emplace_back(predicate); + } + } + + JoinType new_parent_join_type; + JoinType new_child_join_type; + new_parent_join_type = parent_join->type; + new_child_join_type = child_join->type; + // Construct new child join operator + std::shared_ptr new_child_join = + std::make_shared( + LogicalJoin::make(new_child_join_type, new_child_join_predicates)); + new_child_join->PushChild(middle); + new_child_join->PushChild(right); + + // Construct new parent join operator + std::shared_ptr new_parent_join = + std::make_shared( + LogicalJoin::make(new_parent_join_type, new_parent_join_predicates)); + new_parent_join->PushChild(left); + new_parent_join->PushChild(new_child_join); + + transformed.push_back(new_parent_join); +} + /////////////////////////////////////////////////////////////////////////////// /// InnerJoinCommutativity InnerJoinCommutativity::InnerJoinCommutativity() { From 8b08e7d73f598c156f70cd380b71ff50b7c4aa12 Mon Sep 17 00:00:00 2001 From: Guoquan Zhao Date: Sat, 7 Apr 2018 18:58:52 -0400 Subject: [PATCH 08/26] create junit test caes (setup) --- script/testing/junit/OptimizerTest.java | 81 ++++++++++++++++++++++++ script/testing/junit/create_tables_1.sql | 8 +++ 2 files changed, 89 insertions(+) create mode 100644 script/testing/junit/OptimizerTest.java create mode 100644 script/testing/junit/create_tables_1.sql diff --git a/script/testing/junit/OptimizerTest.java b/script/testing/junit/OptimizerTest.java new file mode 100644 index 00000000000..cfab0917ccd --- /dev/null +++ b/script/testing/junit/OptimizerTest.java @@ -0,0 +1,81 @@ +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.BufferedReader; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.io.IOException; +import java.sql.*; + +import static java.sql.Statement.EXECUTE_FAILED; +import static java.sql.Statement.SUCCESS_NO_INFO; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Created by Guoquan Zhao on 4/7/18. + */ +public class OptimizerTest extends PLTestBase { + private static final String[] SQL_DROP_TABLES = + {"DROP TABLE IF EXISTS t1;", + "DROP TABLE IF EXISTS t2;"}; + private Connection conn; + private void initTables1() throws FileNotFoundException, SQLException { + try(BufferedReader reader = new BufferedReader(new FileReader("create_tables_1.sql")); + Statement stmt = conn.createStatement();){ + reader.lines().forEach(s -> { + try { + stmt.addBatch(s); + } catch (SQLException e) { + e.printStackTrace(); + } + }); + int[] results = stmt.executeBatch(); + for (int i = 0; i < results.length; i++) { + assertTrue("batch failed.", (results[i] >= 0 || results[i] == SUCCESS_NO_INFO) && results[i] != EXECUTE_FAILED); + } + ResultSet resultSet = stmt.executeQuery("SELECT COUNT(*) FROM t1;"); + assertEquals(3, resultSet.getInt(0)); + resultSet.close(); + resultSet = stmt.executeQuery("SELECT COUNT(*) FROM t2;"); + assertEquals(3, resultSet.getInt(0)); + resultSet.close(); + } catch (IOException e) { + e.printStackTrace(); + } + + + } + + + @Before + public void Setup() { + try { + conn = makeDefaultConnection(); + conn.setAutoCommit(true); + initTables1(); + } catch (SQLException ex) { + DumpSQLException(ex); + // throw ex; + } catch (FileNotFoundException e) { + e.printStackTrace(); + } + } + + @After + public void Teardown() throws SQLException { + Statement stmt = conn.createStatement(); + for (String s : SQL_DROP_TABLES) { + stmt.execute(s); + } + } + + + @Test + public void testJoin1() throws SQLException { + + + } + +} diff --git a/script/testing/junit/create_tables_1.sql b/script/testing/junit/create_tables_1.sql new file mode 100644 index 00000000000..b5b876889f5 --- /dev/null +++ b/script/testing/junit/create_tables_1.sql @@ -0,0 +1,8 @@ +CREATE TABLE t1(a INT,b INT,c INT); +INSERT INTO t1 VALUES(1,2,3); +INSERT INTO t1 VALUES(2,3,4); +INSERT INTO t1 VALUES(3,4,5); +CREATE TABLE t2(b INT,c INT,d INT); +INSERT INTO t2 VALUES(1,2,3); +INSERT INTO t2 VALUES(2,3,4); +INSERT INTO t2 VALUES(3,4,5); \ No newline at end of file From 4a7d997507834dcd9a399b6e77b9081b5f1f77a3 Mon Sep 17 00:00:00 2001 From: Guoquan Zhao Date: Sun, 8 Apr 2018 00:55:35 -0400 Subject: [PATCH 09/26] add junit test cases --- script/testing/junit/OptimizerTest.java | 126 ++++++++++++++++++++++-- 1 file changed, 118 insertions(+), 8 deletions(-) diff --git a/script/testing/junit/OptimizerTest.java b/script/testing/junit/OptimizerTest.java index cfab0917ccd..5db13511273 100644 --- a/script/testing/junit/OptimizerTest.java +++ b/script/testing/junit/OptimizerTest.java @@ -10,8 +10,7 @@ import static java.sql.Statement.EXECUTE_FAILED; import static java.sql.Statement.SUCCESS_NO_INFO; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * Created by Guoquan Zhao on 4/7/18. @@ -21,9 +20,10 @@ public class OptimizerTest extends PLTestBase { {"DROP TABLE IF EXISTS t1;", "DROP TABLE IF EXISTS t2;"}; private Connection conn; + private void initTables1() throws FileNotFoundException, SQLException { - try(BufferedReader reader = new BufferedReader(new FileReader("create_tables_1.sql")); - Statement stmt = conn.createStatement();){ + try (BufferedReader reader = new BufferedReader(new FileReader("create_tables_1.sql")); + Statement stmt = conn.createStatement();) { reader.lines().forEach(s -> { try { stmt.addBatch(s); @@ -36,10 +36,12 @@ private void initTables1() throws FileNotFoundException, SQLException { assertTrue("batch failed.", (results[i] >= 0 || results[i] == SUCCESS_NO_INFO) && results[i] != EXECUTE_FAILED); } ResultSet resultSet = stmt.executeQuery("SELECT COUNT(*) FROM t1;"); - assertEquals(3, resultSet.getInt(0)); + resultSet.next(); + assertEquals(3, resultSet.getInt(1)); resultSet.close(); resultSet = stmt.executeQuery("SELECT COUNT(*) FROM t2;"); - assertEquals(3, resultSet.getInt(0)); + resultSet.next(); + assertEquals(3, resultSet.getInt(1)); resultSet.close(); } catch (IOException e) { e.printStackTrace(); @@ -57,7 +59,6 @@ public void Setup() { initTables1(); } catch (SQLException ex) { DumpSQLException(ex); - // throw ex; } catch (FileNotFoundException e) { e.printStackTrace(); } @@ -73,9 +74,118 @@ public void Teardown() throws SQLException { @Test - public void testJoin1() throws SQLException { + public void testInnerJoin() throws SQLException { + try ( + Statement stmt = conn.createStatement(); + ResultSet resultSet = stmt.executeQuery("SELECT t1.a FROM t1 INNER JOIN t2 ON (t1.b = t2.b) ORDER BY t1.a;");) { + assertTrue(resultSet.next()); + assertEquals(1, resultSet.getInt(1)); + assertTrue(resultSet.next()); + assertEquals(2, resultSet.getInt(1)); + assertFalse(resultSet.next()); + } catch (Exception e) { + e.printStackTrace(); + fail(); + } + try ( + Statement stmt = conn.createStatement(); + ResultSet resultSet = stmt.executeQuery("SELECT x.a FROM t1 AS x INNER JOIN t2 ON(x.b = t2.b AND x.c = t2.c) ORDER BY x.a;");) { + assertTrue(resultSet.next()); + assertEquals(1, resultSet.getInt(1)); + assertTrue(resultSet.next()); + assertEquals(2, resultSet.getInt(1)); + assertFalse(resultSet.next()); + } catch (Exception e) { + e.printStackTrace(); + fail(); + } } + @Test + public void testLeftOuterJoin() throws SQLException { + try ( + Statement stmt = conn.createStatement(); + ResultSet resultSet = stmt.executeQuery("SELECT * FROM t1 LEFT JOIN t2 ON t1.a=t2.d;");) { + assertTrue(resultSet.next()); + assertEquals(3, resultSet.getInt(4)); + assertEquals(4, resultSet.getInt(5)); + assertEquals(5, resultSet.getInt(6)); + assertEquals(1, resultSet.getInt(1)); + assertEquals(2, resultSet.getInt(2)); + assertEquals(3, resultSet.getInt(3)); + assertTrue(resultSet.next()); + assertEquals(null, resultSet.getObject(1)); + assertEquals(null, resultSet.getObject(2)); + assertEquals(null, resultSet.getObject(3)); + assertTrue(resultSet.next()); + assertEquals(null, resultSet.getObject(1)); + assertEquals(null, resultSet.getObject(2)); + assertEquals(null, resultSet.getObject(3)); + assertFalse(resultSet.next()); + } catch (Exception e) { + e.printStackTrace(); + fail(); + } + try ( + Statement stmt = conn.createStatement(); + ResultSet resultSet = stmt.executeQuery("SELECT * FROM t1 LEFT JOIN t2 ON t1.a=t2.d WHERE t1.a>1")) { + assertTrue(resultSet.next()); + assertEquals(3, resultSet.getInt(4)); + assertEquals(4, resultSet.getInt(5)); + assertEquals(5, resultSet.getInt(6)); + assertEquals(1, resultSet.getInt(1)); + assertEquals(2, resultSet.getInt(2)); + assertEquals(3, resultSet.getInt(3)); + assertTrue(resultSet.next()); + assertEquals(null, resultSet.getObject(1)); + assertEquals(null, resultSet.getObject(2)); + assertEquals(null, resultSet.getObject(3)); + assertFalse(resultSet.next()); + } catch (Exception e) { + e.printStackTrace(); + fail(); + } + try ( + Statement stmt = conn.createStatement(); + ResultSet resultSet = stmt.executeQuery("SELECT * FROM t1 LEFT OUTER JOIN t2 ON t1.a=t2.d WHERE t1.a>1")) { + assertTrue(resultSet.next()); + assertEquals(3, resultSet.getInt(4)); + assertEquals(4, resultSet.getInt(5)); + assertEquals(5, resultSet.getInt(6)); + assertEquals(1, resultSet.getInt(1)); + assertEquals(2, resultSet.getInt(2)); + assertEquals(3, resultSet.getInt(3)); + assertTrue(resultSet.next()); + assertEquals(null, resultSet.getObject(1)); + assertEquals(null, resultSet.getObject(2)); + assertEquals(null, resultSet.getObject(3)); + assertFalse(resultSet.next()); + } catch (Exception e) { + e.printStackTrace(); + fail(); + } + + } + + @Test + public void testLeftOuterJoinWhere() { + try ( + Statement stmt = conn.createStatement(); + ResultSet resultSet = stmt.executeQuery("SELECT * FROM t1 LEFT JOIN t2 ON t1.a=t2.d WHERE t2.b IS NULL OR t2.b>1")) { + // expected result is + // t1 t2 + // 1 2 3 {} {} {} + // 2 3 4 {} {} {} + assertTrue(resultSet.next()); + assertTrue(resultSet.next()); + assertFalse(resultSet.next()); + } catch (Exception e) { + e.printStackTrace(); + fail(); + } + } + + } From 46cb4471d143e3e148ad6259cceda0e4f5258ee1 Mon Sep 17 00:00:00 2001 From: Guoquan Zhao Date: Tue, 10 Apr 2018 02:00:03 -0400 Subject: [PATCH 10/26] * remove Logical operator related to "Inner" and "Outer" since we use one operator with a JoinType to represent them all. * Clang-format. * Update the optimizer_rule_test and optimizer_test to use new operators. --- src/include/common/internal_types.h | 2 - src/include/optimizer/child_stats_deriver.h | 11 +- src/include/optimizer/operator_node.h | 4 - src/include/optimizer/operator_visitor.h | 4 - src/include/optimizer/operators.h | 89 ++---- src/include/optimizer/rule_impls.h | 64 +--- src/include/optimizer/stats_calculator.h | 14 +- src/optimizer/child_stats_deriver.cpp | 15 +- src/optimizer/operators.cpp | 105 +------ .../query_to_operator_transformer.cpp | 20 +- src/optimizer/rule.cpp | 4 - src/optimizer/rule_impls.cpp | 297 +----------------- src/optimizer/stats_calculator.cpp | 94 +----- test/optimizer/optimizer_rule_test.cpp | 36 +-- test/optimizer/optimizer_test.cpp | 9 +- 15 files changed, 111 insertions(+), 657 deletions(-) diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 58aa75b09cc..dee1ffd7ac3 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1326,8 +1326,6 @@ enum class RuleType : uint32_t { // Transformation rules (logical -> logical) JOIN_COMMUTE = 0, JOIN_ASSOCIATE, - INNER_JOIN_COMMUTE, - INNER_JOIN_ASSOCIATE, // Don't move this one LogicalPhysicalDelimiter, diff --git a/src/include/optimizer/child_stats_deriver.h b/src/include/optimizer/child_stats_deriver.h index c3513faa832..cca76ba7071 100644 --- a/src/include/optimizer/child_stats_deriver.h +++ b/src/include/optimizer/child_stats_deriver.h @@ -27,22 +27,17 @@ class Memo; // expression class ChildStatsDeriver : public OperatorVisitor { public: - std::vector DeriveInputStats( - GroupExpression *gexpr, - ExprSet required_cols, Memo *memo); + std::vector DeriveInputStats(GroupExpression *gexpr, + ExprSet required_cols, Memo *memo); void Visit(const LogicalQueryDerivedGet *) override; void Visit(const LogicalJoin *) override; - void Visit(const LogicalInnerJoin *) override; - void Visit(const LogicalLeftJoin *) override; - void Visit(const LogicalRightJoin *) override; - void Visit(const LogicalOuterJoin *) override; void Visit(const LogicalSemiJoin *) override; void Visit(const LogicalAggregateAndGroupBy *) override; private: void PassDownRequiredCols(); - void PassDownColumn(expression::AbstractExpression* col); + void PassDownColumn(expression::AbstractExpression *col); ExprSet required_cols_; GroupExpression *gexpr_; Memo *memo_; diff --git a/src/include/optimizer/operator_node.h b/src/include/optimizer/operator_node.h index a9ff68dca23..3692d629f44 100644 --- a/src/include/optimizer/operator_node.h +++ b/src/include/optimizer/operator_node.h @@ -34,10 +34,6 @@ enum class OpType { LogicalDependentJoin, LogicalSingleJoin, LogicalJoin, - InnerJoin, - LeftJoin, - RightJoin, - OuterJoin, SemiJoin, LogicalAggregateAndGroupBy, LogicalInsert, diff --git a/src/include/optimizer/operator_visitor.h b/src/include/optimizer/operator_visitor.h index f39bfc88d4c..1644e26c1ff 100644 --- a/src/include/optimizer/operator_visitor.h +++ b/src/include/optimizer/operator_visitor.h @@ -61,10 +61,6 @@ class OperatorVisitor { virtual void Visit(const LogicalSingleJoin *) {} virtual void Visit(const LogicalDependentJoin *) {} virtual void Visit(const LogicalJoin *) {} - virtual void Visit(const LogicalInnerJoin *) {} - virtual void Visit(const LogicalLeftJoin *) {} - virtual void Visit(const LogicalRightJoin *) {} - virtual void Visit(const LogicalOuterJoin *) {} virtual void Visit(const LogicalSemiJoin *) {} virtual void Visit(const LogicalAggregateAndGroupBy *) {} virtual void Visit(const LogicalInsert *) {} diff --git a/src/include/optimizer/operators.h b/src/include/optimizer/operators.h index d30eb9b196f..0482afba5aa 100644 --- a/src/include/optimizer/operators.h +++ b/src/include/optimizer/operators.h @@ -31,7 +31,7 @@ class UpdateClause; } namespace catalog { - class TableCatalogObject; +class TableCatalogObject; } namespace optimizer { @@ -51,10 +51,10 @@ class LeafOperator : OperatorNode { //===--------------------------------------------------------------------===// class LogicalGet : public OperatorNode { public: - static Operator make(oid_t get_id = 0, - std::vector predicates = {}, - std::shared_ptr table = nullptr, - std::string alias = "", bool update = false); + static Operator make( + oid_t get_id = 0, std::vector predicates = {}, + std::shared_ptr table = nullptr, + std::string alias = "", bool update = false); bool operator==(const BaseOperatorNode &r) override; @@ -169,7 +169,8 @@ class LogicalJoin : public OperatorNode { public: static Operator make(JoinType _type); - static Operator make(JoinType _type, std::vector &conditions); + static Operator make(JoinType _type, + std::vector &conditions); bool operator==(const BaseOperatorNode &r) override; @@ -179,52 +180,6 @@ class LogicalJoin : public OperatorNode { JoinType type; }; -//===--------------------------------------------------------------------===// -// InnerJoin -//===--------------------------------------------------------------------===// -class LogicalInnerJoin : public OperatorNode { - public: - static Operator make(); - - static Operator make(std::vector &conditions); - - bool operator==(const BaseOperatorNode &r) override; - - hash_t Hash() const override; - - std::vector join_predicates; -}; - -//===--------------------------------------------------------------------===// -// LeftJoin -//===--------------------------------------------------------------------===// -class LogicalLeftJoin : public OperatorNode { - public: - static Operator make(expression::AbstractExpression *condition = nullptr); - - std::shared_ptr join_predicate; -}; - -//===--------------------------------------------------------------------===// -// RightJoin -//===--------------------------------------------------------------------===// -class LogicalRightJoin : public OperatorNode { - public: - static Operator make(expression::AbstractExpression *condition = nullptr); - - std::shared_ptr join_predicate; -}; - -//===--------------------------------------------------------------------===// -// OuterJoin -//===--------------------------------------------------------------------===// -class LogicalOuterJoin : public OperatorNode { - public: - static Operator make(expression::AbstractExpression *condition = nullptr); - - std::shared_ptr join_predicate; -}; - //===--------------------------------------------------------------------===// // SemiJoin //===--------------------------------------------------------------------===// @@ -263,7 +218,8 @@ class LogicalAggregateAndGroupBy class LogicalInsert : public OperatorNode { public: static Operator make( - std::shared_ptr target_table, const std::vector *columns, + std::shared_ptr target_table, + const std::vector *columns, const std::vector>> *values); @@ -275,7 +231,8 @@ class LogicalInsert : public OperatorNode { class LogicalInsertSelect : public OperatorNode { public: - static Operator make(std::shared_ptr target_table); + static Operator make( + std::shared_ptr target_table); std::shared_ptr target_table; }; @@ -303,7 +260,8 @@ class LogicalLimit : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalDelete : public OperatorNode { public: - static Operator make(std::shared_ptr target_table); + static Operator make( + std::shared_ptr target_table); std::shared_ptr target_table; }; @@ -334,7 +292,8 @@ class DummyScan : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalSeqScan : public OperatorNode { public: - static Operator make(oid_t get_id, std::shared_ptr table, + static Operator make(oid_t get_id, + std::shared_ptr table, std::string alias, std::vector predicates, bool update); @@ -356,7 +315,8 @@ class PhysicalSeqScan : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalIndexScan : public OperatorNode { public: - static Operator make(oid_t get_id, std::shared_ptr table, + static Operator make(oid_t get_id, + std::shared_ptr table, std::string alias, std::vector predicates, bool update, oid_t index_id, std::vector key_column_id_list, @@ -430,8 +390,7 @@ class PhysicalLimit : public OperatorNode { class PhysicalNLJoin : public OperatorNode { public: static Operator make( - JoinType _type, - std::vector conditions, + JoinType _type, std::vector conditions, std::vector> &left_keys, std::vector> &right_keys); @@ -452,8 +411,7 @@ class PhysicalNLJoin : public OperatorNode { class PhysicalHashJoin : public OperatorNode { public: static Operator make( - JoinType _type, - std::vector conditions, + JoinType _type, std::vector conditions, std::vector> &left_keys, std::vector> &right_keys); @@ -574,7 +532,8 @@ class PhysicalOuterHashJoin : public OperatorNode { class PhysicalInsert : public OperatorNode { public: static Operator make( - std::shared_ptr target_table, const std::vector *columns, + std::shared_ptr target_table, + const std::vector *columns, const std::vector>> *values); @@ -586,7 +545,8 @@ class PhysicalInsert : public OperatorNode { class PhysicalInsertSelect : public OperatorNode { public: - static Operator make(std::shared_ptr target_table); + static Operator make( + std::shared_ptr target_table); std::shared_ptr target_table; }; @@ -596,7 +556,8 @@ class PhysicalInsertSelect : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalDelete : public OperatorNode { public: - static Operator make(std::shared_ptr target_table); + static Operator make( + std::shared_ptr target_table); std::shared_ptr target_table; }; diff --git a/src/include/optimizer/rule_impls.h b/src/include/optimizer/rule_impls.h index a749675ca2c..b0faa1aabdb 100644 --- a/src/include/optimizer/rule_impls.h +++ b/src/include/optimizer/rule_impls.h @@ -53,37 +53,6 @@ class JoinAssociativity : public Rule { OptimizeContext *context) const override; }; -/** - * @brief (A join B) -> (B join A) - */ -class InnerJoinCommutativity : public Rule { - public: - InnerJoinCommutativity(); - - bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; - - void Transform(std::shared_ptr input, - std::vector> &transformed, - OptimizeContext *context) const override; -}; - -/** - * @brief (A join B) join C -> A join (B join C) - */ - -class InnerJoinAssociativity : public Rule { - public: - InnerJoinAssociativity(); - - bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; - - void Transform(std::shared_ptr input, - std::vector> &transformed, - OptimizeContext *context) const override; -}; - //===--------------------------------------------------------------------===// // Implementation rules //===--------------------------------------------------------------------===// @@ -269,36 +238,6 @@ class JoinToHashJoin : public Rule { OptimizeContext *context) const override; }; -/** - * @brief (Logical Inner Join -> Inner Nested-Loop Join) - */ -class InnerJoinToInnerNLJoin : public Rule { - public: - InnerJoinToInnerNLJoin(); - - bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; - - void Transform(std::shared_ptr input, - std::vector> &transformed, - OptimizeContext *context) const override; -}; - -/** - * @brief (Logical Inner Join -> Inner Hash Join) - */ -class InnerJoinToInnerHashJoin : public Rule { - public: - InnerJoinToInnerHashJoin(); - - bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; - - void Transform(std::shared_ptr input, - std::vector> &transformed, - OptimizeContext *context) const override; -}; - /** * @brief (Logical Distinct -> Physical Distinct) */ @@ -401,7 +340,8 @@ class EmbedFilterIntoGet : public Rule { /////////////////////////////////////////////////////////////////////////////// /// Unnesting rules // We use this promise to determine which rules should be applied first if -// multiple rules are applicable, we need to first pull filters up through mark-join +// multiple rules are applicable, we need to first pull filters up through +// mark-join // then turn mark-join into a regular join operator enum class UnnestPromise { Low = 1, High }; // TODO(boweic): MarkJoin and SingleJoin should not be transformed into inner diff --git a/src/include/optimizer/stats_calculator.h b/src/include/optimizer/stats_calculator.h index 3e9d43c7eb0..224e0faa909 100644 --- a/src/include/optimizer/stats_calculator.h +++ b/src/include/optimizer/stats_calculator.h @@ -26,16 +26,12 @@ class TableStats; */ class StatsCalculator : public OperatorVisitor { public: - void CalculateStats(GroupExpression *gexpr, ExprSet required_cols, - Memo *memo, concurrency::TransactionContext* txn); + void CalculateStats(GroupExpression *gexpr, ExprSet required_cols, Memo *memo, + concurrency::TransactionContext *txn); void Visit(const LogicalGet *) override; void Visit(const LogicalQueryDerivedGet *) override; void Visit(const LogicalJoin *) override; - void Visit(const LogicalInnerJoin *) override; - void Visit(const LogicalLeftJoin *) override; - void Visit(const LogicalRightJoin *) override; - void Visit(const LogicalOuterJoin *) override; void Visit(const LogicalSemiJoin *) override; void Visit(const LogicalAggregateAndGroupBy *) override; void Visit(const LogicalLimit *) override; @@ -65,8 +61,8 @@ class StatsCalculator : public OperatorVisitor { */ void UpdateStatsForFilter( size_t num_rows, - std::unordered_map> - &predicate_stats, + std::unordered_map> & + predicate_stats, const std::vector &predicates); double CalculateSelectivityForPredicate( @@ -76,7 +72,7 @@ class StatsCalculator : public OperatorVisitor { GroupExpression *gexpr_; ExprSet required_cols_; Memo *memo_; - concurrency::TransactionContext* txn_; + concurrency::TransactionContext *txn_; }; } // namespace optimizer diff --git a/src/optimizer/child_stats_deriver.cpp b/src/optimizer/child_stats_deriver.cpp index d3730d4e870..5aa581b0ce6 100644 --- a/src/optimizer/child_stats_deriver.cpp +++ b/src/optimizer/child_stats_deriver.cpp @@ -44,20 +44,7 @@ void ChildStatsDeriver::Visit(const LogicalJoin *op) { } } } -void ChildStatsDeriver::Visit(const LogicalInnerJoin *op) { - PassDownRequiredCols(); - for (auto &annotated_expr : op->join_predicates) { - auto predicate = annotated_expr.expr.get(); - ExprSet expr_set; - expression::ExpressionUtil::GetTupleValueExprs(expr_set, predicate); - for (auto &col : expr_set) { - PassDownColumn(col); - } - } -} -void ChildStatsDeriver::Visit(UNUSED_ATTRIBUTE const LogicalLeftJoin *) {} -void ChildStatsDeriver::Visit(UNUSED_ATTRIBUTE const LogicalRightJoin *) {} -void ChildStatsDeriver::Visit(UNUSED_ATTRIBUTE const LogicalOuterJoin *) {} + void ChildStatsDeriver::Visit(const LogicalSemiJoin *) {} // TODO(boweic): support stats of aggregation void ChildStatsDeriver::Visit(const LogicalAggregateAndGroupBy *) { diff --git a/src/optimizer/operators.cpp b/src/optimizer/operators.cpp index 7ff78610cb1..4ad8414f872 100644 --- a/src/optimizer/operators.cpp +++ b/src/optimizer/operators.cpp @@ -51,7 +51,7 @@ hash_t LogicalGet::Hash() const { } bool LogicalGet::operator==(const BaseOperatorNode &r) { - if (r.GetType()!= OpType::Get) return false; + if (r.GetType() != OpType::Get) return false; const LogicalGet &node = *static_cast(&r); if (predicates.size() != node.predicates.size()) return false; for (size_t i = 0; i < predicates.size(); i++) { @@ -240,7 +240,8 @@ Operator LogicalJoin::make(JoinType _type) { return Operator(join); } -Operator LogicalJoin::make(JoinType _type, std::vector &conditions) { +Operator LogicalJoin::make(JoinType _type, + std::vector &conditions) { LogicalJoin *join = new LogicalJoin; join->join_predicates = std::move(conditions); join->type = _type; @@ -266,70 +267,6 @@ bool LogicalJoin::operator==(const BaseOperatorNode &r) { } return true; } - -//===--------------------------------------------------------------------===// -// InnerJoin -//===--------------------------------------------------------------------===// -Operator LogicalInnerJoin::make() { - LogicalInnerJoin *join = new LogicalInnerJoin; - join->join_predicates = {}; - return Operator(join); -} - -Operator LogicalInnerJoin::make(std::vector &conditions) { - LogicalInnerJoin *join = new LogicalInnerJoin; - join->join_predicates = std::move(conditions); - return Operator(join); -} - -hash_t LogicalInnerJoin::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); - for (auto &pred : join_predicates) - hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); - return hash; -} - -bool LogicalInnerJoin::operator==(const BaseOperatorNode &r) { - if (r.GetType() != OpType::InnerJoin) return false; - const LogicalInnerJoin &node = *static_cast(&r); - if (join_predicates.size() != node.join_predicates.size()) return false; - for (size_t i = 0; i < join_predicates.size(); i++) { - if (!join_predicates[i].expr->ExactlyEquals( - *node.join_predicates[i].expr.get())) - return false; - } - return true; -} - -//===--------------------------------------------------------------------===// -// LeftJoin -//===--------------------------------------------------------------------===// -Operator LogicalLeftJoin::make(expression::AbstractExpression *condition) { - LogicalLeftJoin *join = new LogicalLeftJoin; - join->join_predicate = - std::shared_ptr(condition); - return Operator(join); -} - -//===--------------------------------------------------------------------===// -// RightJoin -//===--------------------------------------------------------------------===// -Operator LogicalRightJoin::make(expression::AbstractExpression *condition) { - LogicalRightJoin *join = new LogicalRightJoin; - join->join_predicate = - std::shared_ptr(condition); - return Operator(join); -} - -//===--------------------------------------------------------------------===// -// OuterJoin -//===--------------------------------------------------------------------===// -Operator LogicalOuterJoin::make(expression::AbstractExpression *condition) { - LogicalOuterJoin *join = new LogicalOuterJoin; - join->join_predicate = - std::shared_ptr(condition); - return Operator(join); -} //===--------------------------------------------------------------------===// // SemiJoin @@ -422,8 +359,8 @@ Operator LogicalDelete::make( //===--------------------------------------------------------------------===// Operator LogicalUpdate::make( std::shared_ptr target_table, - const std::vector> - *updates) { + const std::vector> * + updates) { LogicalUpdate *update_op = new LogicalUpdate; update_op->target_table = target_table; update_op->updates = updates; @@ -594,8 +531,7 @@ Operator PhysicalLimit::make(int64_t offset, int64_t limit) { // NLJoin (Inner + Outer Joins) //===--------------------------------------------------------------------===// Operator PhysicalNLJoin::make( - JoinType _type, - std::vector conditions, + JoinType _type, std::vector conditions, std::vector> &left_keys, std::vector> &right_keys) { PhysicalNLJoin *join = new PhysicalNLJoin(); @@ -621,8 +557,7 @@ hash_t PhysicalNLJoin::Hash() const { bool PhysicalNLJoin::operator==(const BaseOperatorNode &r) { if (r.GetType() != OpType::NLJoin) return false; - const PhysicalNLJoin &node = - *static_cast(&r); + const PhysicalNLJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size() || left_keys.size() != node.left_keys.size() || right_keys.size() != node.right_keys.size()) @@ -646,8 +581,7 @@ bool PhysicalNLJoin::operator==(const BaseOperatorNode &r) { // HashJoin //===--------------------------------------------------------------------===// Operator PhysicalHashJoin::make( - JoinType _type, - std::vector conditions, + JoinType _type, std::vector conditions, std::vector> &left_keys, std::vector> &right_keys) { PhysicalHashJoin *join = new PhysicalHashJoin(); @@ -672,8 +606,7 @@ hash_t PhysicalHashJoin::Hash() const { bool PhysicalHashJoin::operator==(const BaseOperatorNode &r) { if (r.GetType() != OpType::HashJoin) return false; - const PhysicalHashJoin &node = - *static_cast(&r); + const PhysicalHashJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size() || left_keys.size() != node.left_keys.size() || right_keys.size() != node.right_keys.size()) @@ -887,8 +820,8 @@ Operator PhysicalDelete::make( //===--------------------------------------------------------------------===// Operator PhysicalUpdate::make( std::shared_ptr target_table, - const std::vector> - *updates) { + const std::vector> * + updates) { PhysicalUpdate *update = new PhysicalUpdate; update->target_table = target_table; update->updates = updates; @@ -1000,14 +933,6 @@ std::string OperatorNode::name_ = "LogicalDependentJoin"; template <> std::string OperatorNode::name_ = "LogicalJoin"; template <> -std::string OperatorNode::name_ = "LogicalInnerJoin"; -template <> -std::string OperatorNode::name_ = "LogicalLeftJoin"; -template <> -std::string OperatorNode::name_ = "LogicalRightJoin"; -template <> -std::string OperatorNode::name_ = "LogicalOuterJoin"; -template <> std::string OperatorNode::name_ = "LogicalSemiJoin"; template <> std::string OperatorNode::name_ = @@ -1097,14 +1022,6 @@ OpType OperatorNode::type_ = OpType::LogicalDependentJoin; template <> OpType OperatorNode::type_ = OpType::LogicalJoin; template <> -OpType OperatorNode::type_ = OpType::InnerJoin; -template <> -OpType OperatorNode::type_ = OpType::LeftJoin; -template <> -OpType OperatorNode::type_ = OpType::RightJoin; -template <> -OpType OperatorNode::type_ = OpType::OuterJoin; -template <> OpType OperatorNode::type_ = OpType::SemiJoin; template <> OpType OperatorNode::type_ = diff --git a/src/optimizer/query_to_operator_transformer.cpp b/src/optimizer/query_to_operator_transformer.cpp index 36722a98c98..a86c52265a7 100644 --- a/src/optimizer/query_to_operator_transformer.cpp +++ b/src/optimizer/query_to_operator_transformer.cpp @@ -133,26 +133,26 @@ void QueryToOperatorTransformer::Visit(parser::JoinDefinition *node) { switch (node->type) { case JoinType::INNER: { predicates_ = CollectPredicates(node->condition.get(), predicates_); - join_expr = - std::make_shared(LogicalJoin::make(JoinType::INNER)); + join_expr = std::make_shared( + LogicalJoin::make(JoinType::INNER)); break; } case JoinType::OUTER: { predicates_ = CollectPredicates(node->condition.get(), predicates_); - join_expr = - std::make_shared(LogicalJoin::make(JoinType::OUTER)); + join_expr = std::make_shared( + LogicalJoin::make(JoinType::OUTER)); break; } case JoinType::LEFT: { predicates_ = CollectPredicates(node->condition.get(), predicates_); - join_expr = - std::make_shared(LogicalJoin::make(JoinType::LEFT)); + join_expr = std::make_shared( + LogicalJoin::make(JoinType::LEFT)); break; } case JoinType::RIGHT: { predicates_ = CollectPredicates(node->condition.get(), predicates_); - join_expr = - std::make_shared(LogicalJoin::make(JoinType::RIGHT)); + join_expr = std::make_shared( + LogicalJoin::make(JoinType::RIGHT)); break; } case JoinType::SEMI: { @@ -204,8 +204,8 @@ void QueryToOperatorTransformer::Visit(parser::TableRef *node) { // Build a left deep join tree for (size_t i = 1; i < node->list.size(); i++) { node->list.at(i)->Accept(this); - auto join_expr = - std::make_shared(LogicalInnerJoin::make()); + auto join_expr = std::make_shared( + LogicalJoin::make(JoinType::INNER)); join_expr->PushChild(prev_expr); join_expr->PushChild(output_expr_); PELOTON_ASSERT(join_expr->Children().size() == 2); diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index 36d5a7e27b7..cca9bd0497f 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -30,8 +30,6 @@ int Rule::Promise(GroupExpression *group_expr, OptimizeContext *context) const { RuleSet::RuleSet() { AddTransformationRule(new JoinCommutativity()); AddTransformationRule(new JoinAssociativity()); - AddTransformationRule(new InnerJoinCommutativity()); - AddTransformationRule(new InnerJoinAssociativity()); AddImplementationRule(new LogicalDeleteToPhysical()); AddImplementationRule(new LogicalUpdateToPhysical()); AddImplementationRule(new LogicalInsertToPhysical()); @@ -44,8 +42,6 @@ RuleSet::RuleSet() { AddImplementationRule(new LogicalQueryDerivedGetToPhysical()); AddImplementationRule(new JoinToNLJoin()); AddImplementationRule(new JoinToHashJoin()); - AddImplementationRule(new InnerJoinToInnerNLJoin()); - AddImplementationRule(new InnerJoinToInnerHashJoin()); AddImplementationRule(new ImplementDistinct()); AddImplementationRule(new ImplementLimit()); diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index 85612c82cd0..b73af2bd7eb 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -42,7 +42,7 @@ JoinCommutativity::JoinCommutativity() { } bool JoinCommutativity::Check(std::shared_ptr expr, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)expr; return true; @@ -67,9 +67,9 @@ void JoinCommutativity::Transform( LogicalJoin::make(join_type, join_predicates)); std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 2); - LOG_TRACE( - "Reorder left child with op %s and right child with op %s for join", - children[0]->Op().GetName().c_str(), children[1]->Op().GetName().c_str()); + LOG_TRACE("Reorder left child with op %s and right child with op %s for join", + children[0]->Op().GetName().c_str(), + children[1]->Op().GetName().c_str()); result_plan->PushChild(children[1]); result_plan->PushChild(children[0]); @@ -183,147 +183,6 @@ void JoinAssociativity::Transform( transformed.push_back(new_parent_join); } -/////////////////////////////////////////////////////////////////////////////// -/// InnerJoinCommutativity -InnerJoinCommutativity::InnerJoinCommutativity() { - type_ = RuleType::INNER_JOIN_COMMUTE; - - std::shared_ptr left_child(std::make_shared(OpType::Leaf)); - std::shared_ptr right_child(std::make_shared(OpType::Leaf)); - match_pattern = std::make_shared(OpType::InnerJoin); - match_pattern->AddChild(left_child); - match_pattern->AddChild(right_child); -} - -bool InnerJoinCommutativity::Check(std::shared_ptr expr, - OptimizeContext *context) const { - (void)context; - (void)expr; - return true; -} - -void InnerJoinCommutativity::Transform( - std::shared_ptr input, - std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { - auto join_op = input->Op().As(); - auto join_predicates = - std::vector(join_op->join_predicates); - auto result_plan = std::make_shared( - LogicalInnerJoin::make(join_predicates)); - std::vector> children = input->Children(); - PELOTON_ASSERT(children.size() == 2); - LOG_TRACE( - "Reorder left child with op %s and right child with op %s for inner join", - children[0]->Op().GetName().c_str(), children[1]->Op().GetName().c_str()); - result_plan->PushChild(children[1]); - result_plan->PushChild(children[0]); - - transformed.push_back(result_plan); -} - -/////////////////////////////////////////////////////////////////////////////// -/// InnerJoinAssociativity -InnerJoinAssociativity::InnerJoinAssociativity() { - type_ = RuleType::INNER_JOIN_ASSOCIATE; - - // Create left nested join - auto left_child = std::make_shared(OpType::InnerJoin); - left_child->AddChild(std::make_shared(OpType::Leaf)); - left_child->AddChild(std::make_shared(OpType::Leaf)); - - std::shared_ptr right_child(std::make_shared(OpType::Leaf)); - - match_pattern = std::make_shared(OpType::InnerJoin); - match_pattern->AddChild(left_child); - match_pattern->AddChild(right_child); -} - -// TODO: As far as I know, theres nothing else that needs to be checked -bool InnerJoinAssociativity::Check(std::shared_ptr expr, - OptimizeContext *context) const { - (void)context; - (void)expr; - return true; -} - -void InnerJoinAssociativity::Transform( - std::shared_ptr input, - std::vector> &transformed, - OptimizeContext *context) const { - // NOTE: Transforms (left JOIN middle) JOIN right -> left JOIN (middle JOIN - // right) Variables are named accordingly to above transformation - auto parent_join = input->Op().As(); - std::vector> children = input->Children(); - PELOTON_ASSERT(children.size() == 2); - PELOTON_ASSERT(children[0]->Op().GetType() == OpType::InnerJoin); - PELOTON_ASSERT(children[0]->Children().size() == 2); - auto child_join = children[0]->Op().As(); - auto left = children[0]->Children()[0]; - auto middle = children[0]->Children()[1]; - auto right = children[1]; - - LOG_DEBUG("Reordered join structured: (%s JOIN %s) JOIN %s", - left->Op().GetName().c_str(), middle->Op().GetName().c_str(), - right->Op().GetName().c_str()); - - // Get Alias sets - auto &memo = context->metadata->memo; - auto middle_group_id = middle->Op().As()->origin_group; - auto right_group_id = right->Op().As()->origin_group; - - const auto &middle_group_aliases_set = - memo.GetGroupByID(middle_group_id)->GetTableAliases(); - const auto &right_group_aliases_set = - memo.GetGroupByID(right_group_id)->GetTableAliases(); - - // Union Predicates into single alias set for new child join - std::unordered_set right_join_aliases_set; - right_join_aliases_set.insert(middle_group_aliases_set.begin(), - middle_group_aliases_set.end()); - right_join_aliases_set.insert(right_group_aliases_set.begin(), - right_group_aliases_set.end()); - - // Redistribute predicates - auto parent_join_predicates = - std::vector(parent_join->join_predicates); - auto child_join_predicates = - std::vector(child_join->join_predicates); - - std::vector predicates; - predicates.insert(predicates.end(), parent_join_predicates.begin(), - parent_join_predicates.end()); - predicates.insert(predicates.end(), child_join_predicates.begin(), - child_join_predicates.end()); - - std::vector new_child_join_predicates; - std::vector new_parent_join_predicates; - - for (auto predicate : predicates) { - if (util::IsSubset(right_join_aliases_set, predicate.table_alias_set)) { - new_child_join_predicates.emplace_back(predicate); - } else { - new_parent_join_predicates.emplace_back(predicate); - } - } - - // Construct new child join operator - std::shared_ptr new_child_join = - std::make_shared( - LogicalInnerJoin::make(new_child_join_predicates)); - new_child_join->PushChild(middle); - new_child_join->PushChild(right); - - // Construct new parent join operator - std::shared_ptr new_parent_join = - std::make_shared( - LogicalInnerJoin::make(new_parent_join_predicates)); - new_parent_join->PushChild(left); - new_parent_join->PushChild(new_child_join); - - transformed.push_back(new_parent_join); -} - //===--------------------------------------------------------------------===// // Implementation rules //===--------------------------------------------------------------------===// @@ -429,9 +288,8 @@ void GetToIndexScan::Transform( sort_by_asc_base_column = false; break; } - auto bound_oids = - reinterpret_cast(expr) - ->GetBoundOid(); + auto bound_oids = reinterpret_cast( + expr)->GetBoundOid(); sort_col_ids.push_back(std::get<2>(bound_oids)); } // Check whether any index can fulfill sort property @@ -512,20 +370,16 @@ void GetToIndexScan::Transform( if (value_expr->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { value_list.push_back( reinterpret_cast( - value_expr) - ->GetValue()); + value_expr)->GetValue()); LOG_TRACE("Value Type: %d", static_cast( reinterpret_cast( - expr->GetModifiableChild(1)) - ->GetValueType())); + expr->GetModifiableChild(1))->GetValueType())); } else { value_list.push_back( type::ValueFactory::GetParameterOffsetValue( reinterpret_cast( - value_expr) - ->GetValueIdx()) - .Copy()); + value_expr)->GetValueIdx()).Copy()); LOG_TRACE("Parameter offset: %s", (*value_list.rbegin()).GetInfo().c_str()); } @@ -785,7 +639,7 @@ JoinToNLJoin::JoinToNLJoin() { } bool JoinToNLJoin::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; return true; @@ -847,7 +701,7 @@ JoinToHashJoin::JoinToHashJoin() { } bool JoinToHashJoin::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; return true; @@ -887,128 +741,6 @@ void JoinToHashJoin::Transform( transformed.push_back(result_plan); } } - -/////////////////////////////////////////////////////////////////////////////// -/// InnerJoinToInnerNLJoin -InnerJoinToInnerNLJoin::InnerJoinToInnerNLJoin() { - type_ = RuleType::INNER_JOIN_TO_NL_JOIN; - - // TODO NLJoin currently only support left deep tree - std::shared_ptr left_child(std::make_shared(OpType::Leaf)); - std::shared_ptr right_child(std::make_shared(OpType::Leaf)); - - // Initialize a pattern for optimizer to match - match_pattern = std::make_shared(OpType::InnerJoin); - - // Add node - we match join relation R and S - match_pattern->AddChild(left_child); - match_pattern->AddChild(right_child); - - return; -} - -bool InnerJoinToInnerNLJoin::Check(std::shared_ptr plan, - OptimizeContext *context) const { - (void)context; - (void)plan; - return true; -} - -void InnerJoinToInnerNLJoin::Transform( - std::shared_ptr input, - std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { - // first build an expression representing hash join - const LogicalInnerJoin *inner_join = input->Op().As(); - - auto children = input->Children(); - PELOTON_ASSERT(children.size() == 2); - auto left_group_id = children[0]->Op().As()->origin_group; - auto right_group_id = children[1]->Op().As()->origin_group; - auto &left_group_alias = - context->metadata->memo.GetGroupByID(left_group_id)->GetTableAliases(); - auto &right_group_alias = - context->metadata->memo.GetGroupByID(right_group_id)->GetTableAliases(); - std::vector> left_keys; - std::vector> right_keys; - - util::ExtractEquiJoinKeys(inner_join->join_predicates, left_keys, right_keys, - left_group_alias, right_group_alias); - - PELOTON_ASSERT(right_keys.size() == left_keys.size()); - auto result_plan = - std::make_shared(PhysicalInnerNLJoin::make( - inner_join->join_predicates, left_keys, right_keys)); - - // Then push all children into the child list of the new operator - result_plan->PushChild(children[0]); - result_plan->PushChild(children[1]); - - transformed.push_back(result_plan); - - return; -} - -/////////////////////////////////////////////////////////////////////////////// -/// InnerJoinToInnerHashJoin -InnerJoinToInnerHashJoin::InnerJoinToInnerHashJoin() { - type_ = RuleType::INNER_JOIN_TO_HASH_JOIN; - - // Make three node types for pattern matching - std::shared_ptr left_child(std::make_shared(OpType::Leaf)); - std::shared_ptr right_child(std::make_shared(OpType::Leaf)); - - // Initialize a pattern for optimizer to match - match_pattern = std::make_shared(OpType::InnerJoin); - - // Add node - we match join relation R and S as well as the predicate exp - match_pattern->AddChild(left_child); - match_pattern->AddChild(right_child); - - return; -} - -bool InnerJoinToInnerHashJoin::Check(std::shared_ptr plan, - OptimizeContext *context) const { - (void)context; - (void)plan; - return true; -} - -void InnerJoinToInnerHashJoin::Transform( - std::shared_ptr input, - std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { - // first build an expression representing hash join - const LogicalInnerJoin *inner_join = input->Op().As(); - - auto children = input->Children(); - PELOTON_ASSERT(children.size() == 2); - auto left_group_id = children[0]->Op().As()->origin_group; - auto right_group_id = children[1]->Op().As()->origin_group; - auto &left_group_alias = - context->metadata->memo.GetGroupByID(left_group_id)->GetTableAliases(); - auto &right_group_alias = - context->metadata->memo.GetGroupByID(right_group_id)->GetTableAliases(); - std::vector> left_keys; - std::vector> right_keys; - - util::ExtractEquiJoinKeys(inner_join->join_predicates, left_keys, right_keys, - left_group_alias, right_group_alias); - - PELOTON_ASSERT(right_keys.size() == left_keys.size()); - if (!left_keys.empty()) { - auto result_plan = - std::make_shared(PhysicalInnerHashJoin::make( - inner_join->join_predicates, left_keys, right_keys)); - - // Then push all children into the child list of the new operator - result_plan->PushChild(children[0]); - result_plan->PushChild(children[1]); - - transformed.push_back(result_plan); - } -} /////////////////////////////////////////////////////////////////////////////// /// ImplementDistinct @@ -1084,7 +816,8 @@ PushFilterThroughJoin::PushFilterThroughJoin() { type_ = RuleType::PUSH_FILTER_THROUGH_JOIN; // Make three node types for pattern matching - std::shared_ptr child(std::make_shared(OpType::LogicalJoin)); + std::shared_ptr child( + std::make_shared(OpType::LogicalJoin)); child->AddChild(std::make_shared(OpType::Leaf)); child->AddChild(std::make_shared(OpType::Leaf)); @@ -1143,8 +876,8 @@ void PushFilterThroughJoin::Transform( join_predicates.insert(join_predicates.end(), pre_join_predicate.begin(), pre_join_predicate.end()); std::shared_ptr output = - std::make_shared( - LogicalJoin::make(join_op_expr->Op().As()->type, join_predicates)); + std::make_shared(LogicalJoin::make( + join_op_expr->Op().As()->type, join_predicates)); // Construct left filter if any if (!left_predicates.empty()) { diff --git a/src/optimizer/stats_calculator.cpp b/src/optimizer/stats_calculator.cpp index 94a22a30b7d..e6ea7d3346e 100644 --- a/src/optimizer/stats_calculator.cpp +++ b/src/optimizer/stats_calculator.cpp @@ -42,8 +42,8 @@ void StatsCalculator::Visit(const LogicalGet *op) { return; } auto table_stats = std::dynamic_pointer_cast( - StatsStorage::GetInstance()->GetTableStats(op->table->GetDatabaseOid(), - op->table->GetTableOid(), txn_)); + StatsStorage::GetInstance()->GetTableStats( + op->table->GetDatabaseOid(), op->table->GetTableOid(), txn_)); // First, get the required stats of the base table std::unordered_map> required_stats; for (auto &col : required_cols_) { @@ -97,65 +97,6 @@ void StatsCalculator::Visit(const LogicalQueryDerivedGet *) { } void StatsCalculator::Visit(const LogicalJoin *op) { - // Check if there's join condition - PELOTON_ASSERT(gexpr_->GetChildrenGroupsSize() == 2); - auto left_child_group = memo_->GetGroupByID(gexpr_->GetChildGroupId(0)); - auto right_child_group = memo_->GetGroupByID(gexpr_->GetChildGroupId(1)); - auto root_group = memo_->GetGroupByID(gexpr_->GetGroupID()); - // Calculate output num rows first - if (root_group->GetNumRows() == -1) { - size_t curr_rows = - left_child_group->GetNumRows() * right_child_group->GetNumRows(); - for (auto &annotated_expr : op->join_predicates) { - // See if there are join conditions - if (annotated_expr.expr->GetExpressionType() == - ExpressionType::COMPARE_EQUAL && - annotated_expr.expr->GetChild(0)->GetExpressionType() == - ExpressionType::VALUE_TUPLE && - annotated_expr.expr->GetChild(1)->GetExpressionType() == - ExpressionType::VALUE_TUPLE) { - auto left_child = - reinterpret_cast( - annotated_expr.expr->GetChild(0)); - auto right_child = - reinterpret_cast( - annotated_expr.expr->GetChild(1)); - if ((left_child_group->HasColumnStats(left_child->GetColFullName()) && - right_child_group->HasColumnStats( - right_child->GetColFullName())) || - (left_child_group->HasColumnStats(right_child->GetColFullName()) && - right_child_group->HasColumnStats(left_child->GetColFullName()))) { - curr_rows /= std::max(std::max(left_child_group->GetNumRows(), - right_child_group->GetNumRows()), - 1); - } - } - } - root_group->SetNumRows(curr_rows); - } - size_t num_rows = root_group->GetNumRows(); - for (auto &col : required_cols_) { - PELOTON_ASSERT(col->GetExpressionType() == ExpressionType::VALUE_TUPLE); - auto tv_expr = reinterpret_cast(col); - std::shared_ptr column_stats; - // Make a copy from the child stats - if (left_child_group->HasColumnStats(tv_expr->GetColFullName())) { - column_stats = std::make_shared( - *left_child_group->GetStats(tv_expr->GetColFullName())); - } else { - PELOTON_ASSERT(right_child_group->HasColumnStats(tv_expr->GetColFullName())); - column_stats = std::make_shared( - *right_child_group->GetStats(tv_expr->GetColFullName())); - } - // Reset num_rows - column_stats->num_rows = num_rows; - root_group->AddStats(tv_expr->GetColFullName(), column_stats); - } - // TODO(boweic): calculate stats based on predicates other than join - // conditions -} - -void StatsCalculator::Visit(const LogicalInnerJoin *op) { // Check if there's join condition PELOTON_ASSERT(gexpr_->GetChildrenGroupsSize() == 2); auto left_child_group = memo_->GetGroupByID(gexpr_->GetChildGroupId(0)); @@ -202,7 +143,8 @@ void StatsCalculator::Visit(const LogicalInnerJoin *op) { column_stats = std::make_shared( *left_child_group->GetStats(tv_expr->GetColFullName())); } else { - PELOTON_ASSERT(right_child_group->HasColumnStats(tv_expr->GetColFullName())); + PELOTON_ASSERT( + right_child_group->HasColumnStats(tv_expr->GetColFullName())); column_stats = std::make_shared( *right_child_group->GetStats(tv_expr->GetColFullName())); } @@ -213,9 +155,7 @@ void StatsCalculator::Visit(const LogicalInnerJoin *op) { // TODO(boweic): calculate stats based on predicates other than join // conditions } -void StatsCalculator::Visit(UNUSED_ATTRIBUTE const LogicalLeftJoin *op) {} -void StatsCalculator::Visit(UNUSED_ATTRIBUTE const LogicalRightJoin *op) {} -void StatsCalculator::Visit(UNUSED_ATTRIBUTE const LogicalOuterJoin *op) {} + void StatsCalculator::Visit(UNUSED_ATTRIBUTE const LogicalSemiJoin *op) {} void StatsCalculator::Visit(const LogicalAggregateAndGroupBy *) { // TODO(boweic): For now we just pass the stats needed without any @@ -294,8 +234,8 @@ void StatsCalculator::AddBaseTableStats( void StatsCalculator::UpdateStatsForFilter( size_t num_rows, - std::unordered_map> - &predicate_stats, + std::unordered_map> & + predicate_stats, const std::vector &predicates) { // First, construct the table stats as the interface needed it to compute // selectivity @@ -344,10 +284,10 @@ double StatsCalculator::CalculateSelectivityForPredicate( : 0; auto left_expr = expr->GetChild(1 - right_index); - PELOTON_ASSERT(left_expr->GetExpressionType() == ExpressionType::VALUE_TUPLE); - auto col_name = - reinterpret_cast(left_expr) - ->GetColFullName(); + PELOTON_ASSERT(left_expr->GetExpressionType() == + ExpressionType::VALUE_TUPLE); + auto col_name = reinterpret_cast( + left_expr)->GetColFullName(); auto expr_type = expr->GetExpressionType(); if (right_index == 0) { @@ -373,14 +313,12 @@ double StatsCalculator::CalculateSelectivityForPredicate( if (expr->GetChild(right_index)->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { value = reinterpret_cast( - expr->GetModifiableChild(right_index)) - ->GetValue(); + expr->GetModifiableChild(right_index))->GetValue(); } else { - value = type::ValueFactory::GetParameterOffsetValue( - reinterpret_cast( - expr->GetModifiableChild(right_index)) - ->GetValueIdx()) - .Copy(); + value = + type::ValueFactory::GetParameterOffsetValue( + reinterpret_cast( + expr->GetModifiableChild(right_index))->GetValueIdx()).Copy(); } ValueCondition condition(col_name, expr_type, value); selectivity = diff --git a/test/optimizer/optimizer_rule_test.cpp b/test/optimizer/optimizer_rule_test.cpp index 12d047ad51a..ac54ca9df2d 100644 --- a/test/optimizer/optimizer_rule_test.cpp +++ b/test/optimizer/optimizer_rule_test.cpp @@ -36,6 +36,7 @@ #include "planner/update_plan.h" #include "sql/testing_sql_util.h" #include "type/value_factory.h" +#include "common/internal_types.h" namespace peloton { namespace test { @@ -52,12 +53,13 @@ TEST_F(OptimizerRuleTests, SimpleCommutativeRuleTest) { // Build op plan node to match rule auto left_get = std::make_shared(LogicalGet::make()); auto right_get = std::make_shared(LogicalGet::make()); - auto join = std::make_shared(LogicalInnerJoin::make()); + auto join = + std::make_shared(LogicalJoin::make(JoinType::INNER)); join->PushChild(left_get); join->PushChild(right_get); // Setup rule - InnerJoinCommutativity rule; + JoinCommutativity rule; EXPECT_TRUE(rule.Check(join, nullptr)); @@ -113,7 +115,7 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) { child_join_predicates.push_back(pred); auto child_join = std::make_shared( - LogicalInnerJoin::make(child_join_predicates)); + LogicalJoin::make(JoinType::INNER, child_join_predicates)); child_join->PushChild(left_leaf); child_join->PushChild(middle_leaf); optimizer.GetMetadata().memo.InsertExpression( @@ -126,7 +128,7 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) { parent_join_predicates.push_back(pred); auto parent_join = std::make_shared( - LogicalInnerJoin::make(parent_join_predicates)); + LogicalJoin::make(JoinType::INNER, parent_join_predicates)); parent_join->PushChild(child_join); parent_join->PushChild(right_leaf); @@ -138,15 +140,14 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) { EXPECT_EQ(left_leaf, parent_join->Children()[0]->Children()[0]); EXPECT_EQ(middle_leaf, parent_join->Children()[0]->Children()[1]); EXPECT_EQ(right_leaf, parent_join->Children()[1]); - EXPECT_EQ(1, - parent_join->Op().As()->join_predicates.size()); + EXPECT_EQ(1, parent_join->Op().As()->join_predicates.size()); EXPECT_EQ(1, parent_join->Children()[0] ->Op() - .As() + .As() ->join_predicates.size()); // Setup rule - InnerJoinAssociativity rule; + JoinAssociativity rule; EXPECT_TRUE(rule.Check(parent_join, root_context)); std::vector> outputs; @@ -159,8 +160,8 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) { EXPECT_EQ(middle_leaf, output_join->Children()[1]->Children()[0]); EXPECT_EQ(right_leaf, output_join->Children()[1]->Children()[1]); - auto parent_join_op = output_join->Op().As(); - auto child_join_op = output_join->Children()[1]->Op().As(); + auto parent_join_op = output_join->Op().As(); + auto child_join_op = output_join->Children()[1]->Op().As(); EXPECT_EQ(2, parent_join_op->join_predicates.size()); EXPECT_EQ(0, child_join_op->join_predicates.size()); delete root_context; @@ -201,7 +202,7 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest2) { // Make Child Join auto child_join = - std::make_shared(LogicalInnerJoin::make()); + std::make_shared(LogicalJoin::make(JoinType::INNER)); child_join->PushChild(left_leaf); child_join->PushChild(middle_leaf); optimizer.GetMetadata().memo.InsertExpression( @@ -221,7 +222,7 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest2) { parent_join_predicates.push_back(pred2); auto parent_join = std::make_shared( - LogicalInnerJoin::make(parent_join_predicates)); + LogicalJoin::make(JoinType::INNER, parent_join_predicates)); parent_join->PushChild(child_join); parent_join->PushChild(right_leaf); @@ -233,15 +234,14 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest2) { EXPECT_EQ(left_leaf, parent_join->Children()[0]->Children()[0]); EXPECT_EQ(middle_leaf, parent_join->Children()[0]->Children()[1]); EXPECT_EQ(right_leaf, parent_join->Children()[1]); - EXPECT_EQ(2, - parent_join->Op().As()->join_predicates.size()); + EXPECT_EQ(2, parent_join->Op().As()->join_predicates.size()); EXPECT_EQ(0, parent_join->Children()[0] ->Op() - .As() + .As() ->join_predicates.size()); // Setup rule - InnerJoinAssociativity rule; + JoinAssociativity rule; EXPECT_TRUE(rule.Check(parent_join, root_context)); std::vector> outputs; @@ -254,8 +254,8 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest2) { EXPECT_EQ(middle_leaf, output_join->Children()[1]->Children()[0]); EXPECT_EQ(right_leaf, output_join->Children()[1]->Children()[1]); - auto parent_join_op = output_join->Op().As(); - auto child_join_op = output_join->Children()[1]->Op().As(); + auto parent_join_op = output_join->Op().As(); + auto child_join_op = output_join->Children()[1]->Op().As(); EXPECT_EQ(1, parent_join_op->join_predicates.size()); EXPECT_EQ(1, child_join_op->join_predicates.size()); delete root_context; diff --git a/test/optimizer/optimizer_test.cpp b/test/optimizer/optimizer_test.cpp index dae410999d6..525d6c89b18 100644 --- a/test/optimizer/optimizer_test.cpp +++ b/test/optimizer/optimizer_test.cpp @@ -362,8 +362,8 @@ TEST_F(OptimizerTests, PushFilterThroughJoinTest) { // Check join in the root auto group_expr = GetSingleGroupExpression(memo, head_gexpr.get(), 0); - EXPECT_EQ(OpType::InnerJoin, group_expr->Op().GetType()); - auto join_op = group_expr->Op().As(); + EXPECT_EQ(OpType::LogicalJoin, group_expr->Op().GetType()); + auto join_op = group_expr->Op().As(); EXPECT_EQ(1, join_op->join_predicates.size()); EXPECT_TRUE(join_op->join_predicates[0].expr->ExactlyEquals(*predicates[0])); @@ -449,8 +449,9 @@ TEST_F(OptimizerTests, PredicatePushDownRewriteTest) { // Check join in the root auto group_expr = GetSingleGroupExpression(memo, head_gexpr.get(), 0); - EXPECT_EQ(OpType::InnerJoin, group_expr->Op().GetType()); - auto join_op = group_expr->Op().As(); + EXPECT_EQ(OpType::LogicalJoin, group_expr->Op().GetType()); + auto join_op = group_expr->Op().As(); + EXPECT_EQ(JoinType::INNER, join_op->type); EXPECT_EQ(1, join_op->join_predicates.size()); EXPECT_TRUE(join_op->join_predicates[0].expr->ExactlyEquals(*predicates[0])); From 15948802575ea60f6c5714ba07c67f1ebe948cd3 Mon Sep 17 00:00:00 2001 From: Pedro Miguel Reis Bento Paredes Date: Tue, 10 Apr 2018 20:59:11 -0400 Subject: [PATCH 11/26] Attempt at fixing pushdown filter --- src/include/planner/abstract_join_plan.h | 4 ++++ src/include/planner/hash_join_plan.h | 4 +++- src/optimizer/optimizer_task.cpp | 5 +++++ src/optimizer/query_to_operator_transformer.cpp | 11 +++++++---- src/optimizer/rule.cpp | 2 +- src/optimizer/rule_impls.cpp | 9 +++++++-- 6 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/include/planner/abstract_join_plan.h b/src/include/planner/abstract_join_plan.h index 4c7734cfaba..3172c4079e2 100644 --- a/src/include/planner/abstract_join_plan.h +++ b/src/include/planner/abstract_join_plan.h @@ -87,6 +87,10 @@ class AbstractJoinPlan : public AbstractPlan { virtual void HandleSubplanBinding(bool from_left, const BindingContext &input) = 0; + const std::string GetPredicateInfo() const { + return predicate_ != nullptr ? predicate_->GetInfo() : ""; + } + private: /** @brief The type of join that we're going to perform */ JoinType join_type_; diff --git a/src/include/planner/hash_join_plan.h b/src/include/planner/hash_join_plan.h index cf02b77a4e8..b79d50de6ce 100644 --- a/src/include/planner/hash_join_plan.h +++ b/src/include/planner/hash_join_plan.h @@ -57,7 +57,9 @@ class HashJoinPlan : public AbstractJoinPlan { void SetBloomFilterFlag(bool flag) { build_bloomfilter_ = flag; } - const std::string GetInfo() const override { return "HashJoin"; } + const std::string GetInfo() const override { + return "HashJoin(" + GetPredicateInfo() + ")"; + } const std::vector &GetOuterHashIds() const { return outer_column_ids_; diff --git a/src/optimizer/optimizer_task.cpp b/src/optimizer/optimizer_task.cpp index f0a489906ae..8654f92982c 100644 --- a/src/optimizer/optimizer_task.cpp +++ b/src/optimizer/optimizer_task.cpp @@ -423,6 +423,11 @@ void TopDownRewrite::execute() { r.rule->GetMatchPattern()); if (iterator.HasNext()) { auto before = iterator.Next(); + + if (!r.rule->Check(before, context_.get())) { + continue; + } + PELOTON_ASSERT(!iterator.HasNext()); std::vector> after; r.rule->Transform(before, after, context_.get()); diff --git a/src/optimizer/query_to_operator_transformer.cpp b/src/optimizer/query_to_operator_transformer.cpp index a86c52265a7..d71cd0cf4a0 100644 --- a/src/optimizer/query_to_operator_transformer.cpp +++ b/src/optimizer/query_to_operator_transformer.cpp @@ -120,6 +120,8 @@ void QueryToOperatorTransformer::Visit(parser::SelectStatement *op) { predicates_ = std::move(pre_predicates); } void QueryToOperatorTransformer::Visit(parser::JoinDefinition *node) { + auto pre_predicates = std::move(predicates_); + // Get left operator node->left->Accept(this); auto left_expr = output_expr_; @@ -134,25 +136,25 @@ void QueryToOperatorTransformer::Visit(parser::JoinDefinition *node) { case JoinType::INNER: { predicates_ = CollectPredicates(node->condition.get(), predicates_); join_expr = std::make_shared( - LogicalJoin::make(JoinType::INNER)); + LogicalJoin::make(JoinType::INNER, predicates_)); break; } case JoinType::OUTER: { predicates_ = CollectPredicates(node->condition.get(), predicates_); join_expr = std::make_shared( - LogicalJoin::make(JoinType::OUTER)); + LogicalJoin::make(JoinType::OUTER, predicates_)); break; } case JoinType::LEFT: { predicates_ = CollectPredicates(node->condition.get(), predicates_); join_expr = std::make_shared( - LogicalJoin::make(JoinType::LEFT)); + LogicalJoin::make(JoinType::LEFT, predicates_)); break; } case JoinType::RIGHT: { predicates_ = CollectPredicates(node->condition.get(), predicates_); join_expr = std::make_shared( - LogicalJoin::make(JoinType::RIGHT)); + LogicalJoin::make(JoinType::RIGHT, predicates_)); break; } case JoinType::SEMI: { @@ -168,6 +170,7 @@ void QueryToOperatorTransformer::Visit(parser::JoinDefinition *node) { join_expr->PushChild(right_expr); output_expr_ = join_expr; + predicates_ = std::move(pre_predicates); } void QueryToOperatorTransformer::Visit(parser::TableRef *node) { if (node->select != nullptr) { diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index cca9bd0497f..1cce07ee125 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -46,7 +46,7 @@ RuleSet::RuleSet() { AddImplementationRule(new ImplementLimit()); AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN, - new PushFilterThroughJoin()); + new PushFilterThroughJoin()); AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN, new PushFilterThroughAggregation()); AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN, diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index b73af2bd7eb..581f31a8617 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -853,6 +853,11 @@ void PushFilterThroughJoin::Transform( std::vector right_predicates; std::vector join_predicates; + auto join_type = join_op_expr->Op().As()->type; + bool outer_push = (join_type == JoinType::OUTER || + join_type == JoinType::LEFT || + join_type == JoinType::RIGHT); + // Loop over all predicates, check each of them if they can be pushed down to // either the left child or the right child to be evaluated // All predicates in this loop follow conjunction relationship because we @@ -860,10 +865,10 @@ void PushFilterThroughJoin::Transform( // E.g. An expression (test.a = test1.b and test.a = 5) would become // {test.a = test1.b, test.a = 5} for (auto &predicate : predicates) { - if (util::IsSubset(left_group_aliases_set, predicate.table_alias_set)) { + if (util::IsSubset(left_group_aliases_set, predicate.table_alias_set) && !outer_push) { left_predicates.emplace_back(predicate); } else if (util::IsSubset(right_group_aliases_set, - predicate.table_alias_set)) { + predicate.table_alias_set) && !outer_push) { right_predicates.emplace_back(predicate); } else { join_predicates.emplace_back(predicate); From 9a420b76ad3e9c32d8710ed7ef373c1b48b1d177 Mon Sep 17 00:00:00 2001 From: Guoquan Zhao Date: Wed, 11 Apr 2018 12:11:46 -0400 Subject: [PATCH 12/26] add integration test for [INNER | LEFT | RIGHT | OUTER] JoinCommutativeRules --- test/optimizer/optimizer_rule_test.cpp | 75 ++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/test/optimizer/optimizer_rule_test.cpp b/test/optimizer/optimizer_rule_test.cpp index ac54ca9df2d..d8ec828035a 100644 --- a/test/optimizer/optimizer_rule_test.cpp +++ b/test/optimizer/optimizer_rule_test.cpp @@ -73,6 +73,81 @@ TEST_F(OptimizerRuleTests, SimpleCommutativeRuleTest) { EXPECT_EQ(output_join->Children()[1], left_get); } +TEST_F(OptimizerRuleTests, LeftJoinCommutativeRuleTest) { + // Build op plan node to match rule + auto left_get = std::make_shared(LogicalGet::make()); + auto right_get = std::make_shared(LogicalGet::make()); + auto join = + std::make_shared(LogicalJoin::make(JoinType::LEFT)); + join->PushChild(left_get); + join->PushChild(right_get); + + // Setup rule + JoinCommutativity rule; + + EXPECT_TRUE(rule.Check(join, nullptr)); + + std::vector> outputs; + rule.Transform(join, outputs, nullptr); + EXPECT_EQ(outputs.size(), 1); + + auto output_join = outputs[0]; + + EXPECT_EQ(output_join->Children()[0], right_get); + EXPECT_EQ(output_join->Children()[1], left_get); + EXPECT_EQ(output_join->Op().As()->type, JoinType::RIGHT); +} + +TEST_F(OptimizerRuleTests, RightJoinCommutativeRuleTest) { + // Build op plan node to match rule + auto left_get = std::make_shared(LogicalGet::make()); + auto right_get = std::make_shared(LogicalGet::make()); + auto join = + std::make_shared(LogicalJoin::make(JoinType::RIGHT)); + join->PushChild(left_get); + join->PushChild(right_get); + + // Setup rule + JoinCommutativity rule; + + EXPECT_TRUE(rule.Check(join, nullptr)); + + std::vector> outputs; + rule.Transform(join, outputs, nullptr); + EXPECT_EQ(outputs.size(), 1); + + auto output_join = outputs[0]; + + EXPECT_EQ(output_join->Children()[0], right_get); + EXPECT_EQ(output_join->Children()[1], left_get); + EXPECT_EQ(output_join->Op().As()->type, JoinType::LEFT); +} + +TEST_F(OptimizerRuleTests, OuterJoinCommutativeRuleTest) { + // Build op plan node to match rule + auto left_get = std::make_shared(LogicalGet::make()); + auto right_get = std::make_shared(LogicalGet::make()); + auto join = + std::make_shared(LogicalJoin::make(JoinType::OUTER)); + join->PushChild(left_get); + join->PushChild(right_get); + + // Setup rule + JoinCommutativity rule; + + EXPECT_TRUE(rule.Check(join, nullptr)); + + std::vector> outputs; + rule.Transform(join, outputs, nullptr); + EXPECT_EQ(outputs.size(), 1); + + auto output_join = outputs[0]; + + EXPECT_EQ(output_join->Children()[0], right_get); + EXPECT_EQ(output_join->Children()[1], left_get); + EXPECT_EQ(output_join->Op().As()->type, JoinType::OUTER); +} + TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) { // Start Join Structure: (left JOIN middle) JOIN right // End Join Structure: left JOIN (middle JOIN right) From 9946f1eb13dec51992bb49b2f7025134bf2acddf Mon Sep 17 00:00:00 2001 From: Guoquan Zhao Date: Wed, 11 Apr 2018 15:14:18 -0400 Subject: [PATCH 13/26] fix a bug that the ApplyRule with JOIN_COMMUTE task never gets executed even if we actually push it to the stack. --- src/optimizer/group_expression.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimizer/group_expression.cpp b/src/optimizer/group_expression.cpp index 4d874bd27ef..379f00c000f 100644 --- a/src/optimizer/group_expression.cpp +++ b/src/optimizer/group_expression.cpp @@ -95,7 +95,7 @@ bool GroupExpression::operator==(const GroupExpression &r) { } void GroupExpression::SetRuleExplored(Rule *rule) { - rule_mask_.set(rule->GetRuleIdx()) = true; + rule_mask_.set(rule->GetRuleIdx()); } bool GroupExpression::HasRuleExplored(Rule *rule) { From b691aa6fc62c1552e22a4b2f4f040b2865ea2065 Mon Sep 17 00:00:00 2001 From: Irene Qiuwen Kai Date: Thu, 12 Apr 2018 23:15:15 -0400 Subject: [PATCH 14/26] Revert "fix a bug that the ApplyRule with JOIN_COMMUTE task never gets executed even if we actually push it to the stack." This reverts commit 078c0607dd4b0eac2638ae474adcd118128930a6. --- src/optimizer/group_expression.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimizer/group_expression.cpp b/src/optimizer/group_expression.cpp index 379f00c000f..4d874bd27ef 100644 --- a/src/optimizer/group_expression.cpp +++ b/src/optimizer/group_expression.cpp @@ -95,7 +95,7 @@ bool GroupExpression::operator==(const GroupExpression &r) { } void GroupExpression::SetRuleExplored(Rule *rule) { - rule_mask_.set(rule->GetRuleIdx()); + rule_mask_.set(rule->GetRuleIdx()) = true; } bool GroupExpression::HasRuleExplored(Rule *rule) { From c76b5bb48026978eb1ebd65a3b4e7985a6c903e4 Mon Sep 17 00:00:00 2001 From: Irene Qiuwen Kai Date: Thu, 12 Apr 2018 23:16:46 -0400 Subject: [PATCH 15/26] Add more associativity rules. --- src/include/optimizer/rule_impls.h | 3 ++ src/optimizer/rule_impls.cpp | 46 ++++++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/src/include/optimizer/rule_impls.h b/src/include/optimizer/rule_impls.h index b0faa1aabdb..2220247418e 100644 --- a/src/include/optimizer/rule_impls.h +++ b/src/include/optimizer/rule_impls.h @@ -51,6 +51,9 @@ class JoinAssociativity : public Rule { void Transform(std::shared_ptr input, std::vector> &transformed, OptimizeContext *context) const override; + private: + bool StrongPredicate(std::shared_ptr plan, + OptimizeContext *context) const; }; //===--------------------------------------------------------------------===// diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index 581f31a8617..5476d9bb672 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -96,10 +96,41 @@ JoinAssociativity::JoinAssociativity() { bool JoinAssociativity::Check(std::shared_ptr expr, OptimizeContext *context) const { (void)context; + // Associativity rules taken from + // http://15721.courses.cs.cmu.edu/spring2017/papers/14-optimizer1/p539-moerkotte.pdf + /* + * (A inner B) inner C = A inner (B inner C) + * (A right B) inner C = A right (B inner C) + * (A inner B) left C = A inner (B left C) + * (A left B) left C = A left (B left C) with Strong Predicates PBC + * (A full B) left C = A full (B left C) with Strong Predicate PBC + * (A right B) right C = A right (B right C) with Strong Predicate PAB + * (A right B) full C = A right (B full C) with Strong Predicate PAB + * (A full B) full C = A full (B full C) with Strong Predicate PAB & PBC + */ auto parent_join = expr->Op().As(); std::vector> children = expr->Children(); auto child_join = children[0]->Op().As(); - return (parent_join->type == child_join->type); + if (parent_join->type == JoinType::INNER) { + if (child_join->type == JoinType::INNER || child_join->type == JoinType::RIGHT) { + return true; + } + } else if (parent_join->type == JoinType::LEFT) { + if (child_join->type == JoinType::INNER) { + return true; + } else if (child_join->type == JoinType::LEFT || child_join->type == JoinType::OUTER) { + return StrongPredicate(expr, context); + } + } else if (parent_join->type == JoinType::RIGHT) { + if (child_join->type == JoinType::RIGHT) { + return StrongPredicate(expr, context); + } + } else if (parent_join->type == JoinType::OUTER) { + if (child_join->type == JoinType::RIGHT || child_join->type == JoinType::OUTER) { + return StrongPredicate(expr, context); + } + } + return false; } void JoinAssociativity::Transform( @@ -164,8 +195,8 @@ void JoinAssociativity::Transform( JoinType new_parent_join_type; JoinType new_child_join_type; - new_parent_join_type = parent_join->type; - new_child_join_type = child_join->type; + new_parent_join_type = child_join->type; + new_child_join_type = parent_join->type; // Construct new child join operator std::shared_ptr new_child_join = std::make_shared( @@ -183,6 +214,15 @@ void JoinAssociativity::Transform( transformed.push_back(new_parent_join); } +// TODO: some associativity rules can only be applied when the predicate is strong +// To check if the predicate is strong or not is non-trivial +bool JoinAssociativity::StrongPredicate(std::shared_ptr expr, + OptimizeContext *context) const { + (void)context; + (void)expr; + return false; +} + //===--------------------------------------------------------------------===// // Implementation rules //===--------------------------------------------------------------------===// From 53a08861e86352a7b4c1c2f81f18db4b85b6a012 Mon Sep 17 00:00:00 2001 From: Irene Qiuwen Kai Date: Sun, 29 Apr 2018 23:29:28 -0400 Subject: [PATCH 16/26] Initial modifications according to first Code Review 1. Refacored code, added necessary comments. 2. Removed PhysicalInnerNLJoin and PhysicalInnerHashJoin (left/right/outer as well). --- script/testing/junit/OptimizerTest.java | 19 +- .../optimizer/child_property_deriver.h | 8 - src/include/optimizer/cost_calculator.h | 8 - src/include/optimizer/input_column_deriver.h | 16 -- src/include/optimizer/operator_visitor.h | 8 - src/include/optimizer/operators.h | 108 +-------- src/include/optimizer/plan_generator.h | 16 -- src/optimizer/child_property_deriver.cpp | 18 +- src/optimizer/cost_calculator.cpp | 30 +-- src/optimizer/input_column_deriver.cpp | 20 -- src/optimizer/operators.cpp | 220 ++---------------- src/optimizer/plan_generator.cpp | 2 +- src/optimizer/rule_impls.cpp | 2 +- test/optimizer/optimizer_rule_test.cpp | 12 + 14 files changed, 62 insertions(+), 425 deletions(-) diff --git a/script/testing/junit/OptimizerTest.java b/script/testing/junit/OptimizerTest.java index 5db13511273..382af03737e 100644 --- a/script/testing/junit/OptimizerTest.java +++ b/script/testing/junit/OptimizerTest.java @@ -74,7 +74,7 @@ public void Teardown() throws SQLException { @Test - public void testInnerJoin() throws SQLException { + public void testInnerJoin1() throws SQLException { try ( Statement stmt = conn.createStatement(); ResultSet resultSet = stmt.executeQuery("SELECT t1.a FROM t1 INNER JOIN t2 ON (t1.b = t2.b) ORDER BY t1.a;");) { @@ -87,6 +87,10 @@ public void testInnerJoin() throws SQLException { e.printStackTrace(); fail(); } + } + + @Test + public void testInnerJoin2() throws SQLException { try ( Statement stmt = conn.createStatement(); ResultSet resultSet = stmt.executeQuery("SELECT x.a FROM t1 AS x INNER JOIN t2 ON(x.b = t2.b AND x.c = t2.c) ORDER BY x.a;");) { @@ -99,12 +103,10 @@ public void testInnerJoin() throws SQLException { e.printStackTrace(); fail(); } - - } @Test - public void testLeftOuterJoin() throws SQLException { + public void testLeftOuterJoin1() throws SQLException { try ( Statement stmt = conn.createStatement(); ResultSet resultSet = stmt.executeQuery("SELECT * FROM t1 LEFT JOIN t2 ON t1.a=t2.d;");) { @@ -128,6 +130,10 @@ public void testLeftOuterJoin() throws SQLException { e.printStackTrace(); fail(); } + } + + @Test + public void testLeftOuterJoin2() throws SQLException { try ( Statement stmt = conn.createStatement(); ResultSet resultSet = stmt.executeQuery("SELECT * FROM t1 LEFT JOIN t2 ON t1.a=t2.d WHERE t1.a>1")) { @@ -147,6 +153,10 @@ public void testLeftOuterJoin() throws SQLException { e.printStackTrace(); fail(); } + } + + @Test + public void testLeftOuterJoin3() throws SQLException { try ( Statement stmt = conn.createStatement(); ResultSet resultSet = stmt.executeQuery("SELECT * FROM t1 LEFT OUTER JOIN t2 ON t1.a=t2.d WHERE t1.a>1")) { @@ -166,7 +176,6 @@ public void testLeftOuterJoin() throws SQLException { e.printStackTrace(); fail(); } - } @Test diff --git a/src/include/optimizer/child_property_deriver.h b/src/include/optimizer/child_property_deriver.h index 152ef60207e..c01c37389c7 100644 --- a/src/include/optimizer/child_property_deriver.h +++ b/src/include/optimizer/child_property_deriver.h @@ -44,14 +44,6 @@ class ChildPropertyDeriver : public OperatorVisitor { void Visit(const PhysicalLimit *) override; void Visit(const PhysicalNLJoin *) override; void Visit(const PhysicalHashJoin *) override; - void Visit(const PhysicalInnerNLJoin *) override; - void Visit(const PhysicalLeftNLJoin *) override; - void Visit(const PhysicalRightNLJoin *) override; - void Visit(const PhysicalOuterNLJoin *) override; - void Visit(const PhysicalInnerHashJoin *) override; - void Visit(const PhysicalLeftHashJoin *) override; - void Visit(const PhysicalRightHashJoin *) override; - void Visit(const PhysicalOuterHashJoin *) override; void Visit(const PhysicalInsert *) override; void Visit(const PhysicalInsertSelect *) override; void Visit(const PhysicalDelete *) override; diff --git a/src/include/optimizer/cost_calculator.h b/src/include/optimizer/cost_calculator.h index c3c0a31e6e0..87410d0bb07 100644 --- a/src/include/optimizer/cost_calculator.h +++ b/src/include/optimizer/cost_calculator.h @@ -32,14 +32,6 @@ class CostCalculator : public OperatorVisitor { void Visit(const PhysicalLimit *) override; void Visit(const PhysicalNLJoin *) override; void Visit(const PhysicalHashJoin *) override; - void Visit(const PhysicalInnerNLJoin *) override; - void Visit(const PhysicalLeftNLJoin *) override; - void Visit(const PhysicalRightNLJoin *) override; - void Visit(const PhysicalOuterNLJoin *) override; - void Visit(const PhysicalInnerHashJoin *) override; - void Visit(const PhysicalLeftHashJoin *) override; - void Visit(const PhysicalRightHashJoin *) override; - void Visit(const PhysicalOuterHashJoin *) override; void Visit(const PhysicalInsert *) override; void Visit(const PhysicalInsertSelect *) override; void Visit(const PhysicalDelete *) override; diff --git a/src/include/optimizer/input_column_deriver.h b/src/include/optimizer/input_column_deriver.h index d105aa06561..792ebe52bfb 100644 --- a/src/include/optimizer/input_column_deriver.h +++ b/src/include/optimizer/input_column_deriver.h @@ -63,22 +63,6 @@ class InputColumnDeriver : public OperatorVisitor { void Visit(const PhysicalHashJoin *) override; - void Visit(const PhysicalInnerNLJoin *) override; - - void Visit(const PhysicalLeftNLJoin *) override; - - void Visit(const PhysicalRightNLJoin *) override; - - void Visit(const PhysicalOuterNLJoin *) override; - - void Visit(const PhysicalInnerHashJoin *) override; - - void Visit(const PhysicalLeftHashJoin *) override; - - void Visit(const PhysicalRightHashJoin *) override; - - void Visit(const PhysicalOuterHashJoin *) override; - void Visit(const PhysicalInsert *) override; void Visit(const PhysicalInsertSelect *) override; diff --git a/src/include/optimizer/operator_visitor.h b/src/include/optimizer/operator_visitor.h index 1644e26c1ff..4638c1f74ce 100644 --- a/src/include/optimizer/operator_visitor.h +++ b/src/include/optimizer/operator_visitor.h @@ -34,14 +34,6 @@ class OperatorVisitor { virtual void Visit(const PhysicalLimit *) {} virtual void Visit(const PhysicalNLJoin *) {} virtual void Visit(const PhysicalHashJoin *) {} - virtual void Visit(const PhysicalInnerNLJoin *) {} - virtual void Visit(const PhysicalLeftNLJoin *) {} - virtual void Visit(const PhysicalRightNLJoin *) {} - virtual void Visit(const PhysicalOuterNLJoin *) {} - virtual void Visit(const PhysicalInnerHashJoin *) {} - virtual void Visit(const PhysicalLeftHashJoin *) {} - virtual void Visit(const PhysicalRightHashJoin *) {} - virtual void Visit(const PhysicalOuterHashJoin *) {} virtual void Visit(const PhysicalInsert *) {} virtual void Visit(const PhysicalInsertSelect *) {} virtual void Visit(const PhysicalDelete *) {} diff --git a/src/include/optimizer/operators.h b/src/include/optimizer/operators.h index 0482afba5aa..2b2bcfcc43f 100644 --- a/src/include/optimizer/operators.h +++ b/src/include/optimizer/operators.h @@ -167,9 +167,9 @@ class LogicalSingleJoin : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalJoin : public OperatorNode { public: - static Operator make(JoinType _type); + static Operator make(JoinType type); - static Operator make(JoinType _type, + static Operator make(JoinType type, std::vector &conditions); bool operator==(const BaseOperatorNode &r) override; @@ -390,7 +390,7 @@ class PhysicalLimit : public OperatorNode { class PhysicalNLJoin : public OperatorNode { public: static Operator make( - JoinType _type, std::vector conditions, + JoinType type, std::vector conditions, std::vector> &left_keys, std::vector> &right_keys); @@ -411,7 +411,7 @@ class PhysicalNLJoin : public OperatorNode { class PhysicalHashJoin : public OperatorNode { public: static Operator make( - JoinType _type, std::vector conditions, + JoinType type, std::vector conditions, std::vector> &left_keys, std::vector> &right_keys); @@ -426,106 +426,6 @@ class PhysicalHashJoin : public OperatorNode { JoinType type; }; -//===--------------------------------------------------------------------===// -// InnerNLJoin -//===--------------------------------------------------------------------===// -class PhysicalInnerNLJoin : public OperatorNode { - public: - static Operator make( - std::vector conditions, - std::vector> &left_keys, - std::vector> &right_keys); - - bool operator==(const BaseOperatorNode &r) override; - - hash_t Hash() const override; - - std::vector> left_keys; - std::vector> right_keys; - - std::vector join_predicates; -}; - -//===--------------------------------------------------------------------===// -// LeftNLJoin -//===--------------------------------------------------------------------===// -class PhysicalLeftNLJoin : public OperatorNode { - public: - std::shared_ptr join_predicate; - static Operator make( - std::shared_ptr join_predicate); -}; - -//===--------------------------------------------------------------------===// -// RightNLJoin -//===--------------------------------------------------------------------===// -class PhysicalRightNLJoin : public OperatorNode { - public: - std::shared_ptr join_predicate; - static Operator make( - std::shared_ptr join_predicate); -}; - -//===--------------------------------------------------------------------===// -// OuterNLJoin -//===--------------------------------------------------------------------===// -class PhysicalOuterNLJoin : public OperatorNode { - public: - std::shared_ptr join_predicate; - static Operator make( - std::shared_ptr join_predicate); -}; - -//===--------------------------------------------------------------------===// -// InnerHashJoin -//===--------------------------------------------------------------------===// -class PhysicalInnerHashJoin : public OperatorNode { - public: - static Operator make( - std::vector conditions, - std::vector> &left_keys, - std::vector> &right_keys); - - bool operator==(const BaseOperatorNode &r) override; - - hash_t Hash() const override; - - std::vector> left_keys; - std::vector> right_keys; - - std::vector join_predicates; -}; - -//===--------------------------------------------------------------------===// -// LeftHashJoin -//===--------------------------------------------------------------------===// -class PhysicalLeftHashJoin : public OperatorNode { - public: - std::shared_ptr join_predicate; - static Operator make( - std::shared_ptr join_predicate); -}; - -//===--------------------------------------------------------------------===// -// RightHashJoin -//===--------------------------------------------------------------------===// -class PhysicalRightHashJoin : public OperatorNode { - public: - std::shared_ptr join_predicate; - static Operator make( - std::shared_ptr join_predicate); -}; - -//===--------------------------------------------------------------------===// -// OuterHashJoin -//===--------------------------------------------------------------------===// -class PhysicalOuterHashJoin : public OperatorNode { - public: - std::shared_ptr join_predicate; - static Operator make( - std::shared_ptr join_predicate); -}; - //===--------------------------------------------------------------------===// // PhysicalInsert //===--------------------------------------------------------------------===// diff --git a/src/include/optimizer/plan_generator.h b/src/include/optimizer/plan_generator.h index 20935377c3e..11184a05df6 100644 --- a/src/include/optimizer/plan_generator.h +++ b/src/include/optimizer/plan_generator.h @@ -64,22 +64,6 @@ class PlanGenerator : public OperatorVisitor { void Visit(const PhysicalHashJoin *) override; - void Visit(const PhysicalInnerNLJoin *) override; - - void Visit(const PhysicalLeftNLJoin *) override; - - void Visit(const PhysicalRightNLJoin *) override; - - void Visit(const PhysicalOuterNLJoin *) override; - - void Visit(const PhysicalInnerHashJoin *) override; - - void Visit(const PhysicalLeftHashJoin *) override; - - void Visit(const PhysicalRightHashJoin *) override; - - void Visit(const PhysicalOuterHashJoin *) override; - void Visit(const PhysicalInsert *) override; void Visit(const PhysicalInsertSelect *) override; diff --git a/src/optimizer/child_property_deriver.cpp b/src/optimizer/child_property_deriver.cpp index a72f64883a7..b216d49dbd8 100644 --- a/src/optimizer/child_property_deriver.cpp +++ b/src/optimizer/child_property_deriver.cpp @@ -143,43 +143,37 @@ void ChildPropertyDeriver::Visit(const PhysicalDistinct *) { output_.push_back(make_pair(requirements_, move(child_input_properties))); } + void ChildPropertyDeriver::Visit(const PhysicalOrderBy *) {} + void ChildPropertyDeriver::Visit(const PhysicalNLJoin *) { DeriveForJoin(); } + void ChildPropertyDeriver::Visit(const PhysicalHashJoin *) { DeriveForJoin(); } -void ChildPropertyDeriver::Visit(const PhysicalInnerNLJoin *) { - DeriveForJoin(); -} -void ChildPropertyDeriver::Visit(const PhysicalLeftNLJoin *) {} -void ChildPropertyDeriver::Visit(const PhysicalRightNLJoin *) {} -void ChildPropertyDeriver::Visit(const PhysicalOuterNLJoin *) {} -void ChildPropertyDeriver::Visit(const PhysicalInnerHashJoin *) { - DeriveForJoin(); -} -void ChildPropertyDeriver::Visit(const PhysicalLeftHashJoin *) {} -void ChildPropertyDeriver::Visit(const PhysicalRightHashJoin *) {} -void ChildPropertyDeriver::Visit(const PhysicalOuterHashJoin *) {} void ChildPropertyDeriver::Visit(const PhysicalInsert *) { vector> child_input_properties; output_.push_back(make_pair(requirements_, move(child_input_properties))); } + void ChildPropertyDeriver::Visit(const PhysicalInsertSelect *) { // Let child fulfil all the required properties vector> child_input_properties{requirements_}; output_.push_back(make_pair(requirements_, move(child_input_properties))); } + void ChildPropertyDeriver::Visit(const PhysicalUpdate *) { // Let child fulfil all the required properties vector> child_input_properties{requirements_}; output_.push_back(make_pair(requirements_, move(child_input_properties))); } + void ChildPropertyDeriver::Visit(const PhysicalDelete *) { // Let child fulfil all the required properties vector> child_input_properties{requirements_}; diff --git a/src/optimizer/cost_calculator.cpp b/src/optimizer/cost_calculator.cpp index 1b40743473b..98f95013294 100644 --- a/src/optimizer/cost_calculator.cpp +++ b/src/optimizer/cost_calculator.cpp @@ -72,6 +72,7 @@ void CostCalculator::Visit(const PhysicalLimit *op) { output_cost_ = std::min((size_t)child_num_rows, (size_t)op->limit) * DEFAULT_TUPLE_COST; } + void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalNLJoin *op) { auto left_child_rows = memo_->GetGroupByID(gexpr_->GetChildGroupId(0))->GetNumRows(); @@ -80,6 +81,7 @@ void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalNLJoin *op) { output_cost_ = left_child_rows * right_child_rows * DEFAULT_TUPLE_COST; } + void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalHashJoin *op) { auto left_child_rows = memo_->GetGroupByID(gexpr_->GetChildGroupId(0))->GetNumRows(); @@ -88,46 +90,32 @@ void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalHashJoin *op) { output_cost_ = left_child_rows * right_child_rows * DEFAULT_TUPLE_COST; } -void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalInnerNLJoin *op) { - auto left_child_rows = - memo_->GetGroupByID(gexpr_->GetChildGroupId(0))->GetNumRows(); - auto right_child_rows = - memo_->GetGroupByID(gexpr_->GetChildGroupId(1))->GetNumRows(); - output_cost_ = left_child_rows * right_child_rows * DEFAULT_TUPLE_COST; -} -void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalLeftNLJoin *op) {} -void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalRightNLJoin *op) {} -void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalOuterNLJoin *op) {} -void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalInnerHashJoin *op) { - auto left_child_rows = - memo_->GetGroupByID(gexpr_->GetChildGroupId(0))->GetNumRows(); - auto right_child_rows = - memo_->GetGroupByID(gexpr_->GetChildGroupId(1))->GetNumRows(); - // TODO(boweic): Build (left) table should have different cost to probe table - output_cost_ = (left_child_rows + right_child_rows) * DEFAULT_TUPLE_COST; -} -void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalLeftHashJoin *op) {} -void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalRightHashJoin *op) {} -void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalOuterHashJoin *op) {} void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalInsert *op) {} + void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalInsertSelect *op) {} + void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalDelete *op) {} + void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalUpdate *op) {} + void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalHashGroupBy *op) { // TODO(boweic): Integrate hash in groupby may cause us to miss the // opportunity to further optimize some query where the child output is // already hashed by the GroupBy key, we'll do a hash anyway output_cost_ = HashCost() + GroupByCost(); } + void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalSortGroupBy *op) { // Sort group by does not sort the tuples, it requires input columns to be // sorted output_cost_ = GroupByCost(); } + void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalDistinct *op) { output_cost_ = HashCost(); } + void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalAggregate *op) { // TODO(boweic): Ditto, separate groupby operator and implementation(e.g. // hash, sort) may enable opportunity for further optimization diff --git a/src/optimizer/input_column_deriver.cpp b/src/optimizer/input_column_deriver.cpp index c321fcc182f..c942d973de9 100644 --- a/src/optimizer/input_column_deriver.cpp +++ b/src/optimizer/input_column_deriver.cpp @@ -131,26 +131,6 @@ void InputColumnDeriver::Visit(const PhysicalHashJoin *op) { JoinHelper(op); } -void InputColumnDeriver::Visit(const PhysicalInnerNLJoin *op) { - JoinHelper(op); -} - -void InputColumnDeriver::Visit(const PhysicalLeftNLJoin *) {} - -void InputColumnDeriver::Visit(const PhysicalRightNLJoin *) {} - -void InputColumnDeriver::Visit(const PhysicalOuterNLJoin *) {} - -void InputColumnDeriver::Visit(const PhysicalInnerHashJoin *op) { - JoinHelper(op); -} - -void InputColumnDeriver::Visit(const PhysicalLeftHashJoin *) {} - -void InputColumnDeriver::Visit(const PhysicalRightHashJoin *) {} - -void InputColumnDeriver::Visit(const PhysicalOuterHashJoin *) {} - void InputColumnDeriver::Visit(const PhysicalInsert *) { output_input_cols_ = pair, vector>>{ diff --git a/src/optimizer/operators.cpp b/src/optimizer/operators.cpp index 4ad8414f872..232dbb51349 100644 --- a/src/optimizer/operators.cpp +++ b/src/optimizer/operators.cpp @@ -233,18 +233,18 @@ bool LogicalSingleJoin::operator==(const BaseOperatorNode &r) { //===--------------------------------------------------------------------===// // Join (Inner + Outer Joins) //===--------------------------------------------------------------------===// -Operator LogicalJoin::make(JoinType _type) { +Operator LogicalJoin::make(JoinType type) { LogicalJoin *join = new LogicalJoin; join->join_predicates = {}; - join->type = _type; + join->type = type; return Operator(join); } -Operator LogicalJoin::make(JoinType _type, +Operator LogicalJoin::make(JoinType type, std::vector &conditions) { LogicalJoin *join = new LogicalJoin; join->join_predicates = std::move(conditions); - join->type = _type; + join->type = type; return Operator(join); } @@ -531,25 +531,25 @@ Operator PhysicalLimit::make(int64_t offset, int64_t limit) { // NLJoin (Inner + Outer Joins) //===--------------------------------------------------------------------===// Operator PhysicalNLJoin::make( - JoinType _type, std::vector conditions, + JoinType type, std::vector conditions, std::vector> &left_keys, std::vector> &right_keys) { PhysicalNLJoin *join = new PhysicalNLJoin(); join->join_predicates = std::move(conditions); join->left_keys = std::move(left_keys); join->right_keys = std::move(right_keys); - join->type = _type; + join->type = type; return Operator(join); } hash_t PhysicalNLJoin::Hash() const { hash_t hash = BaseOperatorNode::Hash(); - for (auto &expr : left_keys) + for (const auto &expr : left_keys) hash = HashUtil::CombineHashes(hash, expr->Hash()); - for (auto &expr : right_keys) + for (const auto &expr : right_keys) hash = HashUtil::CombineHashes(hash, expr->Hash()); - for (auto &pred : join_predicates) + for (const auto &pred : join_predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); return hash; } @@ -581,24 +581,24 @@ bool PhysicalNLJoin::operator==(const BaseOperatorNode &r) { // HashJoin //===--------------------------------------------------------------------===// Operator PhysicalHashJoin::make( - JoinType _type, std::vector conditions, + JoinType type, std::vector conditions, std::vector> &left_keys, std::vector> &right_keys) { PhysicalHashJoin *join = new PhysicalHashJoin(); join->join_predicates = std::move(conditions); join->left_keys = std::move(left_keys); join->right_keys = std::move(right_keys); - join->type = _type; + join->type = type; return Operator(join); } hash_t PhysicalHashJoin::Hash() const { hash_t hash = BaseOperatorNode::Hash(); - for (auto &expr : left_keys) + for (const auto &expr : left_keys) hash = HashUtil::CombineHashes(hash, expr->Hash()); - for (auto &expr : right_keys) + for (const auto &expr : right_keys) hash = HashUtil::CombineHashes(hash, expr->Hash()); - for (auto &pred : join_predicates) + for (const auto &pred : join_predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); return hash; } @@ -625,161 +625,6 @@ bool PhysicalHashJoin::operator==(const BaseOperatorNode &r) { return true; } -//===--------------------------------------------------------------------===// -// InnerNLJoin -//===--------------------------------------------------------------------===// -Operator PhysicalInnerNLJoin::make( - std::vector conditions, - std::vector> &left_keys, - std::vector> &right_keys) { - PhysicalInnerNLJoin *join = new PhysicalInnerNLJoin(); - join->join_predicates = std::move(conditions); - join->left_keys = std::move(left_keys); - join->right_keys = std::move(right_keys); - - return Operator(join); -} - -hash_t PhysicalInnerNLJoin::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); - for (auto &expr : left_keys) - hash = HashUtil::CombineHashes(hash, expr->Hash()); - for (auto &expr : right_keys) - hash = HashUtil::CombineHashes(hash, expr->Hash()); - for (auto &pred : join_predicates) - hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); - return hash; -} - -bool PhysicalInnerNLJoin::operator==(const BaseOperatorNode &r) { - if (r.GetType() != OpType::InnerNLJoin) return false; - const PhysicalInnerNLJoin &node = - *static_cast(&r); - if (join_predicates.size() != node.join_predicates.size() || - left_keys.size() != node.left_keys.size() || - right_keys.size() != node.right_keys.size()) - return false; - for (size_t i = 0; i < left_keys.size(); i++) { - if (!left_keys[i]->ExactlyEquals(*node.left_keys[i].get())) return false; - } - for (size_t i = 0; i < right_keys.size(); i++) { - if (!right_keys[i]->ExactlyEquals(*node.right_keys[i].get())) return false; - } - for (size_t i = 0; i < join_predicates.size(); i++) { - if (!join_predicates[i].expr->ExactlyEquals( - *node.join_predicates[i].expr.get())) - return false; - } - return true; -} - -//===--------------------------------------------------------------------===// -// LeftNLJoin -//===--------------------------------------------------------------------===// -Operator PhysicalLeftNLJoin::make( - std::shared_ptr join_predicate) { - PhysicalLeftNLJoin *join = new PhysicalLeftNLJoin(); - join->join_predicate = join_predicate; - return Operator(join); -} - -//===--------------------------------------------------------------------===// -// RightNLJoin -//===--------------------------------------------------------------------===// -Operator PhysicalRightNLJoin::make( - std::shared_ptr join_predicate) { - PhysicalRightNLJoin *join = new PhysicalRightNLJoin(); - join->join_predicate = join_predicate; - return Operator(join); -} - -//===--------------------------------------------------------------------===// -// OuterNLJoin -//===--------------------------------------------------------------------===// -Operator PhysicalOuterNLJoin::make( - std::shared_ptr join_predicate) { - PhysicalOuterNLJoin *join = new PhysicalOuterNLJoin(); - join->join_predicate = join_predicate; - return Operator(join); -} - -//===--------------------------------------------------------------------===// -// InnerHashJoin -//===--------------------------------------------------------------------===// -Operator PhysicalInnerHashJoin::make( - std::vector conditions, - std::vector> &left_keys, - std::vector> &right_keys) { - PhysicalInnerHashJoin *join = new PhysicalInnerHashJoin(); - join->join_predicates = std::move(conditions); - join->left_keys = std::move(left_keys); - join->right_keys = std::move(right_keys); - return Operator(join); -} - -hash_t PhysicalInnerHashJoin::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); - for (auto &expr : left_keys) - hash = HashUtil::CombineHashes(hash, expr->Hash()); - for (auto &expr : right_keys) - hash = HashUtil::CombineHashes(hash, expr->Hash()); - for (auto &pred : join_predicates) - hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); - return hash; -} - -bool PhysicalInnerHashJoin::operator==(const BaseOperatorNode &r) { - if (r.GetType() != OpType::InnerHashJoin) return false; - const PhysicalInnerHashJoin &node = - *static_cast(&r); - if (join_predicates.size() != node.join_predicates.size() || - left_keys.size() != node.left_keys.size() || - right_keys.size() != node.right_keys.size()) - return false; - for (size_t i = 0; i < left_keys.size(); i++) { - if (!left_keys[i]->ExactlyEquals(*node.left_keys[i].get())) return false; - } - for (size_t i = 0; i < right_keys.size(); i++) { - if (!right_keys[i]->ExactlyEquals(*node.right_keys[i].get())) return false; - } - for (size_t i = 0; i < join_predicates.size(); i++) { - if (!join_predicates[i].expr->ExactlyEquals( - *node.join_predicates[i].expr.get())) - return false; - } - return true; -} - -//===--------------------------------------------------------------------===// -// LeftHashJoin -//===--------------------------------------------------------------------===// -Operator PhysicalLeftHashJoin::make( - std::shared_ptr join_predicate) { - PhysicalLeftHashJoin *join = new PhysicalLeftHashJoin(); - join->join_predicate = join_predicate; - return Operator(join); -} - -//===--------------------------------------------------------------------===// -// RightHashJoin -//===--------------------------------------------------------------------===// -Operator PhysicalRightHashJoin::make( - std::shared_ptr join_predicate) { - PhysicalRightHashJoin *join = new PhysicalRightHashJoin(); - join->join_predicate = join_predicate; - return Operator(join); -} - -//===--------------------------------------------------------------------===// -// OuterHashJoin -//===--------------------------------------------------------------------===// -Operator PhysicalOuterHashJoin::make( - std::shared_ptr join_predicate) { - PhysicalOuterHashJoin *join = new PhysicalOuterHashJoin(); - join->join_predicate = join_predicate; - return Operator(join); -} - //===--------------------------------------------------------------------===// // PhysicalInsert //===--------------------------------------------------------------------===// @@ -966,25 +811,6 @@ std::string OperatorNode::name_ = "PhysicalNLJoin"; template <> std::string OperatorNode::name_ = "PhysicalHashJoin"; template <> -std::string OperatorNode::name_ = "PhysicalInnerNLJoin"; -template <> -std::string OperatorNode::name_ = "PhysicalLeftNLJoin"; -template <> -std::string OperatorNode::name_ = "PhysicalRightNLJoin"; -template <> -std::string OperatorNode::name_ = "PhysicalOuterNLJoin"; -template <> -std::string OperatorNode::name_ = - "PhysicalInnerHashJoin"; -template <> -std::string OperatorNode::name_ = "PhysicalLeftHashJoin"; -template <> -std::string OperatorNode::name_ = - "PhysicalRightHashJoin"; -template <> -std::string OperatorNode::name_ = - "PhysicalOuterHashJoin"; -template <> std::string OperatorNode::name_ = "PhysicalInsert"; template <> std::string OperatorNode::name_ = "PhysicalInsertSelect"; @@ -1053,25 +879,9 @@ OpType OperatorNode::type_ = OpType::Distinct; template <> OpType OperatorNode::type_ = OpType::PhysicalLimit; template <> -OpType OperatorNode::type_ = OpType::HashJoin; -template <> OpType OperatorNode::type_ = OpType::NLJoin; template <> -OpType OperatorNode::type_ = OpType::InnerNLJoin; -template <> -OpType OperatorNode::type_ = OpType::LeftNLJoin; -template <> -OpType OperatorNode::type_ = OpType::RightNLJoin; -template <> -OpType OperatorNode::type_ = OpType::OuterNLJoin; -template <> -OpType OperatorNode::type_ = OpType::InnerHashJoin; -template <> -OpType OperatorNode::type_ = OpType::LeftHashJoin; -template <> -OpType OperatorNode::type_ = OpType::RightHashJoin; -template <> -OpType OperatorNode::type_ = OpType::OuterHashJoin; +OpType OperatorNode::type_ = OpType::HashJoin; template <> OpType OperatorNode::type_ = OpType::Insert; template <> diff --git a/src/optimizer/plan_generator.cpp b/src/optimizer/plan_generator.cpp index c0fe4efe1ea..2e69da41cbc 100644 --- a/src/optimizer/plan_generator.cpp +++ b/src/optimizer/plan_generator.cpp @@ -237,7 +237,7 @@ void PlanGenerator::Visit(const PhysicalHashJoin *op) { } // Evaluate Expr for hash plan vector> hash_keys; - for (auto &expr : op->right_keys) { + for (const auto &expr : op->right_keys) { auto hash_key = expr->Copy(); expression::ExpressionUtil::EvaluateExpression(r_child_map, hash_key); hash_keys.emplace_back(hash_key); diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index 5476d9bb672..a0b3e479e67 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -689,7 +689,7 @@ void JoinToNLJoin::Transform( std::shared_ptr input, std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - // first build an expression representing hash join + // first build an expression representing nested loop join const LogicalJoin *join = input->Op().As(); auto children = input->Children(); diff --git a/test/optimizer/optimizer_rule_test.cpp b/test/optimizer/optimizer_rule_test.cpp index d8ec828035a..7aa26ecf813 100644 --- a/test/optimizer/optimizer_rule_test.cpp +++ b/test/optimizer/optimizer_rule_test.cpp @@ -50,6 +50,9 @@ using namespace optimizer; class OptimizerRuleTests : public PelotonTest {}; TEST_F(OptimizerRuleTests, SimpleCommutativeRuleTest) { + // Start Join Structure: left JOIN right + // End Join Structure: right JOIN left + // Build op plan node to match rule auto left_get = std::make_shared(LogicalGet::make()); auto right_get = std::make_shared(LogicalGet::make()); @@ -74,6 +77,9 @@ TEST_F(OptimizerRuleTests, SimpleCommutativeRuleTest) { } TEST_F(OptimizerRuleTests, LeftJoinCommutativeRuleTest) { + // Start Join Structure: left LEFT JOIN right + // End Join Structure: right RIGHT JOIN left + // Build op plan node to match rule auto left_get = std::make_shared(LogicalGet::make()); auto right_get = std::make_shared(LogicalGet::make()); @@ -99,6 +105,9 @@ TEST_F(OptimizerRuleTests, LeftJoinCommutativeRuleTest) { } TEST_F(OptimizerRuleTests, RightJoinCommutativeRuleTest) { + // Start Join Structure: left RIGHT JOIN right + // End Join Structure: right LEFT JOIN left + // Build op plan node to match rule auto left_get = std::make_shared(LogicalGet::make()); auto right_get = std::make_shared(LogicalGet::make()); @@ -124,6 +133,9 @@ TEST_F(OptimizerRuleTests, RightJoinCommutativeRuleTest) { } TEST_F(OptimizerRuleTests, OuterJoinCommutativeRuleTest) { + // Start Join Structure: left OUTER JOIN right + // End Join Structure: right OUTER JOIN left + // Build op plan node to match rule auto left_get = std::make_shared(LogicalGet::make()); auto right_get = std::make_shared(LogicalGet::make()); From 686bd779d800986eadabb2df6e1e70a71f1786d4 Mon Sep 17 00:00:00 2001 From: Guoquan Zhao Date: Mon, 30 Apr 2018 20:36:15 -0400 Subject: [PATCH 17/26] Use an util function to do the assertion in juint. --- script/testing/junit/OptimizerTest.java | 113 ++++++++++++------------ script/testing/junit/Utils.java | 62 +++++++++++++ 2 files changed, 117 insertions(+), 58 deletions(-) create mode 100644 script/testing/junit/Utils.java diff --git a/script/testing/junit/OptimizerTest.java b/script/testing/junit/OptimizerTest.java index 382af03737e..4991bd0ce2e 100644 --- a/script/testing/junit/OptimizerTest.java +++ b/script/testing/junit/OptimizerTest.java @@ -1,17 +1,22 @@ import org.junit.After; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; import java.io.BufferedReader; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; -import java.sql.*; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; import static java.sql.Statement.EXECUTE_FAILED; import static java.sql.Statement.SUCCESS_NO_INFO; import static org.junit.Assert.*; + /** * Created by Guoquan Zhao on 4/7/18. */ @@ -36,12 +41,11 @@ private void initTables1() throws FileNotFoundException, SQLException { assertTrue("batch failed.", (results[i] >= 0 || results[i] == SUCCESS_NO_INFO) && results[i] != EXECUTE_FAILED); } ResultSet resultSet = stmt.executeQuery("SELECT COUNT(*) FROM t1;"); - resultSet.next(); - assertEquals(3, resultSet.getInt(1)); + ExpectedResult expectedResult = new ExpectedResult("3"); + Utils.assertResultsSetEqual(resultSet, expectedResult); resultSet.close(); resultSet = stmt.executeQuery("SELECT COUNT(*) FROM t2;"); - resultSet.next(); - assertEquals(3, resultSet.getInt(1)); + Utils.assertResultsSetEqual(resultSet, expectedResult); resultSet.close(); } catch (IOException e) { e.printStackTrace(); @@ -59,8 +63,10 @@ public void Setup() { initTables1(); } catch (SQLException ex) { DumpSQLException(ex); + assertTrue(false); } catch (FileNotFoundException e) { e.printStackTrace(); + assertTrue(false); } } @@ -78,11 +84,9 @@ public void testInnerJoin1() throws SQLException { try ( Statement stmt = conn.createStatement(); ResultSet resultSet = stmt.executeQuery("SELECT t1.a FROM t1 INNER JOIN t2 ON (t1.b = t2.b) ORDER BY t1.a;");) { - assertTrue(resultSet.next()); - assertEquals(1, resultSet.getInt(1)); - assertTrue(resultSet.next()); - assertEquals(2, resultSet.getInt(1)); - assertFalse(resultSet.next()); + ExpectedResult expectedResult = new ExpectedResult("1\n" + + "2"); + Utils.assertResultsSetEqual(resultSet, expectedResult); } catch (Exception e) { e.printStackTrace(); fail(); @@ -94,11 +98,8 @@ public void testInnerJoin2() throws SQLException { try ( Statement stmt = conn.createStatement(); ResultSet resultSet = stmt.executeQuery("SELECT x.a FROM t1 AS x INNER JOIN t2 ON(x.b = t2.b AND x.c = t2.c) ORDER BY x.a;");) { - assertTrue(resultSet.next()); - assertEquals(1, resultSet.getInt(1)); - assertTrue(resultSet.next()); - assertEquals(2, resultSet.getInt(1)); - assertFalse(resultSet.next()); + ExpectedResult expectedResult = new ExpectedResult("1\n" + "2"); + Utils.assertResultsSetEqual(resultSet, expectedResult); } catch (Exception e) { e.printStackTrace(); fail(); @@ -110,75 +111,69 @@ public void testLeftOuterJoin1() throws SQLException { try ( Statement stmt = conn.createStatement(); ResultSet resultSet = stmt.executeQuery("SELECT * FROM t1 LEFT JOIN t2 ON t1.a=t2.d;");) { - assertTrue(resultSet.next()); - assertEquals(3, resultSet.getInt(4)); - assertEquals(4, resultSet.getInt(5)); - assertEquals(5, resultSet.getInt(6)); - assertEquals(1, resultSet.getInt(1)); - assertEquals(2, resultSet.getInt(2)); - assertEquals(3, resultSet.getInt(3)); - assertTrue(resultSet.next()); - assertEquals(null, resultSet.getObject(1)); - assertEquals(null, resultSet.getObject(2)); - assertEquals(null, resultSet.getObject(3)); - assertTrue(resultSet.next()); - assertEquals(null, resultSet.getObject(1)); - assertEquals(null, resultSet.getObject(2)); - assertEquals(null, resultSet.getObject(3)); - assertFalse(resultSet.next()); + String r = + "1|2|3|3|4|5\n" + + "null|null|null|1|2|3\n" + + "null|null|null|2|3|4\n"; + ExpectedResult expectedResult = new ExpectedResult(r); + Utils.assertResultsSetEqual(resultSet, expectedResult); } catch (Exception e) { e.printStackTrace(); fail(); } } + /** + * There is an bug in the executor, skip this test for now. + * + * @throws SQLException + */ @Test + @Ignore public void testLeftOuterJoin2() throws SQLException { try ( Statement stmt = conn.createStatement(); ResultSet resultSet = stmt.executeQuery("SELECT * FROM t1 LEFT JOIN t2 ON t1.a=t2.d WHERE t1.a>1")) { - assertTrue(resultSet.next()); - assertEquals(3, resultSet.getInt(4)); - assertEquals(4, resultSet.getInt(5)); - assertEquals(5, resultSet.getInt(6)); - assertEquals(1, resultSet.getInt(1)); - assertEquals(2, resultSet.getInt(2)); - assertEquals(3, resultSet.getInt(3)); - assertTrue(resultSet.next()); - assertEquals(null, resultSet.getObject(1)); - assertEquals(null, resultSet.getObject(2)); - assertEquals(null, resultSet.getObject(3)); - assertFalse(resultSet.next()); + String r = + "1|2|3|3|4|5\n" + + "null|null|null|2|3|4\n"; + ExpectedResult expectedResult = new ExpectedResult(r); + Utils.assertResultsSetEqual(resultSet, expectedResult); } catch (Exception e) { e.printStackTrace(); fail(); } } + /** + * There is an bug in the executor, skip this test for now. + * + * @throws SQLException + */ @Test + @Ignore public void testLeftOuterJoin3() throws SQLException { try ( Statement stmt = conn.createStatement(); ResultSet resultSet = stmt.executeQuery("SELECT * FROM t1 LEFT OUTER JOIN t2 ON t1.a=t2.d WHERE t1.a>1")) { - assertTrue(resultSet.next()); - assertEquals(3, resultSet.getInt(4)); - assertEquals(4, resultSet.getInt(5)); - assertEquals(5, resultSet.getInt(6)); - assertEquals(1, resultSet.getInt(1)); - assertEquals(2, resultSet.getInt(2)); - assertEquals(3, resultSet.getInt(3)); - assertTrue(resultSet.next()); - assertEquals(null, resultSet.getObject(1)); - assertEquals(null, resultSet.getObject(2)); - assertEquals(null, resultSet.getObject(3)); - assertFalse(resultSet.next()); + String r = + "1|2|3|3|4|5\n" + + "null|null|null|2|3|4\n"; + ExpectedResult expectedResult = new ExpectedResult(r); + Utils.assertResultsSetEqual(resultSet, expectedResult); } catch (Exception e) { e.printStackTrace(); fail(); } } + /** + * There is an bug in the executor, skip this test for now. + * + * @throws SQLException + */ @Test + @Ignore public void testLeftOuterJoinWhere() { try ( Statement stmt = conn.createStatement(); @@ -187,9 +182,11 @@ public void testLeftOuterJoinWhere() { // t1 t2 // 1 2 3 {} {} {} // 2 3 4 {} {} {} - assertTrue(resultSet.next()); - assertTrue(resultSet.next()); - assertFalse(resultSet.next()); + String r = + "1|2|3|null|null|null\n" + + "2|3|4|null|null|null"; + ExpectedResult expectedResult = new ExpectedResult(r); + Utils.assertResultsSetEqual(resultSet, expectedResult); } catch (Exception e) { e.printStackTrace(); fail(); diff --git a/script/testing/junit/Utils.java b/script/testing/junit/Utils.java new file mode 100644 index 00000000000..8e054a69ed3 --- /dev/null +++ b/script/testing/junit/Utils.java @@ -0,0 +1,62 @@ +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.junit.Assert.*; + + +/** + * Created by Guoquan Zhao on 4/30/18. + */ +public class Utils { + public static void assertResultsSetEqual(ResultSet results, ExpectedResult expectedResult) throws SQLException { + int rows = expectedResult.getRows(); + int columns = expectedResult.getColumns(); + + for (int i = 0; i < rows; i++) { + assertTrue(results.next()); + for (int j = 0; j < columns; j++) { + String returnedString = results.getString(j + 1); + if (returnedString == null) { + assertEquals(expectedResult.getItemAtIndex(i, j), "null"); + } else { + assertTrue(results.getString(j + 1).equals(expectedResult.getItemAtIndex(i, j))); + } + } + } + assertFalse(results.next()); + } +} + + +class ExpectedResult { + public ExpectedResult(String expectedResult) { + String[] rows = expectedResult.split("\n"); + List results = Stream.of(rows).map(s -> { + String[] columns = s.split("\\|"); + String[] collect = Stream.of(columns).map(c -> c.trim()).collect(Collectors.toList()).toArray(new String[0]); + return collect; + }).collect(Collectors.toList()); + int num_columns = results.get(0).length; + assertTrue(results.stream().allMatch(strings -> strings.length == num_columns)); + this.rows = results; + } + + public int getRows() { + return this.rows.size(); + } + + public int getColumns() { + return this.rows.get(0).length; + } + + public String getItemAtIndex(int row, int column) { + return this.rows.get(row)[column]; + } + + private List rows; + + +} \ No newline at end of file From 5412a11335dd485bb9219b32a0b5f560dda24b01 Mon Sep 17 00:00:00 2001 From: Guoquan Zhao Date: Tue, 1 May 2018 03:29:00 -0400 Subject: [PATCH 18/26] Resolve conflicts with upstream/master --- src/optimizer/plan_generator.cpp | 94 -------------------------------- 1 file changed, 94 deletions(-) diff --git a/src/optimizer/plan_generator.cpp b/src/optimizer/plan_generator.cpp index 2e69da41cbc..75c678185d5 100644 --- a/src/optimizer/plan_generator.cpp +++ b/src/optimizer/plan_generator.cpp @@ -256,100 +256,6 @@ void PlanGenerator::Visit(const PhysicalHashJoin *op) { output_plan_ = move(join_plan); } -void PlanGenerator::Visit(const PhysicalInnerNLJoin *op) { - std::unique_ptr proj_info; - std::shared_ptr proj_schema; - GenerateProjectionForJoin(proj_info, proj_schema); - - auto join_predicate = - expression::ExpressionUtil::JoinAnnotatedExprs(op->join_predicates); - expression::ExpressionUtil::EvaluateExpression(children_expr_map_, - join_predicate.get()); - expression::ExpressionUtil::ConvertToTvExpr(join_predicate.get(), - children_expr_map_); - - vector left_keys; - vector right_keys; - for (auto &expr : op->left_keys) { - PELOTON_ASSERT(children_expr_map_[0].find(expr.get()) != - children_expr_map_[0].end()); - left_keys.push_back(children_expr_map_[0][expr.get()]); - } - for (auto &expr : op->right_keys) { - PELOTON_ASSERT(children_expr_map_[1].find(expr.get()) != - children_expr_map_[1].end()); - right_keys.emplace_back(children_expr_map_[1][expr.get()]); - } - - auto join_plan = - unique_ptr(new planner::NestedLoopJoinPlan( - JoinType::INNER, move(join_predicate), move(proj_info), proj_schema, - left_keys, right_keys)); - - join_plan->AddChild(move(children_plans_[0])); - join_plan->AddChild(move(children_plans_[1])); - output_plan_ = move(join_plan); -} - -void PlanGenerator::Visit(const PhysicalLeftNLJoin *) {} - -void PlanGenerator::Visit(const PhysicalRightNLJoin *) {} - -void PlanGenerator::Visit(const PhysicalOuterNLJoin *) {} - -void PlanGenerator::Visit(const PhysicalInnerHashJoin *op) { - std::unique_ptr proj_info; - std::shared_ptr proj_schema; - GenerateProjectionForJoin(proj_info, proj_schema); - - auto join_predicate = - expression::ExpressionUtil::JoinAnnotatedExprs(op->join_predicates); - expression::ExpressionUtil::EvaluateExpression(children_expr_map_, - join_predicate.get()); - expression::ExpressionUtil::ConvertToTvExpr(join_predicate.get(), - children_expr_map_); - - vector> left_keys; - vector> right_keys; - vector l_child_map{move(children_expr_map_[0])}; - vector r_child_map{move(children_expr_map_[1])}; - for (auto &expr : op->left_keys) { - auto left_key = expr->Copy(); - expression::ExpressionUtil::EvaluateExpression(l_child_map, left_key); - left_keys.emplace_back(left_key); - } - for (auto &expr : op->right_keys) { - auto right_key = expr->Copy(); - expression::ExpressionUtil::EvaluateExpression(r_child_map, right_key); - right_keys.emplace_back(right_key); - } - // Evaluate Expr for hash plan - vector> hash_keys; - for (auto &expr : op->right_keys) { - auto hash_key = expr->Copy(); - expression::ExpressionUtil::EvaluateExpression(r_child_map, hash_key); - hash_keys.emplace_back(hash_key); - } - - unique_ptr hash_plan(new planner::HashPlan(hash_keys)); - hash_plan->AddChild(move(children_plans_[1])); - - auto join_plan = unique_ptr(new planner::HashJoinPlan( - JoinType::INNER, move(join_predicate), move(proj_info), proj_schema, - left_keys, right_keys, settings::SettingsManager::GetBool( - settings::SettingId::hash_join_bloom_filter))); - - join_plan->AddChild(move(children_plans_[0])); - join_plan->AddChild(move(hash_plan)); - output_plan_ = move(join_plan); -} - -void PlanGenerator::Visit(const PhysicalLeftHashJoin *) {} - -void PlanGenerator::Visit(const PhysicalRightHashJoin *) {} - -void PlanGenerator::Visit(const PhysicalOuterHashJoin *) {} - void PlanGenerator::Visit(const PhysicalInsert *op) { unique_ptr insert_plan(new planner::InsertPlan( storage::StorageManager::GetInstance()->GetTableWithOid( From 29174e27c062a9c8c3c189fc7d799593776ad20f Mon Sep 17 00:00:00 2001 From: Guoquan Zhao Date: Tue, 1 May 2018 16:20:19 -0400 Subject: [PATCH 19/26] try to fix the null pointer exception --- script/testing/junit/Utils.java | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/script/testing/junit/Utils.java b/script/testing/junit/Utils.java index 8e054a69ed3..f590af1e334 100644 --- a/script/testing/junit/Utils.java +++ b/script/testing/junit/Utils.java @@ -22,7 +22,7 @@ public static void assertResultsSetEqual(ResultSet results, ExpectedResult expec if (returnedString == null) { assertEquals(expectedResult.getItemAtIndex(i, j), "null"); } else { - assertTrue(results.getString(j + 1).equals(expectedResult.getItemAtIndex(i, j))); + assertTrue(returnedString.equals(expectedResult.getItemAtIndex(i, j))); } } } @@ -53,7 +53,17 @@ public int getColumns() { } public String getItemAtIndex(int row, int column) { - return this.rows.get(row)[column]; + String ret = this.rows.get(row)[column]; + if (ret == null) { + for (int i = 0; i < this.rows.size(); i++) { + for (int j = 0; j < this.rows.get(i).length; j++) { + System.out.print(this.rows.get(i)[j] + "|"); + } + System.out.println(); + } + throw new RuntimeException("Should not return NULL"); + } + return ret; } private List rows; From 4b1c61b2baa5fb235daac28b5366d636991e50ae Mon Sep 17 00:00:00 2001 From: Guoquan Zhao Date: Tue, 1 May 2018 17:40:08 -0400 Subject: [PATCH 20/26] strange behavior in junit test. add more logs. --- script/testing/junit/OptimizerTest.java | 6 +++--- script/testing/junit/Utils.java | 14 +++++++++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/script/testing/junit/OptimizerTest.java b/script/testing/junit/OptimizerTest.java index 4991bd0ce2e..396c08f474c 100644 --- a/script/testing/junit/OptimizerTest.java +++ b/script/testing/junit/OptimizerTest.java @@ -114,7 +114,7 @@ public void testLeftOuterJoin1() throws SQLException { String r = "1|2|3|3|4|5\n" + "null|null|null|1|2|3\n" + - "null|null|null|2|3|4\n"; + "null|null|null|2|3|4"; ExpectedResult expectedResult = new ExpectedResult(r); Utils.assertResultsSetEqual(resultSet, expectedResult); } catch (Exception e) { @@ -136,7 +136,7 @@ public void testLeftOuterJoin2() throws SQLException { ResultSet resultSet = stmt.executeQuery("SELECT * FROM t1 LEFT JOIN t2 ON t1.a=t2.d WHERE t1.a>1")) { String r = "1|2|3|3|4|5\n" + - "null|null|null|2|3|4\n"; + "null|null|null|2|3|4"; ExpectedResult expectedResult = new ExpectedResult(r); Utils.assertResultsSetEqual(resultSet, expectedResult); } catch (Exception e) { @@ -158,7 +158,7 @@ public void testLeftOuterJoin3() throws SQLException { ResultSet resultSet = stmt.executeQuery("SELECT * FROM t1 LEFT OUTER JOIN t2 ON t1.a=t2.d WHERE t1.a>1")) { String r = "1|2|3|3|4|5\n" + - "null|null|null|2|3|4\n"; + "null|null|null|2|3|4"; ExpectedResult expectedResult = new ExpectedResult(r); Utils.assertResultsSetEqual(resultSet, expectedResult); } catch (Exception e) { diff --git a/script/testing/junit/Utils.java b/script/testing/junit/Utils.java index f590af1e334..6b7dda3d69b 100644 --- a/script/testing/junit/Utils.java +++ b/script/testing/junit/Utils.java @@ -1,4 +1,5 @@ import java.sql.ResultSet; +import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.util.List; import java.util.stream.Collectors; @@ -14,13 +15,24 @@ public class Utils { public static void assertResultsSetEqual(ResultSet results, ExpectedResult expectedResult) throws SQLException { int rows = expectedResult.getRows(); int columns = expectedResult.getColumns(); + System.out.println("expectedResult.rows = " + rows); + System.out.println("expectedResult.columns = " + columns); + ResultSetMetaData rsmd = results.getMetaData(); + int columnsNumber = rsmd.getColumnCount(); + assertEquals(columns, columnsNumber); for (int i = 0; i < rows; i++) { assertTrue(results.next()); for (int j = 0; j < columns; j++) { + String returnedString = results.getString(j + 1); + System.out.println("i = " + i + "\t j = " + j); + System.out.println("returnedString= " + returnedString); + String expected = expectedResult.getItemAtIndex(i, j); + System.out.println("expected = " + expected); + if (returnedString == null) { - assertEquals(expectedResult.getItemAtIndex(i, j), "null"); + assertEquals(expected, "null"); } else { assertTrue(returnedString.equals(expectedResult.getItemAtIndex(i, j))); } From fb36ac00d5f583005c01c17f20c554a72bfb43e2 Mon Sep 17 00:00:00 2001 From: Guoquan Zhao Date: Tue, 1 May 2018 20:30:20 -0400 Subject: [PATCH 21/26] Bug was that executor return different results on different machines. --- script/testing/junit/OptimizerTest.java | 29 ++++++++++++++++++++----- script/testing/junit/Utils.java | 7 +----- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/script/testing/junit/OptimizerTest.java b/script/testing/junit/OptimizerTest.java index 396c08f474c..ffafc80c8bf 100644 --- a/script/testing/junit/OptimizerTest.java +++ b/script/testing/junit/OptimizerTest.java @@ -111,12 +111,29 @@ public void testLeftOuterJoin1() throws SQLException { try ( Statement stmt = conn.createStatement(); ResultSet resultSet = stmt.executeQuery("SELECT * FROM t1 LEFT JOIN t2 ON t1.a=t2.d;");) { - String r = - "1|2|3|3|4|5\n" + - "null|null|null|1|2|3\n" + - "null|null|null|2|3|4"; - ExpectedResult expectedResult = new ExpectedResult(r); - Utils.assertResultsSetEqual(resultSet, expectedResult); +// String r = +// "1|2|3|3|4|5\n" + +// "null|null|null|1|2|3\n" + +// "null|null|null|2|3|4"; +// ExpectedResult expectedResult = new ExpectedResult(r); +// Utils.assertResultsSetEqual(resultSet, expectedResult); + assertTrue(resultSet.next()); + assertEquals(3, resultSet.getInt(4)); + assertEquals(4, resultSet.getInt(5)); + assertEquals(5, resultSet.getInt(6)); + assertEquals(1, resultSet.getInt(1)); + assertEquals(2, resultSet.getInt(2)); + assertEquals(3, resultSet.getInt(3)); + assertTrue(resultSet.next()); + assertEquals(null, resultSet.getObject(1)); + assertEquals(null, resultSet.getObject(2)); + assertEquals(null, resultSet.getObject(3)); + assertTrue(resultSet.next()); + assertEquals(null, resultSet.getObject(1)); + assertEquals(null, resultSet.getObject(2)); + assertEquals(null, resultSet.getObject(3)); + assertFalse(resultSet.next()); + } catch (Exception e) { e.printStackTrace(); fail(); diff --git a/script/testing/junit/Utils.java b/script/testing/junit/Utils.java index 6b7dda3d69b..3ce22358d2a 100644 --- a/script/testing/junit/Utils.java +++ b/script/testing/junit/Utils.java @@ -15,8 +15,6 @@ public class Utils { public static void assertResultsSetEqual(ResultSet results, ExpectedResult expectedResult) throws SQLException { int rows = expectedResult.getRows(); int columns = expectedResult.getColumns(); - System.out.println("expectedResult.rows = " + rows); - System.out.println("expectedResult.columns = " + columns); ResultSetMetaData rsmd = results.getMetaData(); int columnsNumber = rsmd.getColumnCount(); assertEquals(columns, columnsNumber); @@ -26,15 +24,12 @@ public static void assertResultsSetEqual(ResultSet results, ExpectedResult expec for (int j = 0; j < columns; j++) { String returnedString = results.getString(j + 1); - System.out.println("i = " + i + "\t j = " + j); - System.out.println("returnedString= " + returnedString); String expected = expectedResult.getItemAtIndex(i, j); - System.out.println("expected = " + expected); if (returnedString == null) { assertEquals(expected, "null"); } else { - assertTrue(returnedString.equals(expectedResult.getItemAtIndex(i, j))); + assertEquals(returnedString,expected); } } } From 8c4dbfafc32c212439113b511d73660b3d3f85c7 Mon Sep 17 00:00:00 2001 From: Guoquan Zhao Date: Wed, 2 May 2018 01:01:01 -0400 Subject: [PATCH 22/26] use jdk1.8 instead of 1.7 --- Jenkinsfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 0d2352b6a35..b53caf3034e 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -58,7 +58,7 @@ pipeline { sh 'cd build && make check -j4' sh 'cd build && make install' sh 'cd build && bash ../script/testing/psql/psql_test.sh' - sh 'sudo apt-get -qq update && sudo apt-get -qq -y --no-install-recommends install wget default-jdk default-jre' // prerequisites for jdbc_validator + sh 'sudo apt-get -qq update && sudo apt-get -qq -y --no-install-recommends install wget openjdk-8-jdk openjdk-8-jre' // prerequisites for jdbc_validator sh 'cd build && python ../script/validators/jdbc_validator.py' sh 'cd build && python ../script/testing/junit/run_junit.py' } @@ -79,7 +79,7 @@ pipeline { sh 'cd build && make check -j4' sh 'cd build && make install' sh 'cd build && bash ../script/testing/psql/psql_test.sh' - sh 'sudo apt-get -qq update && sudo apt-get -qq -y --no-install-recommends install wget default-jdk default-jre' // prerequisites for jdbc_validator + sh 'sudo apt-get -qq update && sudo apt-get -qq -y --no-install-recommends install wget openjdk-8-jdk openjdk-8-jre' // prerequisites for jdbc_validator sh 'cd build && python ../script/validators/jdbc_validator.py' sh 'cd build && python ../script/testing/junit/run_junit.py' } From 64ab79ecac7d07bd11e6636dcc60905b6ad2d34e Mon Sep 17 00:00:00 2001 From: Guoquan Zhao Date: Wed, 2 May 2018 02:48:43 -0400 Subject: [PATCH 23/26] add openjdk ppa --- Jenkinsfile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Jenkinsfile b/Jenkinsfile index b53caf3034e..0fd59368702 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -58,6 +58,7 @@ pipeline { sh 'cd build && make check -j4' sh 'cd build && make install' sh 'cd build && bash ../script/testing/psql/psql_test.sh' + sh 'sudo add-apt-repository ppa:openjdk-r/ppa' sh 'sudo apt-get -qq update && sudo apt-get -qq -y --no-install-recommends install wget openjdk-8-jdk openjdk-8-jre' // prerequisites for jdbc_validator sh 'cd build && python ../script/validators/jdbc_validator.py' sh 'cd build && python ../script/testing/junit/run_junit.py' @@ -79,6 +80,7 @@ pipeline { sh 'cd build && make check -j4' sh 'cd build && make install' sh 'cd build && bash ../script/testing/psql/psql_test.sh' + sh 'sudo add-apt-repository ppa:openjdk-r/ppa' sh 'sudo apt-get -qq update && sudo apt-get -qq -y --no-install-recommends install wget openjdk-8-jdk openjdk-8-jre' // prerequisites for jdbc_validator sh 'cd build && python ../script/validators/jdbc_validator.py' sh 'cd build && python ../script/testing/junit/run_junit.py' From 385a1dfaecca306bf10aefd685f6806fcd2e970b Mon Sep 17 00:00:00 2001 From: Irene Qiuwen Kai Date: Fri, 4 May 2018 20:29:06 -0400 Subject: [PATCH 24/26] Added strong predicates function in util. --- src/include/optimizer/rule_impls.h | 3 -- src/include/optimizer/util.h | 22 ++++++++++++++ src/optimizer/rule_impls.cpp | 41 +++++++++++++++++--------- src/optimizer/util.cpp | 47 ++++++++++++++++++++++++++++++ 4 files changed, 97 insertions(+), 16 deletions(-) diff --git a/src/include/optimizer/rule_impls.h b/src/include/optimizer/rule_impls.h index 2220247418e..b0faa1aabdb 100644 --- a/src/include/optimizer/rule_impls.h +++ b/src/include/optimizer/rule_impls.h @@ -51,9 +51,6 @@ class JoinAssociativity : public Rule { void Transform(std::shared_ptr input, std::vector> &transformed, OptimizeContext *context) const override; - private: - bool StrongPredicate(std::shared_ptr plan, - OptimizeContext *context) const; }; //===--------------------------------------------------------------------===// diff --git a/src/include/optimizer/util.h b/src/include/optimizer/util.h index 8b9eb4baeef..0a984b01240 100644 --- a/src/include/optimizer/util.h +++ b/src/include/optimizer/util.h @@ -19,6 +19,8 @@ #include "expression/abstract_expression.h" #include "parser/copy_statement.h" #include "planner/abstract_plan.h" +#include "optimizer/optimize_context.h" +#include "expression/tuple_value_expression.h" namespace peloton { @@ -167,6 +169,26 @@ void ExtractEquiJoinKeys( const std::unordered_set &left_alias, const std::unordered_set &right_alias); +/** + * @brief Given an operator expression and context information, check if it + * is strong predicate w.r.t to one table + * A predicate p is strong w.r.t S if the fact that all attributes from S are + * NULL implies that p evaluates to false + * It is used in AssociativityRule transforms when certain joins are applied + */ +bool StrongPredicates( + std::vector predicates, + const std::unordered_set &middle_group_aliases_set); + +/** + * @brief Replace the tuple_value_expression in given expression which + * contains table in middle_group_aliases_set with constant_value_expression + * with FALSE value + */ +void ReplaceWithNull( + std::shared_ptr expr, + const std::unordered_set &middle_group_aliases_set); + } // namespace util } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index a0b3e479e67..e849cda1097 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -111,6 +111,13 @@ bool JoinAssociativity::Check(std::shared_ptr expr, auto parent_join = expr->Op().As(); std::vector> children = expr->Children(); auto child_join = children[0]->Op().As(); + auto middle = children[0]->Children()[1]; + // Get Alias sets + auto &memo = context->metadata->memo; + auto middle_group_id = middle->Op().As()->origin_group; + const auto &middle_group_aliases_set = + memo.GetGroupByID(middle_group_id)->GetTableAliases(); + if (parent_join->type == JoinType::INNER) { if (child_join->type == JoinType::INNER || child_join->type == JoinType::RIGHT) { return true; @@ -119,15 +126,32 @@ bool JoinAssociativity::Check(std::shared_ptr expr, if (child_join->type == JoinType::INNER) { return true; } else if (child_join->type == JoinType::LEFT || child_join->type == JoinType::OUTER) { - return StrongPredicate(expr, context); + return util::StrongPredicates(parent_join->join_predicates, + middle_group_aliases_set); } } else if (parent_join->type == JoinType::RIGHT) { if (child_join->type == JoinType::RIGHT) { - return StrongPredicate(expr, context); + return util::StrongPredicates(child_join->join_predicates, + middle_group_aliases_set); } } else if (parent_join->type == JoinType::OUTER) { - if (child_join->type == JoinType::RIGHT || child_join->type == JoinType::OUTER) { - return StrongPredicate(expr, context); + if (child_join->type == JoinType::RIGHT) { + return util::StrongPredicates(child_join->join_predicates, + middle_group_aliases_set); + } else if (child_join->type == JoinType::OUTER) { + auto parent_join_predicates = + std::vector(parent_join->join_predicates); + auto child_join_predicates = + std::vector(child_join->join_predicates); + + std::vector check_predicates; + check_predicates.insert(check_predicates.end(), + parent_join_predicates.begin(), + parent_join_predicates.end()); + check_predicates.insert(check_predicates.end(), + child_join_predicates.begin(), + child_join_predicates.end()); + return util::StrongPredicates(check_predicates, middle_group_aliases_set); } } return false; @@ -214,15 +238,6 @@ void JoinAssociativity::Transform( transformed.push_back(new_parent_join); } -// TODO: some associativity rules can only be applied when the predicate is strong -// To check if the predicate is strong or not is non-trivial -bool JoinAssociativity::StrongPredicate(std::shared_ptr expr, - OptimizeContext *context) const { - (void)context; - (void)expr; - return false; -} - //===--------------------------------------------------------------------===// // Implementation rules //===--------------------------------------------------------------------===// diff --git a/src/optimizer/util.cpp b/src/optimizer/util.cpp index 0d01e35e8ac..6782f294eee 100644 --- a/src/optimizer/util.cpp +++ b/src/optimizer/util.cpp @@ -18,6 +18,7 @@ #include "planner/copy_plan.h" #include "planner/seq_scan_plan.h" #include "storage/data_table.h" +#include "type/value_factory.h" namespace peloton { namespace optimizer { @@ -250,6 +251,52 @@ void ExtractEquiJoinKeys( } } +bool StrongPredicates( + std::vector predicates, + const std::unordered_set &middle_group_aliases_set) { + for (auto predicate : predicates) { + // create a copy of original predicate. + auto copy_expr = std::shared_ptr(predicate.expr->Copy()); + LOG_DEBUG("AnnotatedExp: %s", copy_expr->GetInfo().c_str()); + // replace tuple_value_expression from predicate which contains table in + // middle_group_aliases_set with FALSE constant value + ReplaceWithNull(copy_expr, middle_group_aliases_set); + + // TODO: some expressions cannot be evaluated with no executor context given + bool ret = copy_expr->Evaluate(nullptr, nullptr, nullptr).IsFalse(); + if(ret){ + return true; + } + } + return false; +} + +void ReplaceWithNull( + std::shared_ptr expr, + const std::unordered_set &middle_group_aliases_set) { + if (expr->GetChildrenSize() == 0) { + return; + } + for (size_t i = 0; i < expr->GetChildrenSize(); i++) { + auto child_expr = expr->GetModifiableChild(i); + // Check if its an TupleValueExpression + if (child_expr->GetExpressionType() == ExpressionType::VALUE_TUPLE) { +// auto val_type = child->GetValueType(); + LOG_DEBUG("Tuple value expression found in child"); + auto child_tv_expr = dynamic_cast(child_expr); + if (middle_group_aliases_set.find(child_tv_expr->GetTableName()) != middle_group_aliases_set.end()) { + LOG_DEBUG("ads"); + expr->SetChild(i, expression::ExpressionUtil::ConstantValueFactory( + type::ValueFactory::GetBooleanValue(false))); + } else { + expr->SetChild(i, expression::ExpressionUtil::ConstantValueFactory( + type::ValueFactory::GetBooleanValue(true))); + } + ReplaceWithNull(std::shared_ptr(child_expr), middle_group_aliases_set); + } + } +} + } // namespace util } // namespace optimizer } // namespace peloton From b0fe16b6c2de27d36dd145bb8c9af32ed78dfa10 Mon Sep 17 00:00:00 2001 From: Irene Qiuwen Kai Date: Sat, 5 May 2018 00:58:21 -0400 Subject: [PATCH 25/26] New approach to check StrongPredicates. Tests need to be added --- src/include/optimizer/util.h | 7 +- src/optimizer/util.cpp | 128 ++++++++++++++++++++++++++++------- 2 files changed, 107 insertions(+), 28 deletions(-) diff --git a/src/include/optimizer/util.h b/src/include/optimizer/util.h index 0a984b01240..97f5e4de79e 100644 --- a/src/include/optimizer/util.h +++ b/src/include/optimizer/util.h @@ -183,12 +183,11 @@ bool StrongPredicates( /** * @brief Replace the tuple_value_expression in given expression which * contains table in middle_group_aliases_set with constant_value_expression - * with FALSE value + * with NULL value */ -void ReplaceWithNull( - std::shared_ptr expr, +expression::AbstractExpression* PredicateEvaluate( + expression::AbstractExpression* expr, const std::unordered_set &middle_group_aliases_set); - } // namespace util } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/util.cpp b/src/optimizer/util.cpp index 6782f294eee..77492deb115 100644 --- a/src/optimizer/util.cpp +++ b/src/optimizer/util.cpp @@ -260,40 +260,120 @@ bool StrongPredicates( LOG_DEBUG("AnnotatedExp: %s", copy_expr->GetInfo().c_str()); // replace tuple_value_expression from predicate which contains table in // middle_group_aliases_set with FALSE constant value - ReplaceWithNull(copy_expr, middle_group_aliases_set); - // TODO: some expressions cannot be evaluated with no executor context given - bool ret = copy_expr->Evaluate(nullptr, nullptr, nullptr).IsFalse(); - if(ret){ - return true; + auto eval_expr = PredicateEvaluate(copy_expr.get(), middle_group_aliases_set); + if(eval_expr != nullptr) { + auto eval_val = eval_expr->Evaluate(nullptr, nullptr, nullptr); + if (eval_val.IsFalse()) { + return true; + } } } return false; } -void ReplaceWithNull( - std::shared_ptr expr, +expression::AbstractExpression* PredicateEvaluate( + expression::AbstractExpression* expr, const std::unordered_set &middle_group_aliases_set) { - if (expr->GetChildrenSize() == 0) { - return; + // if at the lowest level + if(expr->GetChildrenSize()==0){ + if(expr->GetExpressionType() == ExpressionType::VALUE_TUPLE){ + auto tv_expr = dynamic_cast(expr); + if (middle_group_aliases_set.find(tv_expr->GetTableName()) != middle_group_aliases_set.end()) { + return expression::ExpressionUtil::ConstantValueFactory( + type::ValueFactory::GetNullValueByType(type::TypeId::BOOLEAN)); + } + } + if(expr->GetExpressionType() == ExpressionType::VALUE_CONSTANT){ + auto cv_expr = dynamic_cast(expr); + return expression::ExpressionUtil::ConstantValueFactory(cv_expr->GetValue()); + } + // The tuple_value_expression uses table which does not belong to middle_group + // or other expression type we cannot evaluate + return nullptr; } - for (size_t i = 0; i < expr->GetChildrenSize(); i++) { - auto child_expr = expr->GetModifiableChild(i); - // Check if its an TupleValueExpression - if (child_expr->GetExpressionType() == ExpressionType::VALUE_TUPLE) { -// auto val_type = child->GetValueType(); - LOG_DEBUG("Tuple value expression found in child"); - auto child_tv_expr = dynamic_cast(child_expr); - if (middle_group_aliases_set.find(child_tv_expr->GetTableName()) != middle_group_aliases_set.end()) { - LOG_DEBUG("ads"); - expr->SetChild(i, expression::ExpressionUtil::ConstantValueFactory( - type::ValueFactory::GetBooleanValue(false))); - } else { - expr->SetChild(i, expression::ExpressionUtil::ConstantValueFactory( - type::ValueFactory::GetBooleanValue(true))); + + // Conjunction check + if(expr->GetExpressionType() == ExpressionType::CONJUNCTION_AND){ + PELOTON_ASSERT(expr->GetChildrenSize() == 2); + auto l_child = expr->GetModifiableChild(0); + auto r_child = expr->GetModifiableChild(1); + auto l_eval_expr = PredicateEvaluate(l_child, middle_group_aliases_set); + auto r_eval_expr = PredicateEvaluate(r_child, middle_group_aliases_set); + + type::Value l_val; + type::Value r_val; + + if (l_eval_expr!= nullptr){ + PELOTON_ASSERT(l_eval_expr->GetExpressionType()==ExpressionType::VALUE_CONSTANT); + l_val = l_eval_expr->Evaluate(nullptr, nullptr, nullptr); + if(l_val.IsFalse()){ + return expression::ExpressionUtil::ConstantValueFactory( + type::ValueFactory::GetBooleanValue(false)); + } + } + if (r_eval_expr!= nullptr){ + PELOTON_ASSERT(r_eval_expr->GetExpressionType()==ExpressionType::VALUE_CONSTANT); + r_val = r_eval_expr->Evaluate(nullptr, nullptr, nullptr); + if(r_val.IsFalse()){ + return expression::ExpressionUtil::ConstantValueFactory( + type::ValueFactory::GetBooleanValue(false)); + } + } + if (l_eval_expr!= nullptr && r_eval_expr!= nullptr && l_val.IsTrue() && r_val.IsTrue()){ + return expression::ExpressionUtil::ConstantValueFactory( + type::ValueFactory::GetBooleanValue(true)); + } + return nullptr; + } + else if(expr->GetExpressionType()==ExpressionType::CONJUNCTION_OR){ + PELOTON_ASSERT(expr->GetChildrenSize() == 2); + auto l_child = expr->GetModifiableChild(0); + auto r_child = expr->GetModifiableChild(1); + auto l_eval_expr = PredicateEvaluate(l_child, middle_group_aliases_set); + auto r_eval_expr = PredicateEvaluate(r_child, middle_group_aliases_set); + + type::Value l_val; + type::Value r_val; + + if (l_eval_expr!= nullptr){ + PELOTON_ASSERT(l_eval_expr->GetExpressionType()==ExpressionType::VALUE_CONSTANT); + l_val = l_eval_expr->Evaluate(nullptr, nullptr, nullptr); + if(l_val.IsTrue()){ + return expression::ExpressionUtil::ConstantValueFactory( + type::ValueFactory::GetBooleanValue(true)); + } + } + if (r_eval_expr!= nullptr){ + PELOTON_ASSERT(r_eval_expr->GetExpressionType()==ExpressionType::VALUE_CONSTANT); + r_val = r_eval_expr->Evaluate(nullptr, nullptr, nullptr); + if(r_val.IsTrue()){ + return expression::ExpressionUtil::ConstantValueFactory( + type::ValueFactory::GetBooleanValue(true)); + } + } + if (l_eval_expr!= nullptr && r_eval_expr!= nullptr && l_val.IsFalse() && r_val.IsFalse()){ + return expression::ExpressionUtil::ConstantValueFactory( + type::ValueFactory::GetBooleanValue(false)); + } + return nullptr; + } + else{ + for (size_t i = 0; i < expr->GetChildrenSize(); i++) { + auto child_expr = expr->GetModifiableChild(i); + auto child_eval_expr = PredicateEvaluate(child_expr, middle_group_aliases_set); + if(child_eval_expr== nullptr){ + // cannot evaluate + return nullptr; + } + else{ + PELOTON_ASSERT(child_eval_expr->GetExpressionType()==ExpressionType::VALUE_CONSTANT); + expr->SetChild(i, child_eval_expr); } - ReplaceWithNull(std::shared_ptr(child_expr), middle_group_aliases_set); } + // all child are constant value expression + auto expr_val = expr->Evaluate(nullptr, nullptr, nullptr); + return expression::ExpressionUtil::ConstantValueFactory(expr_val); } } From 8e72b8ffa2a40a3f60ea9d750c96ad4b06dfd2fa Mon Sep 17 00:00:00 2001 From: Irene Qiuwen Kai Date: Mon, 14 May 2018 09:31:36 -0400 Subject: [PATCH 26/26] style & format fix --- src/include/optimizer/util.h | 25 +++++----- src/optimizer/rule_impls.cpp | 28 +++++++----- src/optimizer/util.cpp | 88 ++++++++++++++++++++---------------- 3 files changed, 76 insertions(+), 65 deletions(-) diff --git a/src/include/optimizer/util.h b/src/include/optimizer/util.h index 97f5e4de79e..06f0420be11 100644 --- a/src/include/optimizer/util.h +++ b/src/include/optimizer/util.h @@ -35,11 +35,11 @@ class DataTable; namespace optimizer { namespace util { - /** - * @brief Convert upper case letters into lower case in a string - * - * @param str The string to operate on - */ +/** + * @brief Convert upper case letters into lower case in a string + * + * @param str The string to operate on + */ inline void to_lower_string(std::string &str) { std::transform(str.begin(), str.end(), str.begin(), ::tolower); } @@ -112,7 +112,6 @@ expression::AbstractExpression *ConstructJoinPredicate( std::unordered_set &table_alias_set, MultiTablePredicates &join_predicates); - /** * @breif Check if there are any join columns in the join expression * For example, expr = (expr_1) AND (expr_2) AND (expr_3) @@ -154,8 +153,8 @@ ConstructSelectElementMap( */ expression::AbstractExpression *TransformQueryDerivedTablePredicates( const std::unordered_map> - &alias_to_expr_map, + std::shared_ptr> & + alias_to_expr_map, expression::AbstractExpression *expr); /** @@ -171,7 +170,7 @@ void ExtractEquiJoinKeys( /** * @brief Given an operator expression and context information, check if it - * is strong predicate w.r.t to one table + * is strong predicate w.r.t to one table * A predicate p is strong w.r.t S if the fact that all attributes from S are * NULL implies that p evaluates to false * It is used in AssociativityRule transforms when certain joins are applied @@ -180,13 +179,13 @@ bool StrongPredicates( std::vector predicates, const std::unordered_set &middle_group_aliases_set); -/** +/** * @brief Replace the tuple_value_expression in given expression which * contains table in middle_group_aliases_set with constant_value_expression - * with NULL value + * with False value */ -expression::AbstractExpression* PredicateEvaluate( - expression::AbstractExpression* expr, +expression::AbstractExpression *PredicateEvaluate( + expression::AbstractExpression *expr, const std::unordered_set &middle_group_aliases_set); } // namespace util } // namespace optimizer diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index e849cda1097..3e06eeb2501 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -119,33 +119,35 @@ bool JoinAssociativity::Check(std::shared_ptr expr, memo.GetGroupByID(middle_group_id)->GetTableAliases(); if (parent_join->type == JoinType::INNER) { - if (child_join->type == JoinType::INNER || child_join->type == JoinType::RIGHT) { + if (child_join->type == JoinType::INNER || + child_join->type == JoinType::RIGHT) { return true; } } else if (parent_join->type == JoinType::LEFT) { if (child_join->type == JoinType::INNER) { return true; - } else if (child_join->type == JoinType::LEFT || child_join->type == JoinType::OUTER) { + } else if (child_join->type == JoinType::LEFT || + child_join->type == JoinType::OUTER) { return util::StrongPredicates(parent_join->join_predicates, - middle_group_aliases_set); + middle_group_aliases_set); } } else if (parent_join->type == JoinType::RIGHT) { if (child_join->type == JoinType::RIGHT) { return util::StrongPredicates(child_join->join_predicates, - middle_group_aliases_set); + middle_group_aliases_set); } } else if (parent_join->type == JoinType::OUTER) { if (child_join->type == JoinType::RIGHT) { return util::StrongPredicates(child_join->join_predicates, - middle_group_aliases_set); + middle_group_aliases_set); } else if (child_join->type == JoinType::OUTER) { - auto parent_join_predicates = + auto parent_join_predicates = std::vector(parent_join->join_predicates); auto child_join_predicates = std::vector(child_join->join_predicates); std::vector check_predicates; - check_predicates.insert(check_predicates.end(), + check_predicates.insert(check_predicates.end(), parent_join_predicates.begin(), parent_join_predicates.end()); check_predicates.insert(check_predicates.end(), @@ -909,9 +911,9 @@ void PushFilterThroughJoin::Transform( std::vector join_predicates; auto join_type = join_op_expr->Op().As()->type; - bool outer_push = (join_type == JoinType::OUTER || - join_type == JoinType::LEFT || - join_type == JoinType::RIGHT); + bool outer_push = + (join_type == JoinType::OUTER || join_type == JoinType::LEFT || + join_type == JoinType::RIGHT); // Loop over all predicates, check each of them if they can be pushed down to // either the left child or the right child to be evaluated @@ -920,10 +922,12 @@ void PushFilterThroughJoin::Transform( // E.g. An expression (test.a = test1.b and test.a = 5) would become // {test.a = test1.b, test.a = 5} for (auto &predicate : predicates) { - if (util::IsSubset(left_group_aliases_set, predicate.table_alias_set) && !outer_push) { + if (util::IsSubset(left_group_aliases_set, predicate.table_alias_set) && + !outer_push) { left_predicates.emplace_back(predicate); } else if (util::IsSubset(right_group_aliases_set, - predicate.table_alias_set) && !outer_push) { + predicate.table_alias_set) && + !outer_push) { right_predicates.emplace_back(predicate); } else { join_predicates.emplace_back(predicate); diff --git a/src/optimizer/util.cpp b/src/optimizer/util.cpp index 77492deb115..1396f17cbb0 100644 --- a/src/optimizer/util.cpp +++ b/src/optimizer/util.cpp @@ -180,8 +180,7 @@ std::unordered_map> ConstructSelectElementMap( std::vector> &select_list) { std::unordered_map> - res; + std::shared_ptr> res; for (auto &expr : select_list) { std::string alias; if (!expr->alias.empty()) { @@ -200,8 +199,8 @@ ConstructSelectElementMap( expression::AbstractExpression *TransformQueryDerivedTablePredicates( const std::unordered_map> - &alias_to_expr_map, + std::shared_ptr> & + alias_to_expr_map, expression::AbstractExpression *expr) { if (expr->GetExpressionType() == ExpressionType::VALUE_TUPLE) { auto new_expr = @@ -256,13 +255,15 @@ bool StrongPredicates( const std::unordered_set &middle_group_aliases_set) { for (auto predicate : predicates) { // create a copy of original predicate. - auto copy_expr = std::shared_ptr(predicate.expr->Copy()); + auto copy_expr = + std::shared_ptr(predicate.expr->Copy()); LOG_DEBUG("AnnotatedExp: %s", copy_expr->GetInfo().c_str()); // replace tuple_value_expression from predicate which contains table in // middle_group_aliases_set with FALSE constant value - auto eval_expr = PredicateEvaluate(copy_expr.get(), middle_group_aliases_set); - if(eval_expr != nullptr) { + auto eval_expr = + PredicateEvaluate(copy_expr.get(), middle_group_aliases_set); + if (eval_expr != nullptr) { auto eval_val = eval_expr->Evaluate(nullptr, nullptr, nullptr); if (eval_val.IsFalse()) { return true; @@ -272,29 +273,31 @@ bool StrongPredicates( return false; } -expression::AbstractExpression* PredicateEvaluate( - expression::AbstractExpression* expr, +expression::AbstractExpression *PredicateEvaluate( + expression::AbstractExpression *expr, const std::unordered_set &middle_group_aliases_set) { // if at the lowest level - if(expr->GetChildrenSize()==0){ - if(expr->GetExpressionType() == ExpressionType::VALUE_TUPLE){ + if (expr->GetChildrenSize() == 0) { + if (expr->GetExpressionType() == ExpressionType::VALUE_TUPLE) { auto tv_expr = dynamic_cast(expr); - if (middle_group_aliases_set.find(tv_expr->GetTableName()) != middle_group_aliases_set.end()) { + if (middle_group_aliases_set.find(tv_expr->GetTableName()) != + middle_group_aliases_set.end()) { return expression::ExpressionUtil::ConstantValueFactory( type::ValueFactory::GetNullValueByType(type::TypeId::BOOLEAN)); } } - if(expr->GetExpressionType() == ExpressionType::VALUE_CONSTANT){ + if (expr->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { auto cv_expr = dynamic_cast(expr); - return expression::ExpressionUtil::ConstantValueFactory(cv_expr->GetValue()); + return expression::ExpressionUtil::ConstantValueFactory( + cv_expr->GetValue()); } - // The tuple_value_expression uses table which does not belong to middle_group - // or other expression type we cannot evaluate + // Returns nullptr if the tuple_value_expression uses table which does not + // belong to middle_group or other expression type we cannot evaluate return nullptr; } // Conjunction check - if(expr->GetExpressionType() == ExpressionType::CONJUNCTION_AND){ + if (expr->GetExpressionType() == ExpressionType::CONJUNCTION_AND) { PELOTON_ASSERT(expr->GetChildrenSize() == 2); auto l_child = expr->GetModifiableChild(0); auto r_child = expr->GetModifiableChild(1); @@ -304,29 +307,31 @@ expression::AbstractExpression* PredicateEvaluate( type::Value l_val; type::Value r_val; - if (l_eval_expr!= nullptr){ - PELOTON_ASSERT(l_eval_expr->GetExpressionType()==ExpressionType::VALUE_CONSTANT); + if (l_eval_expr != nullptr) { + PELOTON_ASSERT(l_eval_expr->GetExpressionType() == + ExpressionType::VALUE_CONSTANT); l_val = l_eval_expr->Evaluate(nullptr, nullptr, nullptr); - if(l_val.IsFalse()){ + if (l_val.IsFalse()) { return expression::ExpressionUtil::ConstantValueFactory( type::ValueFactory::GetBooleanValue(false)); } } - if (r_eval_expr!= nullptr){ - PELOTON_ASSERT(r_eval_expr->GetExpressionType()==ExpressionType::VALUE_CONSTANT); + if (r_eval_expr != nullptr) { + PELOTON_ASSERT(r_eval_expr->GetExpressionType() == + ExpressionType::VALUE_CONSTANT); r_val = r_eval_expr->Evaluate(nullptr, nullptr, nullptr); - if(r_val.IsFalse()){ + if (r_val.IsFalse()) { return expression::ExpressionUtil::ConstantValueFactory( type::ValueFactory::GetBooleanValue(false)); } } - if (l_eval_expr!= nullptr && r_eval_expr!= nullptr && l_val.IsTrue() && r_val.IsTrue()){ + if (l_eval_expr != nullptr && r_eval_expr != nullptr && l_val.IsTrue() && + r_val.IsTrue()) { return expression::ExpressionUtil::ConstantValueFactory( type::ValueFactory::GetBooleanValue(true)); } return nullptr; - } - else if(expr->GetExpressionType()==ExpressionType::CONJUNCTION_OR){ + } else if (expr->GetExpressionType() == ExpressionType::CONJUNCTION_OR) { PELOTON_ASSERT(expr->GetChildrenSize() == 2); auto l_child = expr->GetModifiableChild(0); auto r_child = expr->GetModifiableChild(1); @@ -336,38 +341,41 @@ expression::AbstractExpression* PredicateEvaluate( type::Value l_val; type::Value r_val; - if (l_eval_expr!= nullptr){ - PELOTON_ASSERT(l_eval_expr->GetExpressionType()==ExpressionType::VALUE_CONSTANT); + if (l_eval_expr != nullptr) { + PELOTON_ASSERT(l_eval_expr->GetExpressionType() == + ExpressionType::VALUE_CONSTANT); l_val = l_eval_expr->Evaluate(nullptr, nullptr, nullptr); - if(l_val.IsTrue()){ + if (l_val.IsTrue()) { return expression::ExpressionUtil::ConstantValueFactory( type::ValueFactory::GetBooleanValue(true)); } } - if (r_eval_expr!= nullptr){ - PELOTON_ASSERT(r_eval_expr->GetExpressionType()==ExpressionType::VALUE_CONSTANT); + if (r_eval_expr != nullptr) { + PELOTON_ASSERT(r_eval_expr->GetExpressionType() == + ExpressionType::VALUE_CONSTANT); r_val = r_eval_expr->Evaluate(nullptr, nullptr, nullptr); - if(r_val.IsTrue()){ + if (r_val.IsTrue()) { return expression::ExpressionUtil::ConstantValueFactory( type::ValueFactory::GetBooleanValue(true)); } } - if (l_eval_expr!= nullptr && r_eval_expr!= nullptr && l_val.IsFalse() && r_val.IsFalse()){ + if (l_eval_expr != nullptr && r_eval_expr != nullptr && l_val.IsFalse() && + r_val.IsFalse()) { return expression::ExpressionUtil::ConstantValueFactory( type::ValueFactory::GetBooleanValue(false)); } return nullptr; - } - else{ + } else { for (size_t i = 0; i < expr->GetChildrenSize(); i++) { auto child_expr = expr->GetModifiableChild(i); - auto child_eval_expr = PredicateEvaluate(child_expr, middle_group_aliases_set); - if(child_eval_expr== nullptr){ + auto child_eval_expr = + PredicateEvaluate(child_expr, middle_group_aliases_set); + if (child_eval_expr == nullptr) { // cannot evaluate return nullptr; - } - else{ - PELOTON_ASSERT(child_eval_expr->GetExpressionType()==ExpressionType::VALUE_CONSTANT); + } else { + PELOTON_ASSERT(child_eval_expr->GetExpressionType() == + ExpressionType::VALUE_CONSTANT); expr->SetChild(i, child_eval_expr); } }