diff --git a/requirements.txt b/requirements.txt index 90d0d76..d376ca3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ asyncpg==0.30.0 -streamflow==0.2.0.dev12 \ No newline at end of file +streamflow @ git+https://github.com/alpha-unito/streamflow@master \ No newline at end of file diff --git a/streamflow/plugins/unito/postgresql/database.py b/streamflow/plugins/unito/postgresql/database.py index 8ed6c30..281a731 100644 --- a/streamflow/plugins/unito/postgresql/database.py +++ b/streamflow/plugins/unito/postgresql/database.py @@ -242,12 +242,17 @@ async def add_target( ) async def add_token( - self, tag: str, type: type[Token], value: Any, port: int | None = None + self, + tag: str, + type: type[Token], + value: Any, + port: int | None = None, + recoverable: bool = False, ) -> int: async with self.pool as pool: async with pool.acquire() as conn: async with conn.transaction(): - return await conn.fetchval( + token_id = await conn.fetchval( "INSERT INTO token(port, type, tag, value) " "VALUES($1, $2, $3, $4) " "RETURNING id", @@ -256,6 +261,11 @@ async def add_token( tag, bytearray(value, "utf-8"), ) + if recoverable: + await conn.fetchval( + "INSERT INTO recoverable(id) VALUES($1)", token_id + ) + return token_id async def add_workflow( self, @@ -460,7 +470,13 @@ async def get_target(self, target_id: int) -> MutableMapping[str, Any]: async def get_token(self, token_id: int) -> MutableMapping[str, Any]: async with self.pool as pool: async with pool.acquire() as conn: - row = await conn.fetchrow("SELECT * FROM token WHERE id = $1", token_id) + row = await conn.fetchrow( + "SELECT *, " + "EXISTS(SELECT 1 FROM recoverable AS r WHERE r.id =$1) AS recoverable " + "FROM token " + "WHERE id =$1", + token_id, + ) return { k: bytearray(v) if isinstance(v, memoryview) else v for k, v in row.items() diff --git a/streamflow/plugins/unito/postgresql/schemas/postgresql.sql b/streamflow/plugins/unito/postgresql/schemas/postgresql.sql index 93100ef..60fb2da 100644 --- a/streamflow/plugins/unito/postgresql/schemas/postgresql.sql +++ b/streamflow/plugins/unito/postgresql/schemas/postgresql.sql @@ -67,6 +67,13 @@ CREATE TABLE IF NOT EXISTS token ); +CREATE TABLE IF NOT EXISTS recoverable +( + id SERIAL PRIMARY KEY, + FOREIGN KEY (id) REFERENCES token (id) +); + + CREATE TABLE IF NOT EXISTS provenance ( dependee INTEGER, diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 63dcbe7..d622898 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -252,9 +252,12 @@ async def test_local_target(context: StreamFlowContext): @pytest.mark.asyncio -async def test_token(context: StreamFlowContext): +@pytest.mark.parametrize("recoverable", ["recoverable", "unrecoverable"]) +async def test_token(context: StreamFlowContext, recoverable: str): """Test saving and loading Token from database""" - token = Token(value=["test", "token"]) + token = Token( + value=["test", "token"], tag="0.0", recoverable=recoverable == "recoverable" + ) await save_load_and_test(token, context)