Skip to content

Support decode_responses #124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 44 additions & 26 deletions mockredis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self,
load_lua_dependencies=True,
blocking_timeout=1000,
blocking_sleep_interval=0.01,
decode_responses=False,
**kwargs):
"""
Initialize as either StrictRedis or Redis.
Expand All @@ -59,6 +60,7 @@ def __init__(self,
self.pubsub = defaultdict(list)
# Dictionary from script to sha ''Script''
self.shas = dict()
self.decode_responses = decode_responses

@classmethod
def from_url(cls, url, db=None, **kwargs):
Expand All @@ -70,7 +72,7 @@ def echo(self, msg):
return self._encode(msg)

def ping(self):
return b'PONG'
return 'PONG' if self.decode_responses else b'PONG'

# Transactions Functions #

Expand Down Expand Up @@ -137,19 +139,24 @@ def execute(self):
def type(self, key):
key = self._encode(key)
if key not in self.redis:
return b'none'
type_ = type(self.redis[key])
if type_ is dict:
return b'hash'
elif type_ is str:
return b'string'
elif type_ is set:
return b'set'
elif type_ is list:
return b'list'
elif type_ is SortedSet:
return b'zset'
raise TypeError("unhandled type {}".format(type_))
res = b'none'
else:
type_ = type(self.redis[key])
if type_ is dict:
res = b'hash'
elif type_ is str:
res = b'string'
elif type_ is set:
res = b'set'
elif type_ is list:
res = b'list'
elif type_ is SortedSet:
res = b'zset'
else:
raise TypeError("unhandled type {}".format(type_))
if self.decode_responses:
return res.decode('utf-8')
return res

def keys(self, pattern='*'):
"""Emulate keys."""
Expand All @@ -166,7 +173,9 @@ def keys(self, pattern='*'):
regex = re.compile(re.sub(r'(^|[^\\])\.', r'\1[^/]', regex))

# Find every key that matches the pattern
return [key for key in self.redis.keys() if regex.match(key.decode('utf-8'))]
return [key for key in self.redis.keys()
if regex.match(key if self.decode_responses
else key.decode('utf-8'))]

def delete(self, *keys):
"""Emulate delete."""
Expand Down Expand Up @@ -815,15 +824,17 @@ def sort(self, name,
return []

by = self._encode(by) if by is not None else by
by_special = dict(zip(('*', 'nosort', '#'),
[self._encode(x) for x in (b'*', b'nosort', b'#')]))
# always organize the items as tuples of the value from the list and the sort key
if by and b'*' in by:
items = [(i, self.get(by.replace(b'*', self._encode(i)))) for i in items]
elif by in [None, b'nosort']:
if by and by_special['*'] in by:
items = [(i, self.get(by.replace(by_special['*'], self._encode(i)))) for i in items]
elif by in [None, by_special['nosort']]:
items = [(i, i) for i in items]
else:
raise ValueError('invalid value for "by": %s' % by)

if by != b'nosort':
if by != by_special['nosort']:
# if sorting, do alpha sort or float (default) and take desc flag into account
sort_type = alpha and str or float
items.sort(key=lambda x: sort_type(x[1]), reverse=bool(desc))
Expand All @@ -835,10 +846,10 @@ def sort(self, name,
# always deal with get specifiers as a list
get = [get]
for g in map(self._encode, get):
if g == b'#':
if g == by_special['#']:
results.append([self.get(i) for i in items])
else:
results.append([self.get(g.replace(b'*', self._encode(i[0]))) for i in items])
results.append([self.get(g.replace(by_special['*'], self._encode(i[0]))) for i in items])
else:
# if not using GET then returning just the item itself
results.append([i[0] for i in items])
Expand Down Expand Up @@ -896,7 +907,10 @@ def _common_scan(self, values_function, cursor='0', match=None, count=10, key=No
values = values[cursor:cursor+count]

if match is not None:
regex = re.compile(b'^' + re.escape(self._encode(match)).replace(b'\\*', b'.*') + b'$')
m_special = dict(zip(('^', '*', '.*', '$'),
[self._encode(x) for x in (b'^', b'\\*', b'.*', b'$')]))
regex = re.compile(m_special['^'] + re.escape(self._encode(match)).replace(m_special['*'],
m_special['.*']) + m_special['$'])
if not key:
key = lambda v: v
values = [v for v in values if regex.match(key(v))]
Expand Down Expand Up @@ -1461,7 +1475,8 @@ def _get_by_type(self, key, operation, create, type_, default, return_default=Tr
Get (and maybe create) a redis data structure by name and type.
"""
key = self._encode(key)
if self.type(key) in [type_, b'none']:
keys = [self._encode(x) for x in (type_, b'none')]
if self.type(key) in keys:
if create:
return self.redis.setdefault(key, default)
else:
Expand Down Expand Up @@ -1545,7 +1560,7 @@ def _score_inclusive(self, score):
def _encode(self, value):
"Return a bytestring representation of the value. Taken from redis-py connection.py"
if isinstance(value, bytes):
return value
value = value
elif isinstance(value, (int, long)):
value = str(value).encode('utf-8')
elif isinstance(value, float):
Expand All @@ -1554,6 +1569,9 @@ def _encode(self, value):
value = str(value).encode('utf-8')
else:
value = value.encode('utf-8', 'strict')

if self.decode_responses:
return value.decode('utf-8')
return value


Expand All @@ -1567,7 +1585,7 @@ def mock_redis_client(**kwargs):
can return a MockRedis object
instead of a Redis object.
"""
return MockRedis()
return MockRedis(**kwargs)

mock_redis_client.from_url = mock_redis_client

Expand All @@ -1578,6 +1596,6 @@ def mock_strict_redis_client(**kwargs):
can return a MockRedis object
instead of a StrictRedis object.
"""
return MockRedis(strict=True)
return MockRedis(strict=True, **kwargs)

mock_strict_redis_client.from_url = mock_strict_redis_client