Skip to content

Commit c200a39

Browse files
authored
Merge pull request #363 from erizocosmico/fix/squash-filter
internal/rule: fix indexes of filters in squash
2 parents 9825ae2 + 796cec2 commit c200a39

File tree

3 files changed

+57
-18
lines changed

3 files changed

+57
-18
lines changed

integration_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"gopkg.in/src-d/go-mysql-server.v0/sql"
1818
"gopkg.in/src-d/go-mysql-server.v0/sql/analyzer"
1919
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
20+
sqlfunction "gopkg.in/src-d/go-mysql-server.v0/sql/expression/function"
2021
"gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa"
2122
)
2223

@@ -160,6 +161,26 @@ func TestIntegration(t *testing.T) {
160161
{"6ecf0ef2c2dffb796033e5a02219af86ec6584e5", int32(9)},
161162
},
162163
},
164+
{
165+
`SELECT MONTH(committer_when) as month,
166+
r.repository_id as repo_id,
167+
committer_email
168+
FROM ref_commits r
169+
INNER JOIN commits c
170+
ON YEAR(c.committer_when) = 2015
171+
AND r.commit_hash = c.commit_hash
172+
WHERE r.ref_name = 'HEAD'`,
173+
[]sql.Row{
174+
{int32(4), "worktree", "[email protected]"},
175+
{int32(3), "worktree", "[email protected]"},
176+
{int32(3), "worktree", "[email protected]"},
177+
{int32(3), "worktree", "[email protected]"},
178+
{int32(3), "worktree", "[email protected]"},
179+
{int32(3), "worktree", "[email protected]"},
180+
{int32(3), "worktree", "[email protected]"},
181+
{int32(3), "worktree", "[email protected]"},
182+
},
183+
},
163184
}
164185

165186
runTests := func(t *testing.T) {
@@ -776,6 +797,7 @@ func newSquashEngine() *sqle.Engine {
776797
Build()
777798
e := sqle.New(catalog, analyzer)
778799
e.AddDatabase(gitbase.NewDatabase("foo"))
800+
e.Catalog.RegisterFunctions(sqlfunction.Defaults)
779801
e.Catalog.RegisterFunctions(function.Functions)
780802
return e
781803
}

internal/rule/squashjoins.go

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,27 +61,34 @@ func SquashJoins(
6161
}
6262

6363
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
64-
if projectSquashes <= 0 {
65-
return n, nil
66-
}
64+
switch n := n.(type) {
65+
case *plan.Project:
66+
if projectSquashes <= 0 {
67+
return n, nil
68+
}
6769

68-
project, ok := n.(*plan.Project)
69-
if !ok {
70-
return n, nil
71-
}
70+
child, ok := n.Child.(*plan.Project)
71+
if !ok {
72+
return n, nil
73+
}
7274

73-
child, ok := project.Child.(*plan.Project)
74-
if !ok {
75-
return n, nil
76-
}
75+
squashedProject, ok := squashProjects(n, child)
76+
if !ok {
77+
return n, nil
78+
}
7779

78-
squashedProject, ok := squashProjects(project, child)
79-
if !ok {
80+
projectSquashes--
81+
return squashedProject, nil
82+
case *plan.Filter:
83+
expr, err := fixFieldIndexes(n.Expression, n.Schema())
84+
if err != nil {
85+
return nil, err
86+
}
87+
88+
return plan.NewFilter(expr, n.Child), nil
89+
default:
8090
return n, nil
8191
}
82-
83-
projectSquashes--
84-
return squashedProject, nil
8592
})
8693
}
8794

@@ -101,7 +108,7 @@ func countProjectSquashes(n sql.Node) int {
101108
}
102109

103110
func squashProjects(parent, child *plan.Project) (sql.Node, bool) {
104-
projections := []sql.Expression{}
111+
var projections []sql.Expression
105112
for _, expr := range parent.Expressions() {
106113
parentField, ok := expr.(*expression.GetField)
107114
if !ok {
@@ -1078,7 +1085,7 @@ func (t *squashedTable) TransformExpressionsUp(sql.TransformExprFunc) (sql.Node,
10781085
return t, nil
10791086
}
10801087
func (t *squashedTable) TransformUp(fn sql.TransformNodeFunc) (sql.Node, error) {
1081-
return t, nil
1088+
return fn(t)
10821089
}
10831090

10841091
type schemaMapperIter struct {

internal/rule/squashjoins_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,16 @@ func TestSquashJoins(t *testing.T) {
9696

9797
result, err := SquashJoins(sql.NewEmptyContext(), analyzer.NewDefault(nil), node)
9898
require.NoError(err)
99+
result, err = result.TransformUp(func(n sql.Node) (sql.Node, error) {
100+
t, ok := n.(*squashedTable)
101+
if ok {
102+
t.schema = nil
103+
return t, nil
104+
}
105+
106+
return n, nil
107+
})
108+
require.NoError(err)
99109
require.Equal(expected, result)
100110
}
101111

0 commit comments

Comments
 (0)