diff --git a/JuliaSyntax/src/porcelain/syntax_graph.jl b/JuliaSyntax/src/porcelain/syntax_graph.jl index fb37e2581e248..d47ce59b77176 100644 --- a/JuliaSyntax/src/porcelain/syntax_graph.jl +++ b/JuliaSyntax/src/porcelain/syntax_graph.jl @@ -740,6 +740,268 @@ function _copy_ast(graph2::SyntaxGraph, graph1::SyntaxGraph, return id2 end +#------------------------------------------------------------------------------- +# AST destructuring utilities + +raw""" +Simple `SyntaxTree` pattern matching + +Returns the first result where its corresponding pattern matches `syntax_tree` +and each extra `cond` is true. Throws an error if no match is found. + +## Patterns + +A pattern is used as both a conditional (does this syntax tree have a certain +structure?) and a `let` (bind trees to these names if so). Each pattern uses a +limited version of the @ast syntax: + +``` + = + | [K"" *] + | [K"" * ... *] + +# note "*" is the meta-operator meaning one or more, and "..." is literal +``` + +where a `[K"k" p1 p2 ps...]` form matches any tree with kind `k` and >=2 +children (bound to `p1` and `p2`), and `ps` is bound to the possibly-empty +SyntaxList of children `3:end`. Identifiers (except `_`) can't be re-used, but +may check for some form of tree equivalence in a future implementation. + +## Extra condition: `when` + +Like an escape hatch to the structure-matching mechanism. `when=cond` requires +`cond` to evaluate to `true` for this branch to be taken. `cond` may also bind +variables or printf-debug the matching process, as it runs only when its pattern +matches and no previous branch was taken. `cond` may not mutate the object +being matched. + +## Scope of variables + +Every `(pattern, when=cond) -> result` introduces a local scope. Identifiers in +the pattern are let-bound when evaluating `cond` and `result`. `cond` can +introduce variables for use in `result`. User code in `cond` and `result` (but +not `pattern`) can refer to outer variables. + +## Example + +``` +julia> st = JuliaSyntax.parsestmt( + JuliaSyntax.SyntaxTree, "function foo(x,y,z); x; end") + +julia> JuliaSyntax.@stm st begin + [K"function" [K"call" fname [K"parameters" kws...]] body] -> + "no positional args, only kwargs: $(kws)" + [K"function" fname] -> + "zero-method function $fname" + [K"function" [K"call" fname args...] body] -> + "normal function $fname" + ([K"=" [K"call" _...] _...], when=(args=if_valid_get_args(st[1]); !isnothing(args))) -> + "deprecated call-equals form with args $args" + (_, when=(show("printf debugging is great"); true)) -> "something else" + _ -> "unreachable due to the case above" +end +"normal function foo" +``` + +See [Racket `match`](https://docs.racket-lang.org/reference/match.html) for the +inspiration for this macro and an example of a much more featureful pattern +language. +""" +macro stm(st, pats) + _stm(__source__, st, pats; debug=false) +end + +"Like `@stm`, but prints a trace during matching." +macro stm_debug(st, pats) + _stm(__source__, st, pats; debug=true) +end + +# TODO: SyntaxList pattern matching could take similar syntax and use most of +# the same machinery + +function _stm(line::LineNumberNode, st, pats; debug=false) + _stm_check_usage(pats) + # We leave most code untouched, so the user probably wants esc(output) + st_gs, result_gs, k_gs, nc_gs = gensym.("st", "result", "k", "nc") + out_blk = Expr(:let, Expr(:block, :($st_gs = $st::SyntaxTree), + :($result_gs = nothing), + :($k_gs = $kind($st_gs)), + :($nc_gs = $numchildren($st_gs))), + Expr(:if, false, nothing)) + case_list_tail = out_blk.args[2].args + for pcr in pats.args + pcr isa LineNumberNode && (line = pcr; continue) + p, cond, result = _stm_destruct_pat(pcr) + pat_ok = p isa Symbol ? true : _stm_matches(p, st_gs, k_gs, nc_gs, debug) + # We need to let-bind patvars in both cond and the result, so result + # needs to live in the first argument of :if with the extra conditions. + case = Expr(:elseif, + Expr(:&&, pat_ok, + Expr(:let, _stm_assigns(p, st_gs), + Expr(:&&, cond, + Expr(:block, line, + :($result_gs = $result), true)))), + result_gs) + push!(case_list_tail, case) + case_list_tail = case_list_tail[3].args + end + push!(case_list_tail, + :(throw(ErrorException(string( + "No match found for `", $st_gs, "` at ", $(string(line))))))) + return esc(out_blk) +end + +# recursively flatten `vcat` expressions +function _stm_vcat_to_hcat(p::Expr) + if Meta.isexpr(p, :vcat) + out = Expr(:hcat) + for a in p.args + Meta.isexpr(a, :row) ? append!(out.args, a.args) : push!(out.args, a) + end + else + out = Expr(p.head, p.args...) + end + for i in eachindex(out.args) + out.args[i] = _stm_vcat_to_hcat(out.args[i]) + end + return out +end +_stm_vcat_to_hcat(x) = x + +# return (pat_expr, when_expr|nothing, res_expr) +function _stm_destruct_pat(pcr::Expr) + pc, r = pcr.args[1:2] + Base.remove_linenums!(pc) # errors in lhs of `->` are caught in usage check + (p_vcat, c) = Meta.isexpr(pc, :tuple) ? + (pc.args[1], pc.args[2].args[2]) : (pc, true) + return (_stm_vcat_to_hcat(p_vcat), c, r) +end + +function _stm_matches_wrapper(p::Expr, st_ex, debug) + st_gs, k_gs, nc_gs = gensym.("st", "k", "nc") + Expr(:let, Expr(:block, :($st_gs = $st_ex), + :($k_gs = $kind($st_gs)), + :($nc_gs = $numchildren($st_gs))), + _stm_matches(p, st_gs, k_gs, nc_gs, debug)) +end + +function _stm_matches(p::Expr, st_gs::Symbol, k_gs::Symbol, nc_gs::Symbol, debug) + pat_k = Kind(p.args[1].args[3]) + out = Expr(:&&, :($pat_k === $k_gs)) + debug && push!(out.args, Expr(:block, :(printstyled( + string("[kind]: ", $k_gs, "\n"); color=:yellow)), true)) + + p_args = p.args[2:end] + dots_i = findfirst(x->Meta.isexpr(x, :(...)), p_args) + dots_start = something(dots_i, length(p_args) + 1) + n_after_dots = length(p_args) - dots_start # -1 if no dots + + push!(out.args, isnothing(dots_i) ? + :($nc_gs === $(length(p_args))) : + :($nc_gs >= $(length(p_args) - 1))) + debug && push!(out.args, Expr(:block, :(printstyled( + string("[numc]: ", $nc_gs, "\n"); color=:yellow)), true)) + + for i in 1:dots_start-1 + p_args[i] isa Symbol && continue + push!(out.args, + _stm_matches_wrapper(p_args[i], :($st_gs[$i]), debug)) + end + for i in n_after_dots-1:-1:0 + p_args[end-i] isa Symbol && continue + push!(out.args, + _stm_matches_wrapper(p_args[end-i], :($st_gs[end-$i]), debug)) + end + debug && push!(out.args, Expr(:block, :(printstyled( + string("matched: ", $st_gs, " with ", $(QuoteNode(p)), "\n"); + color=:green)), true)) + return out +end + +# Assuming _stm_matches, construct an Expr that assigns syms to SyntaxTrees. +# Note st_rhs_expr is a ref-expr with a SyntaxTree/List value (in context). +function _stm_assigns(p, st_rhs_expr; assigns=Expr(:block)) + if p isa Symbol + p != :_ && push!(assigns.args, Expr(:(=), p, st_rhs_expr)) + return assigns + elseif p isa Expr + p_args = p.args[2:end] + dots_i = findfirst(x->Meta.isexpr(x, :(...)), p_args) + dots_start = something(dots_i, length(p_args) + 1) + n_after_dots = length(p_args) - dots_start + for i in 1:dots_start-1 + _stm_assigns(p_args[i], :($st_rhs_expr[$i]); assigns) + end + if !isnothing(dots_i) + _stm_assigns(p_args[dots_i].args[1], + :($st_rhs_expr[$dots_i:end-$n_after_dots]); assigns) + for i in n_after_dots-1:-1:0 + _stm_assigns(p_args[end-i], :($st_rhs_expr[end-$i]); assigns) + end + end + return assigns + end + @assert false "unexpected syntax; enable or fix `_stm_check_usage`" +end + +# Check for correct pattern syntax. Not needed outside of development. +function _stm_check_usage(pats::Expr) + function _stm_check_pattern(p; syms=Set{Symbol}()) + if Meta.isexpr(p, :(...), 1) + p = p.args[1] + @assert(p isa Symbol, "Expected symbol before `...` in $p") + end + if p isa Symbol + # No support for duplicate syms for now (user is either looking for + # some form of equality we don't implement, or they made a mistake) + dup = p in syms && p !== :_ + push!(syms, p) + @assert(!dup, "invalid duplicate non-underscore identifier $p") + return nothing + elseif Meta.isexpr(p, :vect) + @assert(length(p.args) === 1, + "use spaces, not commas, in @stm []-patterns") + elseif Meta.isexpr(p, :hcat) + @assert(length(p.args) >= 2) + elseif Meta.isexpr(p, :vcat) + p = _stm_vcat_to_hcat(p) + @assert(length(p.args) >= 2) + else + @assert(false, "malformed pattern $p") + end + @assert(count(x->Meta.isexpr(x, :(...)), p.args[2:end]) <= 1, + "Multiple `...` in a pattern is ambiguous") + + # This exact `K"kind"` syntax is not necessary since the kind can't be + # provided by a variable, but requiring [K"kinds"] is consistent with + # `@ast` and allows us to implement list matching later. + @assert(Meta.isexpr(p.args[1], :macrocall, 3) && + p.args[1].args[1] === Symbol("@K_str") && + p.args[1].args[3] isa String, "first pattern elt must be K\"\"") + + for subp in p.args[2:end] + _stm_check_pattern(subp; syms) + end + end + + @assert Meta.isexpr(pats, :block) "Usage: @stm st begin; ...; end" + for pcr in filter(e->!isa(e, LineNumberNode), pats.args) + @assert(Meta.isexpr(pcr, :(->), 2), "Expected pat -> res, got malformed case: $pcr") + if Meta.isexpr(pcr.args[1], :tuple) + @assert(length(pcr.args[1].args) === 2, + "Expected `pat` or `(pat, when=cond)`, got $(pcr.args[1])") + p = pcr.args[1].args[1] + c = pcr.args[1].args[2] + @assert(Meta.isexpr(c, :(=), 2) && c.args[1] === :when, + "Expected `(when=cond)` in tuple pattern, got $(c)") + else + p = pcr.args[1] + end + _stm_check_pattern(p) + end +end + #------------------------------------------------------------------------------- # RawGreenNode->SyntaxTree # WIP: expr_structure param will be deleted diff --git a/JuliaSyntax/test/syntax_graph.jl b/JuliaSyntax/test/syntax_graph.jl index 42583bab76bd7..7d7a6776a70ab 100644 --- a/JuliaSyntax/test/syntax_graph.jl +++ b/JuliaSyntax/test/syntax_graph.jl @@ -1,4 +1,4 @@ -using .JuliaSyntax: SyntaxGraph, SyntaxTree, SyntaxList, freeze_attrs, unfreeze_attrs, ensure_attributes, ensure_attributes!, delete_attributes, copy_ast, attrdefs +using .JuliaSyntax: SyntaxGraph, SyntaxTree, SyntaxList, freeze_attrs, unfreeze_attrs, ensure_attributes, ensure_attributes!, delete_attributes, copy_ast, attrdefs, @stm @testset "SyntaxGraph attrs" begin st = parsestmt(SyntaxTree, "function foo end") @@ -100,3 +100,215 @@ end @test_throws ErrorException copy_ast(new_g, st; copy_source=false) end end + +@testset "@stm SyntaxTree pattern-matching" begin + st = parsestmt(SyntaxTree, "foo(a,b=1,c(d=2))") + # (call foo a (kw b 1) (call c (kw d 2))) + + @testset "basic functionality" begin + @test @stm st begin + _ -> true + end + + @test @stm st begin + x -> x isa SyntaxTree + end + + @test @stm st begin + [K"function" f a b c] -> false + [K"call" f a b c] -> true + end + + @test @stm st begin + [K"function" _ _ _ _] -> false + [K"call" _ _ _ _] -> true + end + + @test @stm st begin + [K"call" f a b] -> false + [K"call" f a b c d] -> false + [K"call" f a b c] -> true + end + + @test @stm st begin + [K"call" f a b c] -> + kind(f) === K"Identifier" && + kind(b) === K"kw" && + kind(c) === K"call" + end + end + + @testset "errors" begin + # no match + @test_throws ErrorException @stm st begin + [K"Identifier"] -> false + end + + # assuming we run this checker by default + @testset "_stm_check_usage" begin + bad = Expr[ + :(@stm st begin + [a] -> false + end) + :(@stm st begin + [K"None",a] -> false + end) + :(@stm st begin + [K"None" a a] -> false + end) + :(@stm st begin + x + end) + :(@stm st begin + x() -> false + end) + :(@stm st begin + (a, b=1) -> false + end) + :(@stm st begin + [K"None" a... b...] -> false + end) + ] + for e in bad + Base.remove_linenums!(e) + @testset "$(string(e))" begin + @test_throws AssertionError macroexpand(@__MODULE__, e) + end + end + end + end + + @testset "nested patterns" begin + @test 1 === @stm st begin + [K"call" [K"Identifier"] [K"Identifier"] [K"kw" [K"Identifier"] k1] [K"call" [K"Identifier"] [K"kw" [K"Identifier"] k2]]] -> 1 + [K"call" [K"Identifier"] [K"Identifier"] [K"kw" [K"Identifier"] k1] [K"call" [K"Identifier"] [K"kw" _ k2]]] -> 2 + [K"call" [K"Identifier"] [K"Identifier"] [K"kw" _ k1] [K"call" _ _]] -> 3 + [K"call" [K"Identifier"] [K"Identifier"] _ _ ] -> 4 + [K"call" _ _ _ _] -> 5 + end + @test 1 === @stm st begin + [K"call" _ _ [K"None" [K"Identifier"] k1] [K"None" [K"Identifier"] [K"None" [K"None"] k2]]] -> 5 + [K"call" _ _ [K"kw" [K"Identifier"] k1] [K"None" [K"Identifier"] [K"None" [K"None"] k2]]] -> 4 + [K"call" _ _ [K"kw" [K"Identifier"] k1] [K"call" [K"Identifier"] [K"None" [K"None"] k2]]] -> 3 + [K"call" _ _ [K"kw" [K"Identifier"] k1] [K"call" [K"Identifier"] [K"kw" [K"None"] k2]]] -> 2 + [K"call" _ _ [K"kw" [K"Identifier"] k1] [K"call" [K"Identifier"] [K"kw" [K"Identifier"] k2]]] -> 1 + end + @test 1 === @stm st begin + [K"call" _ _ [K"kw" [K"Identifier"] k1] [K"call" [K"Identifier"] [K"kw" [K"Identifier"] k2] bad]] -> 4 + [K"call" _ _ [K"kw" [K"Identifier"] k1] [K"call" [K"Identifier"] [K"kw" [K"Identifier"] k2 bad]]] -> 3 + [K"call" _ _ [K"kw" [K"Identifier"] k1] [K"call" [K"Identifier"] [K"kw" [K"Identifier" bad] k2]]] -> 2 + [K"call" _ _ [K"kw" [K"Identifier"] k1] [K"call" [K"Identifier"] [K"kw" [K"Identifier"] k2]]] -> 1 + end + end + + @testset "vcat form (newlines in pattern)" begin + @test @stm st begin + [K"call" + f + a + b + c] -> true + end + @test @stm st begin + [K"call" + f a b c] -> true + end + @test @stm st begin + [K"call" + + + f a b c] -> true + end + @test @stm st begin + [K"call" + [K"Identifier"] [K"Identifier"] + [K"kw" [K"Identifier"] k1] + [K"call" + [K"Identifier"] + [K"kw" + [K"Identifier"] + k2]]] -> true + end + end + + @testset "SyntaxList splat matching" begin + # trailing splat + @test @stm st begin + [K"call" f _...] -> true + end + @test @stm st begin + [K"call" f args...] -> kind(f) === K"Identifier" + end + @test @stm st begin + [K"call" f args...] -> args isa SyntaxList && length(args) === 3 + end + @test @stm st begin + [K"call" f args...] -> kind(args[1]) === K"Identifier" && + kind(args[2]) === K"kw" && + kind(args[3]) === K"call" + end + @test @stm st begin + [K"call" f a b c empty...] -> empty isa SyntaxList && length(empty) === 0 + end + + # binds after splat + @test @stm st begin + [K"call" f args... last] -> + args isa SyntaxList && + length(args) === 2 + end + @test @stm st begin + [K"call" f args... last] -> + kind(f) === K"Identifier" && + kind(args[1]) === K"Identifier" && + kind(args[2]) === K"kw" && + kind(last) === K"call" + end + @test @stm st begin + [K"call" empty... f a b c] -> empty isa SyntaxList && length(empty) === 0 + end + end + + @testset "`when` clauses affect matching" begin + @test @stm st begin + (_, when=false) -> false + (_, when=true) -> true + end + @test @stm st begin + ([K"call" _...], when=false) -> false + ([K"call" _...], when=true) -> true + end + @test @stm st begin + ([K"call" _ _...], when=kind(st[1])===K"Identifier") -> true + end + @test @stm st begin + ([K"call" f _...], when=kind(f)===K"Identifier") -> true + end + end + + @testset "effects of when=cond" begin + let x = Int[] + @test @stm st begin + (_, when=(push!(x, 1); true)) -> x == [1] + end + empty!(x) + + @test @stm st begin + (_, when=(push!(x, 1); false)) -> false + (_, when=(push!(x, 2); false)) -> false + (_, when=(push!(x, 3); true)) -> x == [1, 2, 3] + end + empty!(x) + + @test @stm st begin + ([K"block"], when=(push!(x, 123); false)) -> false + (_, when=(push!(x, 1); true)) -> x == [1] + end + empty!(x) + + @test @stm st begin + (x_pat, when=((x_when = x_pat); true)) -> x_pat == x_when + end + end + end +end