diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index 0562cb2cf..9d42c3e1b 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -140,7 +140,8 @@ def request_token(self) -> Optional[str]: a token from the server and returns it. """ if self._auth: - self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1]) + client = self._client() + client.authenticate_basic_token(self._auth[0], self._auth[1]) return self._auth_middleware.token() else: return "IGNORED" @@ -654,8 +655,9 @@ def _do_get( ticket = flight.Ticket(json.dumps(payload).encode("utf-8")) + client = self._client() try: - get = self._flight_client.do_get(ticket) + get = client.do_get(ticket) arrow_table = get.read_all() except Exception as e: self.handle_flight_error(e) @@ -683,10 +685,11 @@ def __exit__( exception_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: - self._flight_client.close() + self.close() def close(self) -> None: - self._flight_client.close() + if self._flight_client: + self._flight_client.close() def _versioned_action_type(self, action_type: str) -> str: return self._arrow_endpoint_version.prefix() + action_type