Skip to content

Commit

Permalink
HOAS in the type checker
Browse files Browse the repository at this point in the history
  • Loading branch information
krangelov committed Mar 10, 2024
1 parent 518e571 commit 280e11c
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 20 deletions.
69 changes: 54 additions & 15 deletions src/runtime/c/pgf/typechecker.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ PgfTypechecker::~PgfTypechecker() {
}
}

PgfTypechecker::Context::Context(PgfTypechecker *tc, Type *exp_type, PgfBindType bind_type) {
PgfTypechecker::Context::Context(PgfTypechecker *tc, Scope *scope, Type *exp_type, PgfBindType bind_type) {
this->tc = tc;
this->scope = scope;
this->bind_type = bind_type;
this->exp_type = exp_type;
this->inf_type = NULL;
Expand All @@ -100,6 +101,22 @@ PgfExpr PgfTypechecker::Context::eabs(PgfBindType btype, PgfText *name, PgfExpr
{
if (!checkImplArgument())
return 0;

if (exp_type == NULL) {
return tc->type_error("Cannot infer the type of a lambda abstraction");
}

Pi *pi = exp_type->is_pi();
if (!pi) {
return tc->type_error("A lambda abstraction must have a function type");
}

Scope new_scope = {.tail=scope, .var=name, .ty=pi->arg};
Context body_ctxt(tc,&new_scope,pi->res);
body = tc->m->match_expr(&body_ctxt, body);
if (body == 0)
return 0;

return tc->u->eabs(btype,name,body);
}

Expand All @@ -108,7 +125,7 @@ PgfExpr PgfTypechecker::Context::eapp(PgfExpr fun, PgfExpr arg)
if (!checkImplArgument())
return 0;

Context fun_ctxt(tc);
Context fun_ctxt(tc, scope);
fun = tc->m->match_expr(&fun_ctxt, fun);
if (fun == 0)
return 0;
Expand All @@ -120,7 +137,7 @@ PgfExpr PgfTypechecker::Context::eapp(PgfExpr fun, PgfExpr arg)
return 0;
}

Context arg_ctxt(tc,pi->arg,pi->bind_type);
Context arg_ctxt(tc,scope,pi->arg,pi->bind_type);
PgfExpr new_arg = tc->m->match_expr(&arg_ctxt, arg);
if (new_arg == 0) {
if (tc->err->type == PGF_EXN_TYPE_ERROR && tc->err->code == 1) {
Expand All @@ -147,8 +164,7 @@ PgfExpr PgfTypechecker::Context::eapp(PgfExpr fun, PgfExpr arg)
inf_type = pi->res;

if (!unifyTypes(&e)) {
free_ref(fun);
free_ref(arg);
free_ref(e);
return 0;
}

Expand All @@ -166,8 +182,10 @@ PgfExpr PgfTypechecker::Context::elit(PgfLiteral lit)

PgfExpr e = tc->u->elit(lit); free_ref(lit);

if (!unifyTypes(&e))
if (!unifyTypes(&e)) {
tc->u->free_ref(e);
return 0;
}

return e;
}
Expand All @@ -184,8 +202,10 @@ PgfExpr PgfTypechecker::Context::emeta(PgfMetaId meta)

inf_type = exp_type;

if (!unifyTypes(&e))
if (!unifyTypes(&e)) {
tc->u->free_ref(e);
return 0;
}

return e;
}
Expand Down Expand Up @@ -219,23 +239,42 @@ PgfExpr PgfTypechecker::Context::evar(int index)
if (!checkImplArgument())
return 0;

return tc->u->evar(index);
Scope *s = scope;
while (s != NULL && index > 0) {
s = s->tail;
index--;
}

if (s == NULL) {
return tc->type_error("Cannot type check an open expression (de Bruijn index %d)", index);
}

inf_type = s->ty;

PgfExpr e = tc->u->evar(index);

if (!unifyTypes(&e)) {
tc->u->free_ref(e);
return 0;
}

return e;
}

PgfExpr PgfTypechecker::Context::etyped(PgfExpr expr, PgfType type)
{
if (!checkImplArgument())
return 0;

Context type_ctxt(tc);
Context type_ctxt(tc, scope);
type = tc->m->match_type(&type_ctxt, type);
if (type == 0)
return 0;

Unmarshaller2 tu(tc->m);
Type *ty = (Type*) tc->m->match_type(&tu,type);

Context expr_ctxt(tc, ty, PGF_BIND_TYPE_EXPLICIT);
Context expr_ctxt(tc, scope, ty, PGF_BIND_TYPE_EXPLICIT);
expr = tc->m->match_expr(&expr_ctxt, expr);
if (expr == 0) {
free_ref(type);
Expand All @@ -258,7 +297,7 @@ PgfExpr PgfTypechecker::Context::eimplarg(PgfExpr expr)
return 0;
}

Context expr_ctxt(tc,exp_type,PGF_BIND_TYPE_EXPLICIT);
Context expr_ctxt(tc,scope,exp_type,PGF_BIND_TYPE_EXPLICIT);
expr = tc->m->match_expr(&expr_ctxt, expr);
if (expr == 0) {
return 0;
Expand Down Expand Up @@ -327,7 +366,7 @@ PgfType PgfTypechecker::Context::dtyp(size_t n_hypos, PgfTypeHypo *hypos,
PgfHypo *hypo = vector_elem(abscat->context,i);
Type *ty = (Type *) tc->db_m.match_type(&tu,hypo->type.as_object());
tc->temps.push_back(ty);
Context expr_ctxt(tc,ty,hypo->bind_type);
Context expr_ctxt(tc,scope,ty,hypo->bind_type);
new_exprs[i] = tc->m->match_expr(&expr_ctxt, exprs[j]);
if (new_exprs[i] == 0) {
if (tc->err->type == PGF_EXN_TYPE_ERROR && tc->err->code == 1) {
Expand Down Expand Up @@ -441,7 +480,7 @@ bool PgfTypechecker::type_error(const char *fmt, ...)

PgfType PgfTypechecker::infer_expr(PgfExpr *pe)
{
Context ctxt(this);
Context ctxt(this,NULL);
*pe = m->match_expr(&ctxt, *pe);
if (*pe == 0)
return 0;
Expand All @@ -452,13 +491,13 @@ PgfExpr PgfTypechecker::check_expr(PgfExpr expr, PgfType type)
{
Unmarshaller2 tu(m);
Type *ty = (Type*) m->match_type(&tu, type);
Context ctxt(this,ty,PGF_BIND_TYPE_EXPLICIT);
Context ctxt(this,NULL,ty,PGF_BIND_TYPE_EXPLICIT);
expr = m->match_expr(&ctxt, expr);
return expr;
}

PgfType PgfTypechecker::check_type(PgfType type)
{
Context ctxt(this);
Context ctxt(this,NULL);
return m->match_type(&ctxt, type);
}
9 changes: 8 additions & 1 deletion src/runtime/c/pgf/typechecker.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,16 @@ class PGF_INTERNAL_DECL PgfTypechecker {

bool unifyTypes(Type *ty1, Type *ty2);

struct Scope {
Scope *tail;
PgfText *var;
Type *ty;
};

struct Context : public PgfUnmarshaller {
PgfTypechecker *tc;

Scope *scope;
PgfBindType bind_type;
Type *exp_type;
Type *inf_type;
Expand All @@ -74,7 +81,7 @@ class PGF_INTERNAL_DECL PgfTypechecker {
bool unifyTypes(PgfExpr *e);

public:
Context(PgfTypechecker *tc, Type *exp_type = NULL, PgfBindType bind_type = PGF_BIND_TYPE_EXPLICIT);
Context(PgfTypechecker *tc, Scope *scope, Type *exp_type = NULL, PgfBindType bind_type = PGF_BIND_TYPE_EXPLICIT);

virtual PgfExpr eabs(PgfBindType btype, PgfText *name, PgfExpr body);
virtual PgfExpr eapp(PgfExpr fun, PgfExpr arg);
Expand Down
11 changes: 7 additions & 4 deletions src/runtime/haskell/tests/typechecking.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ main = do
,TestCase (assertInference "infer n-args 1" gr (Left "Too many arguments") "z z")
,TestCase (assertInference "infer n-args 2" gr (Left "Too many arguments") "s z z")
,TestCase (assertInference "infer implarg 1" gr (Left "Unexpected implicit argument") "s {z}")
,TestCase (assertInference "infer implarg 2" gr (Right "(y : N) -> S") "imp {z}") --
,TestCase (assertInference "infer implarg 2" gr (Right "(y : N) -> S") "imp {z}")
,TestCase (assertInference "infer implarg 3" gr (Right "S") "imp {z} z")
,TestCase (assertInference "infer implarg 4" gr (Right "({x},y : N) -> S") "imp")
,TestCase (assertInference' "infer implarg 4" gr (Right ("imp {?} z","S")) "imp z")
Expand All @@ -24,9 +24,12 @@ main = do
,TestCase (assertInference "infer literal 3" gr (Right "String") "\"abc\"")
,TestCase (assertInference "infer meta 1" gr (Left "Cannot infer the type of a meta variable") "?")
,TestCase (assertInference "infer meta 2" gr (Right "N->N") "<? : N->N>")
,TestCase (assertChecking "check fun" gr (Right "s") "s" "N->N")
,TestCase (assertChecking "check fun" gr (Right "s z") "s z" "N")
,TestCase (assertChecking "check fun" gr (Left "Types doesn't match") "s z" "N->N")
,TestCase (assertInference "infer lambda" gr (Left "Cannot infer the type of a lambda abstraction") "\\x->x")
,TestCase (assertChecking "check fun 1" gr (Right "s") "s" "N->N")
,TestCase (assertChecking "check fun 2" gr (Right "s z") "s z" "N")
,TestCase (assertChecking "check fun 3" gr (Left "Types doesn't match") "s z" "N->N")
,TestCase (assertChecking "check lambda 1" gr (Right "\\x->x") "\\x->x" "N->N")
,TestCase (assertChecking "check lambda 2" gr (Right "\\x->s x") "\\x->s x" "N->N")
,TestCase (assertType "check type 1" gr (Right "N -> N") "N -> N")
,TestCase (assertType "check type 2" gr (Left "Category s is not defined") "s")
,TestCase (assertType "check type 3" gr (Left "Too many arguments to category N - 0 expected but 1 given") "N z")
Expand Down

0 comments on commit 280e11c

Please sign in to comment.