diff --git a/CHANGELOG.md b/CHANGELOG.md index d9f1007e2..d81f57074 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). # Release Notes ## [Unreleased](https://github.com/algolia/algoliasearch-client-python/compare/v2.5.0...master) +- Add support for Answers API [#528](https://github.com/algolia/algoliasearch-client-python/pull/528) ## [v2.5.0](https://github.com/algolia/algoliasearch-client-python/compare/v2.4.0...v2.5.0) diff --git a/algoliasearch/answers_client.py b/algoliasearch/answers_client.py new file mode 100644 index 000000000..d443abcad --- /dev/null +++ b/algoliasearch/answers_client.py @@ -0,0 +1,61 @@ +from typing import Optional, Union, Dict, Any + +from algoliasearch.configs import AnswersConfig +from algoliasearch.helpers import is_async_available +from algoliasearch.http.request_options import RequestOptions +from algoliasearch.http.requester import Requester +from algoliasearch.http.transporter import Transporter +from algoliasearch.http.verb import Verb + + +class AnswersClient(object): + def __init__(self, transporter, config): + # type: (Transporter, AnswersConfig) -> None + + self._transporter = transporter + self._config = config + + @staticmethod + def create(app_id=None, api_key=None): + # type: (Optional[str], Optional[str]) -> AnswersClient # noqa: E501 + + config = AnswersConfig(app_id, api_key) + + return AnswersClient.create_with_config(config) + + @staticmethod + def create_with_config(config): + # type: (AnswersConfig) -> AnswersClient + + requester = Requester() + transporter = Transporter(requester, config) + + client = AnswersClient(transporter, config) + + if is_async_available(): + from algoliasearch.answers_client_async import AnswersClientAsync + from algoliasearch.http.transporter_async import TransporterAsync + from algoliasearch.http.requester_async import RequesterAsync + + return AnswersClientAsync( + client, TransporterAsync(RequesterAsync(), config), config + ) + + return client + + def predict( + self, index_name, answers_parameters, request_options=None + ): # noqa: E501 + # type: (str, dict, Optional[Union[dict, RequestOptions]]) -> dict + + return self._transporter.write( + Verb.POST, + "1/answers/{}/prediction".format(index_name), + answers_parameters, + request_options, + ) + + def close(self): + # type: () -> None + + return self._transporter.close() # type: ignore diff --git a/algoliasearch/answers_client_async.py b/algoliasearch/answers_client_async.py new file mode 100644 index 000000000..c14d57358 --- /dev/null +++ b/algoliasearch/answers_client_async.py @@ -0,0 +1,43 @@ +import types +import asyncio +from typing import Optional, Type + +from algoliasearch.answers_client import AnswersClient +from algoliasearch.configs import AnswersConfig +from algoliasearch.helpers_async import _create_async_methods_in +from algoliasearch.http.transporter_async import TransporterAsync + + +class AnswersClientAsync(AnswersClient): + def __init__(self, answers_client, transporter, search_config): + # type: (AnswersClient, TransporterAsync, AnswersConfig) -> None # noqa: E501 + + self._transporter_async = transporter + + super(AnswersClientAsync, self).__init__( + answers_client._transporter, search_config + ) + + client = AnswersClient(transporter, search_config) + + _create_async_methods_in(self, client) + + @asyncio.coroutine + def __aenter__(self): + # type: () -> AnswersClientAsync # type: ignore + + return self # type: ignore + + @asyncio.coroutine + def __aexit__(self, exc_type, exc, tb): # type: ignore + # type: (Optional[Type[BaseException]], Optional[BaseException],Optional[types.TracebackType]) -> None # noqa: E501 + + yield from self.close_async() # type: ignore + + @asyncio.coroutine + def close_async(self): # type: ignore + # type: () -> None + + super().close() + + yield from self._transporter_async.close() # type: ignore diff --git a/algoliasearch/configs.py b/algoliasearch/configs.py index 02cc9a08d..fd47583fa 100644 --- a/algoliasearch/configs.py +++ b/algoliasearch/configs.py @@ -124,3 +124,10 @@ def build_hosts(self): return HostsCollection( [Host("{}.{}.{}".format("recommendation", self._region, "algolia.com"))] ) + + +class AnswersConfig(SearchConfig): + def __init__(self, app_id=None, api_key=None): + # type: (Optional[str], Optional[str]) -> None + + super(AnswersConfig, self).__init__(app_id, api_key) diff --git a/tests/features/test_answers_client.py b/tests/features/test_answers_client.py new file mode 100644 index 000000000..6cf2d82cc --- /dev/null +++ b/tests/features/test_answers_client.py @@ -0,0 +1,45 @@ +import unittest + +from algoliasearch.exceptions import RequestException +from tests.helpers.factory import Factory as F + + +class TestAnswersClient(unittest.TestCase): + def setUp(self): + self.search_client = F.search_client() + self.client = F.answers_client() + self.index = F.index(self.search_client, self._testMethodName) + self.index.save_objects( + [ + { + "name": "Something", + "description": "The usage is strong in that one", + "objectID": 0, + }, + { + "name": "Another thing", + "description": "This is creative, but unused. ;)", + "objectID": 1, + }, + ] + ).wait() + + def tearDown(self): + self.client.close() + + def test_answers(self): + data = { + "query": "Any usage?", + "queryLanguages": ["en"], + "attributesForPrediction": ["title", "description"], + "nbHits": 2, + } + + try: + response = self.client.predict(self.index.name, data) + print(response) + self.assertTrue("hits" in response) + self.assertIn("usage", response["hits"][0]["_answer"]["extract"]) + self.assertEqual("0", response["hits"][0]["objectID"]) + except RequestException as err: + self.fail(err) # noqa: E501 diff --git a/tests/helpers/factory.py b/tests/helpers/factory.py index 581f45837..ec868bd8d 100644 --- a/tests/helpers/factory.py +++ b/tests/helpers/factory.py @@ -8,6 +8,7 @@ from typing import Optional from algoliasearch.analytics_client import AnalyticsClient +from algoliasearch.answers_client import AnswersClient from algoliasearch.insights_client import InsightsClient from algoliasearch.search_client import SearchClient, SearchConfig from algoliasearch.recommendation_client import RecommendationClient @@ -87,6 +88,15 @@ def recommendation_client(app_id=None, api_key=None): return Factory.decide(RecommendationClient.create(app_id, api_key)) + @staticmethod + def answers_client(app_id=None, api_key=None): + # type: (Optional[str], Optional[str]) -> AnswersClient + + app_id = app_id if app_id is not None else Factory.get_app_id() + api_key = api_key if api_key is not None else Factory.get_api_key() + + return Factory.decide(AnswersClient.create(app_id, api_key)) + @staticmethod def insights_client(app_id=None, api_key=None): # type: (Optional[str], Optional[str]) -> InsightsClient