Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .nvmrc

This file was deleted.

23 changes: 13 additions & 10 deletions infrahub_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ class ProcessRelationsNodeSync(TypedDict):
related_nodes: list[InfrahubNodeSync]


def handle_relogin(func: Callable[..., Coroutine[Any, Any, httpx.Response]]): # type: ignore[no-untyped-def]
def handle_relogin(
func: Callable[..., Coroutine[Any, Any, httpx.Response]],
) -> Callable[..., Coroutine[Any, Any, httpx.Response]]:
@wraps(func)
async def wrapper(client: InfrahubClient, *args: Any, **kwargs: Any) -> httpx.Response:
response = await func(client, *args, **kwargs)
Expand All @@ -108,7 +110,7 @@ async def wrapper(client: InfrahubClient, *args: Any, **kwargs: Any) -> httpx.Re
return wrapper


def handle_relogin_sync(func: Callable[..., httpx.Response]): # type: ignore[no-untyped-def]
def handle_relogin_sync(func: Callable[..., httpx.Response]) -> Callable[..., httpx.Response]:
@wraps(func)
def wrapper(client: InfrahubClientSync, *args: Any, **kwargs: Any) -> httpx.Response:
response = func(client, *args, **kwargs)
Expand Down Expand Up @@ -170,6 +172,7 @@ def __init__(
self.group_context: InfrahubGroupContext | InfrahubGroupContextSync
self._initialize()
self._request_context: RequestContext | None = None
_ = self.config.tls_context # Early load of the TLS context to catch errors

def _initialize(self) -> None:
"""Sets the properties for each version of the client"""
Expand Down Expand Up @@ -590,7 +593,7 @@ async def _process_nodes_and_relationships(
schema_kind (str): The kind of schema being queried.
branch (str): The branch name.
prefetch_relationships (bool): Flag to indicate whether to prefetch relationship data.
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.

Returns:
ProcessRelationsNodeSync: A TypedDict containing two lists:
Expand Down Expand Up @@ -710,7 +713,7 @@ async def all(
at (Timestamp, optional): Time of the query. Defaults to Now.
branch (str, optional): Name of the branch to query from. Defaults to default_branch.
populate_store (bool, optional): Flag to indicate whether to populate the store with the retrieved nodes.
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
offset (int, optional): The offset for pagination.
limit (int, optional): The limit for pagination.
include (list[str], optional): List of attributes or relationships to include in the query.
Expand Down Expand Up @@ -807,7 +810,7 @@ async def filters(
kind (str): kind of the nodes to query
at (Timestamp, optional): Time of the query. Defaults to Now.
branch (str, optional): Name of the branch to query from. Defaults to default_branch.
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
populate_store (bool, optional): Flag to indicate whether to populate the store with the retrieved nodes.
offset (int, optional): The offset for pagination.
limit (int, optional): The limit for pagination.
Expand Down Expand Up @@ -1089,7 +1092,7 @@ async def _default_request_method(

async with httpx.AsyncClient(
**proxy_config,
verify=self.config.tls_ca_file if self.config.tls_ca_file else not self.config.tls_insecure,
verify=self.config.tls_context,
) as client:
try:
response = await client.request(
Expand Down Expand Up @@ -1961,7 +1964,7 @@ def all(
kind (str): kind of the nodes to query
at (Timestamp, optional): Time of the query. Defaults to Now.
branch (str, optional): Name of the branch to query from. Defaults to default_branch.
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
populate_store (bool, optional): Flag to indicate whether to populate the store with the retrieved nodes.
offset (int, optional): The offset for pagination.
limit (int, optional): The limit for pagination.
Expand Down Expand Up @@ -2008,7 +2011,7 @@ def _process_nodes_and_relationships(
schema_kind (str): The kind of schema being queried.
branch (str): The branch name.
prefetch_relationships (bool): Flag to indicate whether to prefetch relationship data.
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.

Returns:
ProcessRelationsNodeSync: A TypedDict containing two lists:
Expand Down Expand Up @@ -2100,7 +2103,7 @@ def filters(
kind (str): kind of the nodes to query
at (Timestamp, optional): Time of the query. Defaults to Now.
branch (str, optional): Name of the branch to query from. Defaults to default_branch.
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
populate_store (bool, optional): Flag to indicate whether to populate the store with the retrieved nodes.
offset (int, optional): The offset for pagination.
limit (int, optional): The limit for pagination.
Expand Down Expand Up @@ -2929,7 +2932,7 @@ def _default_request_method(

with httpx.Client(
**proxy_config,
verify=self.config.tls_ca_file if self.config.tls_ca_file else not self.config.tls_insecure,
verify=self.config.tls_context,
) as client:
try:
response = client.request(
Expand Down
31 changes: 29 additions & 2 deletions infrahub_sdk/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import ssl
from copy import deepcopy
from typing import Any

from pydantic import Field, field_validator, model_validator
from pydantic import Field, PrivateAttr, field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing_extensions import Self

Expand Down Expand Up @@ -78,6 +79,7 @@ class ConfigBase(BaseSettings):
Can be useful to test with self-signed certificates.""",
)
tls_ca_file: str | None = Field(default=None, description="File path to CA cert or bundle in PEM format")
_ssl_context: ssl.SSLContext | None = PrivateAttr(default=None)

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -133,6 +135,28 @@ def default_infrahub_branch(self) -> str:
def password_authentication(self) -> bool:
return bool(self.username)

@property
def tls_context(self) -> ssl.SSLContext:
if self._ssl_context:
return self._ssl_context

if self.tls_insecure:
self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self._ssl_context.check_hostname = False
self._ssl_context.verify_mode = ssl.CERT_NONE
return self._ssl_context

if self.tls_ca_file:
self._ssl_context = ssl.create_default_context(cafile=self.tls_ca_file)

if self._ssl_context is None:
self._ssl_context = ssl.create_default_context()

return self._ssl_context

def set_ssl_context(self, context: ssl.SSLContext) -> None:
self._ssl_context = context
Comment on lines +138 to +158
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard against stale/insecure TLS contexts.

Once _ssl_context is populated, this property returns it forever. Because we eagerly load the context in BaseClient.__init__, changing config.tls_insecure or config.tls_ca_file later never takes effect—e.g., starting in insecure mode leaves verify_mode == CERT_NONE even after you flip tls_insecure=False. That’s an immediate security regression. Please invalidate the cache whenever the relevant settings change (or track the settings used to build the cache). One approach:

-_ssl_context: ssl.SSLContext | None = PrivateAttr(default=None)
+_ssl_context: ssl.SSLContext | None = PrivateAttr(default=None)
+_ssl_context_key: tuple[bool, str | None] | Literal["custom"] | None = PrivateAttr(default=None)

     @property
     def tls_context(self) -> ssl.SSLContext:
-        if self._ssl_context:
+        cache_key: tuple[bool, str | None] | Literal["custom"] = (
+            "custom" if self._ssl_context_key == "custom" else (self.tls_insecure, self.tls_ca_file)
+        )
+        if self._ssl_context and self._ssl_context_key == cache_key:
             return self._ssl_context
@@
-        if self.tls_insecure:
+        if self.tls_insecure:
             self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
             self._ssl_context.check_hostname = False
             self._ssl_context.verify_mode = ssl.CERT_NONE
-            return self._ssl_context
+            self._ssl_context_key = cache_key
+            return self._ssl_context
@@
-        if self.tls_ca_file:
-            self._ssl_context = ssl.create_default_context(cafile=self.tls_ca_file)
+        if self.tls_ca_file:
+            self._ssl_context = ssl.create_default_context(cafile=self.tls_ca_file)
+            self._ssl_context_key = cache_key
+            return self._ssl_context
@@
-        if self._ssl_context is None:
-            self._ssl_context = ssl.create_default_context()
-
-        return self._ssl_context
+        self._ssl_context = ssl.create_default_context()
+        self._ssl_context_key = cache_key
+        return self._ssl_context
@@
     def set_ssl_context(self, context: ssl.SSLContext) -> None:
-        self._ssl_context = context
+        self._ssl_context = context
+        self._ssl_context_key = "custom"

Any similar invalidation is fine so long as updating the config rebuilds the context.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In infrahub_sdk/config.py around lines 138 to 158, the tls_context property
caches a single SSLContext and never rebuilds it when config attributes change,
which can leave an insecure context (CERT_NONE) active after toggling
tls_insecure or updating tls_ca_file; fix by invalidating or rebuilding the
cached _ssl_context whenever relevant settings change—either (a) add setters for
tls_insecure and tls_ca_file that clear self._ssl_context (set to None) whenever
their values are modified, or (b) store the values used to build the current
context and in tls_context compare current settings to those stored and rebuild
if they differ; ensure set_ssl_context still works (overwrites cache) and
document that changing config rebuilds the context on next access.



class Config(ConfigBase):
recorder: RecorderType = Field(default=RecorderType.NONE, description="Select builtin recorder for later replay.")
Expand Down Expand Up @@ -174,4 +198,7 @@ def clone(self, branch: str | None = None) -> Config:
if field not in covered_keys:
config[field] = deepcopy(getattr(self, field))

return Config(**config)
new_config = Config(**config)
if self._ssl_context:
new_config.set_ssl_context(self._ssl_context)
return new_config
86 changes: 85 additions & 1 deletion tests/unit/sdk/test_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import inspect
import ssl
from pathlib import Path

import pytest
from pytest_httpx import HTTPXMock

from infrahub_sdk import InfrahubClient, InfrahubClientSync
from infrahub_sdk import Config, InfrahubClient, InfrahubClientSync
from infrahub_sdk.exceptions import NodeNotFoundError
from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync
from tests.unit.sdk.conftest import BothClients
Expand All @@ -28,6 +30,88 @@

client_types = ["standard", "sync"]

CURRENT_DIRECTORY = Path(__file__).parent


async def test_verify_config_caches_default_ssl_context(monkeypatch) -> None:
contexts: list[tuple[str | None, object]] = []

def fake_create_default_context(*args: object, **kwargs: object) -> object:
context = object()
contexts.append((kwargs.get("cafile"), context))
return context

monkeypatch.setattr("ssl.create_default_context", fake_create_default_context)

client = InfrahubClient(config=Config(address="http://mock"))

first = client.config.tls_context
second = client.config.tls_context

assert first is second
assert contexts == [(None, first)]


async def test_verify_config_caches_tls_ca_file_context(monkeypatch) -> None:
contexts: list[tuple[str | None, object]] = []

def fake_create_default_context(*args: object, **kwargs: object) -> object:
context = object()
contexts.append((kwargs.get("cafile"), context))
return context

monkeypatch.setattr("ssl.create_default_context", fake_create_default_context)

client = InfrahubClient(
config=Config(address="http://mock", tls_ca_file=str(CURRENT_DIRECTORY / "test_data/path-1.pem"))
)

first = client.config.tls_context
second = client.config.tls_context

assert first is second
assert contexts == [(str(CURRENT_DIRECTORY / "test_data/path-1.pem"), first)]

client.config.tls_ca_file = str(CURRENT_DIRECTORY / "test_data/path-2.pem")
third = client.config.tls_context

assert third is first
assert contexts == [
(str(CURRENT_DIRECTORY / "test_data/path-1.pem"), first),
]
Comment on lines +75 to +81
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Rebuild SSL context when configuration changes.

This test codifies that mutating Config.tls_ca_file after first access leaves the cached context untouched, so the client keeps using the old trust store. With validate_assignment=True, callers expect changing tls_ca_file (or toggling tls_insecure) to take effect immediately. Right now the cache masks those updates, which becomes a security risk when trying to re-enable verification. Please adjust the implementation to invalidate the cache when relevant fields change and update this test to expect a fresh context instead of the stale one.

🤖 Prompt for AI Agents
In tests/unit/sdk/test_client.py around lines 75 to 81, the test currently
expects the cached TLS context to remain the same after mutating
client.config.tls_ca_file; instead, update the implementation to
invalidate/rebuild the cached tls_context when relevant Config fields change (at
least tls_ca_file and tls_insecure) so changes take effect immediately; modify
the Config class (e.g., in the setters or in __setattr__ when
validate_assignment is enabled) to clear the cached context on those
assignments, and update this test to assert that a new context object is
returned and that the recorded contexts list includes both the original and the
new (i.e., expect a fresh context rather than the stale one).



async def test_verify_config_respects_tls_insecure(monkeypatch) -> None:
def fake_create_default_context(*args: object, **kwargs: object) -> object:
raise AssertionError("create_default_context should not be called when TLS is insecure")

monkeypatch.setattr("ssl.create_default_context", fake_create_default_context)

client = InfrahubClient(config=Config(address="http://mock", tls_insecure=True))

verify_value = client.config.tls_context

assert verify_value.check_hostname is False
assert verify_value.verify_mode == ssl.CERT_NONE


async def test_verify_config_uses_custom_tls_context(monkeypatch) -> None:
def fake_create_default_context(*args: object, **kwargs: object) -> object:
raise AssertionError("create_default_context should not be called when custom context is provided")

monkeypatch.setattr("ssl.create_default_context", fake_create_default_context)

config = Config(address="http://mock")
custom_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)
config.set_ssl_context(custom_context)

client = InfrahubClient(config=config)

clone_client = client.clone()

assert client.config.tls_context is custom_context
assert clone_client.config.tls_context is custom_context


async def test_method_sanity() -> None:
"""Validate that there is at least one public method and that both clients look the same."""
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/sdk/test_data/path-1.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
-----BEGIN CERTIFICATE-----
MIIBQDCB86ADAgECAhR6y429KiST51bZy+t330M7dN5SbzAFBgMrZXAwFjEUMBIG
A1UEAwwLZXhhbXBsZS5jb20wHhcNMjUxMDE1MTE0MjUwWhcNMzUxMDEzMTE0MjUw
WjAWMRQwEgYDVQQDDAtleGFtcGxlLmNvbTAqMAUGAytlcAMhAPIl8y8AXSWF33vX
JT2YwhMJzarOuSdPif01Gxr3Rr6Lo1MwUTAdBgNVHQ4EFgQU4heN1ZhyXpOujgcJ
WZ4LQk2m7RAwHwYDVR0jBBgwFoAU4heN1ZhyXpOujgcJWZ4LQk2m7RAwDwYDVR0T
AQH/BAUwAwEB/zAFBgMrZXADQQBoEf+8R+KWwGdaoeqinWOvrqbVZatMis0eUMvA
o+vABSPU7LIYGxLT6fpUwFSTvempzNqGZMVJ9UvVH+hYDU4D
-----END CERTIFICATE-----
9 changes: 9 additions & 0 deletions tests/unit/sdk/test_data/path-2.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
-----BEGIN CERTIFICATE-----
MIIBQDCB86ADAgECAhQTRmRZxUSA5L7VfYJb3/t+dRK0ETAFBgMrZXAwFjEUMBIG
A1UEAwwLZXhhbXBsZS5jb20wHhcNMjUxMDE1MTE0MzM0WhcNMzUxMDEzMTE0MzM0
WjAWMRQwEgYDVQQDDAtleGFtcGxlLmNvbTAqMAUGAytlcAMhAK1O3ZhE5qzfT7Qx
+0My3ToDVDi5wwpllkKn0X50zXFao1MwUTAdBgNVHQ4EFgQUH+qBMU+h4t1vdLbO
jMSSgXdURewwHwYDVR0jBBgwFoAUH+qBMU+h4t1vdLbOjMSSgXdURewwDwYDVR0T
AQH/BAUwAwEB/zAFBgMrZXADQQB3Z03f3gQcktxk4h/v8pVi5soz8viPx17TSPXf
1WYG+Jlk4C5GQ+tyjZgZUE9LL2BFRYBv28V/NPT/0TjPGtcC
-----END CERTIFICATE-----