Skip to content

Commit 8f7c906

Browse files
authored
Custom Claims and List Users Support (#86)
* Implemented set_custom_user_claims() function * Unit and integration tests for the new functionality * Implemented list_users() function * Tests for list users API * Updated tests * Implemented user iteration using Python's native iterable API * Cleaned up the test code * Moved user iteration logic to _user_mgt * Code cleanup * Updated the list users API by adding the ListUsersPage class * Updated error message
1 parent 11f70f0 commit 8f7c906

File tree

7 files changed

+658
-10
lines changed

7 files changed

+658
-10
lines changed

firebase_admin/_user_mgt.py

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Firebase user management sub module."""
1616

17+
import json
1718
import re
1819

1920
from google.auth import transport
@@ -29,9 +30,17 @@
2930
USER_CREATE_ERROR = 'USER_CREATE_ERROR'
3031
USER_UPDATE_ERROR = 'USER_UPDATE_ERROR'
3132
USER_DELETE_ERROR = 'USER_DELETE_ERROR'
33+
USER_DOWNLOAD_ERROR = 'LIST_USERS_ERROR'
3234

3335
ID_TOOLKIT_URL = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/'
3436

37+
MAX_LIST_USERS_RESULTS = 1000
38+
MAX_CLAIMS_PAYLOAD_SIZE = 1000
39+
RESERVED_CLAIMS = set([
40+
'acr', 'amr', 'at_hash', 'aud', 'auth_time', 'azp', 'cnf', 'c_hash', 'exp', 'iat',
41+
'iss', 'jti', 'nbf', 'nonce', 'sub', 'firebase',
42+
])
43+
3544

3645
class _Validator(object):
3746
"""A collectoin of data validation utilities.
@@ -118,6 +127,37 @@ def validate_delete_list(cls, delete_attr):
118127
'Invalid delete list: "{0}". Delete list must be a '
119128
'non-empty list.'.format(delete_attr))
120129

130+
@classmethod
131+
def validate_custom_claims(cls, custom_claims):
132+
"""Validates the specified custom claims.
133+
134+
Custom claims must be specified as a JSON string.The string must not exceed 1000
135+
characters, and the parsed JSON payload must not contain reserved JWT claims.
136+
"""
137+
if not isinstance(custom_claims, six.string_types) or not custom_claims:
138+
raise ValueError(
139+
'Invalid custom claims: "{0}". Custom claims must be a non-empty JSON '
140+
'string.'.format(custom_claims))
141+
142+
if len(custom_claims) > MAX_CLAIMS_PAYLOAD_SIZE:
143+
raise ValueError(
144+
'Custom claims payload must not exceed {0} '
145+
'characters.'.format(MAX_CLAIMS_PAYLOAD_SIZE))
146+
try:
147+
parsed = json.loads(custom_claims)
148+
except Exception:
149+
raise ValueError('Failed to parse custom claims string as JSON.')
150+
else:
151+
if not isinstance(parsed, dict):
152+
raise ValueError('Custom claims must be parseable as a JSON object.')
153+
invalid_claims = RESERVED_CLAIMS.intersection(set(parsed.keys()))
154+
if len(invalid_claims) > 1:
155+
joined = ', '.join(sorted(invalid_claims))
156+
raise ValueError('Claims "{0}" are reserved, and must not be set.'.format(joined))
157+
elif len(invalid_claims) == 1:
158+
raise ValueError(
159+
'Claim "{0}" is reserved, and must not be set.'.format(invalid_claims.pop()))
160+
121161

122162
class ApiCallError(Exception):
123163
"""Represents an Exception encountered while invoking the Firebase user management API."""
@@ -132,6 +172,7 @@ class UserManager(object):
132172
"""Provides methods for interacting with the Google Identity Toolkit."""
133173

134174
_VALIDATORS = {
175+
'customAttributes' : _Validator.validate_custom_claims,
135176
'deleteAttribute' : _Validator.validate_delete_list,
136177
'deleteProvider' : _Validator.validate_delete_list,
137178
'disabled' : _Validator.validate_disabled,
@@ -163,7 +204,8 @@ class UserManager(object):
163204
'phone_number' : 'phoneNumber',
164205
'photo_url' : 'photoUrl',
165206
'password' : 'password',
166-
'disabled' : 'disabled',
207+
'disabled' : 'disableUser',
208+
'custom_claims' : 'customAttributes',
167209
}
168210

169211
_REMOVABLE_FIELDS = {
@@ -207,6 +249,26 @@ def get_user(self, **kwargs):
207249
'No user record found for the provided {0}: {1}.'.format(key_type, key))
208250
return response['users'][0]
209251

252+
def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS):
253+
"""Retrieves a batch of users."""
254+
if page_token is not None:
255+
if not isinstance(page_token, six.string_types) or not page_token:
256+
raise ValueError('Page token must be a non-empty string.')
257+
if not isinstance(max_results, int):
258+
raise ValueError('Max results must be an integer.')
259+
elif max_results < 1 or max_results > MAX_LIST_USERS_RESULTS:
260+
raise ValueError(
261+
'Max results must be a positive integer less than '
262+
'{0}.'.format(MAX_LIST_USERS_RESULTS))
263+
264+
payload = {'maxResults': max_results}
265+
if page_token:
266+
payload['nextPageToken'] = page_token
267+
try:
268+
return self._request('post', 'downloadAccount', json=payload)
269+
except requests.exceptions.RequestException as error:
270+
self._handle_http_error(USER_DOWNLOAD_ERROR, 'Failed to download user accounts.', error)
271+
210272
def create_user(self, **kwargs):
211273
"""Creates a new user account with the specified properties."""
212274
payload = self._init_payload('create_user', UserManager._CREATE_USER_FIELDS, **kwargs)
@@ -236,9 +298,12 @@ def update_user(self, uid, **kwargs):
236298
if 'phoneNumber' in payload and payload['phoneNumber'] is None:
237299
payload['deleteProvider'] = ['phone']
238300
del payload['phoneNumber']
239-
if 'disabled' in payload:
240-
payload['disableUser'] = payload['disabled']
241-
del payload['disabled']
301+
if 'customAttributes' in payload:
302+
custom_claims = payload['customAttributes']
303+
if custom_claims is None:
304+
custom_claims = {}
305+
if isinstance(custom_claims, dict):
306+
payload['customAttributes'] = json.dumps(custom_claims)
242307

243308
self._validate(payload, self._VALIDATORS, 'update user')
244309
try:
@@ -306,3 +371,35 @@ def _request(self, method, urlpath, **kwargs):
306371
resp = self._session.request(method, ID_TOOLKIT_URL + urlpath, **kwargs)
307372
resp.raise_for_status()
308373
return resp.json()
374+
375+
376+
class UserIterator(object):
377+
"""An iterator that allows iterating over user accounts, one at a time.
378+
379+
This implementation loads a page of users into memory, and iterates on them. When the whole
380+
page has been traversed, it loads another page. This class never keeps more than one page
381+
of entries in memory.
382+
"""
383+
384+
def __init__(self, current_page):
385+
if not current_page:
386+
raise ValueError('Current page must not be None.')
387+
self._current_page = current_page
388+
self._index = 0
389+
390+
def next(self):
391+
if self._index == len(self._current_page.users):
392+
if self._current_page.has_next_page:
393+
self._current_page = self._current_page.get_next_page()
394+
self._index = 0
395+
if self._index < len(self._current_page.users):
396+
result = self._current_page.users[self._index]
397+
self._index += 1
398+
return result
399+
raise StopIteration
400+
401+
def __next__(self):
402+
return self.next()
403+
404+
def __iter__(self):
405+
return self

firebase_admin/auth.py

Lines changed: 153 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
creating and managing user accounts in Firebase projects.
2020
"""
2121

22+
import json
2223
import time
2324

2425
from google.auth import jwt
@@ -163,6 +164,36 @@ def get_user_by_phone_number(phone_number, app=None):
163164
except _user_mgt.ApiCallError as error:
164165
raise AuthError(error.code, str(error), error.detail)
165166

167+
def list_users(page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS, app=None):
168+
"""Retrieves a page of user accounts from a Firebase project.
169+
170+
The ``page_token`` argument governs the starting point of the page. The ``max_results``
171+
argument governs the maximum number of user accounts that may be included in the returned page.
172+
This function never returns None. If there are no user accounts in the Firebase project, this
173+
returns an empty page.
174+
175+
Args:
176+
page_token: A non-empty page token string, which indicates the starting point of the page
177+
(optional). Defaults to ``None``, which will retrieve the first page of users.
178+
max_results: A positive integer indicating the maximum number of users to include in the
179+
returned page (optional). Defaults to 1000, which is also the maximum number allowed.
180+
app: An App instance (optional).
181+
182+
Returns:
183+
ListUsersPage: A ListUsersPage instance.
184+
185+
Raises:
186+
ValueError: If max_results or page_token are invalid.
187+
AuthError: If an error occurs while retrieving the user accounts.
188+
"""
189+
user_manager = _get_auth_service(app).user_manager
190+
def download(page_token, max_results):
191+
try:
192+
return user_manager.list_users(page_token, max_results)
193+
except _user_mgt.ApiCallError as error:
194+
raise AuthError(error.code, str(error), error.detail)
195+
return ListUsersPage(download, page_token, max_results)
196+
166197

167198
def create_user(**kwargs):
168199
"""Creates a new user account with the specified properties.
@@ -195,11 +226,12 @@ def create_user(**kwargs):
195226
raise AuthError(error.code, str(error), error.detail)
196227

197228

198-
def update_user(uid, **kwargs): # pylint: disable=missing-param-doc
229+
def update_user(uid, **kwargs):
199230
"""Updates an existing user account with the specified properties.
200231
201232
Args:
202233
uid: A user ID string.
234+
kwargs: A series of keyword arguments (optional).
203235
204236
Keyword Args:
205237
display_name: The user's display name (optional). Can be removed by explicitly passing
@@ -212,7 +244,8 @@ def update_user(uid, **kwargs): # pylint: disable=missing-param-doc
212244
photo_url: The user's photo URL (optional). Can be removed by explicitly passing None.
213245
password: The user's raw, unhashed password. (optional).
214246
disabled: A boolean indicating whether or not the user account is disabled (optional).
215-
app: An App instance (optional).
247+
custom_claims: A dictionary or a JSON string contining the custom claims to be set on the
248+
user account (optional).
216249
217250
Returns:
218251
UserRecord: An updated UserRecord instance for the user.
@@ -229,6 +262,31 @@ def update_user(uid, **kwargs): # pylint: disable=missing-param-doc
229262
except _user_mgt.ApiCallError as error:
230263
raise AuthError(error.code, str(error), error.detail)
231264

265+
def set_custom_user_claims(uid, custom_claims, app=None):
266+
"""Sets additional claims on an existing user account.
267+
268+
Custom claims set via this function can be used to define user roles and privilege levels.
269+
These claims propagate to all the devices where the user is already signed in (after token
270+
expiration or when token refresh is forced), and next time the user signs in. The claims
271+
can be accessed via the user's ID token JWT. If a reserved OIDC claim is specified (sub, iat,
272+
iss, etc), an error is thrown. Claims payload must also not be larger then 1000 characters
273+
when serialized into a JSON string.
274+
275+
Args:
276+
uid: A user ID string.
277+
custom_claims: A dictionary or a JSON string of custom claims. Pass None to unset any
278+
claims set previously.
279+
app: An App instance (optional).
280+
281+
Raises:
282+
ValueError: If the specified user ID or the custom claims are invalid.
283+
AuthError: If an error occurs while updating the user account.
284+
"""
285+
user_manager = _get_auth_service(app).user_manager
286+
try:
287+
user_manager.update_user(uid, custom_claims=custom_claims)
288+
except _user_mgt.ApiCallError as error:
289+
raise AuthError(error.code, str(error), error.detail)
232290

233291
def delete_user(uid, app=None):
234292
"""Deletes the user identified by the specified user ID.
@@ -393,6 +451,20 @@ def provider_data(self):
393451
providers = self._data.get('providerUserInfo', [])
394452
return [_ProviderUserInfo(entry) for entry in providers]
395453

454+
@property
455+
def custom_claims(self):
456+
"""Returns any custom claims set on this user account.
457+
458+
Returns:
459+
dict: A dictionary of claims or None.
460+
"""
461+
claims = self._data.get('customAttributes')
462+
if claims:
463+
parsed = json.loads(claims)
464+
if parsed != {}:
465+
return parsed
466+
return None
467+
396468

397469
class UserMetadata(object):
398470
"""Contains additional metadata associated with a user account."""
@@ -414,6 +486,85 @@ def last_sign_in_timestamp(self):
414486
return int(self._data['lastLoginAt'])
415487
return None
416488

489+
class ExportedUserRecord(UserRecord):
490+
"""Contains metadata associated with a user including password hash and salt."""
491+
492+
def __init__(self, data):
493+
super(ExportedUserRecord, self).__init__(data)
494+
495+
@property
496+
def password_hash(self):
497+
"""The user's password hash as a base64-encoded string.
498+
499+
If the Firebase Auth hashing algorithm (SCRYPT) was used to create the user account, this
500+
will be the base64-encoded password hash of the user. If a different hashing algorithm was
501+
used to create this user, as is typical when migrating from another Auth system, this
502+
will be an empty string. If no password is set, this will be None.
503+
"""
504+
return self._data.get('passwordHash')
505+
506+
@property
507+
def password_salt(self):
508+
"""The user's password salt as a base64-encoded string.
509+
510+
If the Firebase Auth hashing algorithm (SCRYPT) was used to create the user account, this
511+
will be the base64-encoded password salt of the user. If a different hashing algorithm was
512+
used to create this user, as is typical when migrating from another Auth system, this will
513+
be an empty string. If no password is set, this will be None.
514+
"""
515+
return self._data.get('salt')
516+
517+
518+
class ListUsersPage(object):
519+
"""Represents a page of user records exported from a Firebase project.
520+
521+
Provides methods for traversing the user accounts included in this page, as well as retrieving
522+
subsequent pages of users. The iterator returned by ``iterate_all()`` can be used to iterate
523+
through all users in the Firebase project starting from this page.
524+
"""
525+
526+
def __init__(self, download, page_token, max_results):
527+
self._download = download
528+
self._max_results = max_results
529+
self._current = download(page_token, max_results)
530+
531+
@property
532+
def users(self):
533+
"""A list of ``ExportedUserRecord`` instances available in this page."""
534+
return [ExportedUserRecord(user) for user in self._current.get('users', [])]
535+
536+
@property
537+
def next_page_token(self):
538+
"""Page token string for the next page (empty string indicates no more pages)."""
539+
return self._current.get('nextPageToken', '')
540+
541+
@property
542+
def has_next_page(self):
543+
"""A boolean indicating whether more pages are available."""
544+
return bool(self.next_page_token)
545+
546+
def get_next_page(self):
547+
"""Retrieves the next page of user accounts, if available.
548+
549+
Returns:
550+
ListUsersPage: Next page of users, or None if this is the last page.
551+
"""
552+
if self.has_next_page:
553+
return ListUsersPage(self._download, self.next_page_token, self._max_results)
554+
return None
555+
556+
def iterate_all(self):
557+
"""Retrieves an iterator for user accounts.
558+
559+
Returned iterator will iterate through all the user accounts in the Firebase project
560+
starting from this page. The iterator will never buffer more than one page of users
561+
in memory at a time.
562+
563+
Returns:
564+
iterator: An iterator of ExportedUserRecord instances.
565+
"""
566+
return _user_mgt.UserIterator(self)
567+
417568

418569
class _ProviderUserInfo(UserInfo):
419570
"""Contains metadata regarding how a user is known by a particular identity provider."""

0 commit comments

Comments
 (0)