diff --git a/mockredis/client.py b/mockredis/client.py index 926e048..b615b37 100644 --- a/mockredis/client.py +++ b/mockredis/client.py @@ -40,6 +40,7 @@ def __init__(self, load_lua_dependencies=True, blocking_timeout=1000, blocking_sleep_interval=0.01, + decode_responses=False, **kwargs): """ Initialize as either StrictRedis or Redis. @@ -59,6 +60,7 @@ def __init__(self, self.pubsub = defaultdict(list) # Dictionary from script to sha ''Script'' self.shas = dict() + self.decode_responses = decode_responses @classmethod def from_url(cls, url, db=None, **kwargs): @@ -70,7 +72,7 @@ def echo(self, msg): return self._encode(msg) def ping(self): - return b'PONG' + return 'PONG' if self.decode_responses else b'PONG' # Transactions Functions # @@ -137,19 +139,24 @@ def execute(self): def type(self, key): key = self._encode(key) if key not in self.redis: - return b'none' - type_ = type(self.redis[key]) - if type_ is dict: - return b'hash' - elif type_ is str: - return b'string' - elif type_ is set: - return b'set' - elif type_ is list: - return b'list' - elif type_ is SortedSet: - return b'zset' - raise TypeError("unhandled type {}".format(type_)) + res = b'none' + else: + type_ = type(self.redis[key]) + if type_ is dict: + res = b'hash' + elif type_ is str: + res = b'string' + elif type_ is set: + res = b'set' + elif type_ is list: + res = b'list' + elif type_ is SortedSet: + res = b'zset' + else: + raise TypeError("unhandled type {}".format(type_)) + if self.decode_responses: + return res.decode('utf-8') + return res def keys(self, pattern='*'): """Emulate keys.""" @@ -166,7 +173,9 @@ def keys(self, pattern='*'): regex = re.compile(re.sub(r'(^|[^\\])\.', r'\1[^/]', regex)) # Find every key that matches the pattern - return [key for key in self.redis.keys() if regex.match(key.decode('utf-8'))] + return [key for key in self.redis.keys() + if regex.match(key if self.decode_responses + else key.decode('utf-8'))] def delete(self, *keys): """Emulate delete.""" @@ -815,15 +824,17 @@ def sort(self, name, return [] by = self._encode(by) if by is not None else by + by_special = dict(zip(('*', 'nosort', '#'), + [self._encode(x) for x in (b'*', b'nosort', b'#')])) # always organize the items as tuples of the value from the list and the sort key - if by and b'*' in by: - items = [(i, self.get(by.replace(b'*', self._encode(i)))) for i in items] - elif by in [None, b'nosort']: + if by and by_special['*'] in by: + items = [(i, self.get(by.replace(by_special['*'], self._encode(i)))) for i in items] + elif by in [None, by_special['nosort']]: items = [(i, i) for i in items] else: raise ValueError('invalid value for "by": %s' % by) - if by != b'nosort': + if by != by_special['nosort']: # if sorting, do alpha sort or float (default) and take desc flag into account sort_type = alpha and str or float items.sort(key=lambda x: sort_type(x[1]), reverse=bool(desc)) @@ -835,10 +846,10 @@ def sort(self, name, # always deal with get specifiers as a list get = [get] for g in map(self._encode, get): - if g == b'#': + if g == by_special['#']: results.append([self.get(i) for i in items]) else: - results.append([self.get(g.replace(b'*', self._encode(i[0]))) for i in items]) + results.append([self.get(g.replace(by_special['*'], self._encode(i[0]))) for i in items]) else: # if not using GET then returning just the item itself results.append([i[0] for i in items]) @@ -896,7 +907,10 @@ def _common_scan(self, values_function, cursor='0', match=None, count=10, key=No values = values[cursor:cursor+count] if match is not None: - regex = re.compile(b'^' + re.escape(self._encode(match)).replace(b'\\*', b'.*') + b'$') + m_special = dict(zip(('^', '*', '.*', '$'), + [self._encode(x) for x in (b'^', b'\\*', b'.*', b'$')])) + regex = re.compile(m_special['^'] + re.escape(self._encode(match)).replace(m_special['*'], + m_special['.*']) + m_special['$']) if not key: key = lambda v: v values = [v for v in values if regex.match(key(v))] @@ -1461,7 +1475,8 @@ def _get_by_type(self, key, operation, create, type_, default, return_default=Tr Get (and maybe create) a redis data structure by name and type. """ key = self._encode(key) - if self.type(key) in [type_, b'none']: + keys = [self._encode(x) for x in (type_, b'none')] + if self.type(key) in keys: if create: return self.redis.setdefault(key, default) else: @@ -1545,7 +1560,7 @@ def _score_inclusive(self, score): def _encode(self, value): "Return a bytestring representation of the value. Taken from redis-py connection.py" if isinstance(value, bytes): - return value + value = value elif isinstance(value, (int, long)): value = str(value).encode('utf-8') elif isinstance(value, float): @@ -1554,6 +1569,9 @@ def _encode(self, value): value = str(value).encode('utf-8') else: value = value.encode('utf-8', 'strict') + + if self.decode_responses: + return value.decode('utf-8') return value @@ -1567,7 +1585,7 @@ def mock_redis_client(**kwargs): can return a MockRedis object instead of a Redis object. """ - return MockRedis() + return MockRedis(**kwargs) mock_redis_client.from_url = mock_redis_client @@ -1578,6 +1596,6 @@ def mock_strict_redis_client(**kwargs): can return a MockRedis object instead of a StrictRedis object. """ - return MockRedis(strict=True) + return MockRedis(strict=True, **kwargs) mock_strict_redis_client.from_url = mock_strict_redis_client