Skip to content

feat: expose DLPack and numpy __array__ protocols on OrtValue#27836

Open
Rishi-Dave wants to merge 1 commit intomicrosoft:mainfrom
Rishi-Dave:rishidave/feat/ortvalue-dlpack-array-protocols
Open

feat: expose DLPack and numpy __array__ protocols on OrtValue#27836
Rishi-Dave wants to merge 1 commit intomicrosoft:mainfrom
Rishi-Dave:rishidave/feat/ortvalue-dlpack-array-protocols

Conversation

@Rishi-Dave
Copy link
Contributor

Summary

  • Add __dlpack__, __dlpack_device__, from_dlpack(), and __array__() to the Python OrtValue wrapper class
  • Enable zero-copy tensor sharing between ONNX Runtime and DLPack-compatible frameworks (PyTorch, JAX, CuPy)
  • Allow np.array(ort_value) and np.asarray(ort_value) via the numpy array protocol

Motivation

Fixes #24071

The C++ pybind layer (C.OrtValue) already implements __dlpack__, __dlpack_device__, and from_dlpack, but these aren't exposed on the user-facing onnxruntime.OrtValue class. Users currently have to reach through ort_value._ortvalue to access DLPack functionality. Similarly, there's no __array__ protocol, so np.asarray(ort_value) doesn't work.

This change surfaces these protocols on the public API so that:

# DLPack: zero-copy sharing with PyTorch
import torch
torch_tensor = torch.from_dlpack(ort_value)
ort_value = OrtValue.from_dlpack(torch_tensor)

# numpy: transparent conversion
arr = np.asarray(ort_value)

Changes

  • onnxruntime/python/onnxruntime_inference_collection.py: Added four methods to OrtValue:

    • __dlpack__(stream=None) — delegates to C.OrtValue.__dlpack__
    • __dlpack_device__() — delegates to C.OrtValue.__dlpack_device__
    • from_dlpack(source) — classmethod that calls source.__dlpack__(), auto-detects bool tensors via is_dlpack_uint8_tensor, and creates an OrtValue
    • __array__(dtype=None, copy=None) — delegates to numpy() with optional dtype casting
    • All DLPack methods guard against builds where DLPack is disabled
  • onnxruntime/test/python/onnxruntime_test_python.py: Added 5 tests:

    • test_ort_value_wrapper_dlpack — round-trip via __dlpack__/__dlpack_device__/from_dlpack
    • test_ort_value_from_dlpack_numpyOrtValue.from_dlpack(numpy_arr)
    • test_ort_value_from_dlpack_dtypes — dtype preservation across float32/float64/int32/int64
    • test_ort_value_array_protocolnp.array()/np.asarray() with dtype casting
    • test_ort_value_array_protocol_int — integer tensor via __array__

Test Plan

  • All 5 new tests pass locally
  • Existing test_ort_value_dlpack and test_ort_value_dlpack_zero_size tests remain unaffected
  • No C++ changes; pure Python addition to the wrapper layer
  • ruff check and ruff format pass clean

Add __dlpack__, __dlpack_device__, from_dlpack(), and __array__() to
the Python OrtValue wrapper class so that OrtValues interoperate
natively with DLPack-aware frameworks (PyTorch, JAX, CuPy) and with
numpy's array protocol (np.array/np.asarray).

The C++ pybind layer already implements these methods on C.OrtValue;
this change surfaces them on the user-facing onnxruntime.OrtValue class
with appropriate build guards and documentation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Make OrtValue compatible with numpy __array__ and dlpack protocols

1 participant