GCC Code Coverage Report


Directory: src/oiseau/
File: src/oiseau/plotting/pyplot.hpp
Date: 2025-05-24 01:28:39
Exec Total Coverage
Lines: 0 25 0.0%
Functions: 0 17 0.0%
Branches: 0 62 0.0%

Line Branch Exec Source
1 // Copyright (C) 2025 Tiago V. L. Amorim (@tiagovla)
2 //
3 // This file is part of oiseau (https://github.com/tiagovla/oiseau)
4 //
5 // SPDX-License-Identifier: GPL-3.0-or-later
6
7 #include <pybind11/cast.h>
8 #include <pybind11/detail/common.h>
9 #include <pybind11/detail/descr.h>
10 #include <pybind11/embed.h>
11 #include <pybind11/functional.h>
12 #include <pybind11/numpy.h>
13 #include <pybind11/pybind11.h>
14 #include <pybind11/pytypes.h>
15 #include <pybind11/stl.h>
16
17 #include <utility>
18 #include <xtensor/containers/xarray.hpp>
19 #include <xtensor/core/xexpression.hpp>
20 #include <xtensor/core/xtensor_forward.hpp>
21
22 namespace py = pybind11;
23 using namespace py::literals;
24
25 namespace pybind11::detail {
26
27 template <typename T>
28 struct type_caster<xt::xarray<T>> {
29 public:
30 PYBIND11_TYPE_CASTER(xt::xarray<T>, const_name("xarray"));
31 static handle cast(xt::xarray<T> &src, return_value_policy, handle) {
32 auto v = py::array(src.size(), src.data(), py::capsule([]() {}));
33 v.resize(src.shape());
34 return v.release();
35 }
36 };
37
38 template <typename E>
39 requires xt::is_xexpression<E>::value
40 struct type_caster<E> {
41 using value_type = typename E::value_type;
42 using array_type = xt::xarray<value_type>;
43
44 PYBIND11_TYPE_CASTER(array_type, _("xexpression"));
45 static handle cast(const E &expr, return_value_policy, handle) {
46 xt::xarray<value_type> tmp = xt::eval(expr);
47 auto *data_holder = new xt::xarray<double>(std::move(tmp));
48 py::capsule owner_capsule(data_holder,
49 [](void *p) { delete static_cast<xt::xarray<value_type> *>(p); });
50 py::array result(data_holder->shape(), data_holder->data(), owner_capsule);
51 return result.release();
52 }
53 };
54
55 } // namespace pybind11::detail
56
57 namespace plt {
58
59 using scoped_interpreter = py::scoped_interpreter;
60
61 inline auto plt() { return py::module_::import("matplotlib.pyplot"); }
62
63 #define DEFINE_PYPLOT_FUNC(name) \
64 template <typename... Args> \
65 auto name(Args &&...args) { \
66 return plt().attr(#name)(std::forward<Args>(args)...); \
67 }
68
69 #define DEFINE_PYPLOT_FUNC_WRAPPED(WrapperClass, name) \
70 template <typename... Args> \
71 WrapperClass name(Args &&...args) { \
72 return WrapperClass(plt().attr(#name)(std::forward<Args>(args)...)); \
73 }
74
75 #define DEFINE_PY_CLASS_METHOD_AUTO(name) \
76 template <typename... Args> \
77 auto name(Args &&...args) { \
78 return m_obj.attr(#name)(std::forward<Args>(args)...); \
79 }
80
81 #define DEFINE_PY_CLASS_METHOD_WRAPPED(WrapperClass, name) \
82 template <typename... Args> \
83 WrapperClass name(Args &&...args) { \
84 return WrapperClass(m_obj.attr(#name)(std::forward<Args>(args)...)); \
85 }
86
87 // --- Classes using the simplified macros ---
88
89 class PYBIND11_EXPORT AxesSubPlot {
90 public:
91 explicit AxesSubPlot(py::object axes) : m_obj(std::move(axes)) {}
92
93 DEFINE_PY_CLASS_METHOD_AUTO(scatter)
94 DEFINE_PY_CLASS_METHOD_AUTO(triplot)
95 DEFINE_PY_CLASS_METHOD_AUTO(plot)
96 DEFINE_PY_CLASS_METHOD_AUTO(text)
97 DEFINE_PY_CLASS_METHOD_AUTO(set_xlabel)
98 DEFINE_PY_CLASS_METHOD_AUTO(set_ylabel)
99 DEFINE_PY_CLASS_METHOD_AUTO(set_zlabel)
100 DEFINE_PY_CLASS_METHOD_AUTO(set_title)
101 DEFINE_PY_CLASS_METHOD_AUTO(grid)
102 DEFINE_PY_CLASS_METHOD_AUTO(legend)
103 DEFINE_PY_CLASS_METHOD_AUTO(set_xlim)
104 DEFINE_PY_CLASS_METHOD_AUTO(set_ylim)
105 DEFINE_PY_CLASS_METHOD_AUTO(set_zlim)
106
107 private:
108 py::object m_obj;
109 };
110
111 class PYBIND11_EXPORT Figure {
112 public:
113 explicit Figure(py::object fig) : m_obj(std::move(fig)) {}
114
115 DEFINE_PY_CLASS_METHOD_WRAPPED(AxesSubPlot, add_subplot)
116 DEFINE_PY_CLASS_METHOD_AUTO(savefig)
117 DEFINE_PY_CLASS_METHOD_AUTO(show)
118
119 private:
120 py::object m_obj;
121 };
122
123 // --- Free functions using the simplified macros ---
124
125 DEFINE_PYPLOT_FUNC(scatter)
126 DEFINE_PYPLOT_FUNC(show)
127 DEFINE_PYPLOT_FUNC(plot)
128 DEFINE_PYPLOT_FUNC(xlabel)
129 DEFINE_PYPLOT_FUNC(ylabel)
130 DEFINE_PYPLOT_FUNC(title)
131 DEFINE_PYPLOT_FUNC(grid)
132 DEFINE_PYPLOT_FUNC(legend)
133 DEFINE_PYPLOT_FUNC(savefig)
134
135 DEFINE_PYPLOT_FUNC_WRAPPED(Figure, figure)
136
137 template <typename... Args>
138 std::pair<Figure, AxesSubPlot> subplots(Args &&...args) {
139 py::object result = plt().attr("subplots")(std::forward<Args>(args)...);
140 auto res = result.cast<py::tuple>();
141 return {Figure(res[0]), AxesSubPlot(res[1])};
142 }
143
144 // --- Undefine all macros ---
145
146 #undef DEFINE_PYPLOT_FUNC
147 #undef DEFINE_PYPLOT_FUNC_WRAPPED
148 #undef DEFINE_PY_CLASS_METHOD_AUTO
149 #undef DEFINE_PY_CLASS_METHOD_WRAPPED
150
151 } // namespace plt
152