-
Couldn't load subscription status.
- Fork 545
bugfix: fix failed unittest on spark (sm_121) #1951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses and resolves three specific unittest failures encountered on Spark environments (sm_121). It implements targeted adjustments to test logic, introduces conditional test skipping based on available hardware resources, and temporarily marks one test as an expected failure due to a known numerical accuracy issue. The overall goal is to enhance test suite stability and reliability on Spark without compromising the integrity of the tests. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdd CUDA-runtime guards and validation around green-context splitting to detect insufficient SMs and resource errors, raise clearer RuntimeError messages, and have tests skip on those conditions; add an xfail for SM 12.1 in a JIT test and tighten a softmax numeric assertion to include rtol. Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test function
participant GreenSplit as flashinfer.green_ctx
participant DeviceQuery as runtime/device query
participant PyTest as pytest
Test->>GreenSplit: call split_device_green_ctx* (groups/min_count or sm_count)
GreenSplit->>DeviceQuery: query device SMs / resource info
DeviceQuery-->>GreenSplit: available_sms
alt runtime error OR required_sms > available_sms
GreenSplit-->>Test: raise RuntimeError("insufficient SMs / resource config …")
Test->>PyTest: catch RuntimeError -> pytest.skip(message with device SMs & params)
else
GreenSplit-->>Test: return split contexts
Test->>Test: run kernels and assertions
end
Note right of Test: Separate flow — test_jit_example queries compute capability\nand marks xfail for SM 12.1
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses three failing unit tests on Spark (sm_121) by adding a guard for SM availability in test_green_ctx.py, marking a test as xfail in test_jit_example.py due to numerical issues, and increasing the tolerance in test_sampling.py. The changes are correct and effectively fix the described issues. I've provided a couple of suggestions for test_green_ctx.py to improve code clarity and reduce duplication.
tests/utils/test_green_ctx.py
Outdated
| total = 0 | ||
| for sm_count in sm_counts: | ||
| rounded = round_up(max(sm_count, min_sm), alignment) | ||
| total += rounded | ||
| return total |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This for-loop can be expressed more concisely using the built-in sum() function with a generator expression. This is a common Python idiom that improves readability.
| total = 0 | |
| for sm_count in sm_counts: | |
| rounded = round_up(max(sm_count, min_sm), alignment) | |
| total += rounded | |
| return total | |
| return sum(round_up(max(sm_count, min_sm), alignment) for sm_count in sm_counts) |
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
tests/utils/test_green_ctx.py (2)
20-24: Consider using built-insum()for improved readability.As noted in previous reviews, this for-loop can be expressed more concisely using the built-in
sum()function with a generator expression, which is a common Python idiom.Apply this diff to refactor:
- total = 0 - for sm_count in sm_counts: - rounded = round_up(max(sm_count, min_sm), alignment) - total += rounded - return total + return sum(round_up(max(sm_count, min_sm), alignment) for sm_count in sm_counts)
36-42: Address the pre-commit formatting failure.The pipeline indicates a formatting issue that needs to be resolved. Please run
pre-commit run --all-filesto apply the formatting changes.Additionally, as noted in previous reviews, this pre-check logic is duplicated across multiple tests. Consider either:
- Extracting it into a pytest fixture or helper function
- Moving the check into the
split_device_green_ctxAPI itself to raise an exception
🧹 Nitpick comments (1)
tests/utils/test_green_ctx.py (1)
43-45: Prefix unused variable with underscore.The
streamsvariable is unpacked but never used in this test function. Prefix it with an underscore to indicate it's intentionally unused.Apply this diff:
- streams, resources = green_ctx.split_device_green_ctx( + _streams, resources = green_ctx.split_device_green_ctx( dev, num_groups, min_count )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/utils/test_green_ctx.py(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/utils/test_green_ctx.py (2)
flashinfer/utils.py (2)
get_compute_capability(251-254)get_device_sm_count(595-596)flashinfer/green_ctx.py (2)
get_sm_count_constraint(34-44)split_device_green_ctx(126-178)
🪛 GitHub Actions: pre-commit
tests/utils/test_green_ctx.py
[error] 40-40: ruff-format: 1 file reformatted by this hook. The pre-commit hook failed; please re-run with 'pre-commit run --all-files' to apply formatting changes.
[error] 40-40: Code style formatting changed by ruff-format. Updated call should be: streams, resources = green_ctx.split_device_green_ctx(dev, num_groups, min_count).
🪛 Ruff (0.14.1)
tests/utils/test_green_ctx.py
43-43: Unpacked variable streams is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (6)
tests/utils/test_green_ctx.py (6)
5-5: LGTM!The imports are necessary for the SM calculation helpers and are correctly placed.
8-13: LGTM!The helper correctly calculates the total SM count required by rounding up the minimum count to meet alignment requirements and multiplying by the number of groups.
61-67: LGTM!The pre-check logic correctly validates SM availability before running the test.
97-103: LGTM!The pre-check correctly uses
calculate_required_sms_by_countsto validate SM availability for tests with specific SM counts.
130-136: LGTM!The pre-check correctly validates SM availability before running the kernel execution test.
165-171: LGTM!The pre-check correctly validates SM availability before running the alignment test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/utils/test_green_ctx.py (1)
39-46: Consider consistency in device object creation.Unlike
test_green_ctx_creation(line 15), this test passestorch.device(device)directly without creating adevvariable first. While both approaches work, consistent usage across all tests would improve readability.Apply this diff for consistency:
+ dev = torch.device(device) try: streams, resources = green_ctx.split_device_green_ctx( - torch.device(device), num_groups, min_count + dev, num_groups, min_count )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/green_ctx.py(3 hunks)tests/utils/test_green_ctx.py(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/green_ctx.py (1)
flashinfer/utils.py (3)
get_compute_capability(251-254)get_device_sm_count(595-596)round_up(589-591)
tests/utils/test_green_ctx.py (1)
flashinfer/green_ctx.py (2)
split_device_green_ctx(126-190)split_device_green_ctx_by_sm_count(193-281)
🪛 Ruff (0.14.1)
flashinfer/green_ctx.py
180-183: Avoid specifying long messages outside the exception class
(TRY003)
264-264: Avoid specifying long messages outside the exception class
(TRY003)
272-275: Avoid specifying long messages outside the exception class
(TRY003)
tests/utils/test_green_ctx.py
17-17: Unpacked variable streams is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (5)
tests/utils/test_green_ctx.py (1)
15-23: Good error handling pattern for insufficient SMs.The try-except block properly catches and skips tests when the device lacks sufficient SMs, which addresses the spark (sm_121) test failures mentioned in the PR objectives.
flashinfer/green_ctx.py (4)
31-31: LGTM! Required import for SM count validation.The
get_device_sm_countimport is correctly added and used in both validation checks (lines 177 and 269).
173-184: Excellent early validation for SM availability.The pre-check correctly computes the required SMs and fails fast before any CUDA operations, providing a clear error message that aligns with the test expectations.
261-261: Good optimization: constraint calculation moved outside loop.Moving
get_sm_count_constraintoutside the loop avoids redundant calls, as the constraints don't change between iterations.
267-276: Proper SM validation with informative error message.The validation correctly sums the rounded SM counts and raises a clear error if insufficient. The error message helpfully includes the actual
rounded_sm_countslist to aid debugging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can confirm that test_jit_example.py now passes or xfails.
test_green_ctx.py still has 7 failures:
================================================================================================================================================= short test summary info =================================================================================================================================================
FAILED tests/utils/test_green_ctx.py::test_green_ctx_creation[16-3-cuda:0] - RuntimeError: CUDA error code=914(b'CUDA_ERROR_INVALID_RESOURCE_TYPE')
FAILED tests/utils/test_green_ctx.py::test_green_ctx_kernel_execution[16-3-cuda:0] - RuntimeError: CUDA error code=914(b'CUDA_ERROR_INVALID_RESOURCE_TYPE')
FAILED tests/utils/test_green_ctx.py::test_split_device_green_ctx_by_sm_count_creation[sm_counts0-cuda:0] - RuntimeError: CUDA error code=914(b'CUDA_ERROR_INVALID_RESOURCE_TYPE')
FAILED tests/utils/test_green_ctx.py::test_split_device_green_ctx_by_sm_count_creation[sm_counts1-cuda:0] - RuntimeError: CUDA error code=914(b'CUDA_ERROR_INVALID_RESOURCE_TYPE')
FAILED tests/utils/test_green_ctx.py::test_split_device_green_ctx_by_sm_count_kernel_execution[sm_counts0-cuda:0] - RuntimeError: CUDA error code=914(b'CUDA_ERROR_INVALID_RESOURCE_TYPE')
FAILED tests/utils/test_green_ctx.py::test_split_device_green_ctx_by_sm_count_kernel_execution[sm_counts1-cuda:0] - RuntimeError: CUDA error code=914(b'CUDA_ERROR_INVALID_RESOURCE_TYPE')
FAILED tests/utils/test_green_ctx.py::test_split_device_green_ctx_by_sm_count_alignment[sm_counts1-cuda:0] - RuntimeError: CUDA error code=914(b'CUDA_ERROR_INVALID_RESOURCE_TYPE')
=================================================================================================================================== 7 failed, 10 passed, 5 skipped, 1 warning in 0.91s ====================================================================================================================================
Please see my other comment for test_sampling.py. There might be nans happening from the kernel, at least in my local env
| probs_ref = torch.softmax(logits_scaled, dim=-1) | ||
|
|
||
| assert torch.allclose(probs, probs_ref, atol=1e-5) | ||
| assert torch.allclose(probs, probs_ref, rtol=1e-5, atol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cannot seem to repro the fix in Spark. It also seems like allclose has a default rtol=1e-5 so this may not even effectively make any change.
In fact in my local env (cu130 container), when I change the tolerance and inject print statements as
probs_ref = torch.softmax(logits_scaled, dim=-1)
print(f"{torch.isnan(probs).sum().item() = }")
print(f"{torch.isnan(probs_ref).sum().item() =}")
assert torch.allclose(probs, probs_ref, rtol=100, atol=100)
I am seeing nans.
(py312) root@c661e6d696f6:/flashinfer# pytest tests/utils/test_sampling.py -x -s
=================================================================================================================================================== test session starts ===================================================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 900 items
tests/utils/test_sampling.py torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
.torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
.torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
.torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
.torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
.torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
.torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
.torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
torch.isnan(probs).sum().item() = 0
torch.isnan(probs_ref).sum().item() =0
.torch.isnan(probs).sum().item() = 4873728
torch.isnan(probs_ref).sum().item() =0
F
======================================================================================================================================================== FAILURES =========================================================================================================================================================
____________________________________________________________________________________________________________________________ test_softmax[True-True-1.0-normal_distribution(std=1)-128256-989] ____________________________________________________________________________________________________________________________
...
> assert torch.allclose(probs, probs_ref, rtol=100, atol=100)
E AssertionError: assert False
E + where False = <built-in method allclose of type object at 0x16bc850>(tensor([[0.0000e+00, 7.8481e-05, 0.0000e+00, ..., 9.0452e-06, 8.5036e-06,\n 0.0000e+00],\n [2.4505e-05, ...05],\n [0.0000e+00, 0.0000e+00, 7.0366e-06, ..., 0.0000e+00, 7.1824e-06,\n 2.0367e-06]], device='cuda:0'), tensor([[0.0000e+00, 7.8481e-05, 0.0000e+00, ..., 9.0452e-06, 8.5036e-06,\n 0.0000e+00],\n [2.4505e-05, ...05],\n [0.0000e+00, 0.0000e+00, 7.0366e-06, ..., 0.0000e+00, 7.1824e-06,\n 2.0367e-06]], device='cuda:0'), rtol=100, atol=100)
E + where <built-in method allclose of type object at 0x16bc850> = torch.allclose
tests/utils/test_sampling.py:76: AssertionError
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tests/utils/test_green_ctx.py (1)
24-32: Extract duplicated error handling to a helper function.The CUDA error checking pattern is duplicated across all 5 test functions. This was previously flagged in review comments and should be addressed to improve maintainability.
Extract the error handling to a helper function:
def _skip_if_insufficient_sms(e: RuntimeError, device: str, context_msg: str): """Helper to skip tests when device has insufficient SMs for green context splitting.""" if "CUDA error code=914" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) or \ "CUDA error code=915" in str(e) or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e): cu_dev = green_ctx.get_cudevice(torch.device(device)) device_resource = green_ctx.get_device_resource(cu_dev) total_sms = device_resource.sm.smCount pytest.skip(f"Insufficient SMs on device. Total SMs available: {total_sms}. {context_msg}") raiseThen simplify each test's except block to:
except RuntimeError as e: _skip_if_insufficient_sms(e, device, f"requested: num_groups={num_groups}, min_count={min_count}")Based on learnings
Also applies to: 57-65, 94-102, 132-140, 170-178
🧹 Nitpick comments (1)
tests/utils/test_green_ctx.py (1)
15-18: Prefix unused variable with underscore.The
streamsvariable is unpacked but never used in this test. Prefix it with_to indicate it's intentionally unused.Apply this diff:
try: - streams, resources = green_ctx.split_device_green_ctx( + _streams, resources = green_ctx.split_device_green_ctx( torch.device(device), num_groups, min_count )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/green_ctx.py(2 hunks)tests/utils/test_green_ctx.py(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/green_ctx.py (2)
flashinfer/utils.py (2)
get_compute_capability(251-254)round_up(589-591)flashinfer/comm/mnnvl.py (1)
round_up(55-57)
tests/utils/test_green_ctx.py (2)
flashinfer/green_ctx.py (5)
split_device_green_ctx(126-189)get_cudevice(47-53)get_device_resource(56-61)split_device_green_ctx_by_sm_count(192-283)get_sm_count_constraint(34-44)flashinfer/utils.py (1)
get_compute_capability(251-254)
🪛 GitHub Actions: pre-commit
flashinfer/green_ctx.py
[error] 1-1: pre-commit: ruff-format reformatted 2 files. Review and commit changes. Command 'pre-commit run --all-files' reported modifications by hook 'ruff-format'.
tests/utils/test_green_ctx.py
[error] 1-1: pre-commit: ruff-format reformatted 2 files. Review and commit changes. Command 'pre-commit run --all-files' reported modifications by hook 'ruff-format'.
🪛 Ruff (0.14.1)
flashinfer/green_ctx.py
177-177: Consider [*results, remaining] instead of concatenation
Replace with [*results, remaining]
(RUF005)
179-179: Consider moving this statement to an else block
(TRY300)
183-188: Avoid specifying long messages outside the exception class
(TRY003)
266-266: Avoid specifying long messages outside the exception class
(TRY003)
271-271: Consider [*results, remaining] instead of concatenation
Replace with [*results, remaining]
(RUF005)
273-273: Consider moving this statement to an else block
(TRY300)
277-282: Avoid specifying long messages outside the exception class
(TRY003)
tests/utils/test_green_ctx.py
16-16: Unpacked variable streams is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🔇 Additional comments (1)
flashinfer/green_ctx.py (1)
255-283: LGTM! Consistent error handling with helpful validation.The implementation correctly validates input SM counts and provides descriptive error messages for CUDA resource failures. The pattern is consistent with
split_device_green_ctx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
flashinfer/green_ctx.py (3)
64-78: Type annotation is incorrect;resultsis a list.
split_resourcereturns a list ofCUdevResourceand a single remainingCUdevResource. Adjust the return type to avoid misleading type checkers.-def split_resource( +def split_resource( resource: CUdevResource, num_groups: int, min_count: int, -) -> Tuple[CUdevResource, CUdevResource]: +) -> Tuple[List[CUdevResource], CUdevResource]:
103-106: Parameter type should beCUdevice, notCUdevResource.
create_green_ctx_streamsreceivescu_devfromget_cudevice(aCUdevice) and passes it tocuGreenCtxCreate. Fix the annotation.-def create_green_ctx_streams( - cu_dev: CUdevResource, resources: List[CUdevResource] +def create_green_ctx_streams( + cu_dev: CUdevice, resources: List[CUdevResource] ) -> List[torch.Stream]:
80-101: Green-context handle leak confirmed in two functions; refactor to eliminate unnecessary context creation.The review is correct. The codebase creates green contexts but never destroys them—no
cuGreenCtxDestroycalls exist anywhere. Two functions are affected:
split_resource_by_sm_count()(lines 80–100): Creates a green context solely to extract a resource already returned bysplit_resource(). The "refresh" operation is unnecessary; the proposed fix (useremainingdirectly) is valid and eliminates the leak for this function.
create_green_ctx_streams()(lines 103–123): Creates green contexts in a loop to generate streams, but never stores or destroys the contexts. They go out of scope immediately after stream extraction, creating a handle leak.The proposed fix for
split_resource_by_sm_count()is sound:- result, remaining = split_resource(resource, 1, sm_count) - results.extend(result) - # Refresh the remaining resource for the next iteration - desc = checkCudaErrors(driver.cuDevResourceGenerateDesc([remaining], 1)) - green_ctx = checkCudaErrors( - driver.cuGreenCtxCreate( - desc, cu_dev, driver.CUgreenCtxCreate_flags.CU_GREEN_CTX_DEFAULT_STREAM - ) - ) - resource = checkCudaErrors( - driver.cuGreenCtxGetDevResource( - green_ctx, driver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM - ) - ) + result, remaining = split_resource(resource, 1, sm_count) + results.extend(result) + resource = remainingAdditionally, review
create_green_ctx_streams()to determine whether green contexts must remain alive for stream validity. If yes, contexts must be retained and properly destroyed; if no, context creation can be eliminated.
♻️ Duplicate comments (1)
tests/utils/test_green_ctx.py (1)
25-38: Deduplicate skip logic via a fixture/helper.The same RuntimeError substring checks + SM-count fetch/skip are repeated across tests. Extract once (fixture/helper) to improve maintainability and keep messages consistent. This was raised earlier; repeating here for the new blocks.
Example fixture:
# conftest.py import pytest import flashinfer.green_ctx as green_ctx import torch CUDA_RES_ERR = ( "CUDA error code=914", "CUDA_ERROR_INVALID_RESOURCE_TYPE", "CUDA error code=915", "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION", ) def skip_if_insufficient_sms(device: str, err: Exception, extra: str) -> None: s = str(err) if any(sig in s for sig in CUDA_RES_ERR): cu_dev = green_ctx.get_cudevice(torch.device(device)) total_sms = green_ctx.get_device_resource(cu_dev).sm.smCount pytest.skip(f"Insufficient SMs ({total_sms}). {extra}") raise errThen in tests:
try: ... except RuntimeError as e: skip_if_insufficient_sms(device, e, f"requested: num_groups={num_groups}, min_count={min_count}")Also applies to: 64-77, 107-120, 151-164, 195-208
🧹 Nitpick comments (5)
tests/utils/test_green_ctx.py (2)
57-63: Remove prints from tests; assert instead.
print(...)adds noisy logs. Prefer simple assertions on shape to keep CI output clean.- print(z.shape) + assert z.shape == (8192, 8192)- print(f"Partition {i}: {z.shape}") + assert z.shape == (4096, 4096)Optional: consider smaller matrices (e.g., 2048 or parametrize) to reduce CI time on small GPUs.
Also applies to: 144-150
180-194: Micro: avoid repeated device construction.Compute
dev = torch.device(device)once and reuse; minor readability and overhead win.- _, resources = green_ctx.split_device_green_ctx_by_sm_count( - torch.device(device), sm_counts - ) + dev = torch.device(device) + _, resources = green_ctx.split_device_green_ctx_by_sm_count(dev, sm_counts) ... - min_sm_count, sm_alignment = green_ctx.get_sm_count_constraint( - *green_ctx.get_compute_capability(torch.device(device)) - ) + min_sm_count, sm_alignment = green_ctx.get_sm_count_constraint( + *green_ctx.get_compute_capability(dev) + )flashinfer/green_ctx.py (3)
173-193: Style and lints: list-unpack concat;try/else; centralize error checks.
- Use list-unpack for concat (RUF005).
- Move
returntoelseoftry(TRY300).- Optional: centralize error signature checks to a helper constant.
try: cu_dev = get_cudevice(dev) resource = get_device_resource(cu_dev) results, remaining = split_resource(resource, num_groups, min_count) - resources = results + [remaining] + resources = [*results, remaining] streams = create_green_ctx_streams(cu_dev, resources) - return streams, resources except RuntimeError as e: - if ( - "CUDA error code=914" in str(e) - or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) - or "CUDA error code=915" in str(e) - or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) - ): + if any(sig in str(e) for sig in ( + "CUDA error code=914", + "CUDA_ERROR_INVALID_RESOURCE_TYPE", + "CUDA error code=915", + "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION", + )): raise RuntimeError( f"{e}\n" f"Failed to split device into {num_groups} groups with min_count={min_count}. " f"This is likely due to insufficient number of SMs available on the device. " f"Please reduce the number of groups or the minimum SM count per group." ) from e raise + else: + return streams, resources
259-295: Hoist constraints; empty-input check; style/lints parity with above.
- Compute
(min_sm_count, sm_alignment)once per device.- Validate
sm_countsis non-empty (docstring promises ValueError).- Apply list-unpack concat and
try/else.try: cu_dev = get_cudevice(dev) resource = get_device_resource(cu_dev) - # Round sm counts to meet the alignment and granularity requirements - rounded_sm_counts = [] - for sm_count in sm_counts: - min_sm_count, sm_alignment = get_sm_count_constraint( - *get_compute_capability(dev) - ) - if sm_count <= 0: - raise ValueError(f"SM count must be positive, got {sm_count}") - rounded_sm_counts.append( - round_up(max(sm_count, min_sm_count), sm_alignment) - ) + # Round sm counts to meet the alignment and granularity requirements + if not sm_counts: + raise ValueError("sm_counts must be non-empty") + min_sm_count, sm_alignment = get_sm_count_constraint( + *get_compute_capability(dev) + ) + rounded_sm_counts = [] + for sm_count in sm_counts: + if sm_count <= 0: + raise ValueError(f"SM count must be positive, got {sm_count}") + rounded_sm_counts.append(round_up(max(sm_count, min_sm_count), sm_alignment)) # Split the device into multiple green contexts results, remaining = split_resource_by_sm_count( cu_dev, resource, rounded_sm_counts ) - resources = results + [remaining] + resources = [*results, remaining] streams = create_green_ctx_streams(cu_dev, resources) - return streams, resources except RuntimeError as e: - if ( - "CUDA error code=914" in str(e) - or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) - or "CUDA error code=915" in str(e) - or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) - ): + if any(sig in str(e) for sig in ( + "CUDA error code=914", + "CUDA_ERROR_INVALID_RESOURCE_TYPE", + "CUDA error code=915", + "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION", + )): raise RuntimeError( f"{e}\n" f"Failed to split device with SM counts {sm_counts} (rounded to {rounded_sm_counts}). " f"This is likely due to insufficient number of SMs available on the device. " f"Please reduce the requested SM counts or use fewer partitions." ) from e raise + else: + return streams, resources
187-193: Optional: avoid long message construction inexcept(TRY003).Consider defining a small custom exception (e.g.,
SMAllocationError) or assembling the message inside the exception class/__str__to satisfy linters and keep handlers catching by type, not by substrings.Also applies to: 289-294
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/green_ctx.py(2 hunks)tests/utils/test_green_ctx.py(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/utils/test_green_ctx.py (2)
flashinfer/green_ctx.py (5)
split_device_green_ctx(126-193)get_cudevice(47-53)get_device_resource(56-61)split_device_green_ctx_by_sm_count(196-295)get_sm_count_constraint(34-44)flashinfer/utils.py (1)
get_compute_capability(251-254)
flashinfer/green_ctx.py (1)
flashinfer/utils.py (2)
get_compute_capability(251-254)round_up(589-591)
🪛 Ruff (0.14.2)
tests/utils/test_green_ctx.py
16-16: Unpacked variable streams is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
flashinfer/green_ctx.py
177-177: Consider [*results, remaining] instead of concatenation
Replace with [*results, remaining]
(RUF005)
179-179: Consider moving this statement to an else block
(TRY300)
187-192: Avoid specifying long messages outside the exception class
(TRY003)
270-270: Avoid specifying long messages outside the exception class
(TRY003)
279-279: Consider [*results, remaining] instead of concatenation
Replace with [*results, remaining]
(RUF005)
281-281: Consider moving this statement to an else block
(TRY300)
289-294: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
| try: | ||
| streams, resources = green_ctx.split_device_green_ctx( | ||
| torch.device(device), num_groups, min_count | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix unused variable per Ruff (RUF059).
streams is not used in this test. Use _ to silence the warning.
- streams, resources = green_ctx.split_device_green_ctx(
+ _, resources = green_ctx.split_device_green_ctx(
torch.device(device), num_groups, min_count
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| try: | |
| streams, resources = green_ctx.split_device_green_ctx( | |
| torch.device(device), num_groups, min_count | |
| ) | |
| try: | |
| _, resources = green_ctx.split_device_green_ctx( | |
| torch.device(device), num_groups, min_count | |
| ) |
🧰 Tools
🪛 Ruff (0.14.2)
16-16: Unpacked variable streams is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
In tests/utils/test_green_ctx.py around lines 15 to 18, the variable `streams`
from the tuple assignment is unused and triggers Ruff RUF059; change the
unpacking to use a throwaway name (e.g., `_, resources =
green_ctx.split_device_green_ctx(torch.device(device), num_groups, min_count)`)
so the test retains the same behavior while silencing the unused-variable
warning.
📌 Description
There are three failed unittests on spark (sm_121):
First one is because spark has small number of SMs (48) and we don't have a guard on green context splitting.
Second one is an unknown issue (logits don't match with reference) and probably related to barriers on sm_121, xfail now and will fix later.
The last one is because of the reduction size difference, and we should increase tolerance (by adding a rtol).
This PR fixes these issues.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Tests
Bug Fixes