Skip to content

Commit 658169d

Browse files
committed
Adds custom encoder/decoder
1 parent 3ae4e43 commit 658169d

File tree

3 files changed

+119
-64
lines changed

3 files changed

+119
-64
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
!**/*.md
77

88
.env
9+
.vscode
910

1011
*.pyc

rejson/client.py

Lines changed: 74 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,22 @@ def long_or_none(r):
2525
return long(r)
2626
return r
2727

28-
def json_or_none(r):
28+
def json_or_none(d):
2929
"Return a deserialized JSON object or None"
30-
if r:
31-
return json.loads(r)
32-
return r
30+
def _f(r):
31+
if r:
32+
return d(r)
33+
return r
34+
return _f
3335

34-
def bulk_of_jsons(b):
36+
def bulk_of_jsons(d):
3537
"Replace serialized JSON values with objects in a bulk array response (list)"
36-
for index, item in enumerate(b):
37-
if item is not None:
38-
b[index] = json.loads(item)
39-
return b
38+
def _f(b):
39+
for index, item in enumerate(b):
40+
if item is not None:
41+
b[index] = d(item)
42+
return b
43+
return _f
4044

4145
class Client(StrictRedis):
4246
"""
@@ -51,49 +55,55 @@ class Client(StrictRedis):
5155
'ver': 1
5256
}
5357

54-
MODULE_CALLBACKS = {
55-
'JSON.DEL': long,
56-
'JSON.GET': json_or_none,
57-
'JSON.MGET': bulk_of_jsons,
58-
'JSON.SET': lambda r: r and nativestr(r) == 'OK',
59-
'JSON.NUMINCRBY': float_or_long,
60-
'JSON.NUMMULTBY': float_or_long,
61-
'JSON.STRAPPEND': long_or_none,
62-
'JSON.STRLEN': long_or_none,
63-
'JSON.ARRAPPEND': long_or_none,
64-
'JSON.ARRINDEX': long_or_none,
65-
'JSON.ARRINSERT': long_or_none,
66-
'JSON.ARRLEN': long_or_none,
67-
'JSON.ARRPOP': json_or_none,
68-
'JSON.ARRTRIM': long_or_none,
69-
'JSON.OBJLEN': long_or_none,
70-
}
58+
_encoder = None
59+
_encode = None
60+
_decoder = None
61+
_decode = None
62+
63+
def __init__(self, encoder=None, decoder=None, *args, **kwargs):
64+
"""
65+
Creates a new ReJSON client
66+
"""
67+
self.setEncoder(encoder)
68+
self.setDecoder(decoder)
69+
StrictRedis.__init__(self, *args, **kwargs)
7170

72-
def __init__(self, *args, **kwargs):
73-
super(Client, self).__init__(*args, **kwargs)
74-
self.__checkPrerequirements()
7571
# Set the module commands' callbacks
76-
for k, v in self.MODULE_CALLBACKS.iteritems():
72+
MODULE_CALLBACKS = {
73+
'JSON.DEL': long,
74+
'JSON.GET': json_or_none(self._decode),
75+
'JSON.MGET': bulk_of_jsons(self._decode),
76+
'JSON.SET': lambda r: r and nativestr(r) == 'OK',
77+
'JSON.NUMINCRBY': float_or_long,
78+
'JSON.NUMMULTBY': float_or_long,
79+
'JSON.STRAPPEND': long_or_none,
80+
'JSON.STRLEN': long_or_none,
81+
'JSON.ARRAPPEND': long_or_none,
82+
'JSON.ARRINDEX': long_or_none,
83+
'JSON.ARRINSERT': long_or_none,
84+
'JSON.ARRLEN': long_or_none,
85+
'JSON.ARRPOP': json_or_none(self._decode),
86+
'JSON.ARRTRIM': long_or_none,
87+
'JSON.OBJLEN': long_or_none,
88+
}
89+
for k, v in MODULE_CALLBACKS.iteritems():
7790
self.set_response_callback(k, v)
91+
92+
def setEncoder(self, encoder):
93+
"Sets the encoder"
94+
if not encoder:
95+
self._encoder = json.JSONEncoder()
96+
else:
97+
self._encoder = encoder
98+
self._encode = self._encoder.encode
7899

79-
def __checkPrerequirements(self):
80-
"Checks that the module is ready"
81-
try:
82-
reply = self.execute_command('MODULE', 'LIST')
83-
except exceptions.ResponseError as e:
84-
if e.message.startswith('unknown command'):
85-
raise exceptions.RedisError('Modules are not supported '
86-
'on your Redis server - consider '
87-
'upgrading to a newer version.')
88-
finally:
89-
info = self.MODULE_INFO
90-
for r in reply:
91-
module = dict(zip(r[0::2], r[1::2]))
92-
if info['name'] == module['name'] and \
93-
info['ver'] <= module['ver']:
94-
return
95-
raise exceptions.RedisError('ReJSON is not loaded - follow the '
96-
'instructions at http://rejson.io')
100+
def setDecoder(self, decoder):
101+
"Sets the decoder"
102+
if not decoder:
103+
self._decoder = json.JSONDecoder()
104+
else:
105+
self._decoder = decoder
106+
self._decode = self._decoder.decode
97107

98108
def JSONDel(self, name, path=Path.rootPath()):
99109
"""
@@ -130,7 +140,8 @@ def JSONSet(self, name, path, obj, nx=False, xx=False):
130140
``nx`` if set to True, set ``value`` only if it does not exist
131141
``xx`` if set to True, set ``value`` only if it exists
132142
"""
133-
pieces = [name, str_path(path), json.dumps(obj)]
143+
pieces = [name, str_path(path), self._encode(obj)]
144+
134145
# Handle existential modifiers
135146
if nx and xx:
136147
raise Exception('nx and xx are mutually exclusive: use one, the '
@@ -152,21 +163,21 @@ def JSONNumIncrBy(self, name, path, number):
152163
Increments the numeric (integer or floating point) JSON value under
153164
``path`` at key ``name`` by the provided ``number``
154165
"""
155-
return self.execute_command('JSON.NUMINCRBY', name, str_path(path), json.dumps(number))
166+
return self.execute_command('JSON.NUMINCRBY', name, str_path(path), self._encode(number))
156167

157168
def JSONNumMultBy(self, name, path, number):
158169
"""
159170
Multiplies the numeric (integer or floating point) JSON value under
160171
``path`` at key ``name`` with the provided ``number``
161172
"""
162-
return self.execute_command('JSON.NUMMULTBY', name, str_path(path), json.dumps(number))
173+
return self.execute_command('JSON.NUMMULTBY', name, str_path(path), self._encode(number))
163174

164175
def JSONStrAppend(self, name, string, path=Path.rootPath()):
165176
"""
166177
Appends to the string JSON value under ``path`` at key ``name`` the
167178
provided ``string``
168179
"""
169-
return self.execute_command('JSON.STRAPPEND', name, str_path(path), json.dumps(string))
180+
return self.execute_command('JSON.STRAPPEND', name, str_path(path), self._encode(string))
170181

171182
def JSONStrLen(self, name, path=Path.rootPath()):
172183
"""
@@ -182,7 +193,7 @@ def JSONArrAppend(self, name, path=Path.rootPath(), *args):
182193
"""
183194
pieces = [name, str_path(path)]
184195
for o in args:
185-
pieces.append(json.dumps(o))
196+
pieces.append(self._encode(o))
186197
return self.execute_command('JSON.ARRAPPEND', *pieces)
187198

188199
def JSONArrIndex(self, name, path, scalar, start=0, stop=-1):
@@ -191,7 +202,7 @@ def JSONArrIndex(self, name, path, scalar, start=0, stop=-1):
191202
``name``. The search can be limited using the optional inclusive
192203
``start`` and exclusive ``stop`` indices.
193204
"""
194-
return self.execute_command('JSON.ARRINDEX', name, str_path(path), json.dumps(scalar), start, stop)
205+
return self.execute_command('JSON.ARRINDEX', name, str_path(path), self._encode(scalar), start, stop)
195206

196207
def JSONArrInsert(self, name, path, index, *args):
197208
"""
@@ -200,7 +211,7 @@ def JSONArrInsert(self, name, path, index, *args):
200211
"""
201212
pieces = [name, str_path(path), index]
202213
for o in args:
203-
pieces.append(json.dumps(o))
214+
pieces.append(self._encode(o))
204215
return self.execute_command('JSON.ARRINSERT', *pieces)
205216

206217
def JSONArrLen(self, name, path=Path.rootPath()):
@@ -246,12 +257,14 @@ def pipeline(self, transaction=True, shard_hint=None):
246257
atomic, pipelines are useful for reducing the back-and-forth overhead
247258
between the client and server.
248259
"""
249-
return Pipeline(
250-
self.connection_pool,
251-
self.response_callbacks,
252-
transaction,
253-
shard_hint)
260+
p = Pipeline(
261+
connection_pool=self.connection_pool,
262+
response_callbacks=self.response_callbacks,
263+
transaction=transaction,
264+
shard_hint=shard_hint)
265+
p.setEncoder(self._encoder)
266+
p.setDecoder(self._decoder)
267+
return p
254268

255269
class Pipeline(BasePipeline, Client):
256270
"Pipeline for ReJSONClient"
257-
pass

test/test.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import redis
23
from unittest import TestCase
34
from rejson import Client, Path
@@ -148,7 +149,7 @@ def testObjLenShouldSucceed(self):
148149
rj.JSONSet('obj', Path.rootPath(), obj)
149150
self.assertEqual(len(obj), rj.JSONObjLen('obj', Path.rootPath()))
150151

151-
def testPipeline(self):
152+
def testPipelineShouldSucceed(self):
152153
"Test pipeline"
153154
rj = Client()
154155
rj.flushdb()
@@ -158,8 +159,48 @@ def testPipeline(self):
158159
p.JSONGet('foo')
159160
p.JSONDel('foo')
160161
p.exists('foo')
161-
self.assertListEqual([ True, 'bar', 1, False], p.execute())
162-
162+
self.assertListEqual([ True, 'bar', 1, False ], p.execute())
163+
164+
def testCustomEncoderDecoderShouldSucceed(self):
165+
"Test a custom encoder and decoder"
166+
167+
class CustomClass(object):
168+
key = ''
169+
val = ''
170+
def __init__(self, k='', v=''):
171+
self.key = k
172+
self.val = v
173+
174+
class TestEncoder(json.JSONEncoder):
175+
def default(self, obj):
176+
if isinstance(obj, CustomClass):
177+
return 'CustomClass:{}:{}'.format(obj.key, obj.val)
178+
return json.JSONEncoder.encode(self, obj)
179+
180+
class TestDecoder(json.JSONDecoder):
181+
def decode(self, obj):
182+
d = json.JSONDecoder.decode(self, obj)
183+
if isinstance(d, basestring) and d.startswith('CustomClass:'):
184+
s = d.split(':')
185+
return CustomClass(k=s[1], v=s[2])
186+
return d
187+
188+
rj = Client(encoder=TestEncoder(), decoder=TestDecoder())
189+
rj.flushdb()
190+
191+
# Check a regular string
192+
self.assertTrue(rj.JSONSet('foo', Path.rootPath(), 'bar'))
193+
self.assertEqual('string', rj.JSONType('foo', Path.rootPath()))
194+
self.assertEqual('bar', rj.JSONGet('foo', Path.rootPath()))
195+
196+
# Check the custom encoder
197+
self.assertTrue(rj.JSONSet('cus', Path.rootPath(), CustomClass('foo', 'bar')))
198+
obj = rj.JSONGet('cus', Path.rootPath())
199+
self.assertIsNotNone(obj)
200+
self.assertEqual(CustomClass, obj.__class__)
201+
self.assertEqual('foo', obj.key)
202+
self.assertEqual('bar', obj.val)
203+
163204
def testUsageExampleShouldSucceed(self):
164205
"Test the usage example"
165206

0 commit comments

Comments
 (0)