Skip to content

Commit

Permalink
Add dexFromCAtom to the C API
Browse files Browse the repository at this point in the history
This makes it possible to use `CAtom` as a common language for getting
values both into and out of Dex.
  • Loading branch information
apaszke committed Sep 29, 2021
1 parent 0755093 commit d3ca6fc
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 10 deletions.
26 changes: 23 additions & 3 deletions python/dex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __getattr__(self, name):
result = api.lookup(self, api.as_cstr(name))
if not result:
api.raise_from_dex()
return Atom(result, self)
return Atom._from_ptr(result, self)


class Prelude(Module):
Expand All @@ -52,15 +52,35 @@ def eval(expr: str, module=prelude, _env=None):
result = api.evalExpr(_env, api.as_cstr(expr))
if not result:
api.raise_from_dex()
return Atom(result, module)
return Atom._from_ptr(result, module)


class Atom:
__slots__ = ('__weakref__', '_as_parameter_', 'module')

def __init__(self, ptr, module):
def __init__(self, value):
catom = api.CAtom()
if isinstance(value, int):
catom.tag = 0
catom.value.tag = 1
catom.value.value = ctypes.c_int(value)
elif isinstance(value, float):
catom.tag = 0
catom.value.tag = 4
catom.value.value = ctypes.c_float(value)
else:
raise ValueError("Can't convert given value to a Dex Atom")
self.module = prelude
self._as_parameter_ = api.fromCAtom(ctypes.pointer(catom))
if not self._as_parameter_:
api.raise_from_dex()

@classmethod
def _from_ptr(cls, ptr, module):
self = super().__new__(cls)
self._as_parameter_ = ptr
self.module = module
return self

def __del__(self):
# TODO: Free
Expand Down
20 changes: 16 additions & 4 deletions python/dex/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,22 @@ def tagged_union(name: str, members: List[type]):
payload = type(name + "Payload", (ctypes.Union,), {"_fields_": named_members})
union = type(name, (ctypes.Structure,), {
"_fields_": [("tag", ctypes.c_uint64), ("payload", payload)],
"value": property(lambda self: getattr(self.payload, f"t{self.tag}")),
"value": property(
fget=lambda self: getattr(self.payload, f"t{self.tag}"),
fset=lambda self, value: setattr(self.payload, f"t{self.tag}", value)),
"Payload": payload,
})
return union

CLit = tagged_union("Lit", [ctypes.c_int64, ctypes.c_int32, ctypes.c_int8, ctypes.c_double, ctypes.c_float])
CLit = tagged_union("Lit", [
ctypes.c_int64,
ctypes.c_int32,
ctypes.c_uint8,
ctypes.c_double,
ctypes.c_float,
ctypes.c_uint32,
ctypes.c_uint64
])
class CRectArray(ctypes.Structure):
_fields_ = [("data", ctypes.c_void_p),
("shape_ptr", ctypes.POINTER(ctypes.c_int64)),
Expand Down Expand Up @@ -65,8 +76,9 @@ def dex_func(name, *signature):
evalExpr = dex_func('dexEvalExpr', HsContextPtr, ctypes.c_char_p, HsAtomPtr)
lookup = dex_func('dexLookup', HsContextPtr, ctypes.c_char_p, HsAtomPtr)

print = dex_func('dexPrint', HsAtomPtr, ctypes.c_char_p)
toCAtom = dex_func('dexToCAtom', HsAtomPtr, CAtomPtr, ctypes.c_int)
print = dex_func('dexPrint', HsAtomPtr, ctypes.c_char_p)
toCAtom = dex_func('dexToCAtom', HsAtomPtr, CAtomPtr, ctypes.c_int)
fromCAtom = dex_func('dexFromCAtom', CAtomPtr, HsAtomPtr)

createJIT = dex_func('dexCreateJIT', HsJITPtr)
destroyJIT = dex_func('dexDestroyJIT', HsJITPtr, None)
Expand Down
2 changes: 2 additions & 0 deletions python/tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def addOne (x: Float) : Float = x + 1.0
def test_scalar_conversions(self):
assert float(dex.eval("5.0")) == 5.0
assert int(dex.eval("5")) == 5
assert str(dex.Atom(5)) == "5"
assert str(dex.Atom(5.0)) == "5."


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions src/Dex/Foreign/API.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ foreign export ccall "dexEvalExpr" dexEvalExpr :: Ptr Context -> CString -> IO
foreign export ccall "dexLookup" dexLookup :: Ptr Context -> CString -> IO (Ptr Atom)

-- Serialization
foreign export ccall "dexPrint" dexPrint :: Ptr Atom -> IO CString
foreign export ccall "dexToCAtom" dexToCAtom :: Ptr Atom -> Ptr CAtom -> IO CInt
foreign export ccall "dexPrint" dexPrint :: Ptr Atom -> IO CString
foreign export ccall "dexToCAtom" dexToCAtom :: Ptr Atom -> Ptr CAtom -> IO CInt
foreign export ccall "dexFromCAtom" dexFromCAtom :: Ptr CAtom -> IO (Ptr Atom)

-- JIT
foreign export ccall "dexCreateJIT" dexCreateJIT :: IO (Ptr JIT)
Expand Down
11 changes: 10 additions & 1 deletion src/Dex/Foreign/Serialize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

module Dex.Foreign.Serialize (
CAtom,
dexPrint, dexToCAtom
dexPrint, dexToCAtom, dexFromCAtom
) where

import Data.Word
Expand Down Expand Up @@ -79,3 +79,12 @@ dexToCAtom atomPtr resultPtr = do
_ -> notSerializable
where
notSerializable = setError "Unserializable atom" $> 0

dexFromCAtom :: Ptr CAtom -> IO (Ptr Atom)
dexFromCAtom catomPtr = do
catom <- peek catomPtr
case catom of
CLit lit -> toStablePtr $ Con $ Lit lit
CRectArray _ _ _ -> unsupported
where
unsupported = setError "Unsupported CAtom" $> nullPtr

0 comments on commit d3ca6fc

Please sign in to comment.