Skip to content

Commit fccd1ce

Browse files
committed
fix: remove refs to redis
1 parent 424db24 commit fccd1ce

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

ext/dapr-ext-langgraph/dapr/ext/langgraph/dapr_checkpointer.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def put(
5757
metadata: CheckpointMetadata,
5858
new_versions: ChannelVersions,
5959
) -> RunnableConfig:
60-
"""Store a checkpoint to Redis with separate blob storage."""
6160
thread_id = config['configurable']['thread_id']
6261
checkpoint_ns = config['configurable'].get('checkpoint_ns', '')
6362
config_checkpoint_id = config['configurable'].get('checkpoint_id', '')
@@ -74,9 +73,9 @@ def put(
7473
parent_checkpoint_id = config_checkpoint_id
7574
checkpoint_id = checkpoint['id']
7675

77-
storage_safe_thread_id = self._safe_redis_id(thread_id)
78-
storage_safe_checkpoint_ns = self._safe_redis_ns(checkpoint_ns)
79-
storage_safe_checkpoint_id = self._safe_redis_id(checkpoint_id)
76+
storage_safe_thread_id = self._safe_id(thread_id)
77+
storage_safe_checkpoint_ns = self._safe_ns(checkpoint_ns)
78+
storage_safe_checkpoint_id = self._safe_id(checkpoint_id)
8079

8180
copy = checkpoint.copy()
8281
next_config = {
@@ -114,7 +113,7 @@ def put(
114113
checkpoint_data['source'] = metadata['source']
115114
checkpoint_data['step'] = metadata['step']
116115

117-
checkpoint_key = self._make_redis_checkpoint_key(
116+
checkpoint_key = self._make_safe_checkpoint_key(
118117
thread_id=thread_id, checkpoint_ns=checkpoint_ns, checkpoint_id=checkpoint_id
119118
)
120119

@@ -142,16 +141,16 @@ def put_writes(
142141
thread_id = config['configurable']['thread_id']
143142
checkpoint_ns = config['configurable'].get('checkpoint_ns', '')
144143
checkpoint_id = config['configurable'].get('checkpoint_id', '')
145-
storage_safe_thread_id = (self._safe_redis_id(thread_id),)
146-
storage_safe_checkpoint_ns = self._safe_redis_ns(checkpoint_ns)
144+
storage_safe_thread_id = (self._safe_id(thread_id),)
145+
storage_safe_checkpoint_ns = self._safe_ns(checkpoint_ns)
147146

148147
writes_objects: List[Dict[str, Any]] = []
149148
for idx, (channel, value) in enumerate(writes):
150149
type_, blob = self.serde.dumps_typed(value)
151150
write_obj: Dict[str, Any] = {
152151
'thread_id': storage_safe_thread_id,
153152
'checkpoint_ns': storage_safe_checkpoint_ns,
154-
'checkpoint_id': self._safe_redis_id(checkpoint_id),
153+
'checkpoint_id': self._safe_id(checkpoint_id),
155154
'task_id': task_id,
156155
'task_path': task_path,
157156
'idx': WRITES_IDX_MAP.get(channel, idx),
@@ -164,13 +163,13 @@ def put_writes(
164163
for write_obj in writes_objects:
165164
idx_value = write_obj['idx']
166165
assert isinstance(idx_value, int)
167-
key = self._make_redis_checkpoint_key(
166+
key = self._make_safe_checkpoint_key(
168167
thread_id=thread_id, checkpoint_ns=checkpoint_ns, checkpoint_id=checkpoint_id
169168
)
170169

171170
self.client.save_state(store_name=self.store_name, key=key, value=json.dumps(write_obj))
172171

173-
checkpoint_key = self._make_redis_checkpoint_key(
172+
checkpoint_key = self._make_safe_checkpoint_key(
174173
thread_id=thread_id, checkpoint_ns=checkpoint_ns, checkpoint_id=checkpoint_id
175174
)
176175

@@ -234,8 +233,8 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
234233
thread_id = config['configurable']['thread_id']
235234
checkpoint_ns = config['configurable'].get('checkpoint_ns', '')
236235

237-
storage_safe_thread_id = self._safe_redis_id(thread_id)
238-
storage_safe_checkpoint_ns = self._safe_redis_ns(checkpoint_ns)
236+
storage_safe_thread_id = self._safe_id(thread_id)
237+
storage_safe_checkpoint_ns = self._safe_ns(checkpoint_ns)
239238

240239
key = ':'.join(
241240
[
@@ -252,10 +251,11 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
252251

253252
# To then derive the checkpoint data
254253
checkpoint_data = self.client.get_state(
255-
store_name=self.store_name, key=
256-
# checkpoint_key.data can either be str or bytes
257-
checkpoint_key.data.decode() if isinstance(checkpoint_key.data, bytes) else
258-
checkpoint_key.data
254+
store_name=self.store_name,
255+
# checkpoint_key.data can either be str or bytes
256+
key=checkpoint_key.data.decode()
257+
if isinstance(checkpoint_key.data, bytes)
258+
else checkpoint_key.data,
259259
)
260260

261261
if not checkpoint_data.data:
@@ -315,10 +315,10 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
315315
pending_writes=[],
316316
)
317317

318-
def _safe_redis_id(self, id) -> str:
318+
def _safe_id(self, id) -> str:
319319
return '00000000-0000-0000-0000-000000000000' if id == '' else id
320320

321-
def _safe_redis_ns(self, ns) -> str:
321+
def _safe_ns(self, ns) -> str:
322322
return '__empty__' if ns == '' else ns
323323

324324
def _convert_checkpoint_message(self, msg_item):
@@ -390,7 +390,7 @@ def _dump_metadata(self, metadata: CheckpointMetadata) -> str:
390390
_, serialized_bytes = self.serde.dumps_typed(metadata)
391391
return serialized_bytes
392392

393-
def _make_redis_checkpoint_key(
393+
def _make_safe_checkpoint_key(
394394
self,
395395
thread_id: str,
396396
checkpoint_ns: str,

0 commit comments

Comments
 (0)