OISEAU
A modern DGTD framework
Loading...
Searching...
No Matches
pyplot.hpp
1// Copyright (C) 2025 Tiago V. L. Amorim (@tiagovla)
2//
3// This file is part of oiseau (https://github.com/tiagovla/oiseau)
4//
5// SPDX-License-Identifier: GPL-3.0-or-later
6
7#include <pybind11/cast.h>
8#include <pybind11/detail/common.h>
9#include <pybind11/detail/descr.h>
10#include <pybind11/embed.h>
11#include <pybind11/functional.h>
12#include <pybind11/numpy.h>
13#include <pybind11/pybind11.h>
14#include <pybind11/pytypes.h>
15#include <pybind11/stl.h>
16
17#include <utility>
18#include <xtensor/containers/xarray.hpp>
19#include <xtensor/core/xexpression.hpp>
20#include <xtensor/core/xtensor_forward.hpp>
21
22namespace py = pybind11;
23using namespace py::literals;
24
25namespace pybind11::detail {
26
27template <typename T>
28struct type_caster<xt::xarray<T>> {
29 public:
30 PYBIND11_TYPE_CASTER(xt::xarray<T>, const_name("xarray"));
31 static handle cast(xt::xarray<T> &src, return_value_policy, handle) {
32 auto v = py::array(src.size(), src.data(), py::capsule([]() {}));
33 v.resize(src.shape());
34 return v.release();
35 }
36};
37
38template <typename E>
39 requires xt::is_xexpression<E>::value
40struct type_caster<E> {
41 using value_type = typename E::value_type;
42 using array_type = xt::xarray<value_type>;
43
44 PYBIND11_TYPE_CASTER(array_type, _("xexpression"));
45 static handle cast(const E &expr, return_value_policy, handle) {
46 xt::xarray<value_type> tmp = xt::eval(expr);
47 auto *data_holder = new xt::xarray<double>(std::move(tmp));
48 py::capsule owner_capsule(data_holder,
49 [](void *p) { delete static_cast<xt::xarray<value_type> *>(p); });
50 py::array result(data_holder->shape(), data_holder->data(), owner_capsule);
51 return result.release();
52 }
53};
54
55} // namespace pybind11::detail
56
57namespace plt {
58
59using scoped_interpreter = py::scoped_interpreter;
60
61inline auto plt() { return py::module_::import("matplotlib.pyplot"); }
62
63#define DEFINE_PYPLOT_FUNC(name) \
64 template <typename... Args> \
65 auto name(Args &&...args) { \
66 return plt().attr(#name)(std::forward<Args>(args)...); \
67 }
68
69#define DEFINE_PYPLOT_FUNC_WRAPPED(WrapperClass, name) \
70 template <typename... Args> \
71 WrapperClass name(Args &&...args) { \
72 return WrapperClass(plt().attr(#name)(std::forward<Args>(args)...)); \
73 }
74
75#define DEFINE_PY_CLASS_METHOD_AUTO(name) \
76 template <typename... Args> \
77 auto name(Args &&...args) { \
78 return m_obj.attr(#name)(std::forward<Args>(args)...); \
79 }
80
81#define DEFINE_PY_CLASS_METHOD_WRAPPED(WrapperClass, name) \
82 template <typename... Args> \
83 WrapperClass name(Args &&...args) { \
84 return WrapperClass(m_obj.attr(#name)(std::forward<Args>(args)...)); \
85 }
86
87// --- Classes using the simplified macros ---
88
89class PYBIND11_EXPORT AxesSubPlot {
90 public:
91 explicit AxesSubPlot(py::object axes) : m_obj(std::move(axes)) {}
92
93 DEFINE_PY_CLASS_METHOD_AUTO(scatter)
94 DEFINE_PY_CLASS_METHOD_AUTO(triplot)
95 DEFINE_PY_CLASS_METHOD_AUTO(plot)
96 DEFINE_PY_CLASS_METHOD_AUTO(text)
97 DEFINE_PY_CLASS_METHOD_AUTO(set_xlabel)
98 DEFINE_PY_CLASS_METHOD_AUTO(set_ylabel)
99 DEFINE_PY_CLASS_METHOD_AUTO(set_zlabel)
100 DEFINE_PY_CLASS_METHOD_AUTO(set_title)
101 DEFINE_PY_CLASS_METHOD_AUTO(grid)
102 DEFINE_PY_CLASS_METHOD_AUTO(legend)
103 DEFINE_PY_CLASS_METHOD_AUTO(set_xlim)
104 DEFINE_PY_CLASS_METHOD_AUTO(set_ylim)
105 DEFINE_PY_CLASS_METHOD_AUTO(set_zlim)
106
107 private:
108 py::object m_obj;
109};
110
111class PYBIND11_EXPORT Figure {
112 public:
113 explicit Figure(py::object fig) : m_obj(std::move(fig)) {}
114
115 DEFINE_PY_CLASS_METHOD_WRAPPED(AxesSubPlot, add_subplot)
116 DEFINE_PY_CLASS_METHOD_AUTO(savefig)
117 DEFINE_PY_CLASS_METHOD_AUTO(show)
118
119 private:
120 py::object m_obj;
121};
122
123// --- Free functions using the simplified macros ---
124
125DEFINE_PYPLOT_FUNC(scatter)
126DEFINE_PYPLOT_FUNC(show)
127DEFINE_PYPLOT_FUNC(plot)
128DEFINE_PYPLOT_FUNC(xlabel)
129DEFINE_PYPLOT_FUNC(ylabel)
130DEFINE_PYPLOT_FUNC(title)
131DEFINE_PYPLOT_FUNC(grid)
132DEFINE_PYPLOT_FUNC(legend)
133DEFINE_PYPLOT_FUNC(savefig)
134
135DEFINE_PYPLOT_FUNC_WRAPPED(Figure, figure)
136
137template <typename... Args>
138std::pair<Figure, AxesSubPlot> subplots(Args &&...args) {
139 py::object result = plt().attr("subplots")(std::forward<Args>(args)...);
140 auto res = result.cast<py::tuple>();
141 return {Figure(res[0]), AxesSubPlot(res[1])};
142}
143
144// --- Undefine all macros ---
145
146#undef DEFINE_PYPLOT_FUNC
147#undef DEFINE_PYPLOT_FUNC_WRAPPED
148#undef DEFINE_PY_CLASS_METHOD_AUTO
149#undef DEFINE_PY_CLASS_METHOD_WRAPPED
150
151} // namespace plt
Definition pyplot.hpp:89
Definition pyplot.hpp:111