Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cuda_bindings/cuda/bindings/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
from typing import Any, Callable

from ._nvvm_utils import check_nvvm_options
from ._ptx_utils import get_minimal_required_cuda_ver_from_ptx_ver, get_ptx_ver
from ._version_check import warn_if_cuda_major_version_mismatch

Expand Down
93 changes: 93 additions & 0 deletions cuda_bindings/cuda/bindings/utils/_nvvm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# SPDX-FileCopyrightText: Copyright (c) 2026-2027 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove -2027

# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import contextlib
from typing import Sequence

_PRECHECK_NVVM_IR = """target triple = "nvptx64-unknown-cuda"
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"

define void @dummy_kernel() {{
entry:
ret void
}}

!nvvm.annotations = !{{!0}}
!0 = !{{void ()* @dummy_kernel, !"kernel", i32 1}}

!nvvmir.version = !{{!1}}
!1 = !{{i32 {major}, i32 {minor}, i32 {debug_major}, i32 {debug_minor}}}
"""


def check_nvvm_options(options: Sequence[bytes]) -> bool:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a public Python API, bytes seems unusual here.

Most pythonic would be Sequence[str], but Sequence[str | bytes] would seem fine, too.

The implementation below actually converts bytes to str:

        options_list = [opt.decode("utf-8") if isinstance(opt, bytes) else opt for opt in options]

So requiring bytes in the API is especially surprising. My recommendation is to simply use Sequence[str]; then the options_list line can be removed completely.

I also recommend using check_nvvm_compiler_options as the function name, for clarity, especially in the call sites.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I will update this, I initially thought this to be Sequence[bytes | str] .

"""
Abstracted from https://github.com/NVIDIA/numba-cuda/pull/681

Check if the specified options are supported by the current libNVVM version.

The options are a list of bytes, each representing a compiler option.

If the test program fails to compile, the options are not supported and False
is returned.

If the test program compiles successfully, True is returned.

cuda.bindings.nvvm returns exceptions instead of return codes.

Parameters
----------
options : Sequence[bytes]
List of compiler options as bytes (e.g., [b"-arch=compute_90", b"-g"]).

Returns
-------
bool
True if the options are supported, False otherwise.

Examples
--------
>>> from cuda.bindings.utils import check_nvvm_options
>>> check_nvvm_options([b"-arch=compute_90", b"-g"])
True
>>> check_nvvm_options([b"-arch=compute_90", b"-numba-debug"])
True # if -numba-debug is supported by the installed libNVVM
Comment on lines +53 to +54
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use a non-numba example here?

"""
try:
from cuda.bindings import nvvm
from cuda.bindings._internal.nvvm import _inspect_function_pointer

if _inspect_function_pointer("__nvvmCreateProgram") == 0:
return False
except Exception:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is far too broad. It's very likely to mask bugs. Also the other except Exception further below. Could you please try this in Cursor:

Could you please narrow down the exception catching as much as possible? Please use ModuleNotFoundError for the nvvm import; we also want to check exc.name == "nvvm". If that import works, we don't want to guard the _inspect_function_pointer import, or the _inspect_function_pointer() call. If those don't work, that's a bug we want to surface.

return False

program = None
try:
program = nvvm.create_program()

major, minor, debug_major, debug_minor = nvvm.ir_version()
precheck_ir = _PRECHECK_NVVM_IR.format(
major=major,
minor=minor,
debug_major=debug_major,
debug_minor=debug_minor,
)
precheck_ir_bytes = precheck_ir.encode("utf-8")
nvvm.add_module_to_program(
program,
precheck_ir_bytes,
len(precheck_ir_bytes),
"precheck.ll",
)

options_list = [opt.decode("utf-8") if isinstance(opt, bytes) else opt for opt in options]
nvvm.verify_program(program, len(options_list), options_list)
nvvm.compile_program(program, len(options_list), options_list)
except Exception:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're compiling anyway, do you actually need the verify_program() call?

I'm thinking the try should just be around this one line:

        try:
            nvvm.compile_program(prog, len(options), options)
        except nvvm.nvvmError as e:
            # can we add something here to ensure we're not masking errors other than invalid options?

I believe it's really important to take great care that we're not masking actual errors; e.g. the hard-wired _PRECHECK_NVVM_IR might need tweaks for future GPU generations. If we're simply reporting any error as "invalid compiler option", it'll potentially take someone downstream a long time to drill down all the way back here.

return False
finally:
if program is not None:
with contextlib.suppress(Exception):
nvvm.destroy_program(program)
return True
89 changes: 11 additions & 78 deletions cuda_core/tests/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,81 +70,14 @@ def _has_nvrtc_pch_apis_for_tests():
)


_libnvvm_version = None
_libnvvm_version_attempted = False

precheck_nvvm_ir = """target triple = "nvptx64-unknown-cuda"
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"

define void @dummy_kernel() {{
entry:
ret void
}}

!nvvm.annotations = !{{!0}}
!0 = !{{void ()* @dummy_kernel, !"kernel", i32 1}}

!nvvmir.version = !{{!1}}
!1 = !{{i32 {major}, i32 {minor}, i32 {debug_major}, i32 {debug_minor}}}
"""


def _get_libnvvm_version_for_tests():
"""
Detect libNVVM version by compiling dummy IR and analyzing the PTX output.

Workaround for the lack of direct libNVVM version API (nvbugs 5312315).
The approach:
- Compile a small dummy NVVM IR to PTX
- Use PTX version analysis APIs if available to infer libNVVM version
- Cache the result for future use
"""
global _libnvvm_version, _libnvvm_version_attempted

if _libnvvm_version_attempted:
return _libnvvm_version

_libnvvm_version_attempted = True

def _check_nvvm_arch(arch: str) -> bool:
"""Check if the given NVVM arch is supported by the installed libNVVM."""
try:
from cuda.core._program import _get_nvvm_module

nvvm = _get_nvvm_module()

try:
from cuda.bindings.utils import get_minimal_required_cuda_ver_from_ptx_ver, get_ptx_ver
except ImportError:
_libnvvm_version = None
return _libnvvm_version

program = nvvm.create_program()
try:
major, minor, debug_major, debug_minor = nvvm.ir_version()
global precheck_nvvm_ir
precheck_nvvm_ir = precheck_nvvm_ir.format(
major=major, minor=minor, debug_major=debug_major, debug_minor=debug_minor
)
precheck_ir_bytes = precheck_nvvm_ir.encode("utf-8")
nvvm.add_module_to_program(program, precheck_ir_bytes, len(precheck_ir_bytes), "precheck.ll")

options = ["-arch=compute_90"]
nvvm.verify_program(program, len(options), options)
nvvm.compile_program(program, len(options), options)

ptx_size = nvvm.get_compiled_result_size(program)
ptx_data = bytearray(ptx_size)
nvvm.get_compiled_result(program, ptx_data)
ptx_str = ptx_data.decode("utf-8")
ptx_version = get_ptx_ver(ptx_str)
cuda_version = get_minimal_required_cuda_ver_from_ptx_ver(ptx_version)
_libnvvm_version = cuda_version
return _libnvvm_version
finally:
nvvm.destroy_program(program)
from cuda.bindings.utils import check_nvvm_options

return check_nvvm_options([f"-arch={arch}".encode()])
except Exception:
Copy link
Copy Markdown
Collaborator

@rwgk rwgk Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only a test, therefore not nearly as critical as in the production code, but we may miss regressions if this isn't handled with similar care.

_libnvvm_version = None
return _libnvvm_version
return False


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -525,8 +458,8 @@ def test_nvvm_compile_invalid_ir():
pytest.param(
ProgramOptions(name="test_sm110_1", arch="sm_110", device_code_optimize=False),
marks=pytest.mark.skipif(
(_get_libnvvm_version_for_tests() or 0) < 13000,
reason="Compute capability 110 requires libNVVM >= 13.0",
not _check_nvvm_arch("compute_110"),
reason="Compute capability 110 not supported by installed libNVVM",
),
),
pytest.param(
Expand All @@ -540,15 +473,15 @@ def test_nvvm_compile_invalid_ir():
device_code_optimize=True,
),
marks=pytest.mark.skipif(
(_get_libnvvm_version_for_tests() or 0) < 13000,
reason="Compute capability 110 requires libNVVM >= 13.0",
not _check_nvvm_arch("compute_110"),
reason="Compute capability 110 not supported by installed libNVVM",
),
),
pytest.param(
ProgramOptions(name="test_sm110_3", arch="sm_110", link_time_optimization=True),
marks=pytest.mark.skipif(
(_get_libnvvm_version_for_tests() or 0) < 13000,
reason="Compute capability 110 requires libNVVM >= 13.0",
not _check_nvvm_arch("compute_110"),
reason="Compute capability 110 not supported by installed libNVVM",
),
),
],
Expand Down
Loading