@@ -57,7 +57,6 @@ def put(
5757 metadata : CheckpointMetadata ,
5858 new_versions : ChannelVersions ,
5959 ) -> RunnableConfig :
60- """Store a checkpoint to Redis with separate blob storage."""
6160 thread_id = config ['configurable' ]['thread_id' ]
6261 checkpoint_ns = config ['configurable' ].get ('checkpoint_ns' , '' )
6362 config_checkpoint_id = config ['configurable' ].get ('checkpoint_id' , '' )
@@ -74,9 +73,9 @@ def put(
7473 parent_checkpoint_id = config_checkpoint_id
7574 checkpoint_id = checkpoint ['id' ]
7675
77- storage_safe_thread_id = self ._safe_redis_id (thread_id )
78- storage_safe_checkpoint_ns = self ._safe_redis_ns (checkpoint_ns )
79- storage_safe_checkpoint_id = self ._safe_redis_id (checkpoint_id )
76+ storage_safe_thread_id = self ._safe_id (thread_id )
77+ storage_safe_checkpoint_ns = self ._safe_ns (checkpoint_ns )
78+ storage_safe_checkpoint_id = self ._safe_id (checkpoint_id )
8079
8180 copy = checkpoint .copy ()
8281 next_config = {
@@ -114,7 +113,7 @@ def put(
114113 checkpoint_data ['source' ] = metadata ['source' ]
115114 checkpoint_data ['step' ] = metadata ['step' ]
116115
117- checkpoint_key = self ._make_redis_checkpoint_key (
116+ checkpoint_key = self ._make_safe_checkpoint_key (
118117 thread_id = thread_id , checkpoint_ns = checkpoint_ns , checkpoint_id = checkpoint_id
119118 )
120119
@@ -142,16 +141,16 @@ def put_writes(
142141 thread_id = config ['configurable' ]['thread_id' ]
143142 checkpoint_ns = config ['configurable' ].get ('checkpoint_ns' , '' )
144143 checkpoint_id = config ['configurable' ].get ('checkpoint_id' , '' )
145- storage_safe_thread_id = (self ._safe_redis_id (thread_id ),)
146- storage_safe_checkpoint_ns = self ._safe_redis_ns (checkpoint_ns )
144+ storage_safe_thread_id = (self ._safe_id (thread_id ),)
145+ storage_safe_checkpoint_ns = self ._safe_ns (checkpoint_ns )
147146
148147 writes_objects : List [Dict [str , Any ]] = []
149148 for idx , (channel , value ) in enumerate (writes ):
150149 type_ , blob = self .serde .dumps_typed (value )
151150 write_obj : Dict [str , Any ] = {
152151 'thread_id' : storage_safe_thread_id ,
153152 'checkpoint_ns' : storage_safe_checkpoint_ns ,
154- 'checkpoint_id' : self ._safe_redis_id (checkpoint_id ),
153+ 'checkpoint_id' : self ._safe_id (checkpoint_id ),
155154 'task_id' : task_id ,
156155 'task_path' : task_path ,
157156 'idx' : WRITES_IDX_MAP .get (channel , idx ),
@@ -164,13 +163,13 @@ def put_writes(
164163 for write_obj in writes_objects :
165164 idx_value = write_obj ['idx' ]
166165 assert isinstance (idx_value , int )
167- key = self ._make_redis_checkpoint_key (
166+ key = self ._make_safe_checkpoint_key (
168167 thread_id = thread_id , checkpoint_ns = checkpoint_ns , checkpoint_id = checkpoint_id
169168 )
170169
171170 self .client .save_state (store_name = self .store_name , key = key , value = json .dumps (write_obj ))
172171
173- checkpoint_key = self ._make_redis_checkpoint_key (
172+ checkpoint_key = self ._make_safe_checkpoint_key (
174173 thread_id = thread_id , checkpoint_ns = checkpoint_ns , checkpoint_id = checkpoint_id
175174 )
176175
@@ -234,8 +233,8 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
234233 thread_id = config ['configurable' ]['thread_id' ]
235234 checkpoint_ns = config ['configurable' ].get ('checkpoint_ns' , '' )
236235
237- storage_safe_thread_id = self ._safe_redis_id (thread_id )
238- storage_safe_checkpoint_ns = self ._safe_redis_ns (checkpoint_ns )
236+ storage_safe_thread_id = self ._safe_id (thread_id )
237+ storage_safe_checkpoint_ns = self ._safe_ns (checkpoint_ns )
239238
240239 key = ':' .join (
241240 [
@@ -252,10 +251,11 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
252251
253252 # To then derive the checkpoint data
254253 checkpoint_data = self .client .get_state (
255- store_name = self .store_name , key =
256- # checkpoint_key.data can either be str or bytes
257- checkpoint_key .data .decode () if isinstance (checkpoint_key .data , bytes ) else
258- checkpoint_key .data
254+ store_name = self .store_name ,
255+ # checkpoint_key.data can either be str or bytes
256+ key = checkpoint_key .data .decode ()
257+ if isinstance (checkpoint_key .data , bytes )
258+ else checkpoint_key .data ,
259259 )
260260
261261 if not checkpoint_data .data :
@@ -315,10 +315,10 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
315315 pending_writes = [],
316316 )
317317
318- def _safe_redis_id (self , id ) -> str :
318+ def _safe_id (self , id ) -> str :
319319 return '00000000-0000-0000-0000-000000000000' if id == '' else id
320320
321- def _safe_redis_ns (self , ns ) -> str :
321+ def _safe_ns (self , ns ) -> str :
322322 return '__empty__' if ns == '' else ns
323323
324324 def _convert_checkpoint_message (self , msg_item ):
@@ -390,7 +390,7 @@ def _dump_metadata(self, metadata: CheckpointMetadata) -> str:
390390 _ , serialized_bytes = self .serde .dumps_typed (metadata )
391391 return serialized_bytes
392392
393- def _make_redis_checkpoint_key (
393+ def _make_safe_checkpoint_key (
394394 self ,
395395 thread_id : str ,
396396 checkpoint_ns : str ,
0 commit comments