Skip to content

Shem fix post body with bytes in batch #711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions src/msgraph_core/requests/batch_request_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import TypeVar, Type, Dict, Optional, Union
import logging
import json
import base64

from kiota_abstractions.request_adapter import RequestAdapter
from kiota_abstractions.request_information import RequestInformation
Expand Down Expand Up @@ -59,12 +61,9 @@ async def post(
response_type = BatchResponseContent

if isinstance(batch_request_content, BatchRequestContent):
print(f"Batch request content: {batch_request_content.requests}")
request_info = await self.to_post_request_information(batch_request_content)
bytes_content = request_info.content
json_content = bytes_content.decode("utf-8")
updated_str = '{"requests":' + json_content + '}'
updated_bytes = updated_str.encode("utf-8")
request_info.content = updated_bytes
request_info.content = self._prepare_request_content(request_info.content)
error_map = error_map or self.error_map
response = None
try:
Expand Down Expand Up @@ -107,14 +106,56 @@ async def _post_batch_collection(
batch_responses = BatchResponseContentCollection()

for batch_request_content in batch_request_content_collection.batches:
print(f"Batch request content: {batch_request_content.requests}")

request_info = await self.to_post_request_information(batch_request_content)
print(f"content before processing {request_info.content}")
updated_bytes = self._prepare_request_content(request_info.content)
request_info.content = updated_bytes
response = await self._request_adapter.send_async(
request_info, BatchResponseContent, error_map or self.error_map
)
batch_responses.add_response(response)

return batch_responses

def _prepare_request_content(self, content: bytes) -> bytes:
"""
Prepares the request content by updating the JSON structure and converting
the 'body' field from string to a dictionary if necessary.

Args:
content (bytes): The original request content.

Returns:
bytes: The updated request content.
"""
json_content = content.decode("utf-8")
print(json_content)
requests_list = json.loads(json_content)
for request in requests_list:
if 'body' in request:
if isinstance(request['body'], dict):
pass
elif isinstance(request['body'], str):
try:
request['body'] = json.loads(request['body'])
except json.JSONDecodeError:
pass
elif isinstance(request['body'], bytes):
request['body'] = base64.b64encode(request['body']).decode('utf-8')

if isinstance(request['body'], dict):
request['headers'] = {"Content-Type": "application/json"}
else:
request['headers'] = {"Content-Type": "application/octet-stream"}
else:
request['headers'] = {"Content-Type": "application/json"}

updated_json_content = json.dumps({"requests": requests_list})
return updated_json_content.encode("utf-8")
# return json.dumps(requests_list).encode("utf-8")

async def to_post_request_information(
self, batch_request_content: BatchRequestContent
) -> RequestInformation:
Expand All @@ -131,6 +172,7 @@ async def to_post_request_information(
if batch_request_content is None:
raise ValueError("batch_request_content cannot be Null.")
batch_request_items = list(batch_request_content.requests.values())
print(f"Batch request items: {batch_request_items}")

request_info = RequestInformation()
request_info.http_method = Method.POST
Expand Down
8 changes: 6 additions & 2 deletions src/msgraph_core/requests/batch_request_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def add_request(self, request_id: Optional[str], request: BatchRequestItem) -> N
request.id = str(uuid.uuid4())
if hasattr(request, 'depends_on') and request.depends_on:
for dependent_id in request.depends_on:
if dependent_id not in [req.id for req in self.requests]:
if dependent_id not in self.requests:
dependent_request = self._request_by_id(dependent_id)
if dependent_request:
self._requests[dependent_id] = dependent_request
Expand Down Expand Up @@ -137,4 +137,8 @@ def serialize(self, writer: SerializationWriter) -> None:
Args:
writer: Serialization writer to use to serialize this model
"""
writer.write_collection_of_object_values("requests", self.requests)
if not writer:
raise ValueError("writer cannot be None")
writer.write_collection_of_object_values({"requests", list(self.requests.values())})
# requests_dict = {request_id: request for request_id, request in self.requests.items()}
# writer.write_object_value("requests", requests_dict)
25 changes: 18 additions & 7 deletions src/msgraph_core/requests/batch_request_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import List, Optional, Dict, Union, Any
from io import BytesIO
import base64
import logging

import urllib.request
from urllib.parse import urlparse

Expand Down Expand Up @@ -238,16 +240,25 @@ def serialize(self, writer: SerializationWriter) -> None:
Args:
writer (SerializationWriter): The writer to write to.
"""
if not writer:
raise ValueError("writer cannot be None")

writer.write_str_value('id', self.id)
writer.write_str_value('method', self.method)
writer.write_str_value('url', self.url)

writer.write_collection_of_primitive_values('depends_on', self._depends_on)
headers = {key: ", ".join(val) for key, val in self._headers.items()}

headers = self._headers
writer.write_collection_of_object_values('headers', headers)

if self._body:
json_object = json.loads(self._body)
is_json_string = json_object and isinstance(json_object, dict)
writer.write_collection_of_object_values(
'body',
json_object if is_json_string else base64.b64encode(self._body).decode('utf-8')
)
if isinstance(self._body, bytes):
body_content = base64.b64encode(self._body).decode('utf-8')
elif isinstance(self._body, str):
body_content = self._body
else:
raise ValueError("Unsupported body type")
writer.write_str_value('body', body_content)
else:
logging.info("Content info: there is no body to serialize")
4 changes: 3 additions & 1 deletion tests/requests/test_batch_request_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def test_get_field_deserializers(batch_request_content):

def test_serialize(batch_request_content):
writer = Mock(spec=SerializationWriter)

batch_request_content.serialize(writer)

writer.write_collection_of_object_values.assert_called_once_with(
"requests", batch_request_content.requests
"requests", list(batch_request_content.requests.values())
)
75 changes: 75 additions & 0 deletions tests/requests/test_batch_request_item.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest
from unittest.mock import Mock
import base64
import json

from urllib.request import Request
from kiota_abstractions.request_information import RequestInformation
from kiota_abstractions.method import Method
Expand All @@ -25,6 +28,60 @@ def batch_request_item(request_info):
return BatchRequestItem(request_information=request_info)


@pytest.fixture
def request_info_json():
request_info = RequestInformation()
request_info.http_method = "POST"
request_info.url = "https://graph.microsoft.com/v1.0/me/events"
request_info.headers = RequestHeaders()
request_info.headers.add("Content-Type", "application/json")
request_info.content = json.dumps(
{
"@odata.type": "#microsoft.graph.event",
"end": {
"dateTime": "2024-10-14T17:30:00",
"timeZone": "Pacific Standard Time"
},
"start": {
"dateTime": "2024-10-14T17:00:00",
"timeZone": "Pacific Standard Time"
},
"subject": "File end-of-day report"
}
).encode('utf-8')
return request_info


@pytest.fixture
def request_info_bytes():
request_info = RequestInformation()
request_info.http_method = "POST"
request_info.url = "https://graph.microsoft.com/v1.0/me/events"
request_info.headers = RequestHeaders()
request_info.headers.add("Content-Type", "application/json")
request_info.content = b'{"@odata.type": "#microsoft.graph.event", "end": {"dateTime": "2024-10-14T17:30:00", "timeZone": "Pacific Standard Time"}, "start": {"dateTime": "2024-10-14T17:00:00", "timeZone": "Pacific Standard Time"}, "subject": "File end-of-day report"}'
return request_info


@pytest.fixture
def batch_request_item_json(request_info_json):
return BatchRequestItem(request_information=request_info_json)


@pytest.fixture
def batch_request_item_bytes(request_info_bytes):
return BatchRequestItem(request_information=request_info_bytes)


def encode_body_to_base64(body):
if isinstance(body, bytes):
return base64.b64encode(body).decode('utf-8')
elif isinstance(body, str):
return base64.b64encode(body.encode('utf-8')).decode('utf-8')
else:
raise ValueError("Unsupported body type")


def test_initialization(batch_request_item, request_info):
assert batch_request_item.method == "GET"
assert batch_request_item.url == "f{base_url}/me"
Expand Down Expand Up @@ -124,3 +181,21 @@ def test_batch_request_item_method_enum():
def test_depends_on_property(batch_request_item):
batch_request_item.set_depends_on(["request1", "request2"])
assert batch_request_item.depends_on == ["request1", "request2"]


def test_serialize_with_json_body(batch_request_item_json):
item = batch_request_item_json
writer = Mock()
processed_body = encode_body_to_base64(item.body)

item.serialize(writer)
writer.write_str_value.assert_called_with('body', processed_body)


def test_serialize_with_bytes_body(batch_request_item_bytes):
item = batch_request_item_bytes
writer = Mock()
processed_body = encode_body_to_base64(item.body)

item.serialize(writer)
writer.write_str_value.assert_called_with('body', processed_body)
Loading