Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT extend http target to allow custom http client #804

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 56 additions & 17 deletions pyrit/prompt_target/http_target/http_target.py
Original file line number Diff line number Diff line change
@@ -44,6 +44,7 @@ def __init__(
use_tls: bool = True,
callback_function: Callable | None = None,
max_requests_per_minute: Optional[int] = None,
client: Optional[httpx.AsyncClient] = None,
**httpx_client_kwargs: Any,
) -> None:
super().__init__(max_requests_per_minute=max_requests_per_minute)
@@ -52,36 +53,72 @@ def __init__(
self.prompt_regex_string = prompt_regex_string
self.use_tls = use_tls
self.httpx_client_kwargs = httpx_client_kwargs or {}
self._client = client

@limit_requests_per_minute
async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:
"""
Sends prompt to HTTP endpoint and returns the response
@classmethod
def with_client(
cls,
client: httpx.AsyncClient,
http_request: str,
prompt_regex_string: str = "{PROMPT}",
callback_function: Callable | None = None,
max_requests_per_minute: Optional[int] = None,
) -> "HTTPTarget":
"""
Alternative constructor that accepts a pre-configured httpx client.

self._validate_request(prompt_request=prompt_request)
request = prompt_request.request_pieces[0]

# Add Prompt into URL (if the URL takes it)
Parameters:
client: Pre-configured httpx.AsyncClient instance
http_request: the header parameters as a request (i.e., from Burp)
prompt_regex_string: the placeholder for the prompt
callback_function: function to parse HTTP response
max_requests_per_minute: Optional rate limiting
"""
instance = cls(
http_request=http_request,
prompt_regex_string=prompt_regex_string,
callback_function=callback_function,
max_requests_per_minute=max_requests_per_minute,
_client=client,
)
return instance

def _inject_prompt_into_request(self, request: PromptRequestPiece) -> str:
"""
Adds the prompt into the URL if the prompt_regex_string is found in the
http_request
"""
re_pattern = re.compile(self.prompt_regex_string)
if re.search(self.prompt_regex_string, self.http_request):
http_request_w_prompt = re_pattern.sub(request.converted_value, self.http_request)
else:
http_request_w_prompt = self.http_request
return http_request_w_prompt

header_dict, http_body, url, http_method, http_version = self.parse_raw_http_request(http_request_w_prompt)
@limit_requests_per_minute
async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:
self._validate_request(prompt_request=prompt_request)
request = prompt_request.request_pieces[0]

http_request_w_prompt = self._inject_prompt_into_request(request)

# Make the actual HTTP request:
header_dict, http_body, url, http_method, http_version = self.parse_raw_http_request(http_request_w_prompt)

# Fix Content-Length if it is in the headers after the prompt is added in:
if "Content-Length" in header_dict:
header_dict["Content-Length"] = str(len(http_body))

http2_version = False
if http_version and "HTTP/2" in http_version:
http2_version = True

async with httpx.AsyncClient(http2=http2_version, **self.httpx_client_kwargs) as client:
if self._client is not None:
client = self._client
cleanup_client = False
else:
client = httpx.AsyncClient(http2=http2_version, **self.httpx_client_kwargs)
cleanup_client = True

try:
match http_body:
case dict():
response = await client.request(
@@ -99,14 +136,16 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P
content=http_body,
follow_redirects=True,
)
response_content = response.content

if self.callback_function:
response_content = self.callback_function(response=response)
response_content = response.content

response_entry = construct_response_from_request(request=request, response_text_pieces=[str(response_content)])
if self.callback_function:
response_content = self.callback_function(response=response)

return response_entry
return construct_response_from_request(request=request, response_text_pieces=[str(response_content)])
finally:
if cleanup_client:
await client.aclose()

def parse_raw_http_request(self, http_request: str) -> tuple[dict[str, str], RequestBody, str, str, str]:
"""
42 changes: 42 additions & 0 deletions tests/unit/target/test_http_target.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
from typing import Callable
from unittest.mock import MagicMock, patch

import httpx
import pytest

from pyrit.prompt_target.http_target.http_target import HTTPTarget
@@ -171,3 +172,44 @@ async def test_send_prompt_async_keeps_original_template(mock_request, mock_http
content='{"prompt": "second_test_prompt"}',
follow_redirects=True,
)


@pytest.mark.asyncio
async def test_http_target_with_injected_client():
custom_client = httpx.AsyncClient(timeout=30.0, verify=False, headers={"X-Custom-Header": "test_value"})

sample_request = (
'POST / HTTP/1.1\nHost: example.com\nContent-Type: application/json\n\n{"prompt": "{PLACEHOLDER_PROMPT}"}'
)

target = HTTPTarget.with_client(
client=custom_client,
http_request=sample_request,
prompt_regex_string="{PLACEHOLDER_PROMPT}",
callback_function=get_http_target_json_response_callback_function(key="mock_key"),
)

assert target._client is custom_client

with patch.object(custom_client, "request") as mock_request:
mock_response = MagicMock()
mock_response.content = b'{"mock_key": "test_value"}'
mock_request.return_value = mock_response

prompt_request = MagicMock()
prompt_request.request_pieces = [MagicMock(converted_value="test_prompt")]

response = await target.send_prompt_async(prompt_request=prompt_request)

assert response.get_value() == "test_value"

mock_request.assert_called_once_with(
method="POST",
url="https://example.com/",
headers={"Host": "example.com", "Content-Type": "application/json", "X-Custom-Header": "test_value"},
content='{"prompt": "test_prompt"}',
follow_redirects=True,
)

assert not custom_client.is_closed, "Client must not be closed after sending a prompt"
await custom_client.aclose()