diff --git a/.codeclimate.yml b/.codeclimate.yml index 8f566cb..2dabf02 100644 --- a/.codeclimate.yml +++ b/.codeclimate.yml @@ -4,7 +4,7 @@ engines: radon: enabled: true config: - python_version: 2 + python_version: 3 ratings: paths: diff --git a/.gitignore b/.gitignore index e105803..5c83f1f 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,7 @@ *.pyc env/ docs/_build +__pycache__/ +.python-version + +sqlalchemy.db diff --git a/circle.yml b/circle.yml index 8f2aeec..d4501fd 100644 --- a/circle.yml +++ b/circle.yml @@ -1,7 +1,7 @@ # circle.yml machine: python: - version: 2.7.3 + version: 3.5.1 services: - docker @@ -9,6 +9,6 @@ test: pre: - pip install coveralls override: - - nosetests tests --with-coverage --cover-package=jsonapi_collections + - nosetests tests --with-coverage --cover-package=jsonapi_query post: - coveralls diff --git a/jsonapi_collections/__init__.py b/jsonapi_collections/__init__.py deleted file mode 100644 index cc41b43..0000000 --- a/jsonapi_collections/__init__.py +++ /dev/null @@ -1,245 +0,0 @@ -# -*- coding: utf-8 -*- -from jsonapi_collections.drivers.sqlalchemy import SQLAlchemyDriver -from jsonapi_collections.errors import JSONAPIError -from jsonapi_collections.filter import FilterParameter -from jsonapi_collections.include import IncludeValue -from jsonapi_collections.sort import SortValue - - -class Resource(object): - """The `Resource` object acts as the central orchestration object. - - Supplying the minimum set of a arguments, `model` and `parameters`, - offers the full suite of filtering and sorting capability to - `SQLAlchemy` query objects. Optionally specifying a `driver` and - `schema` allows for intermediary libraries and implementations to - control how `SQLAlchemy` columns are accessed from their JSONAPI - field representations. - - To use the `Resource` object you must first initialize it. Once - initialized, call the `filter_query` or `sort_query` methods - against a `SQLAlchemy` query object. The response from both - methods will be a permutated query object. - """ - - ERRORS = { - 'field': 'Invalid field specified: {}.', - 'parameter': 'Invalid parameter specified: {}.', - 'value': 'Invalid parameter value specified for {}.' - } - - def __init__(self, model, parameters, driver=None, schema=None): - """Initialize the collection controller. - - :param model: SQLAlchemy model class. - :param parameters: Dictionary of parameter, value pairs. - :param driver: `jsonapi_collections` driver instance. - :param schema: Schema to validate against. - If `None`, the key, value pairs will be validated against - the model. - """ - self.model = model - self.schema = schema or model - self.driver = driver(self) if driver else SQLAlchemyDriver(self) - self.parameters = self._handle_parameters(parameters) - - def filter_query(self, query): - """Filter a given query by a set of parameters or error. - - A successful call to this method will result in a permutated - filtered query object. You can filter or sort your query - object at any point before or after calling this method. The - ordering of the filters and sorts are irrelevant. - - A failed call will result in suppressed FieldError messages - being marshaled to the current context. The list of errors are - formatted according to the JSONAPI 1.0 specification. The - `generate` method will return a list of errors each pointing to - the specific parameter that caused the failure. - - Errors raised by this method are intended for general, public - view and should not be suppressed by default. - - :param query: `SQLAlchemy` query object. - """ - field_names = self.parameters.get('filters', {}) - filters, errors = FilterParameter.generate(self.driver, field_names) - if errors: - raise JSONAPIError(errors) - return FilterParameter.filter_by(query, filters) - - def sort_query(self, query): - """Sort a given query by a set of the `sort` parameter or error. - - A successful call to this method will result in a permutated, - sorted query object. You can filter or sort your query object - at any point before or after calling this method. The ordering - of the filters and sorts are irrelevant. - - A failed call will result in suppressed FieldError messages - being marshaled to the current context. The list of error - messages are formatted according to the JSONAPI 1.0 - specification. The `generate` method will return exactly one - error regardless of the number of sorts. The number of errors - detected can be measured by the length of the error message - list. - - Errors raised by this method are intended for general, public - view and should not be suppressed by default. - - :param query: `SQLAlchemy` query object. - """ - field_names = self.parameters.get('sort', []) - sorts, error = SortValue.generate(self.driver, field_names) - if error: - raise JSONAPIError([error]) - return SortValue.sort_by(query, sorts) - - def paginate_query(self, query): - """Paginate and retrieve a list of models.""" - page = self.parameters['page'] - return query.limit(page['limit']).offset(page['offset']) - - def compound_response(self, models): - """Compound a response object. - - :params models: List of `SQLAlchemy` model instances. - """ - if not isinstance(models, list): - models = [models] - field_names = self.parameters.get('include', []) - includes, error = IncludeValue.generate(self.driver, field_names) - if error: - raise JSONAPIError([error]) - - included = [] - for model in models: - included.extend(IncludeValue.include(model, includes)) - return included - - def _handle_parameters(self, parameters): - """Return a formatted JSONAPI parameters object.""" - errors = [] - - filters, err = self._get_filtered_fields(parameters) - errors.extend(err) - - include, err = self._get_included_parameters(parameters) - errors.extend(err) - - sort, err = self._get_sorted_parameters(parameters) - errors.extend(err) - - page, err = self._get_pagination_parameters(parameters) - errors.extend(err) - - if errors: - raise JSONAPIError(errors) - return { - 'filters': filters, 'include': include, 'sort': sort, 'page': page - } - - def _get_filtered_fields(self, parameters): - """Return a dictionary of field, value pairs to filter by. - - By specifying a parameter, you are requesting for it to be - filtered. Empty strings are interpreted as being `None` type - filters. The following `filter[field]=` is interpreted as - `WHERE field IS NULL`. - - :param parameters: A dictionary of parameters specified during init. - """ - errors = [] - filters = {} - for key, value in parameters.iteritems(): - if not key.startswith('filter['): - continue - - if not self.driver.validate_attribute_path(key[7:-1]): - errors.append({ - 'detail': self.ERRORS['parameter'].format(key), - 'source': {'parameter': key} - }) - continue - - if value == '': - value = None - else: - value = value.split(',') - filters[key[7:-1]] = value - return filters, errors - - def _get_pagination_parameters(self, parameters): - """Return a dictionary of parameter, value pairs to paginate by.""" - errors = [] - pagination_parameters = {} - - limits = ['limit', 'page[size]', 'page[limit]'] - offsets = ['offset', 'page[number]', 'page[offset]'] - for key, value in parameters.iteritems(): - if key not in limits and key not in offsets: - continue - try: - pagination_parameters[key] = int(value) - except ValueError: - errors.append({ - 'detail': self.ERRORS['value'].format(key), - 'source': {'parameter': key} - }) - - page = {'limit': 100, 'offset': 0} - for key, value in pagination_parameters.iteritems(): - if key == 'page[limit]' or key == 'limit' or key == 'page[size]': - page['limit'] = value - elif key == 'page[offset]' or key == 'offset': - page['offset'] = value - elif key == 'page[number]': - page['offset'] = value * parameters.get('page[size]', 0) - return page, errors - - def _get_included_parameters(self, parameters): - """Return a list of field names. - - If the `key` parameter is specified but does not contain a - value then the `key` key will be ignored. - - :param key: String reference to dictionary key. - :param parameters: Dictionary of parameters specified during init. - """ - errors = [] - fields = parameters.get('include', '') - if fields == '': - return [], errors - - fields = fields.split(',') - for field in fields: - if not self.driver.validate_relationship_path(field): - errors.append({ - 'detail': self.ERRORS['field'].format(field), - 'source': {'parameter': 'include'} - }) - return fields, errors - - def _get_sorted_parameters(self, parameters): - """Return a list of field names. - - If the `key` parameter is specified but does not contain a - value then the `key` key will be ignored. - - :param key: String reference to dictionary key. - :param parameters: Dictionary of parameters specified during init. - """ - errors = [] - fields = parameters.get('sort', '') - if fields == '': - return [], errors - - fields = fields.split(',') - for field in fields: - validated_field = field[1:] if field.startswith('-') else field - if not self.driver.validate_attribute_path(validated_field): - errors.append({ - 'detail': self.ERRORS['field'].format(field), - 'source': {'parameter': 'sort'} - }) - return fields, errors diff --git a/jsonapi_collections/drivers/__init__.py b/jsonapi_collections/drivers/__init__.py deleted file mode 100644 index 9986dd3..0000000 --- a/jsonapi_collections/drivers/__init__.py +++ /dev/null @@ -1,78 +0,0 @@ -# -*- coding: utf-8 -*- -from jsonapi_collections.errors import FieldError - - -class BaseDriver(object): - """Extensible driver template for interacting with `SQLAlchemy`. - - Drivers act a bindings for validating and extracting relationship - and attribute data. - - The `BaseDriver` template provides generic `SQLAlchemy` bindings as - well as overridable methods for interacting with third-party schemas. - """ - - def __init__(self, collection): - """DO NOT OVERRIDE. - - :param collection: A `jsonapi_collections` Collection instance. - """ - self.collection = collection - - def get_column_model(self, column): - """Get the parent model of a relationship.""" - if self.is_relationship(column): - return column.property.mapper.class_ - raise FieldError('Invalid relationship specified.') - - def is_relationship(self, column): - """Determine if a field is a relationship.""" - return hasattr(column.property, 'mapper') - - def get_column_type(self, column): - """Return the column's Python type.""" - return column.property.columns[0].type.python_type - - def is_enum(self, column): - """Determine if a column is an enumeration.""" - if hasattr(column.property.columns[0].type, 'enums'): - return True - return False - - def get_column(self, field): - """Return a `SQLAlchemy` column instance. - - :param field: A schema field instance. - """ - raise NotImplementedError - - def get_column_name(self, field_name): - """Return a string reference to a model column.""" - raise NotImplementedError - - def get_field(self, field_name): - """Return a schema field instance. - - :param field_name: A string reference to a schema's field name. - """ - raise NotImplementedError - - def get_related_schema(self, field): - """Return a related schema reference.""" - raise NotImplementedError - - def deserialize(self, column, values, schema=None): - """Parse a set of values into the appropriate type.""" - raise NotImplementedError - - def serialize(self, models): - """Serialize a set of SQLAlchemy instances.""" - raise NotImplementedError - - def validate_attribute_path(self, path): - """Return `False` if the last member is not a valid attribute.""" - raise NotImplementedError - - def validate_relationship_path(self, path): - """Return `False` if all members are not valid relationships.""" - raise NotImplementedError diff --git a/jsonapi_collections/drivers/marshmallow.py b/jsonapi_collections/drivers/marshmallow.py deleted file mode 100644 index 0e68e3c..0000000 --- a/jsonapi_collections/drivers/marshmallow.py +++ /dev/null @@ -1,82 +0,0 @@ -# -*- coding: utf-8 -*- -from jsonapi_collections.drivers import BaseDriver -from jsonapi_collections.errors import FieldError - - -class MarshmallowDriver(BaseDriver): - """Marshmallow bindings.""" - - def get_column(self, column_name, model=None): - """Return a column instance.""" - return getattr(model or self.collection.model, column_name) - - def get_column_name(self, field_name, schema=None): - """Return a string reference to a model column.""" - field = self.get_field(field_name, schema) - return field.attribute or field_name - - def get_field(self, field_name, schema=None): - """Return a marshmallow field instance.""" - schema = schema or self.collection.schema - field = schema._declared_fields.get(field_name) - if field is None: - raise FieldError('Invalid field specified.') - return field - - def get_related_schema(self, field): - """Return a related schema reference.""" - schema = getattr(field, 'schema', None) - if schema is None: - raise FieldError('Invalid relationship specified.') - return schema - - def deserialize(self, column, field_name, values, schema=None): - """Deserialize a given set of values into their python types.""" - field = self.get_field(field_name, schema) - try: - return [field.deserialize(value) for value in values] - except Exception as exc: - if exc.__class__.__name__ == 'ValidationError': - raise FieldError(exc) - raise - - def serialize(self, schema, items): - return schema(many=True).dump(items).data.get('data', []) - - def validate_attribute_path(self, path): - """Return `False` if the provided path cannot be found.""" - fields = path.split('.') - length = len(fields) - - model = None - schema = None - for pos, field in enumerate(fields, 1): - try: - column_name = self.get_column_name(field, schema) - column = self.get_column(column_name, model) - except FieldError: - return False - - if pos != length: - if not self.is_relationship(column): - return False - model = column.property.mapper.class_ - if pos == length and self.is_relationship(column): - return False - return True - - def validate_relationship_path(self, path): - """Return `False` if the path cannot be found.""" - model = None - schema = None - for field in path.split('.'): - try: - column_name = self.get_column_name(field, schema) - column = self.get_column(column_name, model) - except FieldError: - return False - - if not self.is_relationship(column): - return False - model = self.get_column_model(column) - return True diff --git a/jsonapi_collections/drivers/sqlalchemy.py b/jsonapi_collections/drivers/sqlalchemy.py deleted file mode 100644 index 1e6072e..0000000 --- a/jsonapi_collections/drivers/sqlalchemy.py +++ /dev/null @@ -1,118 +0,0 @@ -# -*- coding: utf-8 -*- -from datetime import datetime -from decimal import Decimal, InvalidOperation - -from jsonapi_collections.drivers import BaseDriver -from jsonapi_collections.errors import FieldError - -import json - - - -class UnsafeEncoder(json.JSONEncoder): - """Do not use an encoder like this in production. You need to have - your own specialized security concious encoder. - """ - - def default(self, obj): - fields = {} - columns = [ - x for x in dir(obj) if not x.startswith('_') and - x != 'metadata'] - for column in columns: - data = obj.__getattribute__(column) - try: - json.dumps(data) - fields[column] = data - except TypeError: - fields[column] = None - return fields - - -class SQLAlchemyDriver(BaseDriver): - """SQLAlchemy bindings.""" - - def get_column(self, column_name, model=None): - """Return a column instance.""" - return self.get_field(column_name, model) - - def get_column_name(self, field_name, schema=None): - """Return a string reference to a model column.""" - return field_name - - def get_field(self, field_name, schema=None): - """Return a SQLAlchemy column instance. - - :param field_name: A string reference to a field's name. - """ - field = getattr(schema or self.collection.model, field_name, None) - if field is None: - raise FieldError('Invalid field specified: {}.'.format(field_name)) - return field - - def get_related_schema(self, field): - """Return a related schema reference.""" - return self.get_column_model(field) - - def deserialize(self, column, field_name, values, schema=None): - """Deserialize a set of values.""" - field = self.get_field(field_name, schema) - return [self._deserialize(field, value) for value in values] - - def _deserialize(self, column, value): - """Deserialize a value into its Python type.""" - if self.is_enum(column) and value not in self._enum_choices(column): - raise FieldError('Not a valid choice.') - - if value == '': - return None - - column_type = self.get_column_type(column) - try: - if column_type == datetime: - return datetime.strptime(value, '%Y-%m-%d') - elif column_type in [bool, int, Decimal]: - return column_type(value) - except (ValueError, InvalidOperation) as exc: - raise FieldError(exc.message) - return value - - def _enum_choices(self, column): - """Return a set of choices.""" - return column.property.columns[0].type.enums - - def serialize(self, schema, items): - """Dangerously serialize `SQLAlchemy` model instance.""" - return [json.dumps(item, cls=UnsafeEncoder) for item in items] - - def validate_attribute_path(self, path): - """Return `False` if the provided path cannot be found.""" - fields = path.split('.') - length = len(fields) - model = None - for pos, field in enumerate(fields, 1): - try: - field = self.get_field(field, model) - except FieldError: - return False - if pos != length: - if not self.is_relationship(field): - return False - model = field.property.mapper.class_ - if pos == length and self.is_relationship(field): - return False - return True - - def validate_relationship_path(self, path): - """Return `False` if the path cannot be found.""" - model = None - for field in path.split('.'): - try: - field = self.get_field(field, model) - except FieldError: - return False - - if not self.is_relationship(field): - return False - model = field.property.mapper.class_ - return True diff --git a/jsonapi_collections/errors.py b/jsonapi_collections/errors.py deleted file mode 100644 index cf62f42..0000000 --- a/jsonapi_collections/errors.py +++ /dev/null @@ -1,37 +0,0 @@ -# -*- coding: utf-8 -*- - - -class FieldError(Exception): - """Raised when one or more query parameters could not be found. - - The `FieldError` exception is exclusively used within driver - classes. Failing to find the specified attribute, relationship - attribute, or relationship schema constitues raising a - `FieldError`. - """ - - pass - - -class JSONAPIError(Exception): - """Raised when one or more field errors have been found. - - The `JSONAPIError` exception aggregates a collection of JSONAPI - formatted error messages. If the specified error message is a - string, instead of a list, the message is encapsulated within an - object with a key of `detail`. - """ - - def __init__(self, message, code=400): - """Format the error's message to match the JSONAPI 1.0 spec. - - :param message: A list errors or a string message. - :param code: HTTP status code. - """ - data = {'status': code} - if isinstance(message, str): - data.update({'errors': {'detail': message}}) - else: - data.update({'errors': message}) - self.message = data - super(JSONAPIError, self).__init__(data) diff --git a/jsonapi_collections/filter.py b/jsonapi_collections/filter.py deleted file mode 100644 index f274eb3..0000000 --- a/jsonapi_collections/filter.py +++ /dev/null @@ -1,186 +0,0 @@ -# -*- coding: utf-8 -*- -from datetime import datetime, timedelta - -from jsonapi_collections.errors import FieldError -from sqlalchemy import and_, or_ - - -class FilterParameter(object): - """Formulate a query filter.""" - - relationship = None - - def __init__(self, driver, field_name, values): - """Set the column, driver, many, and values attributes. - - :param driver: `jsonapi_collections` driver instance. - :param field_name: A string representation of a schema field. - :param values: A list of typed values to filter by. - """ - self.driver = driver - self.many = len(values) > 1 - - if "." in field_name: - relationship_name, field_name = field_name.split('.') - - column_name = self.driver.get_column_name(relationship_name) - self.relationship = self.driver.get_column(column_name) - - relationship_field = self.driver.get_field(relationship_name) - schema = self.driver.get_related_schema(relationship_field) - - model = self.driver.get_column_model(self.relationship) - column_name = self.driver.get_column_name(field_name, schema) - self.column = self.driver.get_column(column_name, model) - else: - column_name = self.driver.get_column_name(field_name) - self.column = self.driver.get_column(column_name) - schema = None - - self.values = self.driver.deserialize( - self.column, field_name, values, schema) - - def __call__(self): - """Create a `SQLAlchemy` query expression. - - Filters are constructed with three considerations: - * Is this query occuring across a relationship? - * Is this query one-to-many or many-to-many? - * Does the query need to evaluate more than one value? - - If the query is not filtering across a relationship column, we - can return the filters formulated by the `_prepare_strategies` - call. Multiple strategies are wrapped in an `or_` function. - - If the query is a relationship, we determine whether or not the - relationship has many related models or has one related model. - - If a many-to-many relationship is detected, we query the values - with the `any` method. If a one-to-many relationship is - detected, we query the values with the `has` method. - - If the query has more than one strategy it needs to executed - as an argument to the `or_` function. - """ - filters = self._prepare_strategies(self.values) - - if self.relationship is None: - if self.many: - return or_(*filters) - return filters[0] - - if self.relationship.property.uselist: - wrapper = self.relationship.any - else: - wrapper = self.relationship.has - - if self.many: - return wrapper(or_(*filters)) - return wrapper(*filters) - - @property - def column_type(self): - """Extract the column's type.""" - return self.driver.get_column_type(self.column) - - def _prepare_strategies(self, values): - """Return a set of filters. - - The `_prepare_strategies` method calls `_prepare_strategy` in a - loop and returns the aggregate set of strategies. - - :param values: List of typed values to filter with. - """ - filters = [] - for value in values: - filters.append(self._prepare_strategy(value)) - return filters - - def _prepare_strategy(self, value): - """Return a `SQLAlchemy` query expression. - - The `_prepare_strategy` method considers three things: - * Are you filtering against an Enum column instance? - * Are you filtering with a `None` type value. - * Are you filtering against some specifically handled type. - - If the column is determined to be an enumeration then no - meaningful filtering can occur other than a simple equality - check. - - If the value is determined to be a `None` type then no - meaningful filtering can occur other than a simple equality - check. - - If the column's type is determined to be one of the special - cases, specialized filtering can occur. This filtering is not - specified by the JSONAPI 1.0 specification. Strings columns - are filtered as a wildcard search. Boolean columns are - filtered using the `is_` method. Datetime columns retrieve all - datetimes within the current day. - - If the column is not a special type, a simple equality check - against the value is returned. - - :param value: Typed value to filter with. - """ - if self.driver.is_enum(self.column) or value is None: - return self.column == value - - if self.column_type == str: - return self.column.ilike('%{}%'.format(value)) - elif self.column_type == bool: - return self.column.is_(value) - elif self.column_type == datetime: - tomorrow = value + timedelta(days=1) - return and_(self.column >= value, self.column < tomorrow) - return self.column == value - - @classmethod - def generate(cls, driver, parameters): - """Parse field, value pairs into `FilterParameter` instances. - - The `generate` classmethod bulk initializes a set of field, - value pairs into the `FilterParameter` class. - - A successful generation results in the `filters` list being - appended the newly formed instance. A failed generation - results in a JSONAPI 1.0 specification error object being - appended to the errors list. - - This method can return both a set of filters and a set of - errors. It is recommended that you evaluate the errors - recieved before continuing. - - :param driver: `jsonapi_collections` driver instance. - :param parameters: A dictionary of field, value pairs. - """ - filters = [] - errors = [] - for field_name, values in parameters.iteritems(): - try: - filters.append(cls(driver, field_name, values)) - except FieldError as exc: - message = { - "source": { - "parameter": 'filter[{}]'.format(field_name) - }, - "detail": exc.message - } - errors.append(message) - return filters, errors - - @staticmethod - def filter_by(query, filters): - """Apply a series of `FilterParameter` instances as query filters. - - The `filter_by` staticmethod acts as a helper method to always - ensure that API changes do not disrupt the general process of - filter application. - - :param query: `SQLAlchemy` query object. - :param filters: List of `FilterParameter` instances. - """ - for filter in filters: - query = query.filter(filter()) - return query diff --git a/jsonapi_collections/include.py b/jsonapi_collections/include.py deleted file mode 100644 index 6e9463a..0000000 --- a/jsonapi_collections/include.py +++ /dev/null @@ -1,96 +0,0 @@ -# -*- coding: utf-8 -*- -from jsonapi_collections.errors import FieldError - - -class IncludeValue(object): - """Validate and include a given relationship field. - - The `IncludeValue` object is responsible for validating the - provided fields and including a properly serialized data response - based on the relationships present on the primary resource. - """ - - def __init__(self, driver, field_name): - """Validate the inputs. - - :param schema: Marshmallow schema reference. - :param value: String name of a relationship field. - :param id_field: Optionally specified `ID` field to query. - """ - self.driver = driver - self.field_name = field_name - - def __call__(self, model): - """Return the serailized output of a related field.""" - values, schema = self._get_included_values( - self.field_name.split('.'), self.driver.collection.schema, [model]) - return self.driver.serialize(schema, values) - - def _get_included_values(self, path, schema, values): - """Return a set of model instances and a related serializer. - - :param path: List of schema field names. - :param schema: Serializer class object. - :param values: List of `SQLAlchemy` model instances. - """ - for field_name in path: - column_name = self.driver.get_column_name(field_name, schema) - values = self._get_nested_values(values, column_name) - field = self.driver.get_field(field_name, schema) - schema = self.driver.get_related_schema(field) - return values, schema - - def _get_nested_values(self, values, name): - """Return a set of `SQLAlchemy` model instances. - - This method iterates through a set of relationships and returns - the value of each relationship's specified attribute in a - combined set. - - This method normalizes the value of each relationship to ensure - it is of a list-type. `None` type relationships are skipped - and `one-to-many` and `one-to-one` relationships are typed as a - list. - - :param values: List of `SQLAlchemy` model instances. - :param name: String name of the relationship column to extract. - """ - new_values = [] - for value in values: - relationship = getattr(value, name) - if relationship is None: - continue - elif not isinstance(relationship, list): - relationship = [relationship] - new_values.extend(relationship) - return new_values - - @classmethod - def generate(cls, driver, values): - """Parse a series of strings into `IncludeValue` instances. - - :param driver: `jsonapi_collections` driver instance. - :param values: String list of relationship fields to include. - """ - includes = [] - errors = [] - for value in values: - try: - includes.append(cls(driver, value)) - except FieldError as exc: - errors.append(exc.message) - if errors: - return [], {"source": {"parameter": 'include'}, "detail": errors} - return includes, {} - - @staticmethod - def include(model, includes): - """Dump a series of `IncludeValue` instances to a dictionary. - - :param includes: List of `IncludeValue` instances. - :param model: `SQLAlchemy` model instance. - """ - included = [] - for include in includes: - included.extend(include(model)) - return included diff --git a/jsonapi_collections/sort.py b/jsonapi_collections/sort.py deleted file mode 100644 index 8f9b78f..0000000 --- a/jsonapi_collections/sort.py +++ /dev/null @@ -1,82 +0,0 @@ -# -*- coding: utf-8 -*- -from sqlalchemy import desc - -from jsonapi_collections.errors import FieldError - - -class SortValue(object): - """Formulate a query sort.""" - - def __init__(self, driver, field_name): - """Set a join and a sort reference. - - :param driver: `jsonapi_collections` driver instance. - :param field_name: A string representation of a schema field. - """ - descending = field_name.startswith('-') - if descending: - field_name = field_name[1:] - - if "." in field_name: - relationship_name, attribute_name = field_name.split('.') - - relationship_field = driver.get_field(relationship_name) - relationship_column = driver.get_column(driver.get_column_name( - relationship_name)) - relationship_schema = driver.get_related_schema(relationship_field) - relationship_model = driver.get_column_model(relationship_column) - - column_name = driver.get_column_name( - attribute_name, relationship_schema) - column = driver.get_column(column_name, relationship_model) - self.join = relationship_name - else: - column_name = driver.get_column_name(field_name) - column = driver.get_column(column_name) - self.join = None - - if descending: - self.sort = desc(column) - else: - self.sort = column - - @classmethod - def generate(cls, driver, field_names): - """Parse a series of strings into `SortValue` instances. - - Dot notation can be used to sort by the attributes of a related - schema. E.g. `relationship.attribute`. - - If the string can not be converted, an error is marshaled as a - member of a string list. - - :param driver: `jsonapi_collections` driver reference. - :param field_names: String list of attributes. - """ - sorts = [] - errors = [] - for field_name in field_names: - try: - sorts.append(cls(driver, field_name)) - except FieldError as exc: - errors.append(exc.message) - if errors: - return sorts, {"source": {"parameter": 'sort'}, "detail": errors} - return sorts, None - - @staticmethod - def sort_by(query, values): - """Apply a series of `SortValue` instances to a `SQLAlchemy` query. - - Dot seperated sorts will have the appropriate tables joined - prior to applying the sort. - - :param query: `SQLAlchemy` query object. - :param values: List of `SortValue` instances. - """ - sorts = [] - for value in values: - if value.join is not None: - query = query.join(value.join) - sorts.append(value.sort) - return query.order_by(*sorts) diff --git a/jsonapi_query/__init__.py b/jsonapi_query/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jsonapi_query/compat.py b/jsonapi_query/compat.py new file mode 100644 index 0000000..e69de29 diff --git a/jsonapi_query/database/__init__.py b/jsonapi_query/database/__init__.py new file mode 100644 index 0000000..cb28004 --- /dev/null +++ b/jsonapi_query/database/__init__.py @@ -0,0 +1,25 @@ +from abc import abstractmethod, ABCMeta + + +class BaseQueryMixin(metaclass=ABCMeta): + """Base query class mixin.""" + + @abstractmethod + def apply_filters(self): + return + + @abstractmethod + def apply_filter(self): + return + + @abstractmethod + def apply_sorts(self): + return + + @abstractmethod + def apply_sort(self): + return + + @abstractmethod + def apply_paginators(self): + return diff --git a/jsonapi_query/database/sqlalchemy.py b/jsonapi_query/database/sqlalchemy.py new file mode 100644 index 0000000..a37819c --- /dev/null +++ b/jsonapi_query/database/sqlalchemy.py @@ -0,0 +1,235 @@ +"""SQLAlchemy jsonapi-query adapter.""" +from sqlalchemy.orm import aliased +from jsonapi_query.database import BaseQueryMixin + + +class QueryMixin(BaseQueryMixin): + """SQLAlchemy query class mixin.""" + + default_limit = 50 + default_offset = 0 + + def apply_filters(self, filters): + """Return a query object filtered by a set of column, value pairs. + + :param filters: Triple of column, strategy, and values arguments. + """ + for column, strategy, values, joins in filters: + self = self.apply_filter(column, strategy, values, joins) + return self + + def apply_filter(self, column, strategy, values, joins=[]): + """Return a query object filtered by a column, value pair. + + :param column: SQLAlchemy column object. + :param strategy: Query filter string name reference. + :param values: List of typed values. + :param joins: List of SQLAlchemy mapper objects. + """ + column, classes = self._alias_mappers(column, joins) + for pos, class_ in enumerate(classes): + self = self.join(class_, joins[pos]) + + negated = strategy.startswith('~') + if negated: + strategy = strategy[1:] + + if strategy == 'in': + if negated: + return self.filter(~column.in_(values)) + return self.filter(column.in_(values)) + + if strategy == 'eq': + strategy = self._filter_eq + elif strategy == 'gt': + strategy = self._filter_gt + elif strategy == 'gte': + strategy = self._filter_gte + elif strategy == 'lt': + strategy = self._filter_lt + elif strategy == 'lte': + strategy = self._filter_lte + elif strategy == 'like': + strategy = self._filter_like + elif strategy == 'ilike': + strategy = self._filter_ilike + else: + raise ValueError('Invalid query strategy: {}'.format(strategy)) + + filters = self._get_filters(column, values, strategy, negated) + return self.filter(filters) + + def _get_filters(self, column, values, strategy, negated=False): + filters = None + for value in values: + expression = strategy(column, value) + if negated: + expression = ~expression + if filters is None: + filters = expression + else: + filters = filters | expression + return filters + + def _filter_eq(self, column, value): + return column == value + + def _filter_gt(self, column, value): + return column > value + + def _filter_gte(self, column, value): + return column >= value + + def _filter_lt(self, column, value): + return column < value + + def _filter_lte(self, column, value): + return column <= value + + def _filter_like(self, column, value): + return column.contains(value) + + def _filter_ilike(self, column, value): + return column.ilike('%{}%'.format(value)) + + def apply_sorts(self, sorts): + """Return a query object sorted by a set of columns. + + :param sorts: Triple of direction, column, and joins arguments. + """ + for direction, column, joins in sorts: + self = self.apply_sort(direction, column, joins) + return self + + def apply_sort(self, column, direction, joins=[]): + """Return a query object sorted by a column. + + :param column: SQLAlchemy column object. + :param direction: Query sort direction reference. + :param join: List of SQLAlchemy model objects. + """ + column, classes = self._alias_mappers(column, joins) + for pos, class_ in enumerate(classes): + self = self.join(class_, joins[pos]) + + if direction == '-': + column = column.desc() + return self.order_by(column) + + def apply_paginators(self, paginators): + """Return a query object paginated by a limit and offset value. + + :param paginators: List of stategy and value arguments. + """ + pagination = { + 'limit': self.default_limit, + 'offset': self.default_offset + } + pagination.update({ + strategy: int(value) for strategy, value in paginators}) + if 'number' in pagination: + limit = pagination['limit'] + pagination['offset'] = pagination['number'] * limit - limit + return self.limit(pagination['limit']).offset(pagination['offset']) + + def include(self, mappers): + """Return an additional set of data with the query. + + For a given query, join a set of mappers and select an aliased + entity. The mappers may chain off one another. + + This method does not return a stable data-type. If no mappers + are included, the response type will be a model instance or a + list of model instances. With a set of mappers a tuple or set + of tuples will be returned of length `mappers` + 1. + + Note, if at any point you intend to use `filter_by(column=value)`, + it is highly recommend you call it prior to calling the include + method. Using `filter_by` will select from the most recently + joined aliased model. Using `filter(Model.column == value)` is + safe and will filter as expected. + + :param mappers: A list of SQLAlchemy mapper objects. + """ + if mappers == []: + return self + + selects = [aliased(_get_mapper_class(mapper)) for mapper in mappers] + for pos, select in enumerate(selects): + self = self.outerjoin(select, mappers[pos]).add_entity(select) + return self + + def _alias_mappers(self, column, mappers): + classes = [aliased(_get_mapper_class(mapper)) for mapper in mappers] + for pos, class_ in enumerate(classes): + if column.class_ == _get_aliased_class(class_): + column = getattr(class_, column.property.class_attribute.key) + return column, classes + + +def group_and_remove(items, models): + """Group and restructure a list of tuples by like items. + + This function groups a list of tuples by their model type. Tuple + members that do not match any of the provided models will be not + be returned. + + :param items: A list of rows containing column tuples. + :param models: A list of SQLAlchemy model classes. + """ + if items == [None]: + return [[]] + elif items == [] or not isinstance(items[0], tuple): + return [items] + + response = [[] for model in models] + for item in items: + for member in item: + if member is None: + continue + + position = _get_model_position(member, models) + if member not in response[position]: + response[position].append(member) + return response + + +def _get_model_position(model, models): + class_ = model.__class__ + if class_ in models: + return models.index(class_) + + for model in models: + if issubclass(class_, model): + return models.index(model) + + +def _get_aliased_class(x): + return x._aliased_insp.class_ + + +def _get_mapper_class(mapper): + return mapper.property.mapper.class_ + + +def include(session, model, columns, joins, ids): + """Query a list of models restricted by the filter_model's ID. + + :param session: SQLAlchemy query session. + :param model: SQLAlchemy model class. + :param columns: A list of SQLAlchemy model classes. + :param joins: A list of relationship mappers. + :param ids: A list of IDs to filter by. + """ + if columns == [] or ids == []: + return [] + + selects = [aliased(_get_mapper_class(join)) for join in joins] + query = session.query(*selects).filter(model.id.in_(ids)) + for join in joins: + for select in selects: + if _get_mapper_class(join) == _get_aliased_class(select): + query = query.outerjoin(select, join) + selects.remove(select) + break + return group_and_remove(query.all(), columns) diff --git a/jsonapi_query/errors.py b/jsonapi_query/errors.py new file mode 100644 index 0000000..69fe8c0 --- /dev/null +++ b/jsonapi_query/errors.py @@ -0,0 +1,12 @@ + + +class JSONAPIQueryError(Exception): + """All errors raised from this module will subclass this class.""" + + +class PathError(JSONAPIQueryError): + """Raised when a Python path can not be derived.""" + + +class DataError(JSONAPIQueryError): + """Raised when a value can not be deserialized to a Python type.""" diff --git a/jsonapi_query/translation/__init__.py b/jsonapi_query/translation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jsonapi_query/translation/model/__init__.py b/jsonapi_query/translation/model/__init__.py new file mode 100644 index 0000000..6ada7f1 --- /dev/null +++ b/jsonapi_query/translation/model/__init__.py @@ -0,0 +1,18 @@ +from abc import abstractmethod, ABCMeta + + +class BaseModelDriver(metaclass=ABCMeta): + """Base model driver.""" + + def __init__(self, model, default_attribute='id'): + """Setup the driver. + + :param model: SQLAlchemy model reference. + """ + self.model = model + self.default_attribute = default_attribute + + @abstractmethod + def parse_path(self, path): + """Parse a string path to a column attribute.""" + return diff --git a/jsonapi_query/translation/model/sqlalchemy.py b/jsonapi_query/translation/model/sqlalchemy.py new file mode 100644 index 0000000..0ffe286 --- /dev/null +++ b/jsonapi_query/translation/model/sqlalchemy.py @@ -0,0 +1,75 @@ +"""SQLAlchemy model translation module.""" +from jsonapi_query.errors import PathError +from jsonapi_query.translation.model import BaseModelDriver + + +class SQLAlchemyModelDriver(BaseModelDriver): + """Translate string paths to column references.""" + + def parse_path(self, path): + """Parse a string path to a column attribute. + + Return a triple of a column attribute, the intermediary models, + and the list of joins required to get the attribute. + + :param path: A dot seperated string representing the attributes + of a model class. + """ + if path == '': + return (None, [], []) + + stones = path.split('.') + relationships, attribute = stones[:-1], stones[-1] + + model = self.model + joins = [] + models = [] + for relationship in relationships: + reference = self._get_relationship(relationship, model) + joins.append(self._extract_join(reference, model)) + model = self._get_relationship_class(reference) + models.append(model) + + column = self._get_column(attribute, model) + if self._is_relationship(column): + joins.append(self._extract_join(column, model)) + model = self._get_relationship_class(column) + models.append(model) + column = self._get_column(self.default_attribute, model) + return (column, models, joins) + + def _get_column(self, attribute, model): + try: + return getattr(model, attribute) + except AttributeError: + raise PathError('Invalid path specified.') + + def _get_relationship(self, attribute, model): + column = self._get_column(attribute, model) + if not self._is_relationship(column): + raise PathError('Invalid field type specified.') + return column + + def _get_relationship_class(self, relationship): + """Return the class reference of the given relationship.""" + return relationship.property.mapper.class_ + + def _is_relationship(self, relationship): + """Return `True` if a relationship mapper was specified.""" + try: + self._get_relationship_class(relationship) + except AttributeError: + return False + return True + + def _extract_join(self, mapper, model): + """Extract the join condition for a given relationship mapper.""" + # If the mapper is not self-referential we can return it. + if model != self._get_relationship_class(mapper): + return mapper + + # If the mapper is self-referential we want to join its backref. + if mapper.property.backref is not None: + return mapper.property.backref + else: + return getattr(model, mapper.property.back_populates) diff --git a/jsonapi_query/translation/view/__init__.py b/jsonapi_query/translation/view/__init__.py new file mode 100644 index 0000000..c2c203f --- /dev/null +++ b/jsonapi_query/translation/view/__init__.py @@ -0,0 +1,32 @@ +from abc import abstractmethod, ABCMeta + + +class BaseViewDriver(metaclass=ABCMeta): + """Base view driver.""" + + def __init__(self, view): + """Setup the view driver. + + :param view: Schema object reference. + """ + self.view = view + + @abstractmethod + def initialize_path(self, path): + """Initialize a specified attribute path.""" + return + + @abstractmethod + def get_model_path(self): + """Return a model-safe path.""" + return + + @abstractmethod + def deserialize_values(self, values): + """Deserialize a set of values into their appropriate types.""" + return + + @abstractmethod + def deserialize_value(self, field, value): + """Deserialize a string value to the appropriate type.""" + return diff --git a/jsonapi_query/translation/view/marshmallow_jsonapi.py b/jsonapi_query/translation/view/marshmallow_jsonapi.py new file mode 100644 index 0000000..848afce --- /dev/null +++ b/jsonapi_query/translation/view/marshmallow_jsonapi.py @@ -0,0 +1,81 @@ +"""marshmallow-jsonapi schema translation module.""" +from jsonapi_query.errors import DataError, PathError +from jsonapi_query.translation.view import BaseViewDriver + + +def remove_inflection(text): + """Replace hyphens with underscores.""" + return text.replace('-', '_') + + +class MarshmallowJSONAPIDriver(BaseViewDriver): + """Schema translation handler.""" + + fields = [] + field_names = [] + schemas = [] + + def initialize_path(self, path): + """Initialize a specified attribute path.""" + self.fields = [] + self.field_names = [] + self.schemas = [] + + path = remove_inflection(path) + if path == '': + return self + + stones = path.split('.') + relationships, attribute = stones[:-1], stones[-1] + + schema = self.view + for field_name in relationships: + field = self._get_relationship(field_name, schema) + self._append_field_meta(field, field_name) + schema = field.schema + self.schemas.append(schema) + + field = self._get_field(attribute, schema) + if self._is_relationship(field): + self.schemas.append(field.schema) + self._append_field_meta(field, attribute) + return self + + def _append_field_meta(self, field, field_name): + self.fields.append(field) + self.field_names.append(field.attribute or field_name) + + def get_model_path(self): + """Return a model-safe path.""" + return '.'.join(self.field_names) + + def deserialize_values(self, values): + """Deserialize a set of values into their appropriate types.""" + new = [] + for value in values: + new.append(self.deserialize_value(self.fields[-1], value)) + return new + + def deserialize_value(self, field, value): + """Deserialize a string value to the appropriate type.""" + try: + if value == '': + return None + return field._deserialize(value, None, None) + except: + raise DataError('Invalid value specified.') + + def _get_field(self, attribute, schema): + try: + return schema._declared_fields[attribute] + except KeyError: + raise PathError('Invalid path specified.') + + def _get_relationship(self, attribute, schema): + field = self._get_field(attribute, schema) + if not self._is_relationship(field): + raise PathError('Invalid field type specified.') + return field + + def _is_relationship(self, field): + return hasattr(field, 'schema') diff --git a/jsonapi_query/url.py b/jsonapi_query/url.py new file mode 100644 index 0000000..51c9f4f --- /dev/null +++ b/jsonapi_query/url.py @@ -0,0 +1,82 @@ +""".""" +from urllib.parse import parse_qsl, urlparse + + +STRATEGIES = ['eq', 'gt', 'gte', 'lt', 'lte', 'in', 'like', 'ilike'] +STRATEGY_PARTITION = ':' + + +def get_parameters(url: str) -> dict: + """Convert a URL into a dictionary of parameter, value pairs.""" + parsed_url = urlparse(url) + return {key: value for key, value in parse_qsl(parsed_url.query)} + + +def get_includes(parameters: dict) -> list: + """Return a list of relationships to include. + + :param parameters: Dictionary of parameter name, value pairs. + """ + return parameters.get('include', '').split(',') + + +def get_sorts(parameters: dict) -> list: + """Return a list of doubles to sort by. + + :param parameters: Dictionary of parameter name, value pairs. + """ + sorts = [] + for sort in parameters.get('sort', '').split(','): + if sort.startswith('-') or sort.startswith('+'): + sorts.append((sort[1:], sort[:1])) + elif sort != '': + sorts.append((sort, '+')) + return sorts + + +def get_filters(parameters: dict) -> list: + """Return a list of triples to filter by. + + :param parameters: Dictionary of parameter name, value pairs. + """ + filters = [] + for parameter, value in parameters.items(): + if parameter.startswith('filter[') and parameter.endswith(']'): + filters.append(_get_filter(parameter, value)) + return filters + + +def _get_filter(key: str, value: str) -> tuple: + """Return a triple to filter by.""" + strategy, partition, values = value.partition(STRATEGY_PARTITION) + + negated = strategy.startswith('~') + if negated: + strategy = strategy[1:] + + if partition == '': + values = strategy + strategy = None + elif strategy not in STRATEGIES: + values = ''.join((strategy, partition, values)) + strategy = None + + if negated: + strategy = '~{}'.format(strategy) + return key[7:-1], strategy, values.split(',') + + +def get_paginators(parameters: dict) -> dict: + """Return a list of doubles to filter by. + + :param parameters: Dictionary of parameter name, value pairs. + """ + paginators = [] + for key, value in parameters.items(): + if key in ['page[size]', 'page[limit]']: + paginators.append(('limit', value)) + elif key == 'page[offset]': + paginators.append(('offset', value)) + elif key == 'page[number]': + paginators.append(('number', value)) + return paginators diff --git a/requirements.txt b/requirements.txt index 500f18d..8eadaa8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,4 @@ -Flask==0.10.1 -Flask-SQLAlchemy==2.1 -Jinja2==2.8 -MarkupSafe==0.23 -SQLAlchemy==1.0.11 -Werkzeug==0.11.3 -itsdangerous==0.24 -marshmallow==2.6.0 -marshmallow-jsonapi==0.5.0 -nose==1.3.7 -wsgiref==0.1.2 -recommonmark -sphinx -sphinx-autobuild -coverage +nose +sqlalchemy +marshmallow +marshmallow-jsonapi diff --git a/setup.py b/setup.py index f1a0094..0f4c93e 100644 --- a/setup.py +++ b/setup.py @@ -1,28 +1,25 @@ -"""Flask-SQLAlchemy-JSONAPI. +"""jsonapi-query. -A URL parsing library that uses documented JSONAPI 1.0 specification -URL parameters to filter, sort, and include JSONAPI response objects. +A JSONAPI compliant query library. """ from setuptools import find_packages, setup setup( - name='Flask-SQLAlchemy-JSONAPI', - version='0.1', + name='jsonapi-query', + version='0.2', url='https://github.com/caxiam/sqlalchemy-jsonapi-collections', license='Apache Version 2.0', author='Colton Allen', author_email='colton.allen@caxiam.com', - description='A collection response filtering library.', + description='A JSONAPI compliant query library.', long_description=__doc__, packages=find_packages(exclude=("test*", )), - package_dir={'flask-sqlalchemy-jsonapi': 'flask-sqlalchemy-jsonapi'}, + package_dir={'jsonapi-query': 'jsonapi-query'}, zip_safe=False, include_package_data=True, platforms='any', - install_requires=[ - 'Flask', 'SQLAlchemy', 'marshmallow', 'marshmallow-jsonapi' - ], + install_requires=[], classifiers=[ 'Environment :: Web Environment', 'Intended Audience :: Developers', diff --git a/tests/__init__.py b/tests/__init__.py index 5d163b8..e69de29 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,34 +0,0 @@ -# -*- coding: utf-8 -*- -from unittest import TestCase -from flask import Flask - -from tests import settings -from tests.database import db - -import sys - - -sys.dont_write_bytecode = True - - -class UnitTestCase(TestCase): - - def setUp(self): - """Create app test client""" - self.app = Flask(__name__) - self.app.config.from_object(settings) - if db is not None: - db.init_app(self.app) - self.app_context = self.app.app_context() - self.app_context.push() - self.client = self.app.test_client() - db.create_all() - db.session.commit() - self.addCleanup(self.cleanup) - - def cleanup(self): - """Tear down database""" - db.session.remove() - db.drop_all() - db.get_engine(self.app).dispose() - self.app_context.pop() diff --git a/tests/database.py b/tests/database.py deleted file mode 100644 index 02c35bf..0000000 --- a/tests/database.py +++ /dev/null @@ -1,12 +0,0 @@ -# -*- coding: utf-8 -*- -from flask import current_app -from flask_sqlalchemy import get_state, SQLAlchemy - - -def save(model): - session = get_state(current_app).db.session - session.add(model) - session.commit() - - -db = SQLAlchemy() diff --git a/tests/functional/__init__.py b/tests/functional/__init__.py new file mode 100644 index 0000000..42b31da --- /dev/null +++ b/tests/functional/__init__.py @@ -0,0 +1,5 @@ +from unittest import TestCase + + +class FunctionalTestCase(TestCase): + pass diff --git a/tests/functional/sqlalchemy_tests.py b/tests/functional/sqlalchemy_tests.py new file mode 100644 index 0000000..9ba887a --- /dev/null +++ b/tests/functional/sqlalchemy_tests.py @@ -0,0 +1,233 @@ +"""End to end query testing.""" +from datetime import datetime + +from sqlalchemy.orm import Query, sessionmaker + +from jsonapi_query import url +from jsonapi_query.database.sqlalchemy import group_and_remove, QueryMixin +from jsonapi_query.translation.model.sqlalchemy import SQLAlchemyModelDriver +from jsonapi_query.translation.view.marshmallow_jsonapi import ( + MarshmallowJSONAPIDriver) +from tests.marshmallow_jsonapi import Person as PersonSchema +from tests.sqlalchemy import BaseSQLAlchemyTestCase, Person, School, Student + + +class SQLAlchemyTestCase(BaseSQLAlchemyTestCase): + + def setUp(self): + super().setUp() + self.m_driver = SQLAlchemyModelDriver(Person) + self.v_driver = MarshmallowJSONAPIDriver(PersonSchema) + + class BaseQuery(QueryMixin, Query): + pass + + self.session = sessionmaker(bind=self.engine, query_cls=BaseQuery)() + self.session.begin_nested() + + date = datetime.strptime('2014-01-01', "%Y-%m-%d").date() + fred = Person(name='Fred', age=5, birth_date=date) + self.session.add(fred) + + date = datetime.strptime('2015-01-01', "%Y-%m-%d").date() + carl = Person(name='Carl', age=10, birth_date=date) + self.session.add(carl) + + school = School(name='School') + self.session.add(school) + school = School(name='College') + self.session.add(school) + + student = Student(school_id=1, person_id=1) + self.session.add(student) + + student = Student(school_id=2, person_id=2) + self.session.add(student) + + def test_filter_query(self): + """Test filtering a query by a url string.""" + link = 'testsite.com/people?filter[age]=lt:10' + params = url.get_parameters(link) + + filters = [] + for fltr in url.get_filters(params): + self.v_driver.initialize_path(fltr[0]) + path = self.v_driver.get_model_path() + values = self.v_driver.deserialize_values(fltr[2]) + filters.append((path, fltr[1], values)) + + new = [] + for fltr in filters: + column, models, joins = self.m_driver.parse_path(fltr[0]) + new.append((column, fltr[1], fltr[2], joins)) + + models = self.session.query(Person).apply_filters(new).all() + self.assertTrue(len(models) == 1) + self.assertTrue(models[0].age == 5) + + def test_filter_query_deeply_nested(self): + """Test filtering a query by a deeply nested url string.""" + link = 'testsite.com/people?filter[student.school.title]=eq:School' + params = url.get_parameters(link) + + filters = [] + for fltr in url.get_filters(params): + self.v_driver.initialize_path(fltr[0]) + path = self.v_driver.get_model_path() + values = self.v_driver.deserialize_values(fltr[2]) + filters.append((path, fltr[1], values)) + + new = [] + for fltr in filters: + column, models, joins = self.m_driver.parse_path(fltr[0]) + new.append((column, fltr[1], fltr[2], joins)) + + models = self.session.query(Person).apply_filters(new).all() + self.assertTrue(len(models) == 1) + self.assertTrue(models[0].name == 'Fred') + + def test_sort_query(self): + """Test sorting by an attribute.""" + link = 'testsite.com/people?sort=age' + params = url.get_parameters(link) + + sorts = [] + for sort in url.get_sorts(params): + self.v_driver.initialize_path(sort[0]) + path = self.v_driver.get_model_path() + sorts.append((path, sort[1])) + + new = [] + for sort in sorts: + column, models, joins = self.m_driver.parse_path(sort[0]) + new.append((column, sort[1], joins)) + + models = self.session.query(Person).apply_sorts(new).all() + self.assertTrue(len(models) == 2) + self.assertTrue(models[0].name == 'Fred') + + def test_sort_query_deeply_nested(self): + """Test sorting by a deeply nested attribute.""" + link = 'testsite.com/people?sort=-student.school.title' + params = url.get_parameters(link) + + sorts = [] + for sort in url.get_sorts(params): + self.v_driver.initialize_path(sort[0]) + path = self.v_driver.get_model_path() + sorts.append((path, sort[1])) + + new = [] + for sort in sorts: + column, models, joins = self.m_driver.parse_path(sort[0]) + new.append((column, sort[1], joins)) + + models = self.session.query(Person).apply_sorts(new).all() + self.assertTrue(len(models) == 2) + self.assertTrue(models[0].name == 'Fred') + self.assertTrue(models[1].name == 'Carl') + + def test_paginate_query_by_limit(self): + """Test paginating a query by the limit strategy.""" + link = 'testsite.com/people?page[limit]=1&page[offset]=1' + params = url.get_parameters(link) + + paginators = url.get_paginators(params) + + models = self.session.query(Person).apply_paginators(paginators).all() + self.assertTrue(len(models) == 1) + self.assertTrue(models[0].name == 'Carl') + + def test_paginate_query_by_page(self): + """Test paginating a query by the number strategy.""" + link = 'testsite.com/people?page[size]=1&page[number]=2' + params = url.get_parameters(link) + + paginators = url.get_paginators(params) + + models = self.session.query(Person).apply_paginators(paginators).all() + self.assertTrue(len(models) == 1) + self.assertTrue(models[0].name == 'Carl') + + def test_include_single_column(self): + """Test including a relationship.""" + link = 'testsite.com/people?include=student' + params = url.get_parameters(link) + + includes = [] + schemas = [] + for include_ in url.get_includes(params): + self.v_driver.initialize_path(include_) + includes.append(self.v_driver.get_model_path()) + schemas.extend(self.v_driver.schemas) + + included_models = [] + mappers = [] + for include_ in includes: + _, models, joins = self.m_driver.parse_path(include_) + included_models.extend(models) + mappers.extend(joins) + + items = self.session.query( + Person).filter_by(id=1).include(mappers).all() + items = group_and_remove(items, [Person] + included_models)[1:] + included = [] + for position, columns in enumerate(items): + schema = schemas[position] + included.extend(schema.dump(columns, many=True).data['data']) + + self.assertTrue(len(schemas) == 1) + self.assertTrue(len(included_models) == 1) + self.assertTrue(len(items) == 1) + self.assertTrue(len(items[0]) == 1) + self.assertTrue(len(included) == 1) + + def test_include_multiple_columns(self): + """Test including a list of relationships.""" + def unique(items): + unqiues = [] + for item in items: + if item not in unqiues: + unqiues.append(item) + return unqiues + + link = 'testsite.com/people?include=student.school,student' + params = url.get_parameters(link) + + includes = [] + schemas = [] + for include_ in url.get_includes(params): + self.v_driver.initialize_path(include_) + includes.append(self.v_driver.get_model_path()) + schemas.extend(self.v_driver.schemas) + schemas = unique(schemas) + + included_models = [] + mappers = [] + for include_ in includes: + _, models, joins = self.m_driver.parse_path(include_) + included_models.extend(models) + mappers.extend(joins) + included_models = unique(included_models) + + items = self.session.query( + Person).filter_by(id=1).include(mappers).all() + items = group_and_remove(items, [Person] + included_models)[1:] + included = [] + for position, columns in enumerate(items): + schema = schemas[position] + included.extend(schema.dump(columns, many=True).data['data']) + + self.assertTrue(len(schemas) == 2) + self.assertTrue(len(included_models) == 2) + self.assertTrue(len(items) == 2) + self.assertTrue(len(items[0]) == 1) + self.assertTrue(len(included) == 2) + + self.assertIn('id', included[0]) + self.assertIn('type', included[0]) + self.assertIn('relationships', included[0]) + + self.assertIn('id', included[1]) + self.assertIn('type', included[1]) + self.assertIn('attributes', included[1]) diff --git a/tests/marshmallow_jsonapi.py b/tests/marshmallow_jsonapi.py new file mode 100644 index 0000000..3843ee3 --- /dev/null +++ b/tests/marshmallow_jsonapi.py @@ -0,0 +1,51 @@ +"""Base marshmallow-jsonapi test case module.""" +from marshmallow import class_registry +from marshmallow.base import SchemaABC +from marshmallow.compat import basestring +from marshmallow_jsonapi import fields, Schema + +from tests.unit import UnitTestCase + + +def dasherize(text): + """Replace underscores with hyphens.""" + return text.replace('_', '-') + + +class Person(Schema): + id = fields.Integer() + name = fields.String() + age = fields.Integer() + birth_date = fields.Date() + updated_at = fields.DateTime() + student = fields.Relationship(schema='Student') + + class Meta: + inflect = dasherize + type_ = 'people' + + +class Student(Schema): + id = fields.Integer() + school = fields.Relationship( + schema='School', include_resource_linkage=True, type_='schools') + person = fields.Relationship( + schema='Person', include_resource_linkage=True, type_='people') + + class Meta: + inflect = dasherize + type_ = 'students' + + +class School(Schema): + id = fields.Integer() + title = fields.String(attribute='name') + students = fields.Relationship(schema='Student') + + class Meta: + inflect = dasherize + type_ = 'schools' + + +class BaseMarshmallowJSONAPITestCase(UnitTestCase): + pass diff --git a/tests/mock.py b/tests/mock.py deleted file mode 100644 index 671c553..0000000 --- a/tests/mock.py +++ /dev/null @@ -1,168 +0,0 @@ -# -*- coding: utf-8 -*- -from datetime import datetime - -from sqlalchemy import case, orm -from marshmallow import class_registry, validate -from marshmallow.base import SchemaABC -from marshmallow_jsonapi import fields, Schema - -from tests.database import db, save - - -person_company = db.Table( - 'person_company', - db.Column('person_id', db.Integer, db.ForeignKey('person.id')), - db.Column('company_id', db.Integer, db.ForeignKey('company.id')) -) - - -class PersonModel(db.Model): - __tablename__ = 'person' - - id = db.Column(db.Integer(), primary_key=True) - name = db.Column(db.String) - age = db.Column(db.Integer) - is_employed = db.Column(db.Boolean) - gender = db.Column(db.Enum('male', 'female', name='person_gender')) - rate = db.Column(db.Numeric(12, 2)) - employed_integer = orm.column_property( - case([(is_employed.is_(True), 1)], else_=0)) - created_at = db.Column(db.DateTime, default=datetime.now()) - - companies = db.relationship( - 'CompanyModel', secondary=person_company, backref='persons') - employee = db.relationship( - 'EmployeeModel', uselist=False, back_populates='person') - - @classmethod - def mock(cls, **kwargs): - data = { - 'name': kwargs.pop('name', 'Test'), - 'age': kwargs.pop('age', 10), - 'is_employed': kwargs.pop('is_employed', True), - 'companies': kwargs.pop('companies', []), - 'gender': kwargs.pop('gender', 'male'), - 'rate': kwargs.pop('rate', '1.00') - } - model = cls(**data) - save(model) - return model - - -class EmployeeModel(db.Model): - __tablename__ = 'employee' - - id = db.Column(db.Integer(), primary_key=True) - person_id = db.Column(db.Integer, db.ForeignKey('person.id')) - name = db.Column(db.String) - months_of_service = db.Column(db.Integer) - is_manager = db.Column(db.Boolean) - created_at = db.Column(db.DateTime, default=datetime.now()) - - person = db.relationship('PersonModel', back_populates='employee') - - @classmethod - def mock(cls, **kwargs): - data = { - 'name': kwargs.pop('name', 'Test'), - 'months_of_service': kwargs.pop('months_of_service', 10), - 'is_manager': kwargs.pop('is_manager', True), - 'person_id': kwargs.pop('person_id', None) - } - model = cls(**data) - save(model) - return model - - -class CompanyModel(db.Model): - __tablename__ = 'company' - - id = db.Column(db.Integer(), primary_key=True) - name = db.Column(db.String) - year_established = db.Column(db.Integer) - is_profitable = db.Column(db.Boolean) - created_at = db.Column(db.DateTime, default=datetime.now()) - - @classmethod - def mock(cls, **kwargs): - data = { - 'name': kwargs.pop('name', 'Test'), - 'year_established': kwargs.pop('year_established', 10), - 'is_profitable': kwargs.pop('is_profitable', True) - } - model = cls(**data) - save(model) - return model - - -class SchemaRelationship(fields.Relationship): - - def __init__(self, dump_to=None, related_schema=None, **kwargs): - self.related_schema = related_schema - self.__schema = None - super(SchemaRelationship, self).__init__(**kwargs) - - @property - def schema(self): - if isinstance(self.related_schema, SchemaABC): - self.__schema = self.related_schema - elif (isinstance(self.related_schema, type) and - issubclass(self.related_schema, SchemaABC)): - self.__schema = self.related_schema - elif isinstance(self.related_schema, basestring): - if self.related_schema == 'self': - parent_class = self.parent.__class__ - self.__schema = parent_class - else: - schema_class = class_registry.get_class(self.related_schema) - self.__schema = schema_class - return self.__schema - - -class EmployeeSchema(Schema): - id = fields.String() - name = fields.String() - months_of_service = fields.Integer() - is_manager = fields.Boolean() - created_at = fields.DateTime() - - person = SchemaRelationship(related_schema='PersonSchema') - person_id = SchemaRelationship() - - class Meta: - model = EmployeeModel - type_ = 'employees' - ordered = True - - -class CompanySchema(Schema): - id = fields.String() - name = fields.String() - year_established = fields.Integer() - is_profitable = fields.Boolean() - created_at = fields.DateTime() - - class Meta: - model = CompanyModel - type_ = 'companies' - ordered = True - - -class PersonSchema(Schema): - id = fields.String() - name = fields.String() - age = fields.Integer() - is_employed = fields.Boolean() - gender = fields.String(validate=validate.OneOf(['male', 'female'])) - rate = fields.Decimal(as_string=True, places=2) - employed_integer = fields.Integer() - created_at = fields.DateTime(format='%Y-%m-%d') - - companies = SchemaRelationship(many=True, related_schema=CompanySchema) - employee = SchemaRelationship(many=False, related_schema=EmployeeSchema) - employee_id = SchemaRelationship(many=False, related_schema=EmployeeSchema) - - class Meta: - model = PersonModel - type_ = 'people' - ordered = True diff --git a/tests/settings.py b/tests/settings.py deleted file mode 100644 index c623b10..0000000 --- a/tests/settings.py +++ /dev/null @@ -1,4 +0,0 @@ -# -*- coding: utf-8 -*- -DEBUG = True -TESTING = True -SQLALCHEMY_ENGINE = 'sqlite://' diff --git a/tests/sqlalchemy.py b/tests/sqlalchemy.py new file mode 100644 index 0000000..bd35b09 --- /dev/null +++ b/tests/sqlalchemy.py @@ -0,0 +1,118 @@ +"""Base SQLAlchemy test case module.""" +from datetime import datetime + +from sqlalchemy import create_engine, event, ForeignKey +from sqlalchemy import Column, Date, DateTime, Integer, String +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker, relationship +from unittest import TestCase + + +Base = declarative_base() + + +class Category(Base): + """Mock category table.""" + + __tablename__ = "category" + + id = Column(Integer, primary_key=True) + category_id = Column(Integer, ForeignKey('category.id')) + name = Column(String) + + category = relationship( + 'Category', backref='categories', remote_side=[id], uselist=False) + + +class Product(Base): + """Mock product table.""" + + __tablename__ = "product" + + id = Column(Integer, primary_key=True) + primary_category_id = Column(Integer, ForeignKey('category.id')) + secondary_category_id = Column(Integer, ForeignKey('category.id')) + name = Column(String) + + primary_category = relationship( + 'Category', backref='primary_products', + foreign_keys=[primary_category_id]) + secondary_category = relationship( + 'Category', backref='secondary_products', + foreign_keys=[secondary_category_id]) + + +class Person(Base): + """Mock person table.""" + + __tablename__ = "person" + + id = Column(Integer, primary_key=True) + name = Column(String) + age = Column(Integer) + birth_date = Column(Date) + updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now) + + +class Student(Base): + """Mock 'through' table for `School` and `Person` models.""" + + __tablename__ = "student" + + id = Column(Integer, primary_key=True) + school_id = Column(Integer, ForeignKey('school.id')) + person_id = Column(Integer, ForeignKey('person.id')) + + school = relationship('School', backref='student') + person = relationship('Person', backref='student') + + +class School(Base): + """Mock school table.""" + + __tablename__ = "school" + + id = Column(Integer, primary_key=True) + name = Column(String) + + +class BaseSQLAlchemyTestCase(TestCase): + """Base SQLAlchemy test case. + + For each unittest class, create a database, start a session, run + the tests, rollback the session after each test, and finally drop + the database once all tests have completed. + """ + + def setUp(self): + """Create a save point and start the session.""" + self.engine = engine + self.session = sessionmaker(bind=engine)() + self.session.begin_nested() + + def tearDown(self): + """Close the session and rollback to the previous save point.""" + self.session.rollback() + self.session.close() + + @classmethod + def setUpClass(cls): + """Create the database.""" + global engine + + engine = create_engine('sqlite:///sqlalchemy.db') + + @event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): + dbapi_connection.isolation_level = None + + @event.listens_for(engine, "begin") + def do_begin(conn): + conn.execute("BEGIN") + + Base.metadata.create_all(engine) + + @classmethod + def tearDownClass(cls): + """Destroy the database.""" + engine.dispose() diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index e69de29..fb344ba 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -0,0 +1,5 @@ +from unittest import TestCase + + +class UnitTestCase(TestCase): + pass diff --git a/tests/unit/database/__init__.py b/tests/unit/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/database/sqlalchemy_tests.py b/tests/unit/database/sqlalchemy_tests.py new file mode 100644 index 0000000..c0078c8 --- /dev/null +++ b/tests/unit/database/sqlalchemy_tests.py @@ -0,0 +1,366 @@ +"""Test database interactions.""" +from datetime import datetime + +from sqlalchemy.orm import Query, sessionmaker + +from jsonapi_query.database.sqlalchemy import ( + group_and_remove, include, QueryMixin) +from tests.sqlalchemy import ( + BaseSQLAlchemyTestCase, Category, Person, Product, School, Student) + + +class BaseDatabaseSQLAlchemyTests(BaseSQLAlchemyTestCase): + """Base database SQLAlchemy test case for establishing mock environment.""" + + def setUp(self): + """Set the query class and create a set of rows to test against.""" + super().setUp() + + class BaseQuery(QueryMixin, Query): + pass + + self.session = sessionmaker(bind=self.engine, query_cls=BaseQuery)() + self.session.begin_nested() + + date = datetime.strptime('2014-01-01', "%Y-%m-%d").date() + fred = Person(name='Fred', age=5, birth_date=date) + self.session.add(fred) + + date = datetime.strptime('2015-01-01', "%Y-%m-%d").date() + carl = Person(name='Carl', age=10, birth_date=date) + self.session.add(carl) + + school = School(name='School') + self.session.add(school) + school = School(name='College') + self.session.add(school) + + student = Student(school_id=1, person_id=1) + self.session.add(student) + + student = Student(school_id=2, person_id=2) + self.session.add(student) + + +class FilterSQLAlchemyTestCase(BaseDatabaseSQLAlchemyTests): + """Test query filtering related methods.""" + + def test_query_filter_strategy_eq(self): + """Test filtering a query with the `eq` strategy.""" + models = self.session.query( + Person).apply_filter(Person.name, 'eq', ['Fred']).all() + self.assertTrue(len(models) == 1) + self.assertTrue(models[0].name == 'Fred') + + def test_query_filter_strategy_negation(self): + """Test filtering a query with a negated strategy.""" + models = self.session.query( + Person).apply_filter(Person.name, '~eq', ['Fred']).all() + self.assertTrue(len(models) == 1) + self.assertTrue(models[0].name == 'Carl') + + def test_query_filter_strategy_gt(self): + """Test filtering a query with the `gt` strategy.""" + models = self.session.query( + Person).apply_filter(Person.age, 'gt', [5]).all() + self.assertTrue(len(models) == 1) + self.assertTrue(models[0].age == 10) + + def test_query_filter_strategy_gte(self): + """Test filtering a query with the `gte` strategy.""" + models = self.session.query( + Person).apply_filter(Person.age, 'gte', [5]).all() + self.assertTrue(len(models) == 2) + + def test_query_filter_strategy_lt(self): + """Test filtering a query with the `lt` strategy.""" + date = datetime.strptime('2015-01-01', "%Y-%m-%d").date() + models = self.session.query( + Person).apply_filter(Person.birth_date, 'lt', [date]).all() + self.assertTrue(len(models) == 1) + + def test_query_filter_strategy_lte(self): + """Test filtering a query with the `lte` strategy.""" + date = datetime.strptime('2015-01-01', "%Y-%m-%d").date() + models = self.session.query( + Person).apply_filter(Person.birth_date, 'lte', [date]).all() + self.assertTrue(len(models) == 2) + + def test_query_filter_strategy_like(self): + """Test filtering a query with the `like` strategy.""" + models = self.session.query( + Person).apply_filter(Person.name, 'like', ['Fred']).all() + self.assertTrue(len(models) == 1) + self.assertTrue(models[0].name == 'Fred') + + def test_query_filter_strategy_ilike(self): + """Test filtering a query with the `ilike` strategy.""" + models = self.session.query( + Person).apply_filter(Person.name, 'ilike', ['fred']).all() + self.assertTrue(len(models) == 1) + self.assertTrue(models[0].name == 'Fred') + + def test_query_filter_in_values(self): + """Test filtering a query by the `in` strategy.""" + models = self.session.query( + Person).apply_filter(Person.name, 'in', ['Fred']).all() + self.assertTrue(len(models) == 1) + self.assertTrue(models[0].name == 'Fred') + + def test_query_filter_not_in_values(self): + """Test filtering a query by the `~in` strategy.""" + models = self.session.query( + Person).apply_filter(Person.name, '~in', ['Fred']).all() + self.assertTrue(len(models) == 1) + self.assertTrue(models[0].name == 'Carl') + + def test_query_filter_multiple_values(self): + """Test filtering a query by multiple values.""" + models = self.session.query( + Person).apply_filter(Person.name, 'eq', ['Fred', 'Carl']).all() + self.assertTrue(len(models) == 2) + + def test_query_filter_invalid_strategy(self): + """Test filtering a query by an invalid strategy.""" + try: + self.session.query( + Person).apply_filter(Person.name, 'qq', ['Fred']).all() + self.assertTrue(False) + except ValueError: + self.assertTrue(True) + + def test_query_filter_multiple_joins(self): + """Test filtering a query with multiple join conditions.""" + models = self.session.query(Person).apply_filter( + School.name, 'eq', ['School'], + [Person.student, Student.school]).all() + self.assertTrue(len(models) == 1) + self.assertTrue(models[0].name == 'Fred') + + def test_query_filter_ambiguous_join_conditions(self): + """Test filtering a query under ambiguous join conditions.""" + a = Category(name='Category A') + self.session.add(a) + b = Category(name='Category B', category_id=1) + self.session.add(b) + p = Product(primary_category_id=1, secondary_category_id=2, name='Tst') + self.session.add(p) + + model = self.session.query(Product).apply_filters([ + (Category.name, 'eq', ['Category A'], [Product.primary_category]), + (Category.name, 'eq', ['Category B'], [Product.secondary_category]) + ]).first() + self.assertTrue(isinstance(model, Product)) + self.assertTrue(model.name == 'Tst') + + +class SortSQLAlchemyTestCase(BaseDatabaseSQLAlchemyTests): + """Test query sorting related methods.""" + + def test_query_sort_attribute_ascending(self): + """Test sorting a query by an ascending column.""" + models = self.session.query(Person).apply_sort(Person.name, '+').all() + self.assertTrue(models[0].name == 'Carl') + self.assertTrue(models[1].name == 'Fred') + + def test_query_sort_attribute_descending(self): + """Test sorting a query by a descending column.""" + models = self.session.query(Person).apply_sort(Person.name, '-').all() + self.assertTrue(models[0].name == 'Fred') + self.assertTrue(models[1].name == 'Carl') + + def test_query_sort_relationship_ascending(self): + """Test sorting a query by an ascending relationship column.""" + models = self.session.query( + Student).apply_sort(Person.name, '+', [Student.person]).all() + self.assertTrue(models[0].person.name == 'Carl') + self.assertTrue(models[1].person.name == 'Fred') + + def test_query_sort_relationship_descending(self): + """Test sorting a query by a descending relationship column.""" + models = self.session.query( + Student).apply_sort(Person.name, '-', [Student.person]).all() + self.assertTrue(models[0].person.name == 'Fred') + self.assertTrue(models[1].person.name == 'Carl') + + def test_query_sort_over_multiple_joins(self): + """Test sorting a query with multiple join conditions.""" + models = self.session.query(Person).apply_sort( + School.name, '+', [Person.student, Student.school]).all() + self.assertTrue(models[0].name == 'Carl') + self.assertTrue(models[1].name == 'Fred') + + def test_query_sort_ambiguous_join_conditions(self): + """Test filtering a query under ambiguous join conditions.""" + a = Category(name='Category A') + self.session.add(a) + b = Category(name='Category B', category_id=1) + self.session.add(b) + p = Product(primary_category_id=1, name='A') + self.session.add(p) + p = Product(primary_category_id=2, name='B') + self.session.add(p) + + models = self.session.query(Product).apply_sort( + Category.name, '-', [Product.primary_category]).all() + self.assertTrue(models[0].name == 'B') + self.assertTrue(models[1].name == 'A') + + +class PaginateSQLAlchemyTestCase(BaseDatabaseSQLAlchemyTests): + """Test query pagination related methods.""" + + def test_query_paginate_limit(self): + """Test limiting a query.""" + models = self.session.query( + Person).apply_paginators([('limit', 1)]).all() + self.assertTrue(len(models) == 1) + self.assertTrue(models[0].name == 'Fred') + + def test_query_paginate_offset(self): + """Test offsetting a query.""" + models = self.session.query( + Person).apply_paginators([('offset', 1)]).all() + self.assertTrue(len(models) == 1) + self.assertTrue(models[0].name == 'Carl') + + def test_query_paginate_number(self): + """Test offsetting a query by page number.""" + models = self.session.query( + Person).apply_paginators([('number', 2), ('limit', 1)]).all() + self.assertTrue(len(models) == 1) + self.assertTrue(models[0].name == 'Carl') + + +class IncludeSQLAlchemyTestCase(BaseDatabaseSQLAlchemyTests): + """Test constructing an included query.""" + + def test_include_one_column(self): + """Test including a single relationship.""" + models = self.session.query(Person).include([Person.student]).first() + self.assertTrue(isinstance(models[0], Person)) + self.assertTrue(isinstance(models[1], Student)) + self.assertTrue(models[1].person_id == 1) + + def test_include_multiple_columns(self): + """Test including multiple relationships.""" + models = self.session.query(Person).include([ + Person.student, Student.school]).first() + self.assertTrue(isinstance(models[0], Person)) + self.assertTrue(isinstance(models[1], Student)) + self.assertTrue(isinstance(models[2], School)) + + def test_include_self_referential_relationship(self): + """Test including a self-referential relationship.""" + a = Category(name='Category A') + self.session.add(a) + b = Category(name='Category B', category_id=1) + self.session.add(b) + c = Category(name='Category C', category_id=2) + self.session.add(c) + + models = self.session.query(Category).filter(Category.id == 2).include( + [Category.category, Category.categories]).first() + self.assertTrue(len(models) == 3) + + self.assertTrue(isinstance(models[0], Category)) + self.assertTrue(models[0].id == 2) + + self.assertTrue(isinstance(models[1], Category)) + self.assertTrue(models[1].id == 1) + + self.assertTrue(isinstance(models[2], Category)) + self.assertTrue(models[2].id == 3) + + def test_include_ambiguous_join_conditions(self): + """Test including a model when a join can be made multiple ways.""" + a = Category(name='Category A') + self.session.add(a) + b = Category(name='Category B', category_id=1) + self.session.add(b) + p = Product(primary_category_id=1, secondary_category_id=2, name='Tst') + self.session.add(p) + + models = self.session.query(Product).include([ + Product.primary_category, Product.secondary_category]).first() + self.assertTrue(isinstance(models[1], Category)) + self.assertTrue(isinstance(models[2], Category)) + + def test_include_does_not_restrict_primary_output(self): + """Test including a relationship does not restrict primary output.""" + a = Category(name='Category A') + self.session.add(a) + b = Category(name='Category B', category_id=1) + self.session.add(b) + p = Product(primary_category_id=1, secondary_category_id=2, name='Tst') + self.session.add(p) + + p = Product(name='Tst 2') + self.session.add(p) + + models = self.session.query(Product).include( + [Product.primary_category]).all() + self.assertTrue(len(models) == 2) + + def test_include_no_mappers(self): + """Test including an empty set of relationships.""" + models = self.session.query(Person).include([]).first() + self.assertTrue(isinstance(models, Person)) + + def test_include_from_model(self): + """Test including from a model.""" + a = Category(name='Category A') + self.session.add(a) + + p = Product(primary_category_id=1, name='Test') + self.session.add(p) + + items = include( + self.session, Product, [Category], [Product.primary_category], [1]) + self.assertTrue(isinstance(items, list)) + self.assertTrue(isinstance(items[0], list)) + self.assertTrue(isinstance(items[0][0], Category)) + + def test_include_from_model_without_relationship(self): + """Test including from a model when no relationship exists.""" + p = Product(name='Test') + self.session.add(p) + + items = include( + self.session, Product, [Category], [Product.primary_category], [1]) + self.assertTrue(isinstance(items, list)) + self.assertTrue(isinstance(items[0], list)) + self.assertTrue(items == [[]]) + + items = include( + self.session, Product, [Category, Category], + [Product.primary_category, Product.secondary_category], [1]) + self.assertTrue(isinstance(items, list)) + self.assertTrue(isinstance(items[0], list)) + self.assertTrue(items == [[], []]) + + +class UtilitySQLAlchemyTestCase(BaseDatabaseSQLAlchemyTests): + """Test handling a query's output.""" + + def test_group_and_remove(self): + """Test group and remove utility function.""" + a = Category(name='Category A') + self.session.add(a) + b = Category(name='Category B', category_id=1) + self.session.add(b) + p = Product(primary_category_id=1, secondary_category_id=2, name='Tst') + self.session.add(p) + + p = Product(name='Tst 2') + self.session.add(p) + + # Returns two products and one category. + items = self.session.query(Product).include( + [Product.primary_category]).all() + self.assertTrue(len(items) == 2) + + output = group_and_remove(items, [Product, Category]) + self.assertTrue(len(output) == 2) + self.assertTrue(len(output[0]) == 2) + self.assertTrue(len(output[1]) == 1) diff --git a/tests/unit/filter_tests.py b/tests/unit/filter_tests.py deleted file mode 100644 index 34e8181..0000000 --- a/tests/unit/filter_tests.py +++ /dev/null @@ -1,408 +0,0 @@ -# -*- coding: utf-8 -*- -from datetime import date - -from jsonapi_collections import Resource -from jsonapi_collections.drivers.marshmallow import MarshmallowDriver -from jsonapi_collections.errors import JSONAPIError -from tests import UnitTestCase -from tests.mock import CompanyModel, EmployeeModel, PersonModel, PersonSchema - - -class FilterTestCase(UnitTestCase): - - def setUp(self): - super(FilterTestCase, self).setUp() - self.model = PersonModel - self.query = PersonModel.query - - -class SQLAlchemyTestCase(FilterTestCase): - """Test filtering with the `SQLAlchemy` driver.""" - - def test_filter_string_field(self): - """Test filtering by a string field.""" - PersonModel.mock(name='A PRODUCT Wildcard') - - parameters = {'filter[name]': 'prod'} - query = Resource(self.model, parameters).filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_multiple_string_field(self): - """Test filtering by multiple string fields.""" - PersonModel.mock(name='A PRODUCT Wildcard') - - parameters = {'filter[name]': 'prod,test,card'} - query = Resource(self.model, parameters).filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_enum_field(self): - """Test filtering by an enum field.""" - PersonModel.mock(gender='male') - - parameters = {'filter[gender]': 'male'} - query = Resource(self.model, parameters).filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_enum_field_invalid_value(self): - """Test filtering by an invalid enum value.""" - PersonModel.mock(gender='male') - - try: - parameters = {'filter[gender]': 'mal'} - Resource(self.model, parameters).filter_query(self.query) - except JSONAPIError as exc: - message = exc.message - self.assertIn('detail', message['errors'][0]) - self.assertTrue( - message['errors'][0]['source']['parameter'] == - 'filter[gender]') - - def test_filter_decimal_field(self): - """Test filtering by a decimal field.""" - PersonModel.mock(rate='12.51') - - parameters = {'filter[rate]': '12.51'} - query = Resource(self.model, parameters).filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_decimal_field_invalid_value(self): - """Test filtering by an invalid decimal value.""" - PersonModel.mock(rate='12.51') - - try: - parameters = {'filter[rate]': 'a'} - Resource(self.model, parameters).filter_query(self.query) - except JSONAPIError as exc: - message = exc.message - self.assertIn('detail', message['errors'][0]) - self.assertTrue( - message['errors'][0]['source']['parameter'] == 'filter[rate]') - - def test_filter_boolean_field(self): - """Test filtering by a boolean field.""" - PersonModel.mock(is_employed=True) - - parameters = {'filter[is_employed]': '1'} - query = Resource(self.model, parameters).filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_datetime_field(self): - """Test filtering by a datetime field.""" - PersonModel.mock() - - parameters = {'filter[created_at]': date.today().strftime('%Y-%m-%d')} - query = Resource(self.model, parameters).filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_integer_field(self): - """Test filtering by an integer field.""" - PersonModel.mock(age=80) - - parameters = {'filter[age]': '80'} - query = Resource(self.model, parameters).filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_computed_field(self): - """Test filtering by a computed field.""" - PersonModel.mock(is_employed=True) - - parameters = {'filter[employed_integer]': '1'} - query = Resource(self.model, parameters).filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_one_to_many_relationship(self): - """Test filtering by a foreign key relationship field.""" - person = PersonModel.mock() - EmployeeModel.mock(name="employee", person_id=person.id) - - parameters = {'filter[employee.name]': 'EMPLOYEE'} - query = Resource(self.model, parameters).filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_many_to_many_relationship(self): - """Test filtering by a many-to-many relationship field.""" - company = CompanyModel.mock(name="company") - PersonModel.mock(companies=[company]) - - parameters = {'filter[companies.name]': 'COMPANY'} - query = Resource(self.model, parameters).filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_invalid_field_parameter(self): - """Test filtering by an unknown field.""" - try: - parameters = {'filter[xyz]': 'value'} - Resource(self.model, parameters).filter_query(self.query) - except JSONAPIError as exc: - message = exc.message - self.assertIn('detail', message['errors'][0]) - self.assertTrue( - message['errors'][0]['source']['parameter'] == 'filter[xyz]') - - def test_filter_invalid_relationship_parameter(self): - """Test filtering by an invalid relationship field.""" - try: - parameters = {'filter[employee.xyz]': 'value'} - Resource(self.model, parameters).filter_query(self.query) - except JSONAPIError as exc: - message = exc.message - self.assertIn('detail', message['errors'][0]) - self.assertTrue( - message['errors'][0]['source']['parameter'] == - 'filter[employee.xyz]') - - def test_filter_invalid_relationship(self): - """Test filtering by an invalid relationship value.""" - try: - parameters = {'filter[rate.name]': 'whatever'} - Resource(self.model, parameters).filter_query(self.query) - except JSONAPIError as exc: - message = exc.message - self.assertIn('detail', message['errors'][0]) - self.assertTrue( - message['errors'][0]['source']['parameter'] == - 'filter[rate.name]') - - def test_filter_invalid_value_type(self): - """Test filtering by an invalid field value.""" - try: - parameters = {'filter[datetime]': 'notadate'} - Resource(self.model, parameters).filter_query(self.query) - except JSONAPIError as exc: - message = exc.message - self.assertIn('detail', message['errors'][0]) - self.assertTrue( - message['errors'][0]['source']['parameter'] == - 'filter[datetime]') - - -class MarshmallowTestCase(FilterTestCase): - """Test filtering with the `marshmallow` driver.""" - - def test_filter_string_field(self): - """Test filtering by a string field.""" - PersonModel.mock(name='A PRODUCT Wildcard') - - parameters = {'filter[name]': 'prod'} - query = Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_multiple_string_field(self): - """Test filtering by multiple string fields.""" - PersonModel.mock(name='A PRODUCT Wildcard') - - parameters = {'filter[name]': 'prod,test,card'} - query = Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_enum_field(self): - """Test filtering by an enum field.""" - PersonModel.mock(gender='male') - - parameters = {'filter[gender]': 'male'} - query = Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_enum_field_invalid_value(self): - """Test filtering by an invalid enum value.""" - PersonModel.mock(gender='male') - - try: - parameters = {'filter[gender]': 'mal'} - Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - filter_query(self.query) - except JSONAPIError as exc: - message = exc.message - self.assertIn('detail', message['errors'][0]) - self.assertTrue( - message['errors'][0]['source']['parameter'] == - 'filter[gender]') - - def test_filter_decimal_field(self): - """Test filtering by a decimal field.""" - PersonModel.mock(rate='12.51') - - parameters = {'filter[rate]': '12.51'} - query = Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_decimal_field_invalid_value(self): - """Test filtering by an invalid decimal value.""" - PersonModel.mock(rate='12.51') - - try: - parameters = {'filter[rate]': 'a'} - Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - filter_query(self.query) - except JSONAPIError as exc: - message = exc.message - self.assertIn('detail', message['errors'][0]) - self.assertTrue( - message['errors'][0]['source']['parameter'] == 'filter[rate]') - - def test_filter_boolean_field(self): - """Test filtering by a boolean field.""" - PersonModel.mock(is_employed=True) - - parameters = {'filter[is_employed]': '1'} - query = Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_datetime_field(self): - """Test filtering by a datetime field.""" - PersonModel.mock() - - parameters = {'filter[created_at]': date.today().strftime('%Y-%m-%d')} - query = Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_integer_field(self): - """Test filtering by an integer field.""" - PersonModel.mock(age=80) - - parameters = {'filter[age]': '80'} - query = Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_computed_field(self): - """Test filtering by a computed field.""" - PersonModel.mock(is_employed=True) - - parameters = {'filter[employed_integer]': '1'} - query = Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_one_to_many_relationship(self): - """Test filtering by a foreign key relationship field.""" - person = PersonModel.mock() - EmployeeModel.mock(name="employee", person_id=person.id) - - parameters = {'filter[employee.name]': 'EMPLOYEE'} - query = Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_many_to_many_relationship(self): - """Test filtering by a many-to-many relationship field.""" - company = CompanyModel.mock(name="company") - PersonModel.mock(companies=[company]) - - parameters = {'filter[companies.name]': 'COMPANY'} - query = Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - filter_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - - def test_filter_invalid_field_parameter(self): - """Test filtering by an unknown field.""" - try: - parameters = {'filter[xyz]': 'value'} - Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - filter_query(self.query) - except JSONAPIError as exc: - message = exc.message - self.assertIn('detail', message['errors'][0]) - self.assertTrue( - message['errors'][0]['source']['parameter'] == 'filter[xyz]') - - def test_filter_invalid_relationship_attribute(self): - """Test filtering by an invalid relationship field.""" - try: - parameters = {'filter[employee.xyz]': 'value'} - Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - filter_query(self.query) - except JSONAPIError as exc: - message = exc.message - self.assertIn('detail', message['errors'][0]) - self.assertTrue( - message['errors'][0]['source']['parameter'] == - 'filter[employee.xyz]') - - def test_filter_invalid_relationship(self): - """Test filtering by an invalid relationship value.""" - try: - parameters = {'filter[rate.name]': 'whatever'} - Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - filter_query(self.query) - except JSONAPIError as exc: - message = exc.message - self.assertIn('detail', message['errors'][0]) - self.assertTrue( - message['errors'][0]['source']['parameter'] == - 'filter[rate.name]') - - def test_filter_invalid_value_type(self): - """Test filtering by an invalid field value.""" - try: - parameters = {'filter[datetime]': 'notadate'} - Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - filter_query(self.query) - except JSONAPIError as exc: - message = exc.message - self.assertIn('detail', message['errors'][0]) - self.assertTrue( - message['errors'][0]['source']['parameter'] == - 'filter[datetime]') diff --git a/tests/unit/include_tests.py b/tests/unit/include_tests.py deleted file mode 100644 index 2d66170..0000000 --- a/tests/unit/include_tests.py +++ /dev/null @@ -1,142 +0,0 @@ -# -*- coding: utf-8 -*- -from jsonapi_collections import Resource -from jsonapi_collections.drivers import marshmallow -from jsonapi_collections.errors import JSONAPIError -from tests import UnitTestCase -from tests.mock import ( - CompanyModel, EmployeeModel, EmployeeSchema, PersonModel, PersonSchema -) - - -class IncludeTestCase(UnitTestCase): - - def setUp(self): - super(IncludeTestCase, self).setUp() - self.model = PersonModel - self.view = PersonSchema - self.query = PersonModel.query - self.driver = marshmallow.MarshmallowDriver - - -class MarshmallowIncludeTestCase(IncludeTestCase): - - def test_include_one_to_one(self): - """Test including a one-to-one relationship.""" - model = PersonModel.mock() - EmployeeModel.mock(person_id=1) - - parameters = {'include': 'employee'} - included = Resource( - self.model, parameters, self.driver, self.view).\ - compound_response(model) - self.assertTrue(len(included) == 1) - - def test_include_many_to_many(self): - """Test including a many-to-many relationship.""" - company = CompanyModel.mock() - model = PersonModel.mock(companies=[company]) - - parameters = {'include': 'companies'} - included = Resource( - self.model, parameters, self.driver, self.view).\ - compound_response(model) - self.assertTrue(len(included) == 1) - - def test_include_attribute(self): - """Test including an attribute.""" - model = PersonModel.mock() - - try: - parameters = {'include': 'name'} - Resource( - self.model, parameters, self.driver, self.view).\ - compound_response(model) - self.assertTrue(False) - except JSONAPIError as exc: - message = exc.message - self.assertTrue( - message['errors'][0]['source']['parameter'] == 'include') - self.assertIn('detail', message['errors'][0]) - - def test_include_missing_field(self): - """Test including an unknown field.""" - model = PersonModel.mock() - - try: - parameters = {'include': 'wxyz'} - Resource( - self.model, parameters, self.driver, self.view).\ - compound_response(model) - self.assertTrue(False) - except JSONAPIError as exc: - message = exc.message - self.assertTrue( - message['errors'][0]['source']['parameter'] == 'include') - self.assertIn('detail', message['errors'][0]) - - -class SQLAlchemyIncludeTestCase(IncludeTestCase): - - def test_include_nested(self): - """Test including a nested relationship.""" - company = CompanyModel.mock() - model = PersonModel.mock(companies=[company]) - EmployeeModel.mock(person_id=model.id) - - parameters = {'include': 'employee.person.companies'} - included = Resource(self.model, parameters).compound_response(model) - self.assertTrue(len(included) == 1) - - def test_include_one_to_one(self): - """Test including a one-to-one relationship.""" - model = PersonModel.mock() - EmployeeModel.mock(person_id=1) - - parameters = {'include': 'employee'} - included = Resource(self.model, parameters).compound_response(model) - self.assertTrue(len(included) == 1) - - def test_include_many_to_many(self): - """Test including a many-to-many relationship.""" - company = CompanyModel.mock() - model = PersonModel.mock(companies=[company]) - - parameters = {'include': 'companies'} - included = Resource(self.model, parameters).compound_response(model) - self.assertTrue(len(included) == 1) - - def test_include_empty_relationship(self): - """Test including an empty one-to-one relationship.""" - model = PersonModel.mock() - - parameters = {'include': 'employee'} - included = Resource(self.model, parameters).compound_response(model) - self.assertTrue(len(included) == 0) - - def test_include_attribute(self): - """Test including an attribute.""" - model = PersonModel.mock() - - try: - parameters = {'include': 'name'} - Resource(self.model, parameters).compound_response(model) - self.assertTrue(False) - except JSONAPIError as exc: - message = exc.message - self.assertTrue( - message['errors'][0]['source']['parameter'] == 'include') - self.assertIn('detail', message['errors'][0]) - - def test_include_missing_field(self): - """Test including an unknown field.""" - model = PersonModel.mock() - - try: - parameters = {'include': 'wxyz'} - Resource(self.model, parameters).compound_response(model) - self.assertTrue(False) - except JSONAPIError as exc: - message = exc.message - self.assertTrue( - message['errors'][0]['source']['parameter'] == 'include') - self.assertIn('detail', message['errors'][0]) diff --git a/tests/unit/page_tests.py b/tests/unit/page_tests.py deleted file mode 100644 index 11e73cf..0000000 --- a/tests/unit/page_tests.py +++ /dev/null @@ -1,184 +0,0 @@ -# -*- coding: utf-8 -*- -"""JSONAPI pagination implementation testing. - -This module is dedicated to testing against the various pagination -strategies described in the JSONAPI 1.0 specification. -""" -from jsonapi_collections import Resource -from jsonapi_collections.drivers.marshmallow import MarshmallowDriver -from tests import UnitTestCase -from tests.mock import PersonModel, PersonSchema - - -class PaginationTestCase(UnitTestCase): - """Base pagination test case.""" - - def setUp(self): - """Establish some helpful model and query shortcuts.""" - super(PaginationTestCase, self).setUp() - self.model = PersonModel - self.query = PersonModel.query - - -class SQLAlchemyPaginationTestCase(PaginationTestCase): - """SQLAlchemy driver pagination tests.""" - - def test_page_limit(self): - """Test limiting a page by the page[limit] parameter.""" - PersonModel.mock(name='First') - PersonModel.mock(name='Second') - - parameters = {'page[limit]': '1'} - query = Resource(self.model, parameters).paginate_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - self.assertTrue(result[0].name == 'First') - - def test_page_size(self): - """Test limiting a page by the page[size] parameter.""" - PersonModel.mock(name='First') - PersonModel.mock(name='Second') - - parameters = {'page[size]': '1'} - query = Resource(self.model, parameters).paginate_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - self.assertTrue(result[0].name == 'First') - - def test_limit(self): - """Test limiting a page by the limit parameter.""" - PersonModel.mock(name='First') - PersonModel.mock(name='Second') - - parameters = {'limit': '1'} - query = Resource(self.model, parameters).paginate_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - self.assertTrue(result[0].name == 'First') - - def test_page_offset(self): - """Test offsetting a page by the page[offset] parameter.""" - PersonModel.mock(name='First') - PersonModel.mock(name='Second') - - parameters = {'page[offset]': '1'} - query = Resource(self.model, parameters).paginate_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - self.assertTrue(result[0].name == 'Second') - - def test_page_number(self): - """Test offsetting a page by the page[number] parameter.""" - PersonModel.mock(name='First') - PersonModel.mock(name='Second') - - parameters = {'page[number]': '1', 'page[size]': '1'} - query = Resource(self.model, parameters).paginate_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - self.assertTrue(result[0].name == 'Second') - - def test_offset(self): - """Test offsetting a page by the offset parameter.""" - PersonModel.mock(name='First') - PersonModel.mock(name='Second') - - parameters = {'offset': '1'} - query = Resource(self.model, parameters).paginate_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - self.assertTrue(result[0].name == 'Second') - - -class MarshmallowPaginationTestCase(PaginationTestCase): - """Marshmallow driver pagination tests.""" - - def test_page_limit(self): - """Test limiting a page by the page[limit] parameter.""" - PersonModel.mock(name='First') - PersonModel.mock(name='Second') - - parameters = {'page[limit]': '1'} - query = Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - paginate_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - self.assertTrue(result[0].name == 'First') - - def test_page_size(self): - """Test limiting a page by the page[size] parameter.""" - PersonModel.mock(name='First') - PersonModel.mock(name='Second') - - parameters = {'page[size]': '1'} - query = Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - paginate_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - self.assertTrue(result[0].name == 'First') - - def test_limit(self): - """Test limiting a page by the limit parameter.""" - PersonModel.mock(name='First') - PersonModel.mock(name='Second') - - parameters = {'limit': '1'} - query = Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - paginate_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - self.assertTrue(result[0].name == 'First') - - def test_page_offset(self): - """Test offsetting a page by the page[offset] parameter.""" - PersonModel.mock(name='First') - PersonModel.mock(name='Second') - - parameters = {'page[offset]': '1'} - query = Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - paginate_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - self.assertTrue(result[0].name == 'Second') - - def test_page_number(self): - """Test offsetting a page by the page[number] parameter.""" - PersonModel.mock(name='First') - PersonModel.mock(name='Second') - - parameters = {'page[number]': '1', 'page[size]': '1'} - query = Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - paginate_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - self.assertTrue(result[0].name == 'Second') - - def test_offset(self): - """Test offsetting a page by the offset parameter.""" - PersonModel.mock(name='First') - PersonModel.mock(name='Second') - - parameters = {'offset': '1'} - query = Resource( - self.model, parameters, MarshmallowDriver, PersonSchema).\ - paginate_query(self.query) - - result = query.all() - self.assertEqual(len(result), 1) - self.assertTrue(result[0].name == 'Second') diff --git a/tests/unit/sort_tests.py b/tests/unit/sort_tests.py deleted file mode 100644 index f1b18ae..0000000 --- a/tests/unit/sort_tests.py +++ /dev/null @@ -1,183 +0,0 @@ -# -*- coding: utf-8 -*- -from jsonapi_collections import Resource -from jsonapi_collections.drivers import marshmallow -from jsonapi_collections.errors import JSONAPIError -from tests import UnitTestCase -from tests.mock import CompanyModel, PersonModel, PersonSchema - - -class SortTestCase(UnitTestCase): - - def setUp(self): - super(SortTestCase, self).setUp() - self.model = PersonModel - self.view = PersonSchema - self.query = PersonModel.query - - -class SQLAlchemySortTestCase(SortTestCase): - """Test sorting with the SQLAlchemy driver.""" - - def test_sort_field_ascending(self): - """Test sorting a field in ascending order.""" - PersonModel.mock(name="A") - PersonModel.mock(name="B") - - parameters = {'sort': 'name'} - query = Resource(self.model, parameters).sort_query(self.query) - - result = query.all() - self.assertEqual(result[0].name, 'A') - - def test_sort_field_descending(self): - """Test sorting a field in descending order.""" - PersonModel.mock(name="A") - PersonModel.mock(name="B") - - parameters = {'sort': '-name'} - query = Resource(self.model, parameters).sort_query(self.query) - - result = query.all() - self.assertEqual(result[0].name, 'B') - - def test_sort_relationship_ascending(self): - """Test sorting a relationhip's field in descending order.""" - a = CompanyModel.mock(name="Last") - b = CompanyModel.mock(name="First") - PersonModel.mock(name="A", companies=[a]) - PersonModel.mock(name="B", companies=[b]) - - parameters = {'sort': 'companies.name'} - query = Resource(self.model, parameters).sort_query(self.query) - - result = query.all() - self.assertEqual(result[0].companies[0].name, 'First') - - def test_sort_relationship_descending(self): - """Test sorting a relationhip's field in descending order.""" - a = CompanyModel.mock(name="Last") - b = CompanyModel.mock(name="First") - PersonModel.mock(name="A", companies=[a]) - PersonModel.mock(name="B", companies=[b]) - - parameters = {'sort': '-companies.name'} - query = Resource(self.model, parameters).sort_query(self.query) - - result = query.all() - self.assertEqual(result[0].companies[0].name, 'Last') - - def test_sort_invalid_field(self): - """Test sorting against non-existant attribute.""" - PersonModel.mock() - - try: - parameters = {'sort': 'x'} - Resource(self.model, parameters).sort_query(self.query) - except JSONAPIError as exc: - message = exc.message - self.assertTrue( - message['errors'][0]['source']['parameter'] == 'sort') - self.assertIn('detail', message['errors'][0]) - - def test_sort_invalid_relationship(self): - """Test sorting against non-existant relationship attribute.""" - PersonModel.mock() - - try: - parameters = {'sort': 'companies.x'} - Resource(self.model, parameters).sort_query(self.query) - except JSONAPIError as exc: - message = exc.message - self.assertTrue( - message['errors'][0]['source']['parameter'] == 'sort') - self.assertIn('detail', message['errors'][0]) - - -class MarshmallowSortTestCase(SortTestCase): - """Test sorting with the marshmallow driver.""" - - def test_sort_field_ascending(self): - """Test sorting a field in ascending order.""" - PersonModel.mock(name="A") - PersonModel.mock(name="B") - - parameters = {'sort': 'name'} - query = Resource( - self.model, parameters, marshmallow.MarshmallowDriver, self.view).\ - sort_query(self.query) - - result = query.all() - self.assertEqual(result[0].name, 'A') - - def test_sort_field_descending(self): - """Test sorting a field in descending order.""" - PersonModel.mock(name="A") - PersonModel.mock(name="B") - - parameters = {'sort': '-name'} - query = Resource( - self.model, parameters, marshmallow.MarshmallowDriver, self.view).\ - sort_query(self.query) - - result = query.all() - self.assertEqual(result[0].name, 'B') - - def test_sort_relationship_ascending(self): - """Test sorting a relationhip's field in descending order.""" - a = CompanyModel.mock(name="Last") - b = CompanyModel.mock(name="First") - PersonModel.mock(name="A", companies=[a]) - PersonModel.mock(name="B", companies=[b]) - - parameters = {'sort': 'companies.name'} - query = Resource( - self.model, parameters, marshmallow.MarshmallowDriver, self.view).\ - sort_query(self.query) - - result = query.all() - self.assertEqual(result[0].companies[0].name, 'First') - - def test_sort_relationship_descending(self): - """Test sorting a relationhip's field in descending order.""" - a = CompanyModel.mock(name="Last") - b = CompanyModel.mock(name="First") - PersonModel.mock(name="A", companies=[a]) - PersonModel.mock(name="B", companies=[b]) - - parameters = {'sort': '-companies.name'} - query = Resource( - self.model, parameters, marshmallow.MarshmallowDriver, self.view).\ - sort_query(self.query) - - result = query.all() - self.assertEqual(result[0].companies[0].name, 'Last') - - def test_sort_invalid_field(self): - """Test sorting against non-existant attribute.""" - PersonModel.mock() - - try: - parameters = {'sort': 'x'} - Resource( - self.model, parameters, marshmallow.MarshmallowDriver, - self.view).sort_query(self.query) - except JSONAPIError as exc: - message = exc.message - self.assertTrue( - message['errors'][0]['source']['parameter'] == 'sort') - self.assertIn('detail', message['errors'][0]) - - def test_sort_invalid_relationship(self): - """Test sorting against non-existant relationship attribute.""" - PersonModel.mock() - - try: - parameters = {'sort': 'companies.x'} - Resource( - self.model, parameters, marshmallow.MarshmallowDriver, - self.view).sort_query(self.query) - except JSONAPIError as exc: - message = exc.message - self.assertTrue( - message['errors'][0]['source']['parameter'] == 'sort') - self.assertIn('detail', message['errors'][0]) diff --git a/tests/unit/translation/__init__.py b/tests/unit/translation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/translation/model/__init__.py b/tests/unit/translation/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/translation/model/sqlalchemy_tests.py b/tests/unit/translation/model/sqlalchemy_tests.py new file mode 100644 index 0000000..85ecc2f --- /dev/null +++ b/tests/unit/translation/model/sqlalchemy_tests.py @@ -0,0 +1,71 @@ +"""SQLAlchemy model driver module.""" +from jsonapi_query.errors import PathError +from jsonapi_query.translation.model.sqlalchemy import SQLAlchemyModelDriver +from tests.sqlalchemy import BaseSQLAlchemyTestCase, Person, School, Student + + +class SQLAlchemyModelTestCase(BaseSQLAlchemyTestCase): + """Test converting attribute paths and adding join conditions.""" + + def setUp(self): + """Ran before every test.""" + super().setUp() + # Create a row to initiate the mappers. + self.session.add(Person(name='Fred')) + self.driver = SQLAlchemyModelDriver(Person) + + def test_attribute_column(self): + """Test getting an attribute column.""" + column, models, joins = self.driver.parse_path('name') + + self.assertTrue(column == Person.name) + self.assertTrue(models == []) + self.assertTrue(joins == []) + + def test_nested_attribute_column(self): + """Test getting a related attribute column.""" + column, models, joins = self.driver.parse_path('student.school_id') + + self.assertTrue(column == Student.school_id) + self.assertTrue(models == [Student]) + self.assertTrue(joins == [Person.student]) + + def test_deeply_nested_attribute_column(self): + """Test getting a deeply related attribute column.""" + column, models, joins = self.driver.parse_path('student.school.name') + + self.assertTrue(column == School.name) + self.assertTrue(models == [Student, School]) + self.assertTrue(joins == [Person.student, Student.school]) + + def test_missing_attribute_column(self): + """Test getting a missing attribute's default column.""" + column, models, joins = self.driver.parse_path('student') + + self.assertTrue(column == Student.id) + self.assertTrue(models == [Student]) + self.assertTrue(joins == [Person.student]) + + def test_unknown_attribute(self): + """Test getting an attribute that doesn't exist on the object.""" + try: + self.driver.parse_path('height') + self.assertTrue(False) + except PathError: + self.assertTrue(True) + + def test_column_as_relationship(self): + """Test getting a relationshiup that doesn't exist on the model.""" + try: + self.driver.parse_path('age.id') + self.assertTrue(False) + except PathError: + self.assertTrue(True) + + def test_empty_path(self): + """Test parsing an empty path.""" + column, models, joins = self.driver.parse_path('') + + self.assertTrue(column is None) + self.assertTrue(models == []) + self.assertTrue(joins == []) diff --git a/tests/unit/translation/view/__init__.py b/tests/unit/translation/view/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/translation/view/marshmallow_jsonapi_tests.py b/tests/unit/translation/view/marshmallow_jsonapi_tests.py new file mode 100644 index 0000000..8295426 --- /dev/null +++ b/tests/unit/translation/view/marshmallow_jsonapi_tests.py @@ -0,0 +1,98 @@ +from datetime import date + +from jsonapi_query.errors import DataError, PathError +from jsonapi_query.translation.view.marshmallow_jsonapi import ( + MarshmallowJSONAPIDriver) +from tests.marshmallow_jsonapi import ( + BaseMarshmallowJSONAPITestCase, Person, School, Student) + + +class MarshmallowJSONAPIViewTestCase(BaseMarshmallowJSONAPITestCase): + + def setUp(self): + super().setUp() + self.driver = MarshmallowJSONAPIDriver(Person()) + + def test_initialize_path(self): + """Test initializing a path.""" + self.driver.initialize_path('student.school.title') + + self.assertTrue(len(self.driver.fields) == 3) + self.assertTrue( + self.driver.field_names == ['student', 'school', 'name']) + schemas = [schema.__class__ for schema in self.driver.schemas] + self.assertTrue(schemas == [Student, School]) + + def test_initialize_dasherized_path(self): + """Test initializing a dasherized path.""" + self.driver.initialize_path('birth-date') + + self.assertTrue( + self.driver.fields == [Person._declared_fields['birth_date']]) + self.assertTrue(self.driver.field_names == ['birth_date']) + self.assertTrue(self.driver.schemas == []) + + def test_initialize_empty_path(self): + """Test initializing an empty path.""" + driver = self.driver.initialize_path('') + self.assertTrue(driver == self.driver) + + def test_initializing_multiple_paths(self): + """Test initializing multiple paths.""" + self.driver.initialize_path('student.school.title') + self.driver.initialize_path('birth-date') + + self.assertTrue( + self.driver.fields == [Person._declared_fields['birth_date']]) + self.assertTrue(self.driver.field_names == ['birth_date']) + self.assertTrue(self.driver.schemas == []) + + def test_get_model_path(self): + """Test getting a model-safe path.""" + self.driver.initialize_path('student.school.title') + + path = self.driver.get_model_path() + self.assertTrue(path == 'student.school.name') + + def test_deserialize_values(self): + """Test deserializing a list of values.""" + self.driver.initialize_path('birth-date') + values = self.driver.deserialize_values(['2014-01-01']) + self.assertTrue(values == [date(2014, 1, 1)]) + + self.driver.initialize_path('student.school.title') + values = self.driver.deserialize_values(['PS #118']) + self.assertTrue(values == ['PS #118']) + + def test_deserialize_value(self): + """Test deserializing a single value.""" + self.driver.initialize_path('birth-date') + field = self.driver.fields[-1] + + value = self.driver.deserialize_value(field, '2014-01-01') + self.assertTrue(value == date(2014, 1, 1)) + + def test_get_invalid_attribute(self): + """Test retrieving an invalid attribute path.""" + try: + self.driver.initialize_path('student.title') + self.assertTrue(False) + except PathError: + self.assertTrue(True) + + def test_get_invalid_relationship(self): + """Test retrieving an invalid relationship path.""" + try: + self.driver.initialize_path('birth-date.school') + self.assertTrue(False) + except PathError: + self.assertTrue(True) + + def test_deserialize_invalid_value(self): + """Test deserializing an invalid value.""" + try: + self.driver.initialize_path('birth-date') + self.driver.deserialize_values(['test']) + self.assertTrue(False) + except DataError: + self.assertTrue(True) diff --git a/tests/unit/url_tests.py b/tests/unit/url_tests.py new file mode 100644 index 0000000..694b98f --- /dev/null +++ b/tests/unit/url_tests.py @@ -0,0 +1,216 @@ +from jsonapi_query.url import ( + get_parameters, get_includes, get_sorts, get_filters, get_paginators) +from tests.unit import UnitTestCase + + +class URLTestCase(UnitTestCase): + pass + + +class ParametersURLTestCase(URLTestCase): + + def test_get_parameters(self): + base_url = 'http://www.test.com/1' + query = '?filter[field]=1&page[size]=100&sort=field&include=field' + parameters = get_parameters(base_url + query) + + self.assertTrue(isinstance(parameters, dict)) + self.assertIn('page[size]', parameters) + self.assertIn('filter[field]', parameters) + self.assertIn('sort', parameters) + self.assertIn('include', parameters) + + +class IncludeURLTestCase(URLTestCase): + + def test_get_includes(self): + parameters = {'include': 'relationship,nested.relationship'} + includes = get_includes(parameters) + + self.assertTrue(isinstance(includes, list)) + self.assertTrue(len(includes) == 2) + + self.assertTrue(includes[0] == 'relationship') + self.assertTrue(includes[1] == 'nested.relationship') + + +class SortURLTestCase(URLTestCase): + + def test_get_descending_sort(self): + parameters = {'sort': '-field'} + sorts = get_sorts(parameters) + + self.assertTrue(isinstance(sorts, list)) + self.assertTrue(len(sorts) == 1) + self.assertTrue(sorts[0][0] == 'field') + self.assertTrue(sorts[0][1] == '-') + + def test_get_ascending_sort(self): + parameters = {'sort': '+field'} + sorts = get_sorts(parameters) + + self.assertTrue(isinstance(sorts, list)) + self.assertTrue(len(sorts) == 1) + self.assertTrue(sorts[0][0] == 'field') + self.assertTrue(sorts[0][1] == '+') + + def test_get_sort(self): + parameters = {'sort': 'field'} + sorts = get_sorts(parameters) + + self.assertTrue(isinstance(sorts, list)) + self.assertTrue(len(sorts) == 1) + self.assertTrue(sorts[0][0] == 'field') + self.assertTrue(sorts[0][1] == '+') + + def test_get_multiple_sorts(self): + parameters = {'sort': 'field,-relationship.field'} + sorts = get_sorts(parameters) + + self.assertTrue(isinstance(sorts, list)) + self.assertTrue(len(sorts) == 2) + self.assertTrue(sorts[0][0] == 'field') + self.assertTrue(sorts[0][1] == '+') + self.assertTrue(sorts[1][0] == 'relationship.field') + self.assertTrue(sorts[1][1] == '-') + + +class PageURLTestCase(URLTestCase): + + def test_get_paginator_limit(self): + parameters = {'page[limit]': '100'} + pages = get_paginators(parameters) + + self.assertTrue(isinstance(pages, list)) + self.assertTrue(len(pages) == 1) + + self.assertTrue(isinstance(pages[0], tuple)) + self.assertTrue(pages[0][0] == 'limit') + self.assertTrue(pages[0][1] == '100') + + def test_get_paginator_size(self): + parameters = {'page[size]': '100'} + pages = get_paginators(parameters) + + self.assertTrue(isinstance(pages, list)) + self.assertTrue(len(pages) == 1) + + self.assertTrue(isinstance(pages[0], tuple)) + self.assertTrue(pages[0][0] == 'limit') + self.assertTrue(pages[0][1] == '100') + + def test_get_paginator_offset(self): + parameters = {'page[offset]': '100'} + pages = get_paginators(parameters) + + self.assertTrue(isinstance(pages, list)) + self.assertTrue(len(pages) == 1) + + self.assertTrue(isinstance(pages[0], tuple)) + self.assertTrue(pages[0][0] == 'offset') + self.assertTrue(pages[0][1] == '100') + + def test_get_paginator_number(self): + parameters = {'page[number]': '100'} + pages = get_paginators(parameters) + + self.assertTrue(isinstance(pages, list)) + self.assertTrue(len(pages) == 1) + + self.assertTrue(isinstance(pages[0], tuple)) + self.assertTrue(pages[0][0] == 'number') + self.assertTrue(pages[0][1] == '100') + + +class FilterURLTestCase(URLTestCase): + + def test_get_filter(self): + parameters = {'filter[field]': 'gte:100'} + filters = get_filters(parameters) + + self.assertTrue(isinstance(filters, list)) + self.assertTrue(len(filters) == 1) + + self.assertTrue(isinstance(filters[0], tuple)) + self.assertTrue(filters[0][0] == 'field') + self.assertTrue(filters[0][1] == 'gte') + self.assertTrue(filters[0][2] == ['100']) + + def test_get_filter_negated_strategy(self): + parameters = {'filter[field.id]': '~in:1,2,3'} + filters = get_filters(parameters) + + self.assertTrue(isinstance(filters, list)) + self.assertTrue(len(filters) == 1) + + self.assertTrue(isinstance(filters[0], tuple)) + self.assertTrue(filters[0][0] == 'field.id') + self.assertTrue(filters[0][1] == '~in') + self.assertTrue(filters[0][2] == ['1', '2', '3']) + + def test_get_filter_multiple_fields(self): + parameters = {'filter[field1,field2]': '1'} + filters = get_filters(parameters) + + self.assertTrue(isinstance(filters, list)) + self.assertTrue(len(filters) == 1) + + self.assertTrue(isinstance(filters[0], tuple)) + self.assertTrue(filters[0][0] == 'field1,field2') + self.assertTrue(filters[0][1] is None) + self.assertTrue(filters[0][2] == ['1']) + + def test_get_filter_multiple_strategies(self): + parameters = {'filter[field]': 'eq:gte:test'} + filters = get_filters(parameters) + + self.assertTrue(isinstance(filters, list)) + self.assertTrue(len(filters) == 1) + + self.assertTrue(isinstance(filters[0], tuple)) + self.assertTrue(filters[0][0] == 'field') + self.assertTrue(filters[0][1] == 'eq') + self.assertTrue(filters[0][2] == ['gte:test']) + + def test_get_filter_multiple_values(self): + parameters = {'filter[field]': '1,2,hello'} + filters = get_filters(parameters) + + self.assertTrue(isinstance(filters, list)) + self.assertTrue(len(filters) == 1) + + self.assertTrue(isinstance(filters[0], tuple)) + self.assertTrue(filters[0][0] == 'field') + self.assertTrue(filters[0][1] is None) + self.assertTrue(filters[0][2] == ['1', '2', 'hello']) + + def test_get_filter_default_strategy(self): + parameters = {'filter[field]': 'hello'} + filters = get_filters(parameters) + + self.assertTrue(isinstance(filters, list)) + self.assertTrue(len(filters) == 1) + + self.assertTrue(isinstance(filters[0], tuple)) + self.assertTrue(filters[0][0] == 'field') + self.assertTrue(filters[0][1] is None) + self.assertTrue(filters[0][2] == ['hello']) + + def test_get_filter_invalid_strategy(self): + parameters = {'filter[field]': 'invalid:strategy'} + filters = get_filters(parameters) + + self.assertTrue(isinstance(filters, list)) + self.assertTrue(len(filters) == 1) + + self.assertTrue(isinstance(filters[0], tuple)) + self.assertTrue(filters[0][0] == 'field') + self.assertTrue(filters[0][1] is None) + self.assertTrue(filters[0][2] == ['invalid:strategy']) + + def test_skip_invalid_filter(self): + parameters = {'filter[field]t': 'eq:hello'} + filters = get_filters(parameters) + + self.assertTrue(isinstance(filters, list)) + self.assertTrue(len(filters) == 0)