Skip to content

Commit

Permalink
update math parsing function (#9053)
Browse files Browse the repository at this point in the history
Co-authored-by: Harshil Goel <[email protected]>
  • Loading branch information
harshil-goel and Harshil Goel authored Mar 19, 2024
1 parent 1073f92 commit 5ef9ae6
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 11 deletions.
39 changes: 34 additions & 5 deletions dql/math.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ func isMathFunc(f string) bool {
f == "==" || f == "!=" ||
f == "min" || f == "max" || f == "sqrt" ||
f == "pow" || f == "logbase" || f == "floor" || f == "ceil" ||
f == "since"
f == "since" || f == "dot"
}

func parseMathFunc(it *lex.ItemIterator, again bool) (*MathTree, bool, error) {
func parseMathFunc(gq *GraphQuery, it *lex.ItemIterator, again bool) (*MathTree, bool, error) {
if !again {
it.Next()
item := it.Item()
Expand Down Expand Up @@ -218,7 +218,7 @@ loop:
again := false
var child *MathTree
for {
child, again, err = parseMathFunc(it, again)
child, again, err = parseMathFunc(gq, it, again)
if err != nil {
return nil, false, err
}
Expand All @@ -240,7 +240,7 @@ loop:
}
var child *MathTree
for {
child, again, err = parseMathFunc(it, again)
child, again, err = parseMathFunc(gq, it, again)
if err != nil {
return nil, false, err
}
Expand Down Expand Up @@ -320,6 +320,14 @@ loop:
// The parentheses are balanced out. Let's break.
break loop
}
case item.Typ == itemDollar:
varName, err := parseVarName(it)
if err != nil {
return nil, false, err
}
child := &MathTree{}
child.Var = varName
valueStack.push(child)
default:
return nil, false, errors.Errorf("Unexpected item while parsing math expression: %v",
item)
Expand Down Expand Up @@ -347,6 +355,27 @@ loop:
return res, false, err
}

func (t *MathTree) subs(vmap varMap) error {
if strings.HasPrefix(t.Var, "$") {
va, ok := vmap[t.Var]
if !ok {
return errors.Errorf("Variable not found in math")
}
var err error
t.Const, err = parseValue(va)
if err != nil {
return err
}
t.Var = ""
}
for _, i := range t.Child {
if err := i.subs(vmap); err != nil {
return err
}
}
return nil
}

// debugString converts mathTree to a string. Good for testing, debugging.
// nolint: unused
func (t *MathTree) debugString() string {
Expand Down Expand Up @@ -383,7 +412,7 @@ func (t *MathTree) stringHelper(buf *bytes.Buffer) {
switch t.Fn {
case "+", "-", "/", "*", "%", "exp", "ln", "cond", "min",
"sqrt", "max", "<", ">", "<=", ">=", "==", "!=", "u-",
"logbase", "pow":
"logbase", "pow", "dot":
x.Check2(buf.WriteString(t.Fn))
default:
x.Fatalf("Unknown operator: %q", t.Fn)
Expand Down
100 changes: 95 additions & 5 deletions dql/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

"github.com/dgraph-io/dgraph/lex"
"github.com/dgraph-io/dgraph/protos/pb"
"github.com/dgraph-io/dgraph/types"
"github.com/dgraph-io/dgraph/x"
)

Expand Down Expand Up @@ -208,11 +209,23 @@ var mathOpPrecedence = map[string]int{
"max": 85,
"min": 84,

// NOTE: Previously, we had "/" at precedence 50 and "*" at precedence 49.
// This is problematic because it would evaluate:
// 5 * 10 / 50 as: 5 * (10/50). This is fine for floating point, but breaks
// for integer arithmetic! The result in integer arithmetic is 0, but the precedence
// should actually be evaluated as (5*10)/50 = 1.
"/": 50,
"*": 49,
"%": 48,
"-": 47,
"+": 46,
"*": 50,
// We add dot as lower priority than "/" and "*" so that the expression:
// c1 * v1 dot c2 * v2 gets evaluated as (c1 * v1) dot (c2 * v2).
// Note that v dot c where v is a vector and c is a float (or int) is not legal,
// so we must evaluate this with this precedence! This also implies that we need
// support for v / c where v is a vector and c is a float, and that this should
// be interpreted the same as v * (1/c).
"dot": 49,
"%": 48,
"-": 47,
"+": 46,

"<": 10,
">": 9,
Expand Down Expand Up @@ -303,6 +316,67 @@ type Request struct {
Variables map[string]string
}

func parseValue(v varInfo) (types.Val, error) {
typ := v.Type
if v.Value == "" {
return types.Val{}, errors.Errorf("No value found")
}

switch typ {
case "int":
{
if i, err := strconv.ParseInt(v.Value, 0, 64); err != nil {
return types.Val{}, errors.Wrapf(err, "Expected an int but got %v", v.Value)
} else {
return types.Val{
Tid: types.IntID,
Value: i,
}, nil
}
}
case "float":
{
if i, err := strconv.ParseFloat(v.Value, 64); err != nil {
return types.Val{}, errors.Wrapf(err, "Expected a float but got %v", v.Value)
} else {
return types.Val{
Tid: types.FloatID,
Value: i,
}, nil
}
}
case "bool":
{
if i, err := strconv.ParseBool(v.Value); err != nil {
return types.Val{}, errors.Wrapf(err, "Expected a bool but got %v", v.Value)
} else {
return types.Val{
Tid: types.BoolID,
Value: i,
}, nil
}
}
case "vfloat":
{
if i, err := types.ParseVFloat(v.Value); err != nil {
return types.Val{}, errors.Wrapf(err, "Expected a vfloat but got %v", v.Value)
} else {
return types.Val{
Tid: types.VFloatID,
Value: i,
}, nil
}
}
case "string": // Value is a valid string. No checks required.
return types.Val{
Tid: types.StringID,
Value: v.Value,
}, nil
default:
return types.Val{}, errors.Errorf("Type %q not supported", typ)
}
}

func checkValueType(vm varMap) error {
for k, v := range vm {
typ := v.Type
Expand Down Expand Up @@ -340,6 +414,12 @@ func checkValueType(vm varMap) error {
return errors.Wrapf(err, "Expected a bool but got %v", v.Value)
}
}
case "vfloat":
{
if _, err := types.ParseVFloat(v.Value); err != nil {
return errors.Wrapf(err, "Expected a vfloat but got %v", v.Value)
}
}
case "string": // Value is a valid string. No checks required.
default:
return errors.Errorf("Type %q not supported", typ)
Expand All @@ -361,6 +441,10 @@ func substituteVar(f string, res *string, vmap varMap) error {
return nil
}

func substituteVarInMath(gq *GraphQuery, vmap varMap) error {
return gq.MathExp.subs(vmap)
}

func substituteVariables(gq *GraphQuery, vmap varMap) error {
for k, v := range gq.Args {
// v won't be empty as its handled in parseDqlVariables.
Expand All @@ -382,6 +466,12 @@ func substituteVariables(gq *GraphQuery, vmap varMap) error {
delete(gq.Args, "id")
}

if gq.MathExp != nil {
if err := substituteVarInMath(gq, vmap); err != nil {
return err
}
}

if gq.Func != nil {
if err := substituteVar(gq.Func.Attr, &gq.Func.Attr, vmap); err != nil {
return err
Expand Down Expand Up @@ -3064,7 +3154,7 @@ func godeep(it *lex.ItemIterator, gq *GraphQuery) error {
if varName == "" && alias == "" {
return it.Errorf("Function math should be used with a variable or have an alias")
}
mathTree, again, err := parseMathFunc(it, false)
mathTree, again, err := parseMathFunc(gq, it, false)
if err != nil {
return err
}
Expand Down
100 changes: 99 additions & 1 deletion dql/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/dgraph-io/dgo/v230/protos/api"
"github.com/dgraph-io/dgraph/chunker"
"github.com/dgraph-io/dgraph/lex"
"github.com/dgraph-io/dgraph/types"
)

func childAttrs(g *GraphQuery) []string {
Expand All @@ -38,6 +39,34 @@ func childAttrs(g *GraphQuery) []string {
return out
}

func TestParseMathSubs(t *testing.T) {
q := `query test($a: int) {
q(func: uid(0x1)) {
x as count(uid)
p : math(x + $a)
}
}`

r := Request{
Str: q,
Variables: map[string]string{"$a": "3"},
}
gq, err := Parse(r)
require.NoError(t, err)
val := gq.Query[0].Children[1].MathExp.Child[1].Const
require.NotNil(t, val)
require.Equal(t, val.Tid, types.IntID)
require.Equal(t, val.Value, int64(3))

r = Request{
Str: q,
Variables: map[string]string{"$a": "3.3"},
}
_, err = Parse(r)
require.Error(t, err)
require.Contains(t, err.Error(), "Expected an int but got 3.3")
}

func TestParseCountValError(t *testing.T) {
query := `
{
Expand Down Expand Up @@ -533,6 +562,29 @@ func TestParseQueryWithVarValAggNested4(t *testing.T) {
res.Query[1].Children[0].Children[3].MathExp.debugString())
}

func TestParseQueryWithVarValAggNested5(t *testing.T) {
query := `
{
me(func: uid(L), orderasc: val(d) ) {
name
}
var(func: uid(0x0a)) {
L as friends {
a as age
b as count(friends)
c as count(relatives)
d as math(a * b / c)
}
}
}
`
res, err := Parse(Request{Str: query})
require.NoError(t, err)
require.EqualValues(t, "(/ (* a b) c)",
res.Query[1].Children[0].Children[3].MathExp.debugString())
}

func TestParseQueryWithVarValAggLogSqrt(t *testing.T) {
query := `
{
Expand All @@ -558,6 +610,52 @@ func TestParseQueryWithVarValAggLogSqrt(t *testing.T) {
res.Query[1].Children[0].Children[2].MathExp.debugString())
}

func TestParseQueryWithVarValDotProduct(t *testing.T) {
query := `
{
me(func: uid(L), orderasc: val(d) ) {
name
}
var(func: uid(0x0a)) {
L as friends {
a as vfloat
b as vfloat
c as count(relatives)
d as math(a dot b * c)
}
}
}
`
res, err := Parse(Request{Str: query})
require.NoError(t, err)
require.EqualValues(t, "(dot a (* b c))",
res.Query[1].Children[0].Children[3].MathExp.debugString())
}

func TestParseQueryWithVarValDotProduct2(t *testing.T) {
query := `
{
me(func: uid(L), orderasc: val(d) ) {
name
}
var(func: uid(0x0a)) {
L as friends {
a as vfloat
b as vfloat
c as count(relatives)
d as math(a dot b / c)
}
}
}
`
res, err := Parse(Request{Str: query})
require.NoError(t, err)
require.EqualValues(t, "(dot a (/ b c))",
res.Query[1].Children[0].Children[3].MathExp.debugString())
}

func TestParseQueryWithVarValAggNestedConditional(t *testing.T) {
query := `
{
Expand Down Expand Up @@ -607,7 +705,7 @@ func TestParseQueryWithVarValAggNested3(t *testing.T) {
`
res, err := Parse(Request{Str: query})
require.NoError(t, err)
require.EqualValues(t, "(+ (+ a (* b (/ c a))) (- (exp (+ (+ a b) 1E+00)) (ln c)))",
require.EqualValues(t, "(+ (+ a (/ (* b c) a)) (- (exp (+ (+ a b) 1E+00)) (ln c)))",
res.Query[1].Children[0].Children[3].MathExp.debugString())
}

Expand Down

0 comments on commit 5ef9ae6

Please sign in to comment.