diff --git a/src/runtime/c/pgf/typechecker.cxx b/src/runtime/c/pgf/typechecker.cxx index 6d29af00b..9ae1aa566 100644 --- a/src/runtime/c/pgf/typechecker.cxx +++ b/src/runtime/c/pgf/typechecker.cxx @@ -89,8 +89,9 @@ PgfTypechecker::~PgfTypechecker() { } } -PgfTypechecker::Context::Context(PgfTypechecker *tc, Type *exp_type, PgfBindType bind_type) { +PgfTypechecker::Context::Context(PgfTypechecker *tc, Scope *scope, Type *exp_type, PgfBindType bind_type) { this->tc = tc; + this->scope = scope; this->bind_type = bind_type; this->exp_type = exp_type; this->inf_type = NULL; @@ -100,6 +101,22 @@ PgfExpr PgfTypechecker::Context::eabs(PgfBindType btype, PgfText *name, PgfExpr { if (!checkImplArgument()) return 0; + + if (exp_type == NULL) { + return tc->type_error("Cannot infer the type of a lambda abstraction"); + } + + Pi *pi = exp_type->is_pi(); + if (!pi) { + return tc->type_error("A lambda abstraction must have a function type"); + } + + Scope new_scope = {.tail=scope, .var=name, .ty=pi->arg}; + Context body_ctxt(tc,&new_scope,pi->res); + body = tc->m->match_expr(&body_ctxt, body); + if (body == 0) + return 0; + return tc->u->eabs(btype,name,body); } @@ -108,7 +125,7 @@ PgfExpr PgfTypechecker::Context::eapp(PgfExpr fun, PgfExpr arg) if (!checkImplArgument()) return 0; - Context fun_ctxt(tc); + Context fun_ctxt(tc, scope); fun = tc->m->match_expr(&fun_ctxt, fun); if (fun == 0) return 0; @@ -120,7 +137,7 @@ PgfExpr PgfTypechecker::Context::eapp(PgfExpr fun, PgfExpr arg) return 0; } - Context arg_ctxt(tc,pi->arg,pi->bind_type); + Context arg_ctxt(tc,scope,pi->arg,pi->bind_type); PgfExpr new_arg = tc->m->match_expr(&arg_ctxt, arg); if (new_arg == 0) { if (tc->err->type == PGF_EXN_TYPE_ERROR && tc->err->code == 1) { @@ -147,8 +164,7 @@ PgfExpr PgfTypechecker::Context::eapp(PgfExpr fun, PgfExpr arg) inf_type = pi->res; if (!unifyTypes(&e)) { - free_ref(fun); - free_ref(arg); + free_ref(e); return 0; } @@ -166,8 +182,10 @@ PgfExpr PgfTypechecker::Context::elit(PgfLiteral lit) PgfExpr e = tc->u->elit(lit); free_ref(lit); - if (!unifyTypes(&e)) + if (!unifyTypes(&e)) { + tc->u->free_ref(e); return 0; + } return e; } @@ -184,8 +202,10 @@ PgfExpr PgfTypechecker::Context::emeta(PgfMetaId meta) inf_type = exp_type; - if (!unifyTypes(&e)) + if (!unifyTypes(&e)) { + tc->u->free_ref(e); return 0; + } return e; } @@ -219,7 +239,26 @@ PgfExpr PgfTypechecker::Context::evar(int index) if (!checkImplArgument()) return 0; - return tc->u->evar(index); + Scope *s = scope; + while (s != NULL && index > 0) { + s = s->tail; + index--; + } + + if (s == NULL) { + return tc->type_error("Cannot type check an open expression (de Bruijn index %d)", index); + } + + inf_type = s->ty; + + PgfExpr e = tc->u->evar(index); + + if (!unifyTypes(&e)) { + tc->u->free_ref(e); + return 0; + } + + return e; } PgfExpr PgfTypechecker::Context::etyped(PgfExpr expr, PgfType type) @@ -227,7 +266,7 @@ PgfExpr PgfTypechecker::Context::etyped(PgfExpr expr, PgfType type) if (!checkImplArgument()) return 0; - Context type_ctxt(tc); + Context type_ctxt(tc, scope); type = tc->m->match_type(&type_ctxt, type); if (type == 0) return 0; @@ -235,7 +274,7 @@ PgfExpr PgfTypechecker::Context::etyped(PgfExpr expr, PgfType type) Unmarshaller2 tu(tc->m); Type *ty = (Type*) tc->m->match_type(&tu,type); - Context expr_ctxt(tc, ty, PGF_BIND_TYPE_EXPLICIT); + Context expr_ctxt(tc, scope, ty, PGF_BIND_TYPE_EXPLICIT); expr = tc->m->match_expr(&expr_ctxt, expr); if (expr == 0) { free_ref(type); @@ -258,7 +297,7 @@ PgfExpr PgfTypechecker::Context::eimplarg(PgfExpr expr) return 0; } - Context expr_ctxt(tc,exp_type,PGF_BIND_TYPE_EXPLICIT); + Context expr_ctxt(tc,scope,exp_type,PGF_BIND_TYPE_EXPLICIT); expr = tc->m->match_expr(&expr_ctxt, expr); if (expr == 0) { return 0; @@ -327,7 +366,7 @@ PgfType PgfTypechecker::Context::dtyp(size_t n_hypos, PgfTypeHypo *hypos, PgfHypo *hypo = vector_elem(abscat->context,i); Type *ty = (Type *) tc->db_m.match_type(&tu,hypo->type.as_object()); tc->temps.push_back(ty); - Context expr_ctxt(tc,ty,hypo->bind_type); + Context expr_ctxt(tc,scope,ty,hypo->bind_type); new_exprs[i] = tc->m->match_expr(&expr_ctxt, exprs[j]); if (new_exprs[i] == 0) { if (tc->err->type == PGF_EXN_TYPE_ERROR && tc->err->code == 1) { @@ -441,7 +480,7 @@ bool PgfTypechecker::type_error(const char *fmt, ...) PgfType PgfTypechecker::infer_expr(PgfExpr *pe) { - Context ctxt(this); + Context ctxt(this,NULL); *pe = m->match_expr(&ctxt, *pe); if (*pe == 0) return 0; @@ -452,13 +491,13 @@ PgfExpr PgfTypechecker::check_expr(PgfExpr expr, PgfType type) { Unmarshaller2 tu(m); Type *ty = (Type*) m->match_type(&tu, type); - Context ctxt(this,ty,PGF_BIND_TYPE_EXPLICIT); + Context ctxt(this,NULL,ty,PGF_BIND_TYPE_EXPLICIT); expr = m->match_expr(&ctxt, expr); return expr; } PgfType PgfTypechecker::check_type(PgfType type) { - Context ctxt(this); + Context ctxt(this,NULL); return m->match_type(&ctxt, type); } diff --git a/src/runtime/c/pgf/typechecker.h b/src/runtime/c/pgf/typechecker.h index ff1918fd5..cf33ebca2 100644 --- a/src/runtime/c/pgf/typechecker.h +++ b/src/runtime/c/pgf/typechecker.h @@ -63,9 +63,16 @@ class PGF_INTERNAL_DECL PgfTypechecker { bool unifyTypes(Type *ty1, Type *ty2); + struct Scope { + Scope *tail; + PgfText *var; + Type *ty; + }; + struct Context : public PgfUnmarshaller { PgfTypechecker *tc; + Scope *scope; PgfBindType bind_type; Type *exp_type; Type *inf_type; @@ -74,7 +81,7 @@ class PGF_INTERNAL_DECL PgfTypechecker { bool unifyTypes(PgfExpr *e); public: - Context(PgfTypechecker *tc, Type *exp_type = NULL, PgfBindType bind_type = PGF_BIND_TYPE_EXPLICIT); + Context(PgfTypechecker *tc, Scope *scope, Type *exp_type = NULL, PgfBindType bind_type = PGF_BIND_TYPE_EXPLICIT); virtual PgfExpr eabs(PgfBindType btype, PgfText *name, PgfExpr body); virtual PgfExpr eapp(PgfExpr fun, PgfExpr arg); diff --git a/src/runtime/haskell/tests/typechecking.hs b/src/runtime/haskell/tests/typechecking.hs index 70be26bb5..6db7b2314 100644 --- a/src/runtime/haskell/tests/typechecking.hs +++ b/src/runtime/haskell/tests/typechecking.hs @@ -11,7 +11,7 @@ main = do ,TestCase (assertInference "infer n-args 1" gr (Left "Too many arguments") "z z") ,TestCase (assertInference "infer n-args 2" gr (Left "Too many arguments") "s z z") ,TestCase (assertInference "infer implarg 1" gr (Left "Unexpected implicit argument") "s {z}") - ,TestCase (assertInference "infer implarg 2" gr (Right "(y : N) -> S") "imp {z}") -- + ,TestCase (assertInference "infer implarg 2" gr (Right "(y : N) -> S") "imp {z}") ,TestCase (assertInference "infer implarg 3" gr (Right "S") "imp {z} z") ,TestCase (assertInference "infer implarg 4" gr (Right "({x},y : N) -> S") "imp") ,TestCase (assertInference' "infer implarg 4" gr (Right ("imp {?} z","S")) "imp z") @@ -24,9 +24,12 @@ main = do ,TestCase (assertInference "infer literal 3" gr (Right "String") "\"abc\"") ,TestCase (assertInference "infer meta 1" gr (Left "Cannot infer the type of a meta variable") "?") ,TestCase (assertInference "infer meta 2" gr (Right "N->N") "N>") - ,TestCase (assertChecking "check fun" gr (Right "s") "s" "N->N") - ,TestCase (assertChecking "check fun" gr (Right "s z") "s z" "N") - ,TestCase (assertChecking "check fun" gr (Left "Types doesn't match") "s z" "N->N") + ,TestCase (assertInference "infer lambda" gr (Left "Cannot infer the type of a lambda abstraction") "\\x->x") + ,TestCase (assertChecking "check fun 1" gr (Right "s") "s" "N->N") + ,TestCase (assertChecking "check fun 2" gr (Right "s z") "s z" "N") + ,TestCase (assertChecking "check fun 3" gr (Left "Types doesn't match") "s z" "N->N") + ,TestCase (assertChecking "check lambda 1" gr (Right "\\x->x") "\\x->x" "N->N") + ,TestCase (assertChecking "check lambda 2" gr (Right "\\x->s x") "\\x->s x" "N->N") ,TestCase (assertType "check type 1" gr (Right "N -> N") "N -> N") ,TestCase (assertType "check type 2" gr (Left "Category s is not defined") "s") ,TestCase (assertType "check type 3" gr (Left "Too many arguments to category N - 0 expected but 1 given") "N z")