| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // Copyright (C) 2025 Tiago V. L. Amorim (@tiagovla) | ||
| 2 | // | ||
| 3 | // This file is part of oiseau (https://github.com/tiagovla/oiseau) | ||
| 4 | // | ||
| 5 | // SPDX-License-Identifier: GPL-3.0-or-later | ||
| 6 | |||
| 7 | #include <pybind11/cast.h> | ||
| 8 | #include <pybind11/detail/common.h> | ||
| 9 | #include <pybind11/detail/descr.h> | ||
| 10 | #include <pybind11/embed.h> | ||
| 11 | #include <pybind11/functional.h> | ||
| 12 | #include <pybind11/numpy.h> | ||
| 13 | #include <pybind11/pybind11.h> | ||
| 14 | #include <pybind11/pytypes.h> | ||
| 15 | #include <pybind11/stl.h> | ||
| 16 | |||
| 17 | #include <utility> | ||
| 18 | #include <xtensor/containers/xarray.hpp> | ||
| 19 | #include <xtensor/core/xexpression.hpp> | ||
| 20 | #include <xtensor/core/xtensor_forward.hpp> | ||
| 21 | |||
| 22 | namespace py = pybind11; | ||
| 23 | using namespace py::literals; | ||
| 24 | |||
| 25 | namespace pybind11::detail { | ||
| 26 | |||
| 27 | template <typename T> | ||
| 28 | struct type_caster<xt::xarray<T>> { | ||
| 29 | public: | ||
| 30 | PYBIND11_TYPE_CASTER(xt::xarray<T>, const_name("xarray")); | ||
| 31 | ✗ | static handle cast(xt::xarray<T> &src, return_value_policy, handle) { | |
| 32 | ✗ | auto v = py::array(src.size(), src.data(), py::capsule([]() {})); | |
| 33 | ✗ | v.resize(src.shape()); | |
| 34 | ✗ | return v.release(); | |
| 35 | ✗ | } | |
| 36 | }; | ||
| 37 | |||
| 38 | template <typename E> | ||
| 39 | requires xt::is_xexpression<E>::value | ||
| 40 | struct type_caster<E> { | ||
| 41 | using value_type = typename E::value_type; | ||
| 42 | using array_type = xt::xarray<value_type>; | ||
| 43 | |||
| 44 | PYBIND11_TYPE_CASTER(array_type, _("xexpression")); | ||
| 45 | ✗ | static handle cast(const E &expr, return_value_policy, handle) { | |
| 46 | ✗ | xt::xarray<value_type> tmp = xt::eval(expr); | |
| 47 | ✗ | auto *data_holder = new xt::xarray<double>(std::move(tmp)); | |
| 48 | ✗ | py::capsule owner_capsule(data_holder, | |
| 49 | ✗ | [](void *p) { delete static_cast<xt::xarray<value_type> *>(p); }); | |
| 50 | ✗ | py::array result(data_holder->shape(), data_holder->data(), owner_capsule); | |
| 51 | ✗ | return result.release(); | |
| 52 | ✗ | } | |
| 53 | }; | ||
| 54 | |||
| 55 | } // namespace pybind11::detail | ||
| 56 | |||
| 57 | namespace plt { | ||
| 58 | |||
| 59 | using scoped_interpreter = py::scoped_interpreter; | ||
| 60 | |||
| 61 | ✗ | inline auto plt() { return py::module_::import("matplotlib.pyplot"); } | |
| 62 | |||
| 63 | #define DEFINE_PYPLOT_FUNC(name) \ | ||
| 64 | template <typename... Args> \ | ||
| 65 | auto name(Args &&...args) { \ | ||
| 66 | return plt().attr(#name)(std::forward<Args>(args)...); \ | ||
| 67 | } | ||
| 68 | |||
| 69 | #define DEFINE_PYPLOT_FUNC_WRAPPED(WrapperClass, name) \ | ||
| 70 | template <typename... Args> \ | ||
| 71 | WrapperClass name(Args &&...args) { \ | ||
| 72 | return WrapperClass(plt().attr(#name)(std::forward<Args>(args)...)); \ | ||
| 73 | } | ||
| 74 | |||
| 75 | #define DEFINE_PY_CLASS_METHOD_AUTO(name) \ | ||
| 76 | template <typename... Args> \ | ||
| 77 | auto name(Args &&...args) { \ | ||
| 78 | return m_obj.attr(#name)(std::forward<Args>(args)...); \ | ||
| 79 | } | ||
| 80 | |||
| 81 | #define DEFINE_PY_CLASS_METHOD_WRAPPED(WrapperClass, name) \ | ||
| 82 | template <typename... Args> \ | ||
| 83 | WrapperClass name(Args &&...args) { \ | ||
| 84 | return WrapperClass(m_obj.attr(#name)(std::forward<Args>(args)...)); \ | ||
| 85 | } | ||
| 86 | |||
| 87 | // --- Classes using the simplified macros --- | ||
| 88 | |||
| 89 | class PYBIND11_EXPORT AxesSubPlot { | ||
| 90 | public: | ||
| 91 | ✗ | explicit AxesSubPlot(py::object axes) : m_obj(std::move(axes)) {} | |
| 92 | |||
| 93 | ✗ | DEFINE_PY_CLASS_METHOD_AUTO(scatter) | |
| 94 | ✗ | DEFINE_PY_CLASS_METHOD_AUTO(triplot) | |
| 95 | ✗ | DEFINE_PY_CLASS_METHOD_AUTO(plot) | |
| 96 | DEFINE_PY_CLASS_METHOD_AUTO(text) | ||
| 97 | DEFINE_PY_CLASS_METHOD_AUTO(set_xlabel) | ||
| 98 | DEFINE_PY_CLASS_METHOD_AUTO(set_ylabel) | ||
| 99 | DEFINE_PY_CLASS_METHOD_AUTO(set_zlabel) | ||
| 100 | DEFINE_PY_CLASS_METHOD_AUTO(set_title) | ||
| 101 | DEFINE_PY_CLASS_METHOD_AUTO(grid) | ||
| 102 | DEFINE_PY_CLASS_METHOD_AUTO(legend) | ||
| 103 | DEFINE_PY_CLASS_METHOD_AUTO(set_xlim) | ||
| 104 | DEFINE_PY_CLASS_METHOD_AUTO(set_ylim) | ||
| 105 | DEFINE_PY_CLASS_METHOD_AUTO(set_zlim) | ||
| 106 | |||
| 107 | private: | ||
| 108 | py::object m_obj; | ||
| 109 | }; | ||
| 110 | |||
| 111 | class PYBIND11_EXPORT Figure { | ||
| 112 | public: | ||
| 113 | ✗ | explicit Figure(py::object fig) : m_obj(std::move(fig)) {} | |
| 114 | |||
| 115 | DEFINE_PY_CLASS_METHOD_WRAPPED(AxesSubPlot, add_subplot) | ||
| 116 | DEFINE_PY_CLASS_METHOD_AUTO(savefig) | ||
| 117 | DEFINE_PY_CLASS_METHOD_AUTO(show) | ||
| 118 | |||
| 119 | private: | ||
| 120 | py::object m_obj; | ||
| 121 | }; | ||
| 122 | |||
| 123 | // --- Free functions using the simplified macros --- | ||
| 124 | |||
| 125 | DEFINE_PYPLOT_FUNC(scatter) | ||
| 126 | ✗ | DEFINE_PYPLOT_FUNC(show) | |
| 127 | DEFINE_PYPLOT_FUNC(plot) | ||
| 128 | DEFINE_PYPLOT_FUNC(xlabel) | ||
| 129 | DEFINE_PYPLOT_FUNC(ylabel) | ||
| 130 | DEFINE_PYPLOT_FUNC(title) | ||
| 131 | DEFINE_PYPLOT_FUNC(grid) | ||
| 132 | DEFINE_PYPLOT_FUNC(legend) | ||
| 133 | DEFINE_PYPLOT_FUNC(savefig) | ||
| 134 | |||
| 135 | DEFINE_PYPLOT_FUNC_WRAPPED(Figure, figure) | ||
| 136 | |||
| 137 | template <typename... Args> | ||
| 138 | ✗ | std::pair<Figure, AxesSubPlot> subplots(Args &&...args) { | |
| 139 | ✗ | py::object result = plt().attr("subplots")(std::forward<Args>(args)...); | |
| 140 | ✗ | auto res = result.cast<py::tuple>(); | |
| 141 | ✗ | return {Figure(res[0]), AxesSubPlot(res[1])}; | |
| 142 | ✗ | } | |
| 143 | |||
| 144 | // --- Undefine all macros --- | ||
| 145 | |||
| 146 | #undef DEFINE_PYPLOT_FUNC | ||
| 147 | #undef DEFINE_PYPLOT_FUNC_WRAPPED | ||
| 148 | #undef DEFINE_PY_CLASS_METHOD_AUTO | ||
| 149 | #undef DEFINE_PY_CLASS_METHOD_WRAPPED | ||
| 150 | |||
| 151 | } // namespace plt | ||
| 152 |