-
Notifications
You must be signed in to change notification settings - Fork 1.7k
ci: Add support for max_inflight_responses
parameter to prevent unbounded memory growth in ensemble models
#8458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
pskiran1
wants to merge
18
commits into
main
Choose a base branch
from
spolisetty/tri-26-triton-dali-ensemble-model-memory-issue
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+652
−7
Open
Changes from 3 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
240a223
Update
pskiran1 974aa25
Update
pskiran1 337e0a7
Update
pskiran1 05dcb71
Update
pskiran1 4f379ed
Fix pre-commit
pskiran1 9ed216f
Fix pre-commit errors
pskiran1 78698fc
Update
pskiran1 8665a0d
Update
pskiran1 0258eda
Update
pskiran1 81561fd
Remove duplicate code and add request cancellation test
pskiran1 10dacec
Fix pre-commit
pskiran1 e2e48a3
Fix pre-commit
pskiran1 f8f1468
Update
pskiran1 3d8b848
Update
pskiran1 4a1a8fe
Improve model preparation
pskiran1 554e1b9
Update tests
pskiran1 b2ad735
Add documentation
pskiran1 977420a
Update copyright
pskiran1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions | ||
# are met: | ||
# * Redistributions of source code must retain the above copyright | ||
# notice, this list of conditions and the following disclaimer. | ||
# * Redistributions in binary form must reproduce the above copyright | ||
# notice, this list of conditions and the following disclaimer in the | ||
# documentation and/or other materials provided with the distribution. | ||
# * Neither the name of NVIDIA CORPORATION nor the names of its | ||
# contributors may be used to endorse or promote products derived | ||
# from this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
||
import sys | ||
|
||
sys.path.append("../common") | ||
|
||
import queue | ||
import unittest | ||
from functools import partial | ||
import numpy as np | ||
import test_util as tu | ||
import tritonclient.grpc as grpcclient | ||
|
||
|
||
SERVER_URL = "localhost:8001" | ||
DEFAULT_RESPONSE_TIMEOUT = 60 | ||
|
||
|
||
class UserData: | ||
def __init__(self): | ||
self._response_queue = queue.Queue() | ||
|
||
|
||
def callback(user_data, result, error): | ||
if error: | ||
user_data._response_queue.put(error) | ||
else: | ||
user_data._response_queue.put(result) | ||
|
||
|
||
class EnsembleBackpressureTest(tu.TestResultCollector): | ||
""" | ||
Tests for ensemble backpressure feature (max_ensemble_inflight_responses). | ||
""" | ||
|
||
def _prepare_infer_args(self, input_value): | ||
""" | ||
Create InferInput/InferRequestedOutput lists | ||
""" | ||
input_data = np.array([input_value], dtype=np.int32) | ||
infer_input = [grpcclient.InferInput("IN", input_data.shape, "INT32")] | ||
infer_input[0].set_data_from_numpy(input_data) | ||
outputs = [grpcclient.InferRequestedOutput("OUT")] | ||
return infer_input, outputs | ||
|
||
def _collect_responses(self, user_data): | ||
""" | ||
Collect responses from user_data until the final response flag is seen. | ||
""" | ||
responses = [] | ||
while True: | ||
try: | ||
result = user_data._response_queue.get(timeout=DEFAULT_RESPONSE_TIMEOUT) | ||
except queue.Empty: | ||
self.fail( | ||
f"No response received within {DEFAULT_RESPONSE_TIMEOUT} seconds." | ||
) | ||
|
||
self.assertNotIsInstance( | ||
result, Exception, f"Callback returned an exception: {result}" | ||
) | ||
|
||
# Add response to list if it has data (not empty final-only response) | ||
response = result.get_response() | ||
if len(response.outputs) > 0: | ||
responses.append(result) | ||
|
||
# Check if this is the final response | ||
final = response.parameters.get("triton_final_response") | ||
if final and final.bool_param: | ||
break | ||
|
||
return responses | ||
|
||
def test_backpressure_limits_inflight(self): | ||
""" | ||
Test that max_ensemble_inflight_responses correctly limits concurrent | ||
responses and prevents unbounded memory growth. | ||
""" | ||
model_name = "ensemble_enabled_max_inflight_responses" | ||
expected_count = 16 | ||
user_data = UserData() | ||
|
||
with grpcclient.InferenceServerClient(SERVER_URL) as triton_client: | ||
try: | ||
inputs, outputs = self._prepare_infer_args(expected_count) | ||
|
||
triton_client.start_stream(callback=partial(callback, user_data)) | ||
|
||
triton_client.async_stream_infer( | ||
model_name=model_name, inputs=inputs, outputs=outputs | ||
) | ||
|
||
# Collect responses | ||
responses = self._collect_responses(user_data) | ||
|
||
# Verify we received the expected number of responses | ||
self.assertEqual( | ||
len(responses), | ||
expected_count, | ||
f"Expected {expected_count} responses, got {len(responses)}", | ||
) | ||
|
||
# Verify correctness of responses | ||
for idx, resp in enumerate(responses): | ||
output = resp.as_numpy("OUT") | ||
self.assertEqual( | ||
output[0], idx, f"Response {idx} has incorrect value" | ||
) | ||
|
||
finally: | ||
triton_client.stop_stream() | ||
|
||
def test_backpressure_disabled(self): | ||
""" | ||
Test that ensemble model without max_ensemble_inflight_responses parameter | ||
works fine (original behavior). | ||
""" | ||
model_name = "ensemble_disabled_max_inflight_responses" | ||
expected_count = 16 | ||
user_data = UserData() | ||
|
||
with grpcclient.InferenceServerClient(SERVER_URL) as triton_client: | ||
try: | ||
inputs, outputs = self._prepare_infer_args(expected_count) | ||
|
||
triton_client.start_stream(callback=partial(callback, user_data)) | ||
|
||
triton_client.async_stream_infer( | ||
model_name=model_name, inputs=inputs, outputs=outputs | ||
) | ||
|
||
# Collect responses | ||
responses = self._collect_responses(user_data) | ||
|
||
# Verify we received the expected number of responses | ||
self.assertEqual( | ||
len(responses), | ||
expected_count, | ||
f"Expected {expected_count} responses, got {len(responses)}", | ||
) | ||
|
||
# Verify correctness of responses | ||
for idx, resp in enumerate(responses): | ||
output = resp.as_numpy("OUT") | ||
self.assertEqual( | ||
output[0], idx, f"Response {idx} has incorrect value" | ||
) | ||
|
||
finally: | ||
triton_client.stop_stream() | ||
|
||
def test_backpressure_concurrent_requests(self): | ||
""" | ||
Test that backpressure works correctly with multiple concurrent requests. | ||
Each request should have independent backpressure state. | ||
""" | ||
model_name = "ensemble_enabled_max_inflight_responses" | ||
num_concurrent = 8 | ||
expected_per_request = 8 | ||
|
||
clients = [] | ||
user_datas = [] | ||
|
||
try: | ||
inputs, outputs = self._prepare_infer_args(expected_per_request) | ||
|
||
# Create separate client for each concurrent request | ||
for i in range(num_concurrent): | ||
client = grpcclient.InferenceServerClient(SERVER_URL) | ||
user_data = UserData() | ||
|
||
client.start_stream(callback=partial(callback, user_data)) | ||
client.async_stream_infer( | ||
model_name=model_name, inputs=inputs, outputs=outputs | ||
) | ||
|
||
clients.append(client) | ||
user_datas.append(user_data) | ||
|
||
# Collect and verify responses for all requests | ||
for i, ud in enumerate(user_datas): | ||
responses = self._collect_responses(ud) | ||
self.assertEqual( | ||
len(responses), | ||
expected_per_request, | ||
f"Request {i}: expected {expected_per_request} responses, got {len(responses)}", | ||
) | ||
|
||
# Verify correctness of responses | ||
for idx, resp in enumerate(responses): | ||
output = resp.as_numpy("OUT") | ||
self.assertEqual( | ||
output[0], idx, f"Response {idx} has incorrect value" | ||
) | ||
|
||
finally: | ||
for client in clients: | ||
try: | ||
client.stop_stream() | ||
client.close() | ||
except: | ||
|
||
pass | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
54 changes: 54 additions & 0 deletions
54
qa/L0_simple_ensemble/models/decoupled_producer/1/model.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions | ||
# are met: | ||
# * Redistributions of source code must retain the above copyright | ||
# notice, this list of conditions and the following disclaimer. | ||
# * Redistributions in binary form must reproduce the above copyright | ||
# notice, this list of conditions and the following disclaimer in the | ||
# documentation and/or other materials provided with the distribution. | ||
# * Neither the name of NVIDIA CORPORATION nor the names of its | ||
# contributors may be used to endorse or promote products derived | ||
# from this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
||
|
||
import numpy as np | ||
import triton_python_backend_utils as pb_utils | ||
|
||
|
||
class TritonPythonModel: | ||
""" | ||
Decoupled model that produces N responses based on input value. | ||
""" | ||
|
||
def execute(self, requests): | ||
for request in requests: | ||
# Get input - number of responses to produce | ||
in_tensor = pb_utils.get_input_tensor_by_name(request, "IN") | ||
count = in_tensor.as_numpy()[0] | ||
|
||
response_sender = request.get_response_sender() | ||
|
||
# Produce 'count' responses | ||
for i in range(count): | ||
out_tensor = pb_utils.Tensor("OUT", np.array([i], dtype=np.int32)) | ||
response = pb_utils.InferenceResponse(output_tensors=[out_tensor]) | ||
response_sender.send(response) | ||
|
||
# Send final flag | ||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) | ||
|
||
return None |
58 changes: 58 additions & 0 deletions
58
qa/L0_simple_ensemble/models/decoupled_producer/config.pbtxt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions | ||
# are met: | ||
# * Redistributions of source code must retain the above copyright | ||
# notice, this list of conditions and the following disclaimer. | ||
# * Redistributions in binary form must reproduce the above copyright | ||
# notice, this list of conditions and the following disclaimer in the | ||
# documentation and/or other materials provided with the distribution. | ||
# * Neither the name of NVIDIA CORPORATION nor the names of its | ||
# contributors may be used to endorse or promote products derived | ||
# from this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
||
|
||
name: "decoupled_producer" | ||
backend: "python" | ||
max_batch_size: 0 | ||
|
||
input [ | ||
{ | ||
name: "IN" | ||
data_type: TYPE_INT32 | ||
dims: [ 1 ] | ||
} | ||
] | ||
|
||
output [ | ||
{ | ||
name: "OUT" | ||
data_type: TYPE_INT32 | ||
dims: [ 1 ] | ||
} | ||
] | ||
|
||
instance_group [ | ||
{ | ||
count: 1 | ||
kind: KIND_CPU | ||
} | ||
] | ||
|
||
model_transaction_policy { | ||
decoupled: true | ||
} | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.