diff --git a/Jenkinsfile b/Jenkinsfile index 0d2352b6a35..0fd59368702 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -58,7 +58,8 @@ pipeline { sh 'cd build && make check -j4' sh 'cd build && make install' sh 'cd build && bash ../script/testing/psql/psql_test.sh' - sh 'sudo apt-get -qq update && sudo apt-get -qq -y --no-install-recommends install wget default-jdk default-jre' // prerequisites for jdbc_validator + sh 'sudo add-apt-repository ppa:openjdk-r/ppa' + sh 'sudo apt-get -qq update && sudo apt-get -qq -y --no-install-recommends install wget openjdk-8-jdk openjdk-8-jre' // prerequisites for jdbc_validator sh 'cd build && python ../script/validators/jdbc_validator.py' sh 'cd build && python ../script/testing/junit/run_junit.py' } @@ -79,7 +80,8 @@ pipeline { sh 'cd build && make check -j4' sh 'cd build && make install' sh 'cd build && bash ../script/testing/psql/psql_test.sh' - sh 'sudo apt-get -qq update && sudo apt-get -qq -y --no-install-recommends install wget default-jdk default-jre' // prerequisites for jdbc_validator + sh 'sudo add-apt-repository ppa:openjdk-r/ppa' + sh 'sudo apt-get -qq update && sudo apt-get -qq -y --no-install-recommends install wget openjdk-8-jdk openjdk-8-jre' // prerequisites for jdbc_validator sh 'cd build && python ../script/validators/jdbc_validator.py' sh 'cd build && python ../script/testing/junit/run_junit.py' } diff --git a/script/testing/junit/OptimizerTest.java b/script/testing/junit/OptimizerTest.java new file mode 100644 index 00000000000..ffafc80c8bf --- /dev/null +++ b/script/testing/junit/OptimizerTest.java @@ -0,0 +1,214 @@ +import org.junit.After; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; + +import java.io.BufferedReader; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.io.IOException; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; + +import static java.sql.Statement.EXECUTE_FAILED; +import static java.sql.Statement.SUCCESS_NO_INFO; +import static org.junit.Assert.*; + + +/** + * Created by Guoquan Zhao on 4/7/18. + */ +public class OptimizerTest extends PLTestBase { + private static final String[] SQL_DROP_TABLES = + {"DROP TABLE IF EXISTS t1;", + "DROP TABLE IF EXISTS t2;"}; + private Connection conn; + + private void initTables1() throws FileNotFoundException, SQLException { + try (BufferedReader reader = new BufferedReader(new FileReader("create_tables_1.sql")); + Statement stmt = conn.createStatement();) { + reader.lines().forEach(s -> { + try { + stmt.addBatch(s); + } catch (SQLException e) { + e.printStackTrace(); + } + }); + int[] results = stmt.executeBatch(); + for (int i = 0; i < results.length; i++) { + assertTrue("batch failed.", (results[i] >= 0 || results[i] == SUCCESS_NO_INFO) && results[i] != EXECUTE_FAILED); + } + ResultSet resultSet = stmt.executeQuery("SELECT COUNT(*) FROM t1;"); + ExpectedResult expectedResult = new ExpectedResult("3"); + Utils.assertResultsSetEqual(resultSet, expectedResult); + resultSet.close(); + resultSet = stmt.executeQuery("SELECT COUNT(*) FROM t2;"); + Utils.assertResultsSetEqual(resultSet, expectedResult); + resultSet.close(); + } catch (IOException e) { + e.printStackTrace(); + } + + + } + + + @Before + public void Setup() { + try { + conn = makeDefaultConnection(); + conn.setAutoCommit(true); + initTables1(); + } catch (SQLException ex) { + DumpSQLException(ex); + assertTrue(false); + } catch (FileNotFoundException e) { + e.printStackTrace(); + assertTrue(false); + } + } + + @After + public void Teardown() throws SQLException { + Statement stmt = conn.createStatement(); + for (String s : SQL_DROP_TABLES) { + stmt.execute(s); + } + } + + + @Test + public void testInnerJoin1() throws SQLException { + try ( + Statement stmt = conn.createStatement(); + ResultSet resultSet = stmt.executeQuery("SELECT t1.a FROM t1 INNER JOIN t2 ON (t1.b = t2.b) ORDER BY t1.a;");) { + ExpectedResult expectedResult = new ExpectedResult("1\n" + + "2"); + Utils.assertResultsSetEqual(resultSet, expectedResult); + } catch (Exception e) { + e.printStackTrace(); + fail(); + } + } + + @Test + public void testInnerJoin2() throws SQLException { + try ( + Statement stmt = conn.createStatement(); + ResultSet resultSet = stmt.executeQuery("SELECT x.a FROM t1 AS x INNER JOIN t2 ON(x.b = t2.b AND x.c = t2.c) ORDER BY x.a;");) { + ExpectedResult expectedResult = new ExpectedResult("1\n" + "2"); + Utils.assertResultsSetEqual(resultSet, expectedResult); + } catch (Exception e) { + e.printStackTrace(); + fail(); + } + } + + @Test + public void testLeftOuterJoin1() throws SQLException { + try ( + Statement stmt = conn.createStatement(); + ResultSet resultSet = stmt.executeQuery("SELECT * FROM t1 LEFT JOIN t2 ON t1.a=t2.d;");) { +// String r = +// "1|2|3|3|4|5\n" + +// "null|null|null|1|2|3\n" + +// "null|null|null|2|3|4"; +// ExpectedResult expectedResult = new ExpectedResult(r); +// Utils.assertResultsSetEqual(resultSet, expectedResult); + assertTrue(resultSet.next()); + assertEquals(3, resultSet.getInt(4)); + assertEquals(4, resultSet.getInt(5)); + assertEquals(5, resultSet.getInt(6)); + assertEquals(1, resultSet.getInt(1)); + assertEquals(2, resultSet.getInt(2)); + assertEquals(3, resultSet.getInt(3)); + assertTrue(resultSet.next()); + assertEquals(null, resultSet.getObject(1)); + assertEquals(null, resultSet.getObject(2)); + assertEquals(null, resultSet.getObject(3)); + assertTrue(resultSet.next()); + assertEquals(null, resultSet.getObject(1)); + assertEquals(null, resultSet.getObject(2)); + assertEquals(null, resultSet.getObject(3)); + assertFalse(resultSet.next()); + + } catch (Exception e) { + e.printStackTrace(); + fail(); + } + } + + /** + * There is an bug in the executor, skip this test for now. + * + * @throws SQLException + */ + @Test + @Ignore + public void testLeftOuterJoin2() throws SQLException { + try ( + Statement stmt = conn.createStatement(); + ResultSet resultSet = stmt.executeQuery("SELECT * FROM t1 LEFT JOIN t2 ON t1.a=t2.d WHERE t1.a>1")) { + String r = + "1|2|3|3|4|5\n" + + "null|null|null|2|3|4"; + ExpectedResult expectedResult = new ExpectedResult(r); + Utils.assertResultsSetEqual(resultSet, expectedResult); + } catch (Exception e) { + e.printStackTrace(); + fail(); + } + } + + /** + * There is an bug in the executor, skip this test for now. + * + * @throws SQLException + */ + @Test + @Ignore + public void testLeftOuterJoin3() throws SQLException { + try ( + Statement stmt = conn.createStatement(); + ResultSet resultSet = stmt.executeQuery("SELECT * FROM t1 LEFT OUTER JOIN t2 ON t1.a=t2.d WHERE t1.a>1")) { + String r = + "1|2|3|3|4|5\n" + + "null|null|null|2|3|4"; + ExpectedResult expectedResult = new ExpectedResult(r); + Utils.assertResultsSetEqual(resultSet, expectedResult); + } catch (Exception e) { + e.printStackTrace(); + fail(); + } + } + + /** + * There is an bug in the executor, skip this test for now. + * + * @throws SQLException + */ + @Test + @Ignore + public void testLeftOuterJoinWhere() { + try ( + Statement stmt = conn.createStatement(); + ResultSet resultSet = stmt.executeQuery("SELECT * FROM t1 LEFT JOIN t2 ON t1.a=t2.d WHERE t2.b IS NULL OR t2.b>1")) { + // expected result is + // t1 t2 + // 1 2 3 {} {} {} + // 2 3 4 {} {} {} + String r = + "1|2|3|null|null|null\n" + + "2|3|4|null|null|null"; + ExpectedResult expectedResult = new ExpectedResult(r); + Utils.assertResultsSetEqual(resultSet, expectedResult); + } catch (Exception e) { + e.printStackTrace(); + fail(); + } + } + + +} diff --git a/script/testing/junit/Utils.java b/script/testing/junit/Utils.java new file mode 100644 index 00000000000..3ce22358d2a --- /dev/null +++ b/script/testing/junit/Utils.java @@ -0,0 +1,79 @@ +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.junit.Assert.*; + + +/** + * Created by Guoquan Zhao on 4/30/18. + */ +public class Utils { + public static void assertResultsSetEqual(ResultSet results, ExpectedResult expectedResult) throws SQLException { + int rows = expectedResult.getRows(); + int columns = expectedResult.getColumns(); + ResultSetMetaData rsmd = results.getMetaData(); + int columnsNumber = rsmd.getColumnCount(); + assertEquals(columns, columnsNumber); + + for (int i = 0; i < rows; i++) { + assertTrue(results.next()); + for (int j = 0; j < columns; j++) { + + String returnedString = results.getString(j + 1); + String expected = expectedResult.getItemAtIndex(i, j); + + if (returnedString == null) { + assertEquals(expected, "null"); + } else { + assertEquals(returnedString,expected); + } + } + } + assertFalse(results.next()); + } +} + + +class ExpectedResult { + public ExpectedResult(String expectedResult) { + String[] rows = expectedResult.split("\n"); + List results = Stream.of(rows).map(s -> { + String[] columns = s.split("\\|"); + String[] collect = Stream.of(columns).map(c -> c.trim()).collect(Collectors.toList()).toArray(new String[0]); + return collect; + }).collect(Collectors.toList()); + int num_columns = results.get(0).length; + assertTrue(results.stream().allMatch(strings -> strings.length == num_columns)); + this.rows = results; + } + + public int getRows() { + return this.rows.size(); + } + + public int getColumns() { + return this.rows.get(0).length; + } + + public String getItemAtIndex(int row, int column) { + String ret = this.rows.get(row)[column]; + if (ret == null) { + for (int i = 0; i < this.rows.size(); i++) { + for (int j = 0; j < this.rows.get(i).length; j++) { + System.out.print(this.rows.get(i)[j] + "|"); + } + System.out.println(); + } + throw new RuntimeException("Should not return NULL"); + } + return ret; + } + + private List rows; + + +} \ No newline at end of file diff --git a/script/testing/junit/create_tables_1.sql b/script/testing/junit/create_tables_1.sql new file mode 100644 index 00000000000..b5b876889f5 --- /dev/null +++ b/script/testing/junit/create_tables_1.sql @@ -0,0 +1,8 @@ +CREATE TABLE t1(a INT,b INT,c INT); +INSERT INTO t1 VALUES(1,2,3); +INSERT INTO t1 VALUES(2,3,4); +INSERT INTO t1 VALUES(3,4,5); +CREATE TABLE t2(b INT,c INT,d INT); +INSERT INTO t2 VALUES(1,2,3); +INSERT INTO t2 VALUES(2,3,4); +INSERT INTO t2 VALUES(3,4,5); \ No newline at end of file diff --git a/src/codegen/query_compiler.cpp b/src/codegen/query_compiler.cpp index 104e4f5783a..66d4ba90ff7 100644 --- a/src/codegen/query_compiler.cpp +++ b/src/codegen/query_compiler.cpp @@ -74,9 +74,10 @@ bool QueryCompiler::IsSupported(const planner::AbstractPlan &plan) { case PlanNodeType::HASHJOIN: { const auto &join = static_cast(plan); // Right now, only support inner joins - if (join.GetJoinType() == JoinType::INNER) { - break; + if (join.GetJoinType() != JoinType::INNER) { + return false; } + break; } case PlanNodeType::HASH: { break; diff --git a/src/executor/nested_loop_join_executor.cpp b/src/executor/nested_loop_join_executor.cpp index 6f1bca1eb36..4bec12cdc10 100644 --- a/src/executor/nested_loop_join_executor.cpp +++ b/src/executor/nested_loop_join_executor.cpp @@ -101,7 +101,7 @@ bool NestedLoopJoinExecutor::DExecute() { // If we have already retrieved all left child's results in buffer if (left_child_done_ == true) { LOG_TRACE("Left is done which means all join comparison completes"); - return false; + return BuildOuterJoinOutput(); } // If left tile result is not done, continue the left tuples diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 995a92cea2d..00f4d5be0d2 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1325,8 +1325,8 @@ std::ostream &operator<<(std::ostream &os, const PropertyType &type); enum class RuleType : uint32_t { // Transformation rules (logical -> logical) - INNER_JOIN_COMMUTE = 0, - INNER_JOIN_ASSOCIATE, + JOIN_COMMUTE = 0, + JOIN_ASSOCIATE, // Don't move this one LogicalPhysicalDelimiter, @@ -1342,6 +1342,8 @@ enum class RuleType : uint32_t { INSERT_SELECT_TO_PHYSICAL, AGGREGATE_TO_HASH_AGGREGATE, AGGREGATE_TO_PLAIN_AGGREGATE, + JOIN_TO_NL_JOIN, + JOIN_TO_HASH_JOIN, INNER_JOIN_TO_NL_JOIN, INNER_JOIN_TO_HASH_JOIN, IMPLEMENT_DISTINCT, diff --git a/src/include/optimizer/child_property_deriver.h b/src/include/optimizer/child_property_deriver.h index bd4aeb7b933..c01c37389c7 100644 --- a/src/include/optimizer/child_property_deriver.h +++ b/src/include/optimizer/child_property_deriver.h @@ -42,14 +42,8 @@ class ChildPropertyDeriver : public OperatorVisitor { void Visit(const QueryDerivedScan *op) override; void Visit(const PhysicalOrderBy *) override; void Visit(const PhysicalLimit *) override; - void Visit(const PhysicalInnerNLJoin *) override; - void Visit(const PhysicalLeftNLJoin *) override; - void Visit(const PhysicalRightNLJoin *) override; - void Visit(const PhysicalOuterNLJoin *) override; - void Visit(const PhysicalInnerHashJoin *) override; - void Visit(const PhysicalLeftHashJoin *) override; - void Visit(const PhysicalRightHashJoin *) override; - void Visit(const PhysicalOuterHashJoin *) override; + void Visit(const PhysicalNLJoin *) override; + void Visit(const PhysicalHashJoin *) override; void Visit(const PhysicalInsert *) override; void Visit(const PhysicalInsertSelect *) override; void Visit(const PhysicalDelete *) override; diff --git a/src/include/optimizer/child_stats_deriver.h b/src/include/optimizer/child_stats_deriver.h index d0c72f9bf9b..cca76ba7071 100644 --- a/src/include/optimizer/child_stats_deriver.h +++ b/src/include/optimizer/child_stats_deriver.h @@ -27,21 +27,17 @@ class Memo; // expression class ChildStatsDeriver : public OperatorVisitor { public: - std::vector DeriveInputStats( - GroupExpression *gexpr, - ExprSet required_cols, Memo *memo); + std::vector DeriveInputStats(GroupExpression *gexpr, + ExprSet required_cols, Memo *memo); void Visit(const LogicalQueryDerivedGet *) override; - void Visit(const LogicalInnerJoin *) override; - void Visit(const LogicalLeftJoin *) override; - void Visit(const LogicalRightJoin *) override; - void Visit(const LogicalOuterJoin *) override; + void Visit(const LogicalJoin *) override; void Visit(const LogicalSemiJoin *) override; void Visit(const LogicalAggregateAndGroupBy *) override; private: void PassDownRequiredCols(); - void PassDownColumn(expression::AbstractExpression* col); + void PassDownColumn(expression::AbstractExpression *col); ExprSet required_cols_; GroupExpression *gexpr_; Memo *memo_; diff --git a/src/include/optimizer/cost_calculator.h b/src/include/optimizer/cost_calculator.h index 442f386fc5f..87410d0bb07 100644 --- a/src/include/optimizer/cost_calculator.h +++ b/src/include/optimizer/cost_calculator.h @@ -30,14 +30,8 @@ class CostCalculator : public OperatorVisitor { void Visit(const QueryDerivedScan *) override; void Visit(const PhysicalOrderBy *) override; void Visit(const PhysicalLimit *) override; - void Visit(const PhysicalInnerNLJoin *) override; - void Visit(const PhysicalLeftNLJoin *) override; - void Visit(const PhysicalRightNLJoin *) override; - void Visit(const PhysicalOuterNLJoin *) override; - void Visit(const PhysicalInnerHashJoin *) override; - void Visit(const PhysicalLeftHashJoin *) override; - void Visit(const PhysicalRightHashJoin *) override; - void Visit(const PhysicalOuterHashJoin *) override; + void Visit(const PhysicalNLJoin *) override; + void Visit(const PhysicalHashJoin *) override; void Visit(const PhysicalInsert *) override; void Visit(const PhysicalInsertSelect *) override; void Visit(const PhysicalDelete *) override; diff --git a/src/include/optimizer/input_column_deriver.h b/src/include/optimizer/input_column_deriver.h index fa1ec6ca5a1..792ebe52bfb 100644 --- a/src/include/optimizer/input_column_deriver.h +++ b/src/include/optimizer/input_column_deriver.h @@ -59,21 +59,9 @@ class InputColumnDeriver : public OperatorVisitor { void Visit(const PhysicalLimit *) override; - void Visit(const PhysicalInnerNLJoin *) override; + void Visit(const PhysicalNLJoin *) override; - void Visit(const PhysicalLeftNLJoin *) override; - - void Visit(const PhysicalRightNLJoin *) override; - - void Visit(const PhysicalOuterNLJoin *) override; - - void Visit(const PhysicalInnerHashJoin *) override; - - void Visit(const PhysicalLeftHashJoin *) override; - - void Visit(const PhysicalRightHashJoin *) override; - - void Visit(const PhysicalOuterHashJoin *) override; + void Visit(const PhysicalHashJoin *) override; void Visit(const PhysicalInsert *) override; diff --git a/src/include/optimizer/operator_node.h b/src/include/optimizer/operator_node.h index cb20c163bbe..3692d629f44 100644 --- a/src/include/optimizer/operator_node.h +++ b/src/include/optimizer/operator_node.h @@ -33,10 +33,7 @@ enum class OpType { LogicalMarkJoin, LogicalDependentJoin, LogicalSingleJoin, - InnerJoin, - LeftJoin, - RightJoin, - OuterJoin, + LogicalJoin, SemiJoin, LogicalAggregateAndGroupBy, LogicalInsert, @@ -55,6 +52,8 @@ enum class OpType { OrderBy, PhysicalLimit, Distinct, + NLJoin, + HashJoin, InnerNLJoin, LeftNLJoin, RightNLJoin, diff --git a/src/include/optimizer/operator_visitor.h b/src/include/optimizer/operator_visitor.h index 75b0a9f9c67..4638c1f74ce 100644 --- a/src/include/optimizer/operator_visitor.h +++ b/src/include/optimizer/operator_visitor.h @@ -32,14 +32,8 @@ class OperatorVisitor { virtual void Visit(const QueryDerivedScan *) {} virtual void Visit(const PhysicalOrderBy *) {} virtual void Visit(const PhysicalLimit *) {} - virtual void Visit(const PhysicalInnerNLJoin *) {} - virtual void Visit(const PhysicalLeftNLJoin *) {} - virtual void Visit(const PhysicalRightNLJoin *) {} - virtual void Visit(const PhysicalOuterNLJoin *) {} - virtual void Visit(const PhysicalInnerHashJoin *) {} - virtual void Visit(const PhysicalLeftHashJoin *) {} - virtual void Visit(const PhysicalRightHashJoin *) {} - virtual void Visit(const PhysicalOuterHashJoin *) {} + virtual void Visit(const PhysicalNLJoin *) {} + virtual void Visit(const PhysicalHashJoin *) {} virtual void Visit(const PhysicalInsert *) {} virtual void Visit(const PhysicalInsertSelect *) {} virtual void Visit(const PhysicalDelete *) {} @@ -58,10 +52,7 @@ class OperatorVisitor { virtual void Visit(const LogicalMarkJoin *) {} virtual void Visit(const LogicalSingleJoin *) {} virtual void Visit(const LogicalDependentJoin *) {} - virtual void Visit(const LogicalInnerJoin *) {} - virtual void Visit(const LogicalLeftJoin *) {} - virtual void Visit(const LogicalRightJoin *) {} - virtual void Visit(const LogicalOuterJoin *) {} + virtual void Visit(const LogicalJoin *) {} virtual void Visit(const LogicalSemiJoin *) {} virtual void Visit(const LogicalAggregateAndGroupBy *) {} virtual void Visit(const LogicalInsert *) {} diff --git a/src/include/optimizer/operators.h b/src/include/optimizer/operators.h index a745439251a..2b2bcfcc43f 100644 --- a/src/include/optimizer/operators.h +++ b/src/include/optimizer/operators.h @@ -31,7 +31,7 @@ class UpdateClause; } namespace catalog { - class TableCatalogObject; +class TableCatalogObject; } namespace optimizer { @@ -51,10 +51,10 @@ class LeafOperator : OperatorNode { //===--------------------------------------------------------------------===// class LogicalGet : public OperatorNode { public: - static Operator make(oid_t get_id = 0, - std::vector predicates = {}, - std::shared_ptr table = nullptr, - std::string alias = "", bool update = false); + static Operator make( + oid_t get_id = 0, std::vector predicates = {}, + std::shared_ptr table = nullptr, + std::string alias = "", bool update = false); bool operator==(const BaseOperatorNode &r) override; @@ -163,49 +163,21 @@ class LogicalSingleJoin : public OperatorNode { }; //===--------------------------------------------------------------------===// -// InnerJoin +// Join (Inner + Outer Joins) //===--------------------------------------------------------------------===// -class LogicalInnerJoin : public OperatorNode { +class LogicalJoin : public OperatorNode { public: - static Operator make(); + static Operator make(JoinType type); - static Operator make(std::vector &conditions); + static Operator make(JoinType type, + std::vector &conditions); bool operator==(const BaseOperatorNode &r) override; hash_t Hash() const override; std::vector join_predicates; -}; - -//===--------------------------------------------------------------------===// -// LeftJoin -//===--------------------------------------------------------------------===// -class LogicalLeftJoin : public OperatorNode { - public: - static Operator make(expression::AbstractExpression *condition = nullptr); - - std::shared_ptr join_predicate; -}; - -//===--------------------------------------------------------------------===// -// RightJoin -//===--------------------------------------------------------------------===// -class LogicalRightJoin : public OperatorNode { - public: - static Operator make(expression::AbstractExpression *condition = nullptr); - - std::shared_ptr join_predicate; -}; - -//===--------------------------------------------------------------------===// -// OuterJoin -//===--------------------------------------------------------------------===// -class LogicalOuterJoin : public OperatorNode { - public: - static Operator make(expression::AbstractExpression *condition = nullptr); - - std::shared_ptr join_predicate; + JoinType type; }; //===--------------------------------------------------------------------===// @@ -246,7 +218,8 @@ class LogicalAggregateAndGroupBy class LogicalInsert : public OperatorNode { public: static Operator make( - std::shared_ptr target_table, const std::vector *columns, + std::shared_ptr target_table, + const std::vector *columns, const std::vector>> *values); @@ -258,7 +231,8 @@ class LogicalInsert : public OperatorNode { class LogicalInsertSelect : public OperatorNode { public: - static Operator make(std::shared_ptr target_table); + static Operator make( + std::shared_ptr target_table); std::shared_ptr target_table; }; @@ -286,7 +260,8 @@ class LogicalLimit : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalDelete : public OperatorNode { public: - static Operator make(std::shared_ptr target_table); + static Operator make( + std::shared_ptr target_table); std::shared_ptr target_table; }; @@ -317,7 +292,8 @@ class DummyScan : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalSeqScan : public OperatorNode { public: - static Operator make(oid_t get_id, std::shared_ptr table, + static Operator make(oid_t get_id, + std::shared_ptr table, std::string alias, std::vector predicates, bool update); @@ -339,7 +315,8 @@ class PhysicalSeqScan : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalIndexScan : public OperatorNode { public: - static Operator make(oid_t get_id, std::shared_ptr table, + static Operator make(oid_t get_id, + std::shared_ptr table, std::string alias, std::vector predicates, bool update, oid_t index_id, std::vector key_column_id_list, @@ -408,12 +385,12 @@ class PhysicalLimit : public OperatorNode { }; //===--------------------------------------------------------------------===// -// InnerNLJoin +// NLJoin (Inner + Outer Joins) //===--------------------------------------------------------------------===// -class PhysicalInnerNLJoin : public OperatorNode { +class PhysicalNLJoin : public OperatorNode { public: static Operator make( - std::vector conditions, + JoinType type, std::vector conditions, std::vector> &left_keys, std::vector> &right_keys); @@ -425,45 +402,16 @@ class PhysicalInnerNLJoin : public OperatorNode { std::vector> right_keys; std::vector join_predicates; + JoinType type; }; //===--------------------------------------------------------------------===// -// LeftNLJoin +// HashJoin (Inner + Outer Joins) //===--------------------------------------------------------------------===// -class PhysicalLeftNLJoin : public OperatorNode { +class PhysicalHashJoin : public OperatorNode { public: - std::shared_ptr join_predicate; - static Operator make( - std::shared_ptr join_predicate); -}; - -//===--------------------------------------------------------------------===// -// RightNLJoin -//===--------------------------------------------------------------------===// -class PhysicalRightNLJoin : public OperatorNode { - public: - std::shared_ptr join_predicate; static Operator make( - std::shared_ptr join_predicate); -}; - -//===--------------------------------------------------------------------===// -// OuterNLJoin -//===--------------------------------------------------------------------===// -class PhysicalOuterNLJoin : public OperatorNode { - public: - std::shared_ptr join_predicate; - static Operator make( - std::shared_ptr join_predicate); -}; - -//===--------------------------------------------------------------------===// -// InnerHashJoin -//===--------------------------------------------------------------------===// -class PhysicalInnerHashJoin : public OperatorNode { - public: - static Operator make( - std::vector conditions, + JoinType type, std::vector conditions, std::vector> &left_keys, std::vector> &right_keys); @@ -475,36 +423,7 @@ class PhysicalInnerHashJoin : public OperatorNode { std::vector> right_keys; std::vector join_predicates; -}; - -//===--------------------------------------------------------------------===// -// LeftHashJoin -//===--------------------------------------------------------------------===// -class PhysicalLeftHashJoin : public OperatorNode { - public: - std::shared_ptr join_predicate; - static Operator make( - std::shared_ptr join_predicate); -}; - -//===--------------------------------------------------------------------===// -// RightHashJoin -//===--------------------------------------------------------------------===// -class PhysicalRightHashJoin : public OperatorNode { - public: - std::shared_ptr join_predicate; - static Operator make( - std::shared_ptr join_predicate); -}; - -//===--------------------------------------------------------------------===// -// OuterHashJoin -//===--------------------------------------------------------------------===// -class PhysicalOuterHashJoin : public OperatorNode { - public: - std::shared_ptr join_predicate; - static Operator make( - std::shared_ptr join_predicate); + JoinType type; }; //===--------------------------------------------------------------------===// @@ -513,7 +432,8 @@ class PhysicalOuterHashJoin : public OperatorNode { class PhysicalInsert : public OperatorNode { public: static Operator make( - std::shared_ptr target_table, const std::vector *columns, + std::shared_ptr target_table, + const std::vector *columns, const std::vector>> *values); @@ -525,7 +445,8 @@ class PhysicalInsert : public OperatorNode { class PhysicalInsertSelect : public OperatorNode { public: - static Operator make(std::shared_ptr target_table); + static Operator make( + std::shared_ptr target_table); std::shared_ptr target_table; }; @@ -535,7 +456,8 @@ class PhysicalInsertSelect : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalDelete : public OperatorNode { public: - static Operator make(std::shared_ptr target_table); + static Operator make( + std::shared_ptr target_table); std::shared_ptr target_table; }; diff --git a/src/include/optimizer/plan_generator.h b/src/include/optimizer/plan_generator.h index c0a21259bc6..11184a05df6 100644 --- a/src/include/optimizer/plan_generator.h +++ b/src/include/optimizer/plan_generator.h @@ -60,21 +60,9 @@ class PlanGenerator : public OperatorVisitor { void Visit(const PhysicalLimit *) override; - void Visit(const PhysicalInnerNLJoin *) override; + void Visit(const PhysicalNLJoin *) override; - void Visit(const PhysicalLeftNLJoin *) override; - - void Visit(const PhysicalRightNLJoin *) override; - - void Visit(const PhysicalOuterNLJoin *) override; - - void Visit(const PhysicalInnerHashJoin *) override; - - void Visit(const PhysicalLeftHashJoin *) override; - - void Visit(const PhysicalRightHashJoin *) override; - - void Visit(const PhysicalOuterHashJoin *) override; + void Visit(const PhysicalHashJoin *) override; void Visit(const PhysicalInsert *) override; diff --git a/src/include/optimizer/rule_impls.h b/src/include/optimizer/rule_impls.h index 2c40e3f3c81..b0faa1aabdb 100644 --- a/src/include/optimizer/rule_impls.h +++ b/src/include/optimizer/rule_impls.h @@ -26,9 +26,9 @@ namespace optimizer { /** * @brief (A join B) -> (B join A) */ -class InnerJoinCommutativity : public Rule { +class JoinCommutativity : public Rule { public: - InnerJoinCommutativity(); + JoinCommutativity(); bool Check(std::shared_ptr plan, OptimizeContext *context) const override; @@ -41,10 +41,9 @@ class InnerJoinCommutativity : public Rule { /** * @brief (A join B) join C -> A join (B join C) */ - -class InnerJoinAssociativity : public Rule { +class JoinAssociativity : public Rule { public: - InnerJoinAssociativity(); + JoinAssociativity(); bool Check(std::shared_ptr plan, OptimizeContext *context) const override; @@ -210,11 +209,11 @@ class LogicalAggregateToPhysical : public Rule { }; /** - * @brief (Logical Inner Join -> Inner Nested-Loop Join) + * @brief (Logical Join -> Nested-Loop Join) */ -class InnerJoinToInnerNLJoin : public Rule { +class JoinToNLJoin : public Rule { public: - InnerJoinToInnerNLJoin(); + JoinToNLJoin(); bool Check(std::shared_ptr plan, OptimizeContext *context) const override; @@ -225,11 +224,11 @@ class InnerJoinToInnerNLJoin : public Rule { }; /** - * @brief (Logical Inner Join -> Inner Hash Join) + * @brief (Logical Join -> Hash Join) */ -class InnerJoinToInnerHashJoin : public Rule { +class JoinToHashJoin : public Rule { public: - InnerJoinToInnerHashJoin(); + JoinToHashJoin(); bool Check(std::shared_ptr plan, OptimizeContext *context) const override; @@ -341,7 +340,8 @@ class EmbedFilterIntoGet : public Rule { /////////////////////////////////////////////////////////////////////////////// /// Unnesting rules // We use this promise to determine which rules should be applied first if -// multiple rules are applicable, we need to first pull filters up through mark-join +// multiple rules are applicable, we need to first pull filters up through +// mark-join // then turn mark-join into a regular join operator enum class UnnestPromise { Low = 1, High }; // TODO(boweic): MarkJoin and SingleJoin should not be transformed into inner diff --git a/src/include/optimizer/stats_calculator.h b/src/include/optimizer/stats_calculator.h index 5aed2902671..224e0faa909 100644 --- a/src/include/optimizer/stats_calculator.h +++ b/src/include/optimizer/stats_calculator.h @@ -26,15 +26,12 @@ class TableStats; */ class StatsCalculator : public OperatorVisitor { public: - void CalculateStats(GroupExpression *gexpr, ExprSet required_cols, - Memo *memo, concurrency::TransactionContext* txn); + void CalculateStats(GroupExpression *gexpr, ExprSet required_cols, Memo *memo, + concurrency::TransactionContext *txn); void Visit(const LogicalGet *) override; void Visit(const LogicalQueryDerivedGet *) override; - void Visit(const LogicalInnerJoin *) override; - void Visit(const LogicalLeftJoin *) override; - void Visit(const LogicalRightJoin *) override; - void Visit(const LogicalOuterJoin *) override; + void Visit(const LogicalJoin *) override; void Visit(const LogicalSemiJoin *) override; void Visit(const LogicalAggregateAndGroupBy *) override; void Visit(const LogicalLimit *) override; @@ -64,8 +61,8 @@ class StatsCalculator : public OperatorVisitor { */ void UpdateStatsForFilter( size_t num_rows, - std::unordered_map> - &predicate_stats, + std::unordered_map> & + predicate_stats, const std::vector &predicates); double CalculateSelectivityForPredicate( @@ -75,7 +72,7 @@ class StatsCalculator : public OperatorVisitor { GroupExpression *gexpr_; ExprSet required_cols_; Memo *memo_; - concurrency::TransactionContext* txn_; + concurrency::TransactionContext *txn_; }; } // namespace optimizer diff --git a/src/include/optimizer/util.h b/src/include/optimizer/util.h index 8b9eb4baeef..06f0420be11 100644 --- a/src/include/optimizer/util.h +++ b/src/include/optimizer/util.h @@ -19,6 +19,8 @@ #include "expression/abstract_expression.h" #include "parser/copy_statement.h" #include "planner/abstract_plan.h" +#include "optimizer/optimize_context.h" +#include "expression/tuple_value_expression.h" namespace peloton { @@ -33,11 +35,11 @@ class DataTable; namespace optimizer { namespace util { - /** - * @brief Convert upper case letters into lower case in a string - * - * @param str The string to operate on - */ +/** + * @brief Convert upper case letters into lower case in a string + * + * @param str The string to operate on + */ inline void to_lower_string(std::string &str) { std::transform(str.begin(), str.end(), str.begin(), ::tolower); } @@ -110,7 +112,6 @@ expression::AbstractExpression *ConstructJoinPredicate( std::unordered_set &table_alias_set, MultiTablePredicates &join_predicates); - /** * @breif Check if there are any join columns in the join expression * For example, expr = (expr_1) AND (expr_2) AND (expr_3) @@ -152,8 +153,8 @@ ConstructSelectElementMap( */ expression::AbstractExpression *TransformQueryDerivedTablePredicates( const std::unordered_map> - &alias_to_expr_map, + std::shared_ptr> & + alias_to_expr_map, expression::AbstractExpression *expr); /** @@ -167,6 +168,25 @@ void ExtractEquiJoinKeys( const std::unordered_set &left_alias, const std::unordered_set &right_alias); +/** + * @brief Given an operator expression and context information, check if it + * is strong predicate w.r.t to one table + * A predicate p is strong w.r.t S if the fact that all attributes from S are + * NULL implies that p evaluates to false + * It is used in AssociativityRule transforms when certain joins are applied + */ +bool StrongPredicates( + std::vector predicates, + const std::unordered_set &middle_group_aliases_set); + +/** + * @brief Replace the tuple_value_expression in given expression which + * contains table in middle_group_aliases_set with constant_value_expression + * with False value + */ +expression::AbstractExpression *PredicateEvaluate( + expression::AbstractExpression *expr, + const std::unordered_set &middle_group_aliases_set); } // namespace util } // namespace optimizer } // namespace peloton diff --git a/src/include/planner/abstract_join_plan.h b/src/include/planner/abstract_join_plan.h index 4c7734cfaba..3172c4079e2 100644 --- a/src/include/planner/abstract_join_plan.h +++ b/src/include/planner/abstract_join_plan.h @@ -87,6 +87,10 @@ class AbstractJoinPlan : public AbstractPlan { virtual void HandleSubplanBinding(bool from_left, const BindingContext &input) = 0; + const std::string GetPredicateInfo() const { + return predicate_ != nullptr ? predicate_->GetInfo() : ""; + } + private: /** @brief The type of join that we're going to perform */ JoinType join_type_; diff --git a/src/include/planner/hash_join_plan.h b/src/include/planner/hash_join_plan.h index 5159834d7f2..0ca3bbbbd31 100644 --- a/src/include/planner/hash_join_plan.h +++ b/src/include/planner/hash_join_plan.h @@ -55,7 +55,9 @@ class HashJoinPlan : public AbstractJoinPlan { void SetBloomFilterFlag(bool flag) { build_bloomfilter_ = flag; } - const std::string GetInfo() const override { return "HashJoin"; } + const std::string GetInfo() const override { + return "HashJoin(" + GetPredicateInfo() + ")"; + } void GetLeftHashKeys( std::vector &keys) const; diff --git a/src/optimizer/child_property_deriver.cpp b/src/optimizer/child_property_deriver.cpp index 1df06b3ea50..b216d49dbd8 100644 --- a/src/optimizer/child_property_deriver.cpp +++ b/src/optimizer/child_property_deriver.cpp @@ -143,37 +143,37 @@ void ChildPropertyDeriver::Visit(const PhysicalDistinct *) { output_.push_back(make_pair(requirements_, move(child_input_properties))); } + void ChildPropertyDeriver::Visit(const PhysicalOrderBy *) {} -void ChildPropertyDeriver::Visit(const PhysicalInnerNLJoin *) { + +void ChildPropertyDeriver::Visit(const PhysicalNLJoin *) { DeriveForJoin(); } -void ChildPropertyDeriver::Visit(const PhysicalLeftNLJoin *) {} -void ChildPropertyDeriver::Visit(const PhysicalRightNLJoin *) {} -void ChildPropertyDeriver::Visit(const PhysicalOuterNLJoin *) {} -void ChildPropertyDeriver::Visit(const PhysicalInnerHashJoin *) { + +void ChildPropertyDeriver::Visit(const PhysicalHashJoin *) { DeriveForJoin(); } -void ChildPropertyDeriver::Visit(const PhysicalLeftHashJoin *) {} -void ChildPropertyDeriver::Visit(const PhysicalRightHashJoin *) {} -void ChildPropertyDeriver::Visit(const PhysicalOuterHashJoin *) {} void ChildPropertyDeriver::Visit(const PhysicalInsert *) { vector> child_input_properties; output_.push_back(make_pair(requirements_, move(child_input_properties))); } + void ChildPropertyDeriver::Visit(const PhysicalInsertSelect *) { // Let child fulfil all the required properties vector> child_input_properties{requirements_}; output_.push_back(make_pair(requirements_, move(child_input_properties))); } + void ChildPropertyDeriver::Visit(const PhysicalUpdate *) { // Let child fulfil all the required properties vector> child_input_properties{requirements_}; output_.push_back(make_pair(requirements_, move(child_input_properties))); } + void ChildPropertyDeriver::Visit(const PhysicalDelete *) { // Let child fulfil all the required properties vector> child_input_properties{requirements_}; diff --git a/src/optimizer/child_stats_deriver.cpp b/src/optimizer/child_stats_deriver.cpp index 0833d55a0f0..5aa581b0ce6 100644 --- a/src/optimizer/child_stats_deriver.cpp +++ b/src/optimizer/child_stats_deriver.cpp @@ -33,7 +33,7 @@ vector ChildStatsDeriver::DeriveInputStats(GroupExpression *gexpr, // TODO(boweic): support stats derivation for derivedGet void ChildStatsDeriver::Visit(const LogicalQueryDerivedGet *) {} -void ChildStatsDeriver::Visit(const LogicalInnerJoin *op) { +void ChildStatsDeriver::Visit(const LogicalJoin *op) { PassDownRequiredCols(); for (auto &annotated_expr : op->join_predicates) { auto predicate = annotated_expr.expr.get(); @@ -44,9 +44,7 @@ void ChildStatsDeriver::Visit(const LogicalInnerJoin *op) { } } } -void ChildStatsDeriver::Visit(UNUSED_ATTRIBUTE const LogicalLeftJoin *) {} -void ChildStatsDeriver::Visit(UNUSED_ATTRIBUTE const LogicalRightJoin *) {} -void ChildStatsDeriver::Visit(UNUSED_ATTRIBUTE const LogicalOuterJoin *) {} + void ChildStatsDeriver::Visit(const LogicalSemiJoin *) {} // TODO(boweic): support stats of aggregation void ChildStatsDeriver::Visit(const LogicalAggregateAndGroupBy *) { diff --git a/src/optimizer/cost_calculator.cpp b/src/optimizer/cost_calculator.cpp index 5dda9e67c8a..98f95013294 100644 --- a/src/optimizer/cost_calculator.cpp +++ b/src/optimizer/cost_calculator.cpp @@ -72,7 +72,8 @@ void CostCalculator::Visit(const PhysicalLimit *op) { output_cost_ = std::min((size_t)child_num_rows, (size_t)op->limit) * DEFAULT_TUPLE_COST; } -void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalInnerNLJoin *op) { + +void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalNLJoin *op) { auto left_child_rows = memo_->GetGroupByID(gexpr_->GetChildGroupId(0))->GetNumRows(); auto right_child_rows = @@ -80,38 +81,41 @@ void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalInnerNLJoin *op) { output_cost_ = left_child_rows * right_child_rows * DEFAULT_TUPLE_COST; } -void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalLeftNLJoin *op) {} -void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalRightNLJoin *op) {} -void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalOuterNLJoin *op) {} -void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalInnerHashJoin *op) { + +void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalHashJoin *op) { auto left_child_rows = memo_->GetGroupByID(gexpr_->GetChildGroupId(0))->GetNumRows(); auto right_child_rows = memo_->GetGroupByID(gexpr_->GetChildGroupId(1))->GetNumRows(); - // TODO(boweic): Build (left) table should have different cost to probe table - output_cost_ = (left_child_rows + right_child_rows) * DEFAULT_TUPLE_COST; + + output_cost_ = left_child_rows * right_child_rows * DEFAULT_TUPLE_COST; } -void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalLeftHashJoin *op) {} -void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalRightHashJoin *op) {} -void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalOuterHashJoin *op) {} + void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalInsert *op) {} + void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalInsertSelect *op) {} + void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalDelete *op) {} + void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalUpdate *op) {} + void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalHashGroupBy *op) { // TODO(boweic): Integrate hash in groupby may cause us to miss the // opportunity to further optimize some query where the child output is // already hashed by the GroupBy key, we'll do a hash anyway output_cost_ = HashCost() + GroupByCost(); } + void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalSortGroupBy *op) { // Sort group by does not sort the tuples, it requires input columns to be // sorted output_cost_ = GroupByCost(); } + void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalDistinct *op) { output_cost_ = HashCost(); } + void CostCalculator::Visit(UNUSED_ATTRIBUTE const PhysicalAggregate *op) { // TODO(boweic): Ditto, separate groupby operator and implementation(e.g. // hash, sort) may enable opportunity for further optimization diff --git a/src/optimizer/input_column_deriver.cpp b/src/optimizer/input_column_deriver.cpp index 7819f81afb9..c942d973de9 100644 --- a/src/optimizer/input_column_deriver.cpp +++ b/src/optimizer/input_column_deriver.cpp @@ -123,26 +123,14 @@ void InputColumnDeriver::Visit(const PhysicalAggregate *op) { void InputColumnDeriver::Visit(const PhysicalDistinct *) { Passdown(); } -void InputColumnDeriver::Visit(const PhysicalInnerNLJoin *op) { +void InputColumnDeriver::Visit(const PhysicalNLJoin *op) { JoinHelper(op); } -void InputColumnDeriver::Visit(const PhysicalLeftNLJoin *) {} - -void InputColumnDeriver::Visit(const PhysicalRightNLJoin *) {} - -void InputColumnDeriver::Visit(const PhysicalOuterNLJoin *) {} - -void InputColumnDeriver::Visit(const PhysicalInnerHashJoin *op) { +void InputColumnDeriver::Visit(const PhysicalHashJoin *op) { JoinHelper(op); } -void InputColumnDeriver::Visit(const PhysicalLeftHashJoin *) {} - -void InputColumnDeriver::Visit(const PhysicalRightHashJoin *) {} - -void InputColumnDeriver::Visit(const PhysicalOuterHashJoin *) {} - void InputColumnDeriver::Visit(const PhysicalInsert *) { output_input_cols_ = pair, vector>>{ @@ -246,13 +234,13 @@ void InputColumnDeriver::JoinHelper(const BaseOperatorNode *op) { const vector> *left_keys = nullptr; const vector> *right_keys = nullptr; - if (op->GetType() == OpType::InnerHashJoin) { - auto join_op = reinterpret_cast(op); + if (op->GetType() == OpType::HashJoin) { + auto join_op = reinterpret_cast(op); join_conds = &(join_op->join_predicates); left_keys = &(join_op->left_keys); right_keys = &(join_op->right_keys); - } else if (op->GetType() == OpType::InnerNLJoin) { - auto join_op = reinterpret_cast(op); + } else if (op->GetType() == OpType::NLJoin) { + auto join_op = reinterpret_cast(op); join_conds = &(join_op->join_predicates); left_keys = &(join_op->left_keys); right_keys = &(join_op->right_keys); diff --git a/src/optimizer/operators.cpp b/src/optimizer/operators.cpp index 78c34d16257..232dbb51349 100644 --- a/src/optimizer/operators.cpp +++ b/src/optimizer/operators.cpp @@ -51,7 +51,7 @@ hash_t LogicalGet::Hash() const { } bool LogicalGet::operator==(const BaseOperatorNode &r) { - if (r.GetType()!= OpType::Get) return false; + if (r.GetType() != OpType::Get) return false; const LogicalGet &node = *static_cast(&r); if (predicates.size() != node.predicates.size()) return false; for (size_t i = 0; i < predicates.size(); i++) { @@ -231,30 +231,34 @@ bool LogicalSingleJoin::operator==(const BaseOperatorNode &r) { } //===--------------------------------------------------------------------===// -// InnerJoin +// Join (Inner + Outer Joins) //===--------------------------------------------------------------------===// -Operator LogicalInnerJoin::make() { - LogicalInnerJoin *join = new LogicalInnerJoin; +Operator LogicalJoin::make(JoinType type) { + LogicalJoin *join = new LogicalJoin; join->join_predicates = {}; + join->type = type; return Operator(join); } -Operator LogicalInnerJoin::make(std::vector &conditions) { - LogicalInnerJoin *join = new LogicalInnerJoin; +Operator LogicalJoin::make(JoinType type, + std::vector &conditions) { + LogicalJoin *join = new LogicalJoin; join->join_predicates = std::move(conditions); + join->type = type; return Operator(join); } -hash_t LogicalInnerJoin::Hash() const { +hash_t LogicalJoin::Hash() const { hash_t hash = BaseOperatorNode::Hash(); for (auto &pred : join_predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); return hash; } -bool LogicalInnerJoin::operator==(const BaseOperatorNode &r) { - if (r.GetType() != OpType::InnerJoin) return false; - const LogicalInnerJoin &node = *static_cast(&r); +bool LogicalJoin::operator==(const BaseOperatorNode &r) { + if (r.GetType() != OpType::LogicalJoin) return false; + + const LogicalJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size()) return false; for (size_t i = 0; i < join_predicates.size(); i++) { if (!join_predicates[i].expr->ExactlyEquals( @@ -265,37 +269,7 @@ bool LogicalInnerJoin::operator==(const BaseOperatorNode &r) { } //===--------------------------------------------------------------------===// -// LeftJoin -//===--------------------------------------------------------------------===// -Operator LogicalLeftJoin::make(expression::AbstractExpression *condition) { - LogicalLeftJoin *join = new LogicalLeftJoin; - join->join_predicate = - std::shared_ptr(condition); - return Operator(join); -} - -//===--------------------------------------------------------------------===// -// RightJoin -//===--------------------------------------------------------------------===// -Operator LogicalRightJoin::make(expression::AbstractExpression *condition) { - LogicalRightJoin *join = new LogicalRightJoin; - join->join_predicate = - std::shared_ptr(condition); - return Operator(join); -} - -//===--------------------------------------------------------------------===// -// OuterJoin -//===--------------------------------------------------------------------===// -Operator LogicalOuterJoin::make(expression::AbstractExpression *condition) { - LogicalOuterJoin *join = new LogicalOuterJoin; - join->join_predicate = - std::shared_ptr(condition); - return Operator(join); -} - -//===--------------------------------------------------------------------===// -// OuterJoin +// SemiJoin //===--------------------------------------------------------------------===// Operator LogicalSemiJoin::make(expression::AbstractExpression *condition) { LogicalSemiJoin *join = new LogicalSemiJoin; @@ -385,8 +359,8 @@ Operator LogicalDelete::make( //===--------------------------------------------------------------------===// Operator LogicalUpdate::make( std::shared_ptr target_table, - const std::vector> - *updates) { + const std::vector> * + updates) { LogicalUpdate *update_op = new LogicalUpdate; update_op->target_table = target_table; update_op->updates = updates; @@ -554,39 +528,41 @@ Operator PhysicalLimit::make(int64_t offset, int64_t limit) { } //===--------------------------------------------------------------------===// -// InnerNLJoin +// NLJoin (Inner + Outer Joins) //===--------------------------------------------------------------------===// -Operator PhysicalInnerNLJoin::make( - std::vector conditions, +Operator PhysicalNLJoin::make( + JoinType type, std::vector conditions, std::vector> &left_keys, std::vector> &right_keys) { - PhysicalInnerNLJoin *join = new PhysicalInnerNLJoin(); + PhysicalNLJoin *join = new PhysicalNLJoin(); join->join_predicates = std::move(conditions); join->left_keys = std::move(left_keys); join->right_keys = std::move(right_keys); + join->type = type; return Operator(join); } -hash_t PhysicalInnerNLJoin::Hash() const { +hash_t PhysicalNLJoin::Hash() const { hash_t hash = BaseOperatorNode::Hash(); - for (auto &expr : left_keys) + for (const auto &expr : left_keys) hash = HashUtil::CombineHashes(hash, expr->Hash()); - for (auto &expr : right_keys) + for (const auto &expr : right_keys) hash = HashUtil::CombineHashes(hash, expr->Hash()); - for (auto &pred : join_predicates) + for (const auto &pred : join_predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); return hash; } -bool PhysicalInnerNLJoin::operator==(const BaseOperatorNode &r) { - if (r.GetType() != OpType::InnerNLJoin) return false; - const PhysicalInnerNLJoin &node = - *static_cast(&r); +bool PhysicalNLJoin::operator==(const BaseOperatorNode &r) { + if (r.GetType() != OpType::NLJoin) return false; + + const PhysicalNLJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size() || left_keys.size() != node.left_keys.size() || right_keys.size() != node.right_keys.size()) return false; + for (size_t i = 0; i < left_keys.size(); i++) { if (!left_keys[i]->ExactlyEquals(*node.left_keys[i].get())) return false; } @@ -602,64 +578,35 @@ bool PhysicalInnerNLJoin::operator==(const BaseOperatorNode &r) { } //===--------------------------------------------------------------------===// -// LeftNLJoin -//===--------------------------------------------------------------------===// -Operator PhysicalLeftNLJoin::make( - std::shared_ptr join_predicate) { - PhysicalLeftNLJoin *join = new PhysicalLeftNLJoin(); - join->join_predicate = join_predicate; - return Operator(join); -} - +// HashJoin //===--------------------------------------------------------------------===// -// RightNLJoin -//===--------------------------------------------------------------------===// -Operator PhysicalRightNLJoin::make( - std::shared_ptr join_predicate) { - PhysicalRightNLJoin *join = new PhysicalRightNLJoin(); - join->join_predicate = join_predicate; - return Operator(join); -} - -//===--------------------------------------------------------------------===// -// OuterNLJoin -//===--------------------------------------------------------------------===// -Operator PhysicalOuterNLJoin::make( - std::shared_ptr join_predicate) { - PhysicalOuterNLJoin *join = new PhysicalOuterNLJoin(); - join->join_predicate = join_predicate; - return Operator(join); -} - -//===--------------------------------------------------------------------===// -// InnerHashJoin -//===--------------------------------------------------------------------===// -Operator PhysicalInnerHashJoin::make( - std::vector conditions, +Operator PhysicalHashJoin::make( + JoinType type, std::vector conditions, std::vector> &left_keys, std::vector> &right_keys) { - PhysicalInnerHashJoin *join = new PhysicalInnerHashJoin(); + PhysicalHashJoin *join = new PhysicalHashJoin(); join->join_predicates = std::move(conditions); join->left_keys = std::move(left_keys); join->right_keys = std::move(right_keys); + join->type = type; return Operator(join); } -hash_t PhysicalInnerHashJoin::Hash() const { +hash_t PhysicalHashJoin::Hash() const { hash_t hash = BaseOperatorNode::Hash(); - for (auto &expr : left_keys) + for (const auto &expr : left_keys) hash = HashUtil::CombineHashes(hash, expr->Hash()); - for (auto &expr : right_keys) + for (const auto &expr : right_keys) hash = HashUtil::CombineHashes(hash, expr->Hash()); - for (auto &pred : join_predicates) + for (const auto &pred : join_predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); return hash; } -bool PhysicalInnerHashJoin::operator==(const BaseOperatorNode &r) { - if (r.GetType() != OpType::InnerHashJoin) return false; - const PhysicalInnerHashJoin &node = - *static_cast(&r); +bool PhysicalHashJoin::operator==(const BaseOperatorNode &r) { + if (r.GetType() != OpType::HashJoin) return false; + + const PhysicalHashJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size() || left_keys.size() != node.left_keys.size() || right_keys.size() != node.right_keys.size()) @@ -678,36 +625,6 @@ bool PhysicalInnerHashJoin::operator==(const BaseOperatorNode &r) { return true; } -//===--------------------------------------------------------------------===// -// LeftHashJoin -//===--------------------------------------------------------------------===// -Operator PhysicalLeftHashJoin::make( - std::shared_ptr join_predicate) { - PhysicalLeftHashJoin *join = new PhysicalLeftHashJoin(); - join->join_predicate = join_predicate; - return Operator(join); -} - -//===--------------------------------------------------------------------===// -// RightHashJoin -//===--------------------------------------------------------------------===// -Operator PhysicalRightHashJoin::make( - std::shared_ptr join_predicate) { - PhysicalRightHashJoin *join = new PhysicalRightHashJoin(); - join->join_predicate = join_predicate; - return Operator(join); -} - -//===--------------------------------------------------------------------===// -// OuterHashJoin -//===--------------------------------------------------------------------===// -Operator PhysicalOuterHashJoin::make( - std::shared_ptr join_predicate) { - PhysicalOuterHashJoin *join = new PhysicalOuterHashJoin(); - join->join_predicate = join_predicate; - return Operator(join); -} - //===--------------------------------------------------------------------===// // PhysicalInsert //===--------------------------------------------------------------------===// @@ -748,8 +665,8 @@ Operator PhysicalDelete::make( //===--------------------------------------------------------------------===// Operator PhysicalUpdate::make( std::shared_ptr target_table, - const std::vector> - *updates) { + const std::vector> * + updates) { PhysicalUpdate *update = new PhysicalUpdate; update->target_table = target_table; update->updates = updates; @@ -859,13 +776,7 @@ std::string OperatorNode::name_ = "LogicalSingleJoin"; template <> std::string OperatorNode::name_ = "LogicalDependentJoin"; template <> -std::string OperatorNode::name_ = "LogicalInnerJoin"; -template <> -std::string OperatorNode::name_ = "LogicalLeftJoin"; -template <> -std::string OperatorNode::name_ = "LogicalRightJoin"; -template <> -std::string OperatorNode::name_ = "LogicalOuterJoin"; +std::string OperatorNode::name_ = "LogicalJoin"; template <> std::string OperatorNode::name_ = "LogicalSemiJoin"; template <> @@ -896,24 +807,9 @@ std::string OperatorNode::name_ = "PhysicalOrderBy"; template <> std::string OperatorNode::name_ = "PhysicalLimit"; template <> -std::string OperatorNode::name_ = "PhysicalInnerNLJoin"; -template <> -std::string OperatorNode::name_ = "PhysicalLeftNLJoin"; -template <> -std::string OperatorNode::name_ = "PhysicalRightNLJoin"; -template <> -std::string OperatorNode::name_ = "PhysicalOuterNLJoin"; -template <> -std::string OperatorNode::name_ = - "PhysicalInnerHashJoin"; -template <> -std::string OperatorNode::name_ = "PhysicalLeftHashJoin"; -template <> -std::string OperatorNode::name_ = - "PhysicalRightHashJoin"; +std::string OperatorNode::name_ = "PhysicalNLJoin"; template <> -std::string OperatorNode::name_ = - "PhysicalOuterHashJoin"; +std::string OperatorNode::name_ = "PhysicalHashJoin"; template <> std::string OperatorNode::name_ = "PhysicalInsert"; template <> @@ -950,13 +846,7 @@ OpType OperatorNode::type_ = OpType::LogicalSingleJoin; template <> OpType OperatorNode::type_ = OpType::LogicalDependentJoin; template <> -OpType OperatorNode::type_ = OpType::InnerJoin; -template <> -OpType OperatorNode::type_ = OpType::LeftJoin; -template <> -OpType OperatorNode::type_ = OpType::RightJoin; -template <> -OpType OperatorNode::type_ = OpType::OuterJoin; +OpType OperatorNode::type_ = OpType::LogicalJoin; template <> OpType OperatorNode::type_ = OpType::SemiJoin; template <> @@ -989,21 +879,9 @@ OpType OperatorNode::type_ = OpType::Distinct; template <> OpType OperatorNode::type_ = OpType::PhysicalLimit; template <> -OpType OperatorNode::type_ = OpType::InnerNLJoin; -template <> -OpType OperatorNode::type_ = OpType::LeftNLJoin; -template <> -OpType OperatorNode::type_ = OpType::RightNLJoin; -template <> -OpType OperatorNode::type_ = OpType::OuterNLJoin; -template <> -OpType OperatorNode::type_ = OpType::InnerHashJoin; -template <> -OpType OperatorNode::type_ = OpType::LeftHashJoin; -template <> -OpType OperatorNode::type_ = OpType::RightHashJoin; +OpType OperatorNode::type_ = OpType::NLJoin; template <> -OpType OperatorNode::type_ = OpType::OuterHashJoin; +OpType OperatorNode::type_ = OpType::HashJoin; template <> OpType OperatorNode::type_ = OpType::Insert; template <> diff --git a/src/optimizer/optimizer_task.cpp b/src/optimizer/optimizer_task.cpp index f0a489906ae..8654f92982c 100644 --- a/src/optimizer/optimizer_task.cpp +++ b/src/optimizer/optimizer_task.cpp @@ -423,6 +423,11 @@ void TopDownRewrite::execute() { r.rule->GetMatchPattern()); if (iterator.HasNext()) { auto before = iterator.Next(); + + if (!r.rule->Check(before, context_.get())) { + continue; + } + PELOTON_ASSERT(!iterator.HasNext()); std::vector> after; r.rule->Transform(before, after, context_.get()); diff --git a/src/optimizer/plan_generator.cpp b/src/optimizer/plan_generator.cpp index a16b70c3878..a5715280ae2 100644 --- a/src/optimizer/plan_generator.cpp +++ b/src/optimizer/plan_generator.cpp @@ -198,7 +198,7 @@ void PlanGenerator::Visit(const PhysicalDistinct *) { output_plan_ = move(hash_plan); } -void PlanGenerator::Visit(const PhysicalInnerNLJoin *op) { +void PlanGenerator::Visit(const PhysicalNLJoin *op) { std::unique_ptr proj_info; std::shared_ptr proj_schema; GenerateProjectionForJoin(proj_info, proj_schema); @@ -214,18 +214,18 @@ void PlanGenerator::Visit(const PhysicalInnerNLJoin *op) { vector right_keys; for (auto &expr : op->left_keys) { PELOTON_ASSERT(children_expr_map_[0].find(expr.get()) != - children_expr_map_[0].end()); + children_expr_map_[0].end()); left_keys.push_back(children_expr_map_[0][expr.get()]); } for (auto &expr : op->right_keys) { PELOTON_ASSERT(children_expr_map_[1].find(expr.get()) != - children_expr_map_[1].end()); + children_expr_map_[1].end()); right_keys.emplace_back(children_expr_map_[1][expr.get()]); } - auto join_plan = + unique_ptr join_plan = unique_ptr(new planner::NestedLoopJoinPlan( - JoinType::INNER, move(join_predicate), move(proj_info), proj_schema, + op->type, move(join_predicate), move(proj_info), proj_schema, left_keys, right_keys)); join_plan->AddChild(move(children_plans_[0])); @@ -233,13 +233,7 @@ void PlanGenerator::Visit(const PhysicalInnerNLJoin *op) { output_plan_ = move(join_plan); } -void PlanGenerator::Visit(const PhysicalLeftNLJoin *) {} - -void PlanGenerator::Visit(const PhysicalRightNLJoin *) {} - -void PlanGenerator::Visit(const PhysicalOuterNLJoin *) {} - -void PlanGenerator::Visit(const PhysicalInnerHashJoin *op) { +void PlanGenerator::Visit(const PhysicalHashJoin *op) { std::unique_ptr proj_info; std::shared_ptr proj_schema; GenerateProjectionForJoin(proj_info, proj_schema); @@ -267,7 +261,7 @@ void PlanGenerator::Visit(const PhysicalInnerHashJoin *op) { } // Evaluate Expr for hash plan vector> hash_keys; - for (auto &expr : op->right_keys) { + for (const auto &expr : op->right_keys) { auto hash_key = expr->Copy(); expression::ExpressionUtil::EvaluateExpression(r_child_map, hash_key); hash_keys.emplace_back(hash_key); @@ -277,7 +271,7 @@ void PlanGenerator::Visit(const PhysicalInnerHashJoin *op) { hash_plan->AddChild(move(children_plans_[1])); auto join_plan = unique_ptr(new planner::HashJoinPlan( - JoinType::INNER, move(join_predicate), move(proj_info), proj_schema, + op->type, move(join_predicate), move(proj_info), proj_schema, left_keys, right_keys, settings::SettingsManager::GetBool( settings::SettingId::hash_join_bloom_filter))); @@ -286,12 +280,6 @@ void PlanGenerator::Visit(const PhysicalInnerHashJoin *op) { output_plan_ = move(join_plan); } -void PlanGenerator::Visit(const PhysicalLeftHashJoin *) {} - -void PlanGenerator::Visit(const PhysicalRightHashJoin *) {} - -void PlanGenerator::Visit(const PhysicalOuterHashJoin *) {} - void PlanGenerator::Visit(const PhysicalInsert *op) { unique_ptr insert_plan(new planner::InsertPlan( storage::StorageManager::GetInstance()->GetTableWithOid( diff --git a/src/optimizer/query_to_operator_transformer.cpp b/src/optimizer/query_to_operator_transformer.cpp index ff75140d5f5..d71cd0cf4a0 100644 --- a/src/optimizer/query_to_operator_transformer.cpp +++ b/src/optimizer/query_to_operator_transformer.cpp @@ -120,6 +120,8 @@ void QueryToOperatorTransformer::Visit(parser::SelectStatement *op) { predicates_ = std::move(pre_predicates); } void QueryToOperatorTransformer::Visit(parser::JoinDefinition *node) { + auto pre_predicates = std::move(predicates_); + // Get left operator node->left->Accept(this); auto left_expr = output_expr_; @@ -133,23 +135,26 @@ void QueryToOperatorTransformer::Visit(parser::JoinDefinition *node) { switch (node->type) { case JoinType::INNER: { predicates_ = CollectPredicates(node->condition.get(), predicates_); - join_expr = - std::make_shared(LogicalInnerJoin::make()); + join_expr = std::make_shared( + LogicalJoin::make(JoinType::INNER, predicates_)); break; } case JoinType::OUTER: { + predicates_ = CollectPredicates(node->condition.get(), predicates_); join_expr = std::make_shared( - LogicalOuterJoin::make(node->condition->Copy())); + LogicalJoin::make(JoinType::OUTER, predicates_)); break; } case JoinType::LEFT: { + predicates_ = CollectPredicates(node->condition.get(), predicates_); join_expr = std::make_shared( - LogicalLeftJoin::make(node->condition->Copy())); + LogicalJoin::make(JoinType::LEFT, predicates_)); break; } case JoinType::RIGHT: { + predicates_ = CollectPredicates(node->condition.get(), predicates_); join_expr = std::make_shared( - LogicalRightJoin::make(node->condition->Copy())); + LogicalJoin::make(JoinType::RIGHT, predicates_)); break; } case JoinType::SEMI: { @@ -165,6 +170,7 @@ void QueryToOperatorTransformer::Visit(parser::JoinDefinition *node) { join_expr->PushChild(right_expr); output_expr_ = join_expr; + predicates_ = std::move(pre_predicates); } void QueryToOperatorTransformer::Visit(parser::TableRef *node) { if (node->select != nullptr) { @@ -201,8 +207,8 @@ void QueryToOperatorTransformer::Visit(parser::TableRef *node) { // Build a left deep join tree for (size_t i = 1; i < node->list.size(); i++) { node->list.at(i)->Accept(this); - auto join_expr = - std::make_shared(LogicalInnerJoin::make()); + auto join_expr = std::make_shared( + LogicalJoin::make(JoinType::INNER)); join_expr->PushChild(prev_expr); join_expr->PushChild(output_expr_); PELOTON_ASSERT(join_expr->Children().size() == 2); diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index 1e81799147d..1cce07ee125 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -28,8 +28,8 @@ int Rule::Promise(GroupExpression *group_expr, OptimizeContext *context) const { } RuleSet::RuleSet() { - AddTransformationRule(new InnerJoinCommutativity()); - AddTransformationRule(new InnerJoinAssociativity()); + AddTransformationRule(new JoinCommutativity()); + AddTransformationRule(new JoinAssociativity()); AddImplementationRule(new LogicalDeleteToPhysical()); AddImplementationRule(new LogicalUpdateToPhysical()); AddImplementationRule(new LogicalInsertToPhysical()); @@ -40,13 +40,13 @@ RuleSet::RuleSet() { AddImplementationRule(new GetToSeqScan()); AddImplementationRule(new GetToIndexScan()); AddImplementationRule(new LogicalQueryDerivedGetToPhysical()); - AddImplementationRule(new InnerJoinToInnerNLJoin()); - AddImplementationRule(new InnerJoinToInnerHashJoin()); + AddImplementationRule(new JoinToNLJoin()); + AddImplementationRule(new JoinToHashJoin()); AddImplementationRule(new ImplementDistinct()); AddImplementationRule(new ImplementLimit()); AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN, - new PushFilterThroughJoin()); + new PushFilterThroughJoin()); AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN, new PushFilterThroughAggregation()); AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN, diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index e540555c9e3..3e06eeb2501 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -30,38 +30,46 @@ namespace optimizer { //===--------------------------------------------------------------------===// /////////////////////////////////////////////////////////////////////////////// -/// InnerJoinCommutativity -InnerJoinCommutativity::InnerJoinCommutativity() { - type_ = RuleType::INNER_JOIN_COMMUTE; +/// JoinCommutativity +JoinCommutativity::JoinCommutativity() { + type_ = RuleType::JOIN_COMMUTE; std::shared_ptr left_child(std::make_shared(OpType::Leaf)); std::shared_ptr right_child(std::make_shared(OpType::Leaf)); - match_pattern = std::make_shared(OpType::InnerJoin); + match_pattern = std::make_shared(OpType::LogicalJoin); match_pattern->AddChild(left_child); match_pattern->AddChild(right_child); } -bool InnerJoinCommutativity::Check(std::shared_ptr expr, - OptimizeContext *context) const { +bool JoinCommutativity::Check(std::shared_ptr expr, + OptimizeContext *context) const { (void)context; (void)expr; return true; } -void InnerJoinCommutativity::Transform( +void JoinCommutativity::Transform( std::shared_ptr input, std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - auto join_op = input->Op().As(); + auto join_op = input->Op().As(); auto join_predicates = std::vector(join_op->join_predicates); + + auto join_type = join_op->type; + if (join_type == JoinType::LEFT) { + join_type = JoinType::RIGHT; + } else if (join_type == JoinType::RIGHT) { + join_type = JoinType::LEFT; + } + auto result_plan = std::make_shared( - LogicalInnerJoin::make(join_predicates)); + LogicalJoin::make(join_type, join_predicates)); std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 2); - LOG_TRACE( - "Reorder left child with op %s and right child with op %s for inner join", - children[0]->Op().GetName().c_str(), children[1]->Op().GetName().c_str()); + LOG_TRACE("Reorder left child with op %s and right child with op %s for join", + children[0]->Op().GetName().c_str(), + children[1]->Op().GetName().c_str()); result_plan->PushChild(children[1]); result_plan->PushChild(children[0]); @@ -69,42 +77,100 @@ void InnerJoinCommutativity::Transform( } /////////////////////////////////////////////////////////////////////////////// -/// InnerJoinAssociativity -InnerJoinAssociativity::InnerJoinAssociativity() { - type_ = RuleType::INNER_JOIN_ASSOCIATE; +/// JoinAssociativity +JoinAssociativity::JoinAssociativity() { + type_ = RuleType::JOIN_ASSOCIATE; // Create left nested join - auto left_child = std::make_shared(OpType::InnerJoin); + auto left_child = std::make_shared(OpType::LogicalJoin); left_child->AddChild(std::make_shared(OpType::Leaf)); left_child->AddChild(std::make_shared(OpType::Leaf)); std::shared_ptr right_child(std::make_shared(OpType::Leaf)); - match_pattern = std::make_shared(OpType::InnerJoin); + match_pattern = std::make_shared(OpType::LogicalJoin); match_pattern->AddChild(left_child); match_pattern->AddChild(right_child); } -// TODO: As far as I know, theres nothing else that needs to be checked -bool InnerJoinAssociativity::Check(std::shared_ptr expr, - OptimizeContext *context) const { +bool JoinAssociativity::Check(std::shared_ptr expr, + OptimizeContext *context) const { (void)context; - (void)expr; - return true; + // Associativity rules taken from + // http://15721.courses.cs.cmu.edu/spring2017/papers/14-optimizer1/p539-moerkotte.pdf + /* + * (A inner B) inner C = A inner (B inner C) + * (A right B) inner C = A right (B inner C) + * (A inner B) left C = A inner (B left C) + * (A left B) left C = A left (B left C) with Strong Predicates PBC + * (A full B) left C = A full (B left C) with Strong Predicate PBC + * (A right B) right C = A right (B right C) with Strong Predicate PAB + * (A right B) full C = A right (B full C) with Strong Predicate PAB + * (A full B) full C = A full (B full C) with Strong Predicate PAB & PBC + */ + auto parent_join = expr->Op().As(); + std::vector> children = expr->Children(); + auto child_join = children[0]->Op().As(); + auto middle = children[0]->Children()[1]; + // Get Alias sets + auto &memo = context->metadata->memo; + auto middle_group_id = middle->Op().As()->origin_group; + const auto &middle_group_aliases_set = + memo.GetGroupByID(middle_group_id)->GetTableAliases(); + + if (parent_join->type == JoinType::INNER) { + if (child_join->type == JoinType::INNER || + child_join->type == JoinType::RIGHT) { + return true; + } + } else if (parent_join->type == JoinType::LEFT) { + if (child_join->type == JoinType::INNER) { + return true; + } else if (child_join->type == JoinType::LEFT || + child_join->type == JoinType::OUTER) { + return util::StrongPredicates(parent_join->join_predicates, + middle_group_aliases_set); + } + } else if (parent_join->type == JoinType::RIGHT) { + if (child_join->type == JoinType::RIGHT) { + return util::StrongPredicates(child_join->join_predicates, + middle_group_aliases_set); + } + } else if (parent_join->type == JoinType::OUTER) { + if (child_join->type == JoinType::RIGHT) { + return util::StrongPredicates(child_join->join_predicates, + middle_group_aliases_set); + } else if (child_join->type == JoinType::OUTER) { + auto parent_join_predicates = + std::vector(parent_join->join_predicates); + auto child_join_predicates = + std::vector(child_join->join_predicates); + + std::vector check_predicates; + check_predicates.insert(check_predicates.end(), + parent_join_predicates.begin(), + parent_join_predicates.end()); + check_predicates.insert(check_predicates.end(), + child_join_predicates.begin(), + child_join_predicates.end()); + return util::StrongPredicates(check_predicates, middle_group_aliases_set); + } + } + return false; } -void InnerJoinAssociativity::Transform( +void JoinAssociativity::Transform( std::shared_ptr input, std::vector> &transformed, OptimizeContext *context) const { // NOTE: Transforms (left JOIN middle) JOIN right -> left JOIN (middle JOIN // right) Variables are named accordingly to above transformation - auto parent_join = input->Op().As(); + auto parent_join = input->Op().As(); std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 2); - PELOTON_ASSERT(children[0]->Op().GetType() == OpType::InnerJoin); + PELOTON_ASSERT(children[0]->Op().GetType() == OpType::LogicalJoin); PELOTON_ASSERT(children[0]->Children().size() == 2); - auto child_join = children[0]->Op().As(); + auto child_join = children[0]->Op().As(); auto left = children[0]->Children()[0]; auto middle = children[0]->Children()[1]; auto right = children[1]; @@ -153,17 +219,21 @@ void InnerJoinAssociativity::Transform( } } + JoinType new_parent_join_type; + JoinType new_child_join_type; + new_parent_join_type = child_join->type; + new_child_join_type = parent_join->type; // Construct new child join operator std::shared_ptr new_child_join = std::make_shared( - LogicalInnerJoin::make(new_child_join_predicates)); + LogicalJoin::make(new_child_join_type, new_child_join_predicates)); new_child_join->PushChild(middle); new_child_join->PushChild(right); // Construct new parent join operator std::shared_ptr new_parent_join = std::make_shared( - LogicalInnerJoin::make(new_parent_join_predicates)); + LogicalJoin::make(new_parent_join_type, new_parent_join_predicates)); new_parent_join->PushChild(left); new_parent_join->PushChild(new_child_join); @@ -275,9 +345,8 @@ void GetToIndexScan::Transform( sort_by_asc_base_column = false; break; } - auto bound_oids = - reinterpret_cast(expr) - ->GetBoundOid(); + auto bound_oids = reinterpret_cast( + expr)->GetBoundOid(); sort_col_ids.push_back(std::get<2>(bound_oids)); } // Check whether any index can fulfill sort property @@ -358,20 +427,16 @@ void GetToIndexScan::Transform( if (value_expr->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { value_list.push_back( reinterpret_cast( - value_expr) - ->GetValue()); + value_expr)->GetValue()); LOG_TRACE("Value Type: %d", static_cast( reinterpret_cast( - expr->GetModifiableChild(1)) - ->GetValueType())); + expr->GetModifiableChild(1))->GetValueType())); } else { value_list.push_back( type::ValueFactory::GetParameterOffsetValue( reinterpret_cast( - value_expr) - ->GetValueIdx()) - .Copy()); + value_expr)->GetValueIdx()).Copy()); LOG_TRACE("Parameter offset: %s", (*value_list.rbegin()).GetInfo().c_str()); } @@ -612,16 +677,16 @@ void LogicalAggregateToPhysical::Transform( } /////////////////////////////////////////////////////////////////////////////// -/// InnerJoinToInnerNLJoin -InnerJoinToInnerNLJoin::InnerJoinToInnerNLJoin() { - type_ = RuleType::INNER_JOIN_TO_NL_JOIN; +/// JoinToNLJoin +JoinToNLJoin::JoinToNLJoin() { + type_ = RuleType::JOIN_TO_NL_JOIN; // TODO NLJoin currently only support left deep tree std::shared_ptr left_child(std::make_shared(OpType::Leaf)); std::shared_ptr right_child(std::make_shared(OpType::Leaf)); // Initialize a pattern for optimizer to match - match_pattern = std::make_shared(OpType::InnerJoin); + match_pattern = std::make_shared(OpType::LogicalJoin); // Add node - we match join relation R and S match_pattern->AddChild(left_child); @@ -630,19 +695,19 @@ InnerJoinToInnerNLJoin::InnerJoinToInnerNLJoin() { return; } -bool InnerJoinToInnerNLJoin::Check(std::shared_ptr plan, - OptimizeContext *context) const { +bool JoinToNLJoin::Check(std::shared_ptr plan, + OptimizeContext *context) const { (void)context; (void)plan; return true; } -void InnerJoinToInnerNLJoin::Transform( +void JoinToNLJoin::Transform( std::shared_ptr input, std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - // first build an expression representing hash join - const LogicalInnerJoin *inner_join = input->Op().As(); + // first build an expression representing nested loop join + const LogicalJoin *join = input->Op().As(); auto children = input->Children(); PELOTON_ASSERT(children.size() == 2); @@ -655,13 +720,14 @@ void InnerJoinToInnerNLJoin::Transform( std::vector> left_keys; std::vector> right_keys; - util::ExtractEquiJoinKeys(inner_join->join_predicates, left_keys, right_keys, + util::ExtractEquiJoinKeys(join->join_predicates, left_keys, right_keys, left_group_alias, right_group_alias); PELOTON_ASSERT(right_keys.size() == left_keys.size()); - auto result_plan = - std::make_shared(PhysicalInnerNLJoin::make( - inner_join->join_predicates, left_keys, right_keys)); + std::shared_ptr result_plan; + + result_plan = std::make_shared(PhysicalNLJoin::make( + join->type, join->join_predicates, left_keys, right_keys)); // Then push all children into the child list of the new operator result_plan->PushChild(children[0]); @@ -673,16 +739,16 @@ void InnerJoinToInnerNLJoin::Transform( } /////////////////////////////////////////////////////////////////////////////// -/// InnerJoinToInnerHashJoin -InnerJoinToInnerHashJoin::InnerJoinToInnerHashJoin() { - type_ = RuleType::INNER_JOIN_TO_HASH_JOIN; +/// JoinToInnerHashJoin +JoinToHashJoin::JoinToHashJoin() { + type_ = RuleType::JOIN_TO_HASH_JOIN; // Make three node types for pattern matching std::shared_ptr left_child(std::make_shared(OpType::Leaf)); std::shared_ptr right_child(std::make_shared(OpType::Leaf)); // Initialize a pattern for optimizer to match - match_pattern = std::make_shared(OpType::InnerJoin); + match_pattern = std::make_shared(OpType::LogicalJoin); // Add node - we match join relation R and S as well as the predicate exp match_pattern->AddChild(left_child); @@ -691,19 +757,19 @@ InnerJoinToInnerHashJoin::InnerJoinToInnerHashJoin() { return; } -bool InnerJoinToInnerHashJoin::Check(std::shared_ptr plan, - OptimizeContext *context) const { +bool JoinToHashJoin::Check(std::shared_ptr plan, + OptimizeContext *context) const { (void)context; (void)plan; return true; } -void InnerJoinToInnerHashJoin::Transform( +void JoinToHashJoin::Transform( std::shared_ptr input, std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { // first build an expression representing hash join - const LogicalInnerJoin *inner_join = input->Op().As(); + const LogicalJoin *join = input->Op().As(); auto children = input->Children(); PELOTON_ASSERT(children.size() == 2); @@ -716,14 +782,14 @@ void InnerJoinToInnerHashJoin::Transform( std::vector> left_keys; std::vector> right_keys; - util::ExtractEquiJoinKeys(inner_join->join_predicates, left_keys, right_keys, + util::ExtractEquiJoinKeys(join->join_predicates, left_keys, right_keys, left_group_alias, right_group_alias); PELOTON_ASSERT(right_keys.size() == left_keys.size()); if (!left_keys.empty()) { auto result_plan = - std::make_shared(PhysicalInnerHashJoin::make( - inner_join->join_predicates, left_keys, right_keys)); + std::make_shared(PhysicalHashJoin::make( + join->type, join->join_predicates, left_keys, right_keys)); // Then push all children into the child list of the new operator result_plan->PushChild(children[0]); @@ -807,7 +873,8 @@ PushFilterThroughJoin::PushFilterThroughJoin() { type_ = RuleType::PUSH_FILTER_THROUGH_JOIN; // Make three node types for pattern matching - std::shared_ptr child(std::make_shared(OpType::InnerJoin)); + std::shared_ptr child( + std::make_shared(OpType::LogicalJoin)); child->AddChild(std::make_shared(OpType::Leaf)); child->AddChild(std::make_shared(OpType::Leaf)); @@ -843,6 +910,11 @@ void PushFilterThroughJoin::Transform( std::vector right_predicates; std::vector join_predicates; + auto join_type = join_op_expr->Op().As()->type; + bool outer_push = + (join_type == JoinType::OUTER || join_type == JoinType::LEFT || + join_type == JoinType::RIGHT); + // Loop over all predicates, check each of them if they can be pushed down to // either the left child or the right child to be evaluated // All predicates in this loop follow conjunction relationship because we @@ -850,10 +922,12 @@ void PushFilterThroughJoin::Transform( // E.g. An expression (test.a = test1.b and test.a = 5) would become // {test.a = test1.b, test.a = 5} for (auto &predicate : predicates) { - if (util::IsSubset(left_group_aliases_set, predicate.table_alias_set)) { + if (util::IsSubset(left_group_aliases_set, predicate.table_alias_set) && + !outer_push) { left_predicates.emplace_back(predicate); } else if (util::IsSubset(right_group_aliases_set, - predicate.table_alias_set)) { + predicate.table_alias_set) && + !outer_push) { right_predicates.emplace_back(predicate); } else { join_predicates.emplace_back(predicate); @@ -862,12 +936,12 @@ void PushFilterThroughJoin::Transform( // Construct join operator auto pre_join_predicate = - join_op_expr->Op().As()->join_predicates; + join_op_expr->Op().As()->join_predicates; join_predicates.insert(join_predicates.end(), pre_join_predicate.begin(), pre_join_predicate.end()); std::shared_ptr output = - std::make_shared( - LogicalInnerJoin::make(join_predicates)); + std::make_shared(LogicalJoin::make( + join_op_expr->Op().As()->type, join_predicates)); // Construct left filter if any if (!left_predicates.empty()) { @@ -1087,7 +1161,7 @@ void MarkJoinToInnerJoin::Transform( PELOTON_ASSERT(mark_join->join_predicates.empty()); std::shared_ptr output = - std::make_shared(LogicalInnerJoin::make()); + std::make_shared(LogicalJoin::make(JoinType::INNER)); output->PushChild(join_children[0]); output->PushChild(join_children[1]); @@ -1138,7 +1212,7 @@ void SingleJoinToInnerJoin::Transform( PELOTON_ASSERT(single_join->join_predicates.empty()); std::shared_ptr output = - std::make_shared(LogicalInnerJoin::make()); + std::make_shared(LogicalJoin::make(JoinType::INNER)); output->PushChild(join_children[0]); output->PushChild(join_children[1]); diff --git a/src/optimizer/stats_calculator.cpp b/src/optimizer/stats_calculator.cpp index 3cdb34c4d9d..e6ea7d3346e 100644 --- a/src/optimizer/stats_calculator.cpp +++ b/src/optimizer/stats_calculator.cpp @@ -42,8 +42,8 @@ void StatsCalculator::Visit(const LogicalGet *op) { return; } auto table_stats = std::dynamic_pointer_cast( - StatsStorage::GetInstance()->GetTableStats(op->table->GetDatabaseOid(), - op->table->GetTableOid(), txn_)); + StatsStorage::GetInstance()->GetTableStats( + op->table->GetDatabaseOid(), op->table->GetTableOid(), txn_)); // First, get the required stats of the base table std::unordered_map> required_stats; for (auto &col : required_cols_) { @@ -96,7 +96,7 @@ void StatsCalculator::Visit(const LogicalQueryDerivedGet *) { } } -void StatsCalculator::Visit(const LogicalInnerJoin *op) { +void StatsCalculator::Visit(const LogicalJoin *op) { // Check if there's join condition PELOTON_ASSERT(gexpr_->GetChildrenGroupsSize() == 2); auto left_child_group = memo_->GetGroupByID(gexpr_->GetChildGroupId(0)); @@ -143,7 +143,8 @@ void StatsCalculator::Visit(const LogicalInnerJoin *op) { column_stats = std::make_shared( *left_child_group->GetStats(tv_expr->GetColFullName())); } else { - PELOTON_ASSERT(right_child_group->HasColumnStats(tv_expr->GetColFullName())); + PELOTON_ASSERT( + right_child_group->HasColumnStats(tv_expr->GetColFullName())); column_stats = std::make_shared( *right_child_group->GetStats(tv_expr->GetColFullName())); } @@ -154,9 +155,7 @@ void StatsCalculator::Visit(const LogicalInnerJoin *op) { // TODO(boweic): calculate stats based on predicates other than join // conditions } -void StatsCalculator::Visit(UNUSED_ATTRIBUTE const LogicalLeftJoin *op) {} -void StatsCalculator::Visit(UNUSED_ATTRIBUTE const LogicalRightJoin *op) {} -void StatsCalculator::Visit(UNUSED_ATTRIBUTE const LogicalOuterJoin *op) {} + void StatsCalculator::Visit(UNUSED_ATTRIBUTE const LogicalSemiJoin *op) {} void StatsCalculator::Visit(const LogicalAggregateAndGroupBy *) { // TODO(boweic): For now we just pass the stats needed without any @@ -235,8 +234,8 @@ void StatsCalculator::AddBaseTableStats( void StatsCalculator::UpdateStatsForFilter( size_t num_rows, - std::unordered_map> - &predicate_stats, + std::unordered_map> & + predicate_stats, const std::vector &predicates) { // First, construct the table stats as the interface needed it to compute // selectivity @@ -285,10 +284,10 @@ double StatsCalculator::CalculateSelectivityForPredicate( : 0; auto left_expr = expr->GetChild(1 - right_index); - PELOTON_ASSERT(left_expr->GetExpressionType() == ExpressionType::VALUE_TUPLE); - auto col_name = - reinterpret_cast(left_expr) - ->GetColFullName(); + PELOTON_ASSERT(left_expr->GetExpressionType() == + ExpressionType::VALUE_TUPLE); + auto col_name = reinterpret_cast( + left_expr)->GetColFullName(); auto expr_type = expr->GetExpressionType(); if (right_index == 0) { @@ -314,14 +313,12 @@ double StatsCalculator::CalculateSelectivityForPredicate( if (expr->GetChild(right_index)->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { value = reinterpret_cast( - expr->GetModifiableChild(right_index)) - ->GetValue(); + expr->GetModifiableChild(right_index))->GetValue(); } else { - value = type::ValueFactory::GetParameterOffsetValue( - reinterpret_cast( - expr->GetModifiableChild(right_index)) - ->GetValueIdx()) - .Copy(); + value = + type::ValueFactory::GetParameterOffsetValue( + reinterpret_cast( + expr->GetModifiableChild(right_index))->GetValueIdx()).Copy(); } ValueCondition condition(col_name, expr_type, value); selectivity = diff --git a/src/optimizer/util.cpp b/src/optimizer/util.cpp index 0d01e35e8ac..1396f17cbb0 100644 --- a/src/optimizer/util.cpp +++ b/src/optimizer/util.cpp @@ -18,6 +18,7 @@ #include "planner/copy_plan.h" #include "planner/seq_scan_plan.h" #include "storage/data_table.h" +#include "type/value_factory.h" namespace peloton { namespace optimizer { @@ -179,8 +180,7 @@ std::unordered_map> ConstructSelectElementMap( std::vector> &select_list) { std::unordered_map> - res; + std::shared_ptr> res; for (auto &expr : select_list) { std::string alias; if (!expr->alias.empty()) { @@ -199,8 +199,8 @@ ConstructSelectElementMap( expression::AbstractExpression *TransformQueryDerivedTablePredicates( const std::unordered_map> - &alias_to_expr_map, + std::shared_ptr> & + alias_to_expr_map, expression::AbstractExpression *expr) { if (expr->GetExpressionType() == ExpressionType::VALUE_TUPLE) { auto new_expr = @@ -250,6 +250,141 @@ void ExtractEquiJoinKeys( } } +bool StrongPredicates( + std::vector predicates, + const std::unordered_set &middle_group_aliases_set) { + for (auto predicate : predicates) { + // create a copy of original predicate. + auto copy_expr = + std::shared_ptr(predicate.expr->Copy()); + LOG_DEBUG("AnnotatedExp: %s", copy_expr->GetInfo().c_str()); + // replace tuple_value_expression from predicate which contains table in + // middle_group_aliases_set with FALSE constant value + + auto eval_expr = + PredicateEvaluate(copy_expr.get(), middle_group_aliases_set); + if (eval_expr != nullptr) { + auto eval_val = eval_expr->Evaluate(nullptr, nullptr, nullptr); + if (eval_val.IsFalse()) { + return true; + } + } + } + return false; +} + +expression::AbstractExpression *PredicateEvaluate( + expression::AbstractExpression *expr, + const std::unordered_set &middle_group_aliases_set) { + // if at the lowest level + if (expr->GetChildrenSize() == 0) { + if (expr->GetExpressionType() == ExpressionType::VALUE_TUPLE) { + auto tv_expr = dynamic_cast(expr); + if (middle_group_aliases_set.find(tv_expr->GetTableName()) != + middle_group_aliases_set.end()) { + return expression::ExpressionUtil::ConstantValueFactory( + type::ValueFactory::GetNullValueByType(type::TypeId::BOOLEAN)); + } + } + if (expr->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { + auto cv_expr = dynamic_cast(expr); + return expression::ExpressionUtil::ConstantValueFactory( + cv_expr->GetValue()); + } + // Returns nullptr if the tuple_value_expression uses table which does not + // belong to middle_group or other expression type we cannot evaluate + return nullptr; + } + + // Conjunction check + if (expr->GetExpressionType() == ExpressionType::CONJUNCTION_AND) { + PELOTON_ASSERT(expr->GetChildrenSize() == 2); + auto l_child = expr->GetModifiableChild(0); + auto r_child = expr->GetModifiableChild(1); + auto l_eval_expr = PredicateEvaluate(l_child, middle_group_aliases_set); + auto r_eval_expr = PredicateEvaluate(r_child, middle_group_aliases_set); + + type::Value l_val; + type::Value r_val; + + if (l_eval_expr != nullptr) { + PELOTON_ASSERT(l_eval_expr->GetExpressionType() == + ExpressionType::VALUE_CONSTANT); + l_val = l_eval_expr->Evaluate(nullptr, nullptr, nullptr); + if (l_val.IsFalse()) { + return expression::ExpressionUtil::ConstantValueFactory( + type::ValueFactory::GetBooleanValue(false)); + } + } + if (r_eval_expr != nullptr) { + PELOTON_ASSERT(r_eval_expr->GetExpressionType() == + ExpressionType::VALUE_CONSTANT); + r_val = r_eval_expr->Evaluate(nullptr, nullptr, nullptr); + if (r_val.IsFalse()) { + return expression::ExpressionUtil::ConstantValueFactory( + type::ValueFactory::GetBooleanValue(false)); + } + } + if (l_eval_expr != nullptr && r_eval_expr != nullptr && l_val.IsTrue() && + r_val.IsTrue()) { + return expression::ExpressionUtil::ConstantValueFactory( + type::ValueFactory::GetBooleanValue(true)); + } + return nullptr; + } else if (expr->GetExpressionType() == ExpressionType::CONJUNCTION_OR) { + PELOTON_ASSERT(expr->GetChildrenSize() == 2); + auto l_child = expr->GetModifiableChild(0); + auto r_child = expr->GetModifiableChild(1); + auto l_eval_expr = PredicateEvaluate(l_child, middle_group_aliases_set); + auto r_eval_expr = PredicateEvaluate(r_child, middle_group_aliases_set); + + type::Value l_val; + type::Value r_val; + + if (l_eval_expr != nullptr) { + PELOTON_ASSERT(l_eval_expr->GetExpressionType() == + ExpressionType::VALUE_CONSTANT); + l_val = l_eval_expr->Evaluate(nullptr, nullptr, nullptr); + if (l_val.IsTrue()) { + return expression::ExpressionUtil::ConstantValueFactory( + type::ValueFactory::GetBooleanValue(true)); + } + } + if (r_eval_expr != nullptr) { + PELOTON_ASSERT(r_eval_expr->GetExpressionType() == + ExpressionType::VALUE_CONSTANT); + r_val = r_eval_expr->Evaluate(nullptr, nullptr, nullptr); + if (r_val.IsTrue()) { + return expression::ExpressionUtil::ConstantValueFactory( + type::ValueFactory::GetBooleanValue(true)); + } + } + if (l_eval_expr != nullptr && r_eval_expr != nullptr && l_val.IsFalse() && + r_val.IsFalse()) { + return expression::ExpressionUtil::ConstantValueFactory( + type::ValueFactory::GetBooleanValue(false)); + } + return nullptr; + } else { + for (size_t i = 0; i < expr->GetChildrenSize(); i++) { + auto child_expr = expr->GetModifiableChild(i); + auto child_eval_expr = + PredicateEvaluate(child_expr, middle_group_aliases_set); + if (child_eval_expr == nullptr) { + // cannot evaluate + return nullptr; + } else { + PELOTON_ASSERT(child_eval_expr->GetExpressionType() == + ExpressionType::VALUE_CONSTANT); + expr->SetChild(i, child_eval_expr); + } + } + // all child are constant value expression + auto expr_val = expr->Evaluate(nullptr, nullptr, nullptr); + return expression::ExpressionUtil::ConstantValueFactory(expr_val); + } +} + } // namespace util } // namespace optimizer } // namespace peloton diff --git a/test/optimizer/optimizer_rule_test.cpp b/test/optimizer/optimizer_rule_test.cpp index 12d047ad51a..7aa26ecf813 100644 --- a/test/optimizer/optimizer_rule_test.cpp +++ b/test/optimizer/optimizer_rule_test.cpp @@ -36,6 +36,7 @@ #include "planner/update_plan.h" #include "sql/testing_sql_util.h" #include "type/value_factory.h" +#include "common/internal_types.h" namespace peloton { namespace test { @@ -49,15 +50,102 @@ using namespace optimizer; class OptimizerRuleTests : public PelotonTest {}; TEST_F(OptimizerRuleTests, SimpleCommutativeRuleTest) { + // Start Join Structure: left JOIN right + // End Join Structure: right JOIN left + + // Build op plan node to match rule + auto left_get = std::make_shared(LogicalGet::make()); + auto right_get = std::make_shared(LogicalGet::make()); + auto join = + std::make_shared(LogicalJoin::make(JoinType::INNER)); + join->PushChild(left_get); + join->PushChild(right_get); + + // Setup rule + JoinCommutativity rule; + + EXPECT_TRUE(rule.Check(join, nullptr)); + + std::vector> outputs; + rule.Transform(join, outputs, nullptr); + EXPECT_EQ(outputs.size(), 1); + + auto output_join = outputs[0]; + + EXPECT_EQ(output_join->Children()[0], right_get); + EXPECT_EQ(output_join->Children()[1], left_get); +} + +TEST_F(OptimizerRuleTests, LeftJoinCommutativeRuleTest) { + // Start Join Structure: left LEFT JOIN right + // End Join Structure: right RIGHT JOIN left + + // Build op plan node to match rule + auto left_get = std::make_shared(LogicalGet::make()); + auto right_get = std::make_shared(LogicalGet::make()); + auto join = + std::make_shared(LogicalJoin::make(JoinType::LEFT)); + join->PushChild(left_get); + join->PushChild(right_get); + + // Setup rule + JoinCommutativity rule; + + EXPECT_TRUE(rule.Check(join, nullptr)); + + std::vector> outputs; + rule.Transform(join, outputs, nullptr); + EXPECT_EQ(outputs.size(), 1); + + auto output_join = outputs[0]; + + EXPECT_EQ(output_join->Children()[0], right_get); + EXPECT_EQ(output_join->Children()[1], left_get); + EXPECT_EQ(output_join->Op().As()->type, JoinType::RIGHT); +} + +TEST_F(OptimizerRuleTests, RightJoinCommutativeRuleTest) { + // Start Join Structure: left RIGHT JOIN right + // End Join Structure: right LEFT JOIN left + + // Build op plan node to match rule + auto left_get = std::make_shared(LogicalGet::make()); + auto right_get = std::make_shared(LogicalGet::make()); + auto join = + std::make_shared(LogicalJoin::make(JoinType::RIGHT)); + join->PushChild(left_get); + join->PushChild(right_get); + + // Setup rule + JoinCommutativity rule; + + EXPECT_TRUE(rule.Check(join, nullptr)); + + std::vector> outputs; + rule.Transform(join, outputs, nullptr); + EXPECT_EQ(outputs.size(), 1); + + auto output_join = outputs[0]; + + EXPECT_EQ(output_join->Children()[0], right_get); + EXPECT_EQ(output_join->Children()[1], left_get); + EXPECT_EQ(output_join->Op().As()->type, JoinType::LEFT); +} + +TEST_F(OptimizerRuleTests, OuterJoinCommutativeRuleTest) { + // Start Join Structure: left OUTER JOIN right + // End Join Structure: right OUTER JOIN left + // Build op plan node to match rule auto left_get = std::make_shared(LogicalGet::make()); auto right_get = std::make_shared(LogicalGet::make()); - auto join = std::make_shared(LogicalInnerJoin::make()); + auto join = + std::make_shared(LogicalJoin::make(JoinType::OUTER)); join->PushChild(left_get); join->PushChild(right_get); // Setup rule - InnerJoinCommutativity rule; + JoinCommutativity rule; EXPECT_TRUE(rule.Check(join, nullptr)); @@ -69,6 +157,7 @@ TEST_F(OptimizerRuleTests, SimpleCommutativeRuleTest) { EXPECT_EQ(output_join->Children()[0], right_get); EXPECT_EQ(output_join->Children()[1], left_get); + EXPECT_EQ(output_join->Op().As()->type, JoinType::OUTER); } TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) { @@ -113,7 +202,7 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) { child_join_predicates.push_back(pred); auto child_join = std::make_shared( - LogicalInnerJoin::make(child_join_predicates)); + LogicalJoin::make(JoinType::INNER, child_join_predicates)); child_join->PushChild(left_leaf); child_join->PushChild(middle_leaf); optimizer.GetMetadata().memo.InsertExpression( @@ -126,7 +215,7 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) { parent_join_predicates.push_back(pred); auto parent_join = std::make_shared( - LogicalInnerJoin::make(parent_join_predicates)); + LogicalJoin::make(JoinType::INNER, parent_join_predicates)); parent_join->PushChild(child_join); parent_join->PushChild(right_leaf); @@ -138,15 +227,14 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) { EXPECT_EQ(left_leaf, parent_join->Children()[0]->Children()[0]); EXPECT_EQ(middle_leaf, parent_join->Children()[0]->Children()[1]); EXPECT_EQ(right_leaf, parent_join->Children()[1]); - EXPECT_EQ(1, - parent_join->Op().As()->join_predicates.size()); + EXPECT_EQ(1, parent_join->Op().As()->join_predicates.size()); EXPECT_EQ(1, parent_join->Children()[0] ->Op() - .As() + .As() ->join_predicates.size()); // Setup rule - InnerJoinAssociativity rule; + JoinAssociativity rule; EXPECT_TRUE(rule.Check(parent_join, root_context)); std::vector> outputs; @@ -159,8 +247,8 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) { EXPECT_EQ(middle_leaf, output_join->Children()[1]->Children()[0]); EXPECT_EQ(right_leaf, output_join->Children()[1]->Children()[1]); - auto parent_join_op = output_join->Op().As(); - auto child_join_op = output_join->Children()[1]->Op().As(); + auto parent_join_op = output_join->Op().As(); + auto child_join_op = output_join->Children()[1]->Op().As(); EXPECT_EQ(2, parent_join_op->join_predicates.size()); EXPECT_EQ(0, child_join_op->join_predicates.size()); delete root_context; @@ -201,7 +289,7 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest2) { // Make Child Join auto child_join = - std::make_shared(LogicalInnerJoin::make()); + std::make_shared(LogicalJoin::make(JoinType::INNER)); child_join->PushChild(left_leaf); child_join->PushChild(middle_leaf); optimizer.GetMetadata().memo.InsertExpression( @@ -221,7 +309,7 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest2) { parent_join_predicates.push_back(pred2); auto parent_join = std::make_shared( - LogicalInnerJoin::make(parent_join_predicates)); + LogicalJoin::make(JoinType::INNER, parent_join_predicates)); parent_join->PushChild(child_join); parent_join->PushChild(right_leaf); @@ -233,15 +321,14 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest2) { EXPECT_EQ(left_leaf, parent_join->Children()[0]->Children()[0]); EXPECT_EQ(middle_leaf, parent_join->Children()[0]->Children()[1]); EXPECT_EQ(right_leaf, parent_join->Children()[1]); - EXPECT_EQ(2, - parent_join->Op().As()->join_predicates.size()); + EXPECT_EQ(2, parent_join->Op().As()->join_predicates.size()); EXPECT_EQ(0, parent_join->Children()[0] ->Op() - .As() + .As() ->join_predicates.size()); // Setup rule - InnerJoinAssociativity rule; + JoinAssociativity rule; EXPECT_TRUE(rule.Check(parent_join, root_context)); std::vector> outputs; @@ -254,8 +341,8 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest2) { EXPECT_EQ(middle_leaf, output_join->Children()[1]->Children()[0]); EXPECT_EQ(right_leaf, output_join->Children()[1]->Children()[1]); - auto parent_join_op = output_join->Op().As(); - auto child_join_op = output_join->Children()[1]->Op().As(); + auto parent_join_op = output_join->Op().As(); + auto child_join_op = output_join->Children()[1]->Op().As(); EXPECT_EQ(1, parent_join_op->join_predicates.size()); EXPECT_EQ(1, child_join_op->join_predicates.size()); delete root_context; diff --git a/test/optimizer/optimizer_test.cpp b/test/optimizer/optimizer_test.cpp index 8b5ed1e0ec7..91ef4e580e1 100644 --- a/test/optimizer/optimizer_test.cpp +++ b/test/optimizer/optimizer_test.cpp @@ -366,8 +366,8 @@ TEST_F(OptimizerTests, PushFilterThroughJoinTest) { // Check join in the root auto group_expr = GetSingleGroupExpression(memo, head_gexpr.get(), 0); - EXPECT_EQ(OpType::InnerJoin, group_expr->Op().GetType()); - auto join_op = group_expr->Op().As(); + EXPECT_EQ(OpType::LogicalJoin, group_expr->Op().GetType()); + auto join_op = group_expr->Op().As(); EXPECT_EQ(1, join_op->join_predicates.size()); EXPECT_TRUE(join_op->join_predicates[0].expr->ExactlyEquals(*predicates[0])); @@ -453,8 +453,9 @@ TEST_F(OptimizerTests, PredicatePushDownRewriteTest) { // Check join in the root auto group_expr = GetSingleGroupExpression(memo, head_gexpr.get(), 0); - EXPECT_EQ(OpType::InnerJoin, group_expr->Op().GetType()); - auto join_op = group_expr->Op().As(); + EXPECT_EQ(OpType::LogicalJoin, group_expr->Op().GetType()); + auto join_op = group_expr->Op().As(); + EXPECT_EQ(JoinType::INNER, join_op->type); EXPECT_EQ(1, join_op->join_predicates.size()); EXPECT_TRUE(join_op->join_predicates[0].expr->ExactlyEquals(*predicates[0]));