feat: expose DLPack and numpy __array__ protocols on OrtValue#27836
Open
Rishi-Dave wants to merge 1 commit intomicrosoft:mainfrom
Open
feat: expose DLPack and numpy __array__ protocols on OrtValue#27836Rishi-Dave wants to merge 1 commit intomicrosoft:mainfrom
Rishi-Dave wants to merge 1 commit intomicrosoft:mainfrom
Conversation
Add __dlpack__, __dlpack_device__, from_dlpack(), and __array__() to the Python OrtValue wrapper class so that OrtValues interoperate natively with DLPack-aware frameworks (PyTorch, JAX, CuPy) and with numpy's array protocol (np.array/np.asarray). The C++ pybind layer already implements these methods on C.OrtValue; this change surfaces them on the user-facing onnxruntime.OrtValue class with appropriate build guards and documentation.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
__dlpack__,__dlpack_device__,from_dlpack(), and__array__()to the PythonOrtValuewrapper classnp.array(ort_value)andnp.asarray(ort_value)via the numpy array protocolMotivation
Fixes #24071
The C++ pybind layer (
C.OrtValue) already implements__dlpack__,__dlpack_device__, andfrom_dlpack, but these aren't exposed on the user-facingonnxruntime.OrtValueclass. Users currently have to reach throughort_value._ortvalueto access DLPack functionality. Similarly, there's no__array__protocol, sonp.asarray(ort_value)doesn't work.This change surfaces these protocols on the public API so that:
Changes
onnxruntime/python/onnxruntime_inference_collection.py: Added four methods toOrtValue:__dlpack__(stream=None)— delegates toC.OrtValue.__dlpack____dlpack_device__()— delegates toC.OrtValue.__dlpack_device__from_dlpack(source)— classmethod that callssource.__dlpack__(), auto-detects bool tensors viais_dlpack_uint8_tensor, and creates anOrtValue__array__(dtype=None, copy=None)— delegates tonumpy()with optional dtype castingonnxruntime/test/python/onnxruntime_test_python.py: Added 5 tests:test_ort_value_wrapper_dlpack— round-trip via__dlpack__/__dlpack_device__/from_dlpacktest_ort_value_from_dlpack_numpy—OrtValue.from_dlpack(numpy_arr)test_ort_value_from_dlpack_dtypes— dtype preservation across float32/float64/int32/int64test_ort_value_array_protocol—np.array()/np.asarray()with dtype castingtest_ort_value_array_protocol_int— integer tensor via__array__Test Plan
test_ort_value_dlpackandtest_ort_value_dlpack_zero_sizetests remain unaffectedruff checkandruff formatpass clean