diff --git a/python/dex/__init__.py b/python/dex/__init__.py index 46ba9e088..c6359ccc4 100644 --- a/python/dex/__init__.py +++ b/python/dex/__init__.py @@ -33,7 +33,7 @@ def __getattr__(self, name): result = api.lookup(self, api.as_cstr(name)) if not result: api.raise_from_dex() - return Atom(result, self) + return Atom._from_ptr(result, self) class Prelude(Module): @@ -52,15 +52,35 @@ def eval(expr: str, module=prelude, _env=None): result = api.evalExpr(_env, api.as_cstr(expr)) if not result: api.raise_from_dex() - return Atom(result, module) + return Atom._from_ptr(result, module) class Atom: __slots__ = ('__weakref__', '_as_parameter_', 'module') - def __init__(self, ptr, module): + def __init__(self, value): + catom = api.CAtom() + if isinstance(value, int): + catom.tag = 0 + catom.value.tag = 1 + catom.value.value = ctypes.c_int(value) + elif isinstance(value, float): + catom.tag = 0 + catom.value.tag = 4 + catom.value.value = ctypes.c_float(value) + else: + raise ValueError("Can't convert given value to a Dex Atom") + self.module = prelude + self._as_parameter_ = api.fromCAtom(ctypes.pointer(catom)) + if not self._as_parameter_: + api.raise_from_dex() + + @classmethod + def _from_ptr(cls, ptr, module): + self = super().__new__(cls) self._as_parameter_ = ptr self.module = module + return self def __del__(self): # TODO: Free diff --git a/python/dex/api.py b/python/dex/api.py index 9bd272230..910137ed0 100644 --- a/python/dex/api.py +++ b/python/dex/api.py @@ -17,11 +17,22 @@ def tagged_union(name: str, members: List[type]): payload = type(name + "Payload", (ctypes.Union,), {"_fields_": named_members}) union = type(name, (ctypes.Structure,), { "_fields_": [("tag", ctypes.c_uint64), ("payload", payload)], - "value": property(lambda self: getattr(self.payload, f"t{self.tag}")), + "value": property( + fget=lambda self: getattr(self.payload, f"t{self.tag}"), + fset=lambda self, value: setattr(self.payload, f"t{self.tag}", value)), + "Payload": payload, }) return union -CLit = tagged_union("Lit", [ctypes.c_int64, ctypes.c_int32, ctypes.c_int8, ctypes.c_double, ctypes.c_float]) +CLit = tagged_union("Lit", [ + ctypes.c_int64, + ctypes.c_int32, + ctypes.c_uint8, + ctypes.c_double, + ctypes.c_float, + ctypes.c_uint32, + ctypes.c_uint64 +]) class CRectArray(ctypes.Structure): _fields_ = [("data", ctypes.c_void_p), ("shape_ptr", ctypes.POINTER(ctypes.c_int64)), @@ -65,8 +76,9 @@ def dex_func(name, *signature): evalExpr = dex_func('dexEvalExpr', HsContextPtr, ctypes.c_char_p, HsAtomPtr) lookup = dex_func('dexLookup', HsContextPtr, ctypes.c_char_p, HsAtomPtr) -print = dex_func('dexPrint', HsAtomPtr, ctypes.c_char_p) -toCAtom = dex_func('dexToCAtom', HsAtomPtr, CAtomPtr, ctypes.c_int) +print = dex_func('dexPrint', HsAtomPtr, ctypes.c_char_p) +toCAtom = dex_func('dexToCAtom', HsAtomPtr, CAtomPtr, ctypes.c_int) +fromCAtom = dex_func('dexFromCAtom', CAtomPtr, HsAtomPtr) createJIT = dex_func('dexCreateJIT', HsJITPtr) destroyJIT = dex_func('dexDestroyJIT', HsJITPtr, None) diff --git a/python/tests/api_test.py b/python/tests/api_test.py index 9a7c0d353..96f806d49 100644 --- a/python/tests/api_test.py +++ b/python/tests/api_test.py @@ -39,6 +39,8 @@ def addOne (x: Float) : Float = x + 1.0 def test_scalar_conversions(self): assert float(dex.eval("5.0")) == 5.0 assert int(dex.eval("5")) == 5 + assert str(dex.Atom(5)) == "5" + assert str(dex.Atom(5.0)) == "5." if __name__ == "__main__": diff --git a/src/Dex/Foreign/API.hs b/src/Dex/Foreign/API.hs index f6c8349c7..70d15571c 100644 --- a/src/Dex/Foreign/API.hs +++ b/src/Dex/Foreign/API.hs @@ -31,8 +31,9 @@ foreign export ccall "dexEvalExpr" dexEvalExpr :: Ptr Context -> CString -> IO foreign export ccall "dexLookup" dexLookup :: Ptr Context -> CString -> IO (Ptr Atom) -- Serialization -foreign export ccall "dexPrint" dexPrint :: Ptr Atom -> IO CString -foreign export ccall "dexToCAtom" dexToCAtom :: Ptr Atom -> Ptr CAtom -> IO CInt +foreign export ccall "dexPrint" dexPrint :: Ptr Atom -> IO CString +foreign export ccall "dexToCAtom" dexToCAtom :: Ptr Atom -> Ptr CAtom -> IO CInt +foreign export ccall "dexFromCAtom" dexFromCAtom :: Ptr CAtom -> IO (Ptr Atom) -- JIT foreign export ccall "dexCreateJIT" dexCreateJIT :: IO (Ptr JIT) diff --git a/src/Dex/Foreign/Serialize.hs b/src/Dex/Foreign/Serialize.hs index b2d3a223f..f97e37982 100644 --- a/src/Dex/Foreign/Serialize.hs +++ b/src/Dex/Foreign/Serialize.hs @@ -6,7 +6,7 @@ module Dex.Foreign.Serialize ( CAtom, - dexPrint, dexToCAtom + dexPrint, dexToCAtom, dexFromCAtom ) where import Data.Word @@ -79,3 +79,12 @@ dexToCAtom atomPtr resultPtr = do _ -> notSerializable where notSerializable = setError "Unserializable atom" $> 0 + +dexFromCAtom :: Ptr CAtom -> IO (Ptr Atom) +dexFromCAtom catomPtr = do + catom <- peek catomPtr + case catom of + CLit lit -> toStablePtr $ Con $ Lit lit + CRectArray _ _ _ -> unsupported + where + unsupported = setError "Unsupported CAtom" $> nullPtr