Skip to content

Commit 7b4adf7

Browse files
committed
Resolve merge conflict for infrahub_sdk/client.py
2 parents 5fb0758 + 496376e commit 7b4adf7

File tree

6 files changed

+145
-14
lines changed

6 files changed

+145
-14
lines changed

.nvmrc

Lines changed: 0 additions & 1 deletion
This file was deleted.

infrahub_sdk/client.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ class ProcessRelationsNodeSync(TypedDict):
9494
related_nodes: list[InfrahubNodeSync]
9595

9696

97-
def handle_relogin(func: Callable[..., Coroutine[Any, Any, httpx.Response]]): # type: ignore[no-untyped-def]
97+
def handle_relogin(
98+
func: Callable[..., Coroutine[Any, Any, httpx.Response]],
99+
) -> Callable[..., Coroutine[Any, Any, httpx.Response]]:
98100
@wraps(func)
99101
async def wrapper(client: InfrahubClient, *args: Any, **kwargs: Any) -> httpx.Response:
100102
response = await func(client, *args, **kwargs)
@@ -108,7 +110,7 @@ async def wrapper(client: InfrahubClient, *args: Any, **kwargs: Any) -> httpx.Re
108110
return wrapper
109111

110112

111-
def handle_relogin_sync(func: Callable[..., httpx.Response]): # type: ignore[no-untyped-def]
113+
def handle_relogin_sync(func: Callable[..., httpx.Response]) -> Callable[..., httpx.Response]:
112114
@wraps(func)
113115
def wrapper(client: InfrahubClientSync, *args: Any, **kwargs: Any) -> httpx.Response:
114116
response = func(client, *args, **kwargs)
@@ -170,6 +172,7 @@ def __init__(
170172
self.group_context: InfrahubGroupContext | InfrahubGroupContextSync
171173
self._initialize()
172174
self._request_context: RequestContext | None = None
175+
_ = self.config.tls_context # Early load of the TLS context to catch errors
173176

174177
def _initialize(self) -> None:
175178
"""Sets the properties for each version of the client"""
@@ -590,7 +593,7 @@ async def _process_nodes_and_relationships(
590593
schema_kind (str): The kind of schema being queried.
591594
branch (str): The branch name.
592595
prefetch_relationships (bool): Flag to indicate whether to prefetch relationship data.
593-
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
596+
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
594597
595598
Returns:
596599
ProcessRelationsNodeSync: A TypedDict containing two lists:
@@ -710,7 +713,7 @@ async def all(
710713
at (Timestamp, optional): Time of the query. Defaults to Now.
711714
branch (str, optional): Name of the branch to query from. Defaults to default_branch.
712715
populate_store (bool, optional): Flag to indicate whether to populate the store with the retrieved nodes.
713-
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
716+
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
714717
offset (int, optional): The offset for pagination.
715718
limit (int, optional): The limit for pagination.
716719
include (list[str], optional): List of attributes or relationships to include in the query.
@@ -807,7 +810,7 @@ async def filters(
807810
kind (str): kind of the nodes to query
808811
at (Timestamp, optional): Time of the query. Defaults to Now.
809812
branch (str, optional): Name of the branch to query from. Defaults to default_branch.
810-
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
813+
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
811814
populate_store (bool, optional): Flag to indicate whether to populate the store with the retrieved nodes.
812815
offset (int, optional): The offset for pagination.
813816
limit (int, optional): The limit for pagination.
@@ -1089,7 +1092,7 @@ async def _default_request_method(
10891092

10901093
async with httpx.AsyncClient(
10911094
**proxy_config,
1092-
verify=self.config.tls_ca_file if self.config.tls_ca_file else not self.config.tls_insecure,
1095+
verify=self.config.tls_context,
10931096
) as client:
10941097
try:
10951098
response = await client.request(
@@ -1961,7 +1964,7 @@ def all(
19611964
kind (str): kind of the nodes to query
19621965
at (Timestamp, optional): Time of the query. Defaults to Now.
19631966
branch (str, optional): Name of the branch to query from. Defaults to default_branch.
1964-
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
1967+
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
19651968
populate_store (bool, optional): Flag to indicate whether to populate the store with the retrieved nodes.
19661969
offset (int, optional): The offset for pagination.
19671970
limit (int, optional): The limit for pagination.
@@ -2008,7 +2011,7 @@ def _process_nodes_and_relationships(
20082011
schema_kind (str): The kind of schema being queried.
20092012
branch (str): The branch name.
20102013
prefetch_relationships (bool): Flag to indicate whether to prefetch relationship data.
2011-
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
2014+
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
20122015
20132016
Returns:
20142017
ProcessRelationsNodeSync: A TypedDict containing two lists:
@@ -2100,7 +2103,7 @@ def filters(
21002103
kind (str): kind of the nodes to query
21012104
at (Timestamp, optional): Time of the query. Defaults to Now.
21022105
branch (str, optional): Name of the branch to query from. Defaults to default_branch.
2103-
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
2106+
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
21042107
populate_store (bool, optional): Flag to indicate whether to populate the store with the retrieved nodes.
21052108
offset (int, optional): The offset for pagination.
21062109
limit (int, optional): The limit for pagination.
@@ -2929,7 +2932,7 @@ def _default_request_method(
29292932

29302933
with httpx.Client(
29312934
**proxy_config,
2932-
verify=self.config.tls_ca_file if self.config.tls_ca_file else not self.config.tls_insecure,
2935+
verify=self.config.tls_context,
29332936
) as client:
29342937
try:
29352938
response = client.request(

infrahub_sdk/config.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

3+
import ssl
34
from copy import deepcopy
45
from typing import Any
56

6-
from pydantic import Field, field_validator, model_validator
7+
from pydantic import Field, PrivateAttr, field_validator, model_validator
78
from pydantic_settings import BaseSettings, SettingsConfigDict
89
from typing_extensions import Self
910

@@ -78,6 +79,7 @@ class ConfigBase(BaseSettings):
7879
Can be useful to test with self-signed certificates.""",
7980
)
8081
tls_ca_file: str | None = Field(default=None, description="File path to CA cert or bundle in PEM format")
82+
_ssl_context: ssl.SSLContext | None = PrivateAttr(default=None)
8183

8284
@model_validator(mode="before")
8385
@classmethod
@@ -133,6 +135,28 @@ def default_infrahub_branch(self) -> str:
133135
def password_authentication(self) -> bool:
134136
return bool(self.username)
135137

138+
@property
139+
def tls_context(self) -> ssl.SSLContext:
140+
if self._ssl_context:
141+
return self._ssl_context
142+
143+
if self.tls_insecure:
144+
self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
145+
self._ssl_context.check_hostname = False
146+
self._ssl_context.verify_mode = ssl.CERT_NONE
147+
return self._ssl_context
148+
149+
if self.tls_ca_file:
150+
self._ssl_context = ssl.create_default_context(cafile=self.tls_ca_file)
151+
152+
if self._ssl_context is None:
153+
self._ssl_context = ssl.create_default_context()
154+
155+
return self._ssl_context
156+
157+
def set_ssl_context(self, context: ssl.SSLContext) -> None:
158+
self._ssl_context = context
159+
136160

137161
class Config(ConfigBase):
138162
recorder: RecorderType = Field(default=RecorderType.NONE, description="Select builtin recorder for later replay.")
@@ -174,4 +198,7 @@ def clone(self, branch: str | None = None) -> Config:
174198
if field not in covered_keys:
175199
config[field] = deepcopy(getattr(self, field))
176200

177-
return Config(**config)
201+
new_config = Config(**config)
202+
if self._ssl_context:
203+
new_config.set_ssl_context(self._ssl_context)
204+
return new_config

tests/unit/sdk/test_client.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import inspect
2+
import ssl
3+
from pathlib import Path
24

35
import pytest
46
from pytest_httpx import HTTPXMock
57

6-
from infrahub_sdk import InfrahubClient, InfrahubClientSync
8+
from infrahub_sdk import Config, InfrahubClient, InfrahubClientSync
79
from infrahub_sdk.exceptions import NodeNotFoundError
810
from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync
911
from tests.unit.sdk.conftest import BothClients
@@ -28,6 +30,88 @@
2830

2931
client_types = ["standard", "sync"]
3032

33+
CURRENT_DIRECTORY = Path(__file__).parent
34+
35+
36+
async def test_verify_config_caches_default_ssl_context(monkeypatch) -> None:
37+
contexts: list[tuple[str | None, object]] = []
38+
39+
def fake_create_default_context(*args: object, **kwargs: object) -> object:
40+
context = object()
41+
contexts.append((kwargs.get("cafile"), context))
42+
return context
43+
44+
monkeypatch.setattr("ssl.create_default_context", fake_create_default_context)
45+
46+
client = InfrahubClient(config=Config(address="http://mock"))
47+
48+
first = client.config.tls_context
49+
second = client.config.tls_context
50+
51+
assert first is second
52+
assert contexts == [(None, first)]
53+
54+
55+
async def test_verify_config_caches_tls_ca_file_context(monkeypatch) -> None:
56+
contexts: list[tuple[str | None, object]] = []
57+
58+
def fake_create_default_context(*args: object, **kwargs: object) -> object:
59+
context = object()
60+
contexts.append((kwargs.get("cafile"), context))
61+
return context
62+
63+
monkeypatch.setattr("ssl.create_default_context", fake_create_default_context)
64+
65+
client = InfrahubClient(
66+
config=Config(address="http://mock", tls_ca_file=str(CURRENT_DIRECTORY / "test_data/path-1.pem"))
67+
)
68+
69+
first = client.config.tls_context
70+
second = client.config.tls_context
71+
72+
assert first is second
73+
assert contexts == [(str(CURRENT_DIRECTORY / "test_data/path-1.pem"), first)]
74+
75+
client.config.tls_ca_file = str(CURRENT_DIRECTORY / "test_data/path-2.pem")
76+
third = client.config.tls_context
77+
78+
assert third is first
79+
assert contexts == [
80+
(str(CURRENT_DIRECTORY / "test_data/path-1.pem"), first),
81+
]
82+
83+
84+
async def test_verify_config_respects_tls_insecure(monkeypatch) -> None:
85+
def fake_create_default_context(*args: object, **kwargs: object) -> object:
86+
raise AssertionError("create_default_context should not be called when TLS is insecure")
87+
88+
monkeypatch.setattr("ssl.create_default_context", fake_create_default_context)
89+
90+
client = InfrahubClient(config=Config(address="http://mock", tls_insecure=True))
91+
92+
verify_value = client.config.tls_context
93+
94+
assert verify_value.check_hostname is False
95+
assert verify_value.verify_mode == ssl.CERT_NONE
96+
97+
98+
async def test_verify_config_uses_custom_tls_context(monkeypatch) -> None:
99+
def fake_create_default_context(*args: object, **kwargs: object) -> object:
100+
raise AssertionError("create_default_context should not be called when custom context is provided")
101+
102+
monkeypatch.setattr("ssl.create_default_context", fake_create_default_context)
103+
104+
config = Config(address="http://mock")
105+
custom_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)
106+
config.set_ssl_context(custom_context)
107+
108+
client = InfrahubClient(config=config)
109+
110+
clone_client = client.clone()
111+
112+
assert client.config.tls_context is custom_context
113+
assert clone_client.config.tls_context is custom_context
114+
31115

32116
async def test_method_sanity() -> None:
33117
"""Validate that there is at least one public method and that both clients look the same."""
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
-----BEGIN CERTIFICATE-----
2+
MIIBQDCB86ADAgECAhR6y429KiST51bZy+t330M7dN5SbzAFBgMrZXAwFjEUMBIG
3+
A1UEAwwLZXhhbXBsZS5jb20wHhcNMjUxMDE1MTE0MjUwWhcNMzUxMDEzMTE0MjUw
4+
WjAWMRQwEgYDVQQDDAtleGFtcGxlLmNvbTAqMAUGAytlcAMhAPIl8y8AXSWF33vX
5+
JT2YwhMJzarOuSdPif01Gxr3Rr6Lo1MwUTAdBgNVHQ4EFgQU4heN1ZhyXpOujgcJ
6+
WZ4LQk2m7RAwHwYDVR0jBBgwFoAU4heN1ZhyXpOujgcJWZ4LQk2m7RAwDwYDVR0T
7+
AQH/BAUwAwEB/zAFBgMrZXADQQBoEf+8R+KWwGdaoeqinWOvrqbVZatMis0eUMvA
8+
o+vABSPU7LIYGxLT6fpUwFSTvempzNqGZMVJ9UvVH+hYDU4D
9+
-----END CERTIFICATE-----
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
-----BEGIN CERTIFICATE-----
2+
MIIBQDCB86ADAgECAhQTRmRZxUSA5L7VfYJb3/t+dRK0ETAFBgMrZXAwFjEUMBIG
3+
A1UEAwwLZXhhbXBsZS5jb20wHhcNMjUxMDE1MTE0MzM0WhcNMzUxMDEzMTE0MzM0
4+
WjAWMRQwEgYDVQQDDAtleGFtcGxlLmNvbTAqMAUGAytlcAMhAK1O3ZhE5qzfT7Qx
5+
+0My3ToDVDi5wwpllkKn0X50zXFao1MwUTAdBgNVHQ4EFgQUH+qBMU+h4t1vdLbO
6+
jMSSgXdURewwHwYDVR0jBBgwFoAUH+qBMU+h4t1vdLbOjMSSgXdURewwDwYDVR0T
7+
AQH/BAUwAwEB/zAFBgMrZXADQQB3Z03f3gQcktxk4h/v8pVi5soz8viPx17TSPXf
8+
1WYG+Jlk4C5GQ+tyjZgZUE9LL2BFRYBv28V/NPT/0TjPGtcC
9+
-----END CERTIFICATE-----

0 commit comments

Comments
 (0)