Skip to content

Commit 69c55ff

Browse files
Kick out py_shared_ptr and instead keep the python instance alive manually by holding a reference to it
See pybind/pybind11#1389 for why py_shared_ptr was needed in the first place, and the comment from May 27 why we may not want to use it (reference cycle)
1 parent 1e412b8 commit 69c55ff

7 files changed

+19
-54
lines changed

src/simsopt/field/magneticfieldclasses.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,10 @@ class InterpolatedField(sopp.InterpolatedField, MagneticField):
365365
This resulting interpolant can then be evaluated very quickly.
366366
"""
367367

368-
def __init__(self, *args):
368+
def __init__(self, underlying_field, *args):
369+
self.__underlying_field = underlying_field
369370
MagneticField.__init__(self)
370-
sopp.InterpolatedField.__init__(self, *args)
371+
sopp.InterpolatedField.__init__(self, underlying_field, *args)
371372

372373
def to_vtk(self, filename, h=0.1):
373374
"""Export the field evaluated on a regular grid for visualisation with e.g. Paraview."""

src/simsoptpp/py_shared_ptr.h

Lines changed: 0 additions & 30 deletions
This file was deleted.

src/simsoptpp/python.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
#include "pybind11/pybind11.h"
22
#include "pybind11/stl.h"
33
#include "pybind11/functional.h"
4-
#include "py_shared_ptr.h"
5-
PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr<T>);
64
#define FORCE_IMPORT_ARRAY
75
#include "xtensor-python/pyarray.hpp" // Numpy bindings
86
typedef xt::pyarray<double> PyArray;

src/simsoptpp/python_curves.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#include "pybind11/pybind11.h"
22
#include "pybind11/stl.h"
3+
namespace py = pybind11;
34
#include "xtensor-python/pyarray.hpp" // Numpy bindings
45
typedef xt::pyarray<double> PyArray;
5-
#include "py_shared_ptr.h"
6-
PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr<T>);
76
using std::shared_ptr;
87

98

@@ -89,16 +88,16 @@ template <typename T, typename S> void register_common_curve_methods(S &c) {
8988
}
9089

9190
void init_curves(py::module_ &m) {
92-
auto pycurve = py::class_<PyCurve, py_shared_ptr<PyCurve>, PyCurveTrampoline<PyCurve>>(m, "Curve")
91+
auto pycurve = py::class_<PyCurve, PyCurveTrampoline<PyCurve>, shared_ptr<PyCurve>>(m, "Curve")
9392
.def(py::init<vector<double>>());
9493
register_common_curve_methods<PyCurve>(pycurve);
9594

96-
auto pycurvexyzfourier = py::class_<PyCurveXYZFourier, py_shared_ptr<PyCurveXYZFourier>, PyCurveXYZFourierTrampoline<PyCurveXYZFourier>, PyCurve>(m, "CurveXYZFourier")
95+
auto pycurvexyzfourier = py::class_<PyCurveXYZFourier, PyCurveXYZFourierTrampoline<PyCurveXYZFourier>, shared_ptr<PyCurveXYZFourier>, PyCurve>(m, "CurveXYZFourier")
9796
.def(py::init<vector<double>, int>())
9897
.def_readonly("dofs", &PyCurveXYZFourier::dofs);
9998
register_common_curve_methods<PyCurveXYZFourier>(pycurvexyzfourier);
10099

101-
auto pycurverzfourier = py::class_<PyCurveRZFourier, py_shared_ptr<PyCurveRZFourier>, PyCurveRZFourierTrampoline<PyCurveRZFourier>, PyCurve>(m, "CurveRZFourier")
100+
auto pycurverzfourier = py::class_<PyCurveRZFourier, PyCurveRZFourierTrampoline<PyCurveRZFourier>, shared_ptr<PyCurveRZFourier>, PyCurve>(m, "CurveRZFourier")
102101
//.def(py::init<int, int>())
103102
.def(py::init<vector<double>, int, int, bool>())
104103
.def_readwrite("rc", &PyCurveRZFourier::rc)

src/simsoptpp/python_magneticfield.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
#include "pybind11/pybind11.h"
22
#include "pybind11/stl.h"
33
#include "pybind11/functional.h"
4+
namespace py = pybind11;
45
#include "xtensor-python/pyarray.hpp" // Numpy bindings
56
#include "xtensor-python/pytensor.hpp" // Numpy bindings
67
typedef xt::pyarray<double> PyArray;
78
typedef xt::pytensor<double, 2, xt::layout_type::row_major> PyTensor;
8-
#include "py_shared_ptr.h"
9-
PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr<T>);
109
using std::shared_ptr;
1110
using std::vector;
1211

@@ -53,17 +52,17 @@ template <typename T, typename S> void register_common_field_methods(S &c) {
5352

5453
void init_magneticfields(py::module_ &m){
5554

56-
py::class_<InterpolationRule, py_shared_ptr<InterpolationRule>>(m, "InterpolationRule", "Abstract class for interpolation rules on an interval.")
55+
py::class_<InterpolationRule, shared_ptr<InterpolationRule>>(m, "InterpolationRule", "Abstract class for interpolation rules on an interval.")
5756
.def_readonly("degree", &InterpolationRule::degree, "The degree of the polynomial. The number of interpolation points in `degree+1`.");
5857

59-
py::class_<UniformInterpolationRule, py_shared_ptr<UniformInterpolationRule>, InterpolationRule>(m, "UniformInterpolationRule", "Polynomial interpolation using equispaced points.")
58+
py::class_<UniformInterpolationRule, shared_ptr<UniformInterpolationRule>, InterpolationRule>(m, "UniformInterpolationRule", "Polynomial interpolation using equispaced points.")
6059
.def(py::init<int>())
6160
.def_readonly("degree", &UniformInterpolationRule::degree, "The degree of the polynomial. The number of interpolation points in `degree+1`.");
62-
py::class_<ChebyshevInterpolationRule, py_shared_ptr<ChebyshevInterpolationRule>, InterpolationRule>(m, "ChebyshevInterpolationRule", "Polynomial interpolation using chebychev points.")
61+
py::class_<ChebyshevInterpolationRule, shared_ptr<ChebyshevInterpolationRule>, InterpolationRule>(m, "ChebyshevInterpolationRule", "Polynomial interpolation using chebychev points.")
6362
.def(py::init<int>())
6463
.def_readonly("degree", &ChebyshevInterpolationRule::degree, "The degree of the polynomial. The number of interpolation points in `degree+1`.");
6564

66-
py::class_<RegularGridInterpolant3D<PyTensor>, py_shared_ptr<RegularGridInterpolant3D<PyTensor>>>(m, "RegularGridInterpolant3D",
65+
py::class_<RegularGridInterpolant3D<PyTensor>, shared_ptr<RegularGridInterpolant3D<PyTensor>>>(m, "RegularGridInterpolant3D",
6766
R"pbdoc(
6867
Interpolates a (vector valued) function on a uniform grid.
6968
This interpolant is optimized for fast function evaluation (at the cost of memory usage). The main purpose of this class is to be used to interpolate magnetic fields and then use the interpolant for tasks such as fieldline or particle tracing for which the field needs to be evaluated many many times.
@@ -74,32 +73,32 @@ void init_magneticfields(py::module_ &m){
7473
.def("evaluate_batch", &RegularGridInterpolant3D<PyTensor>::evaluate_batch, "Evaluate the interpolant at multiple points (faster than `evaluate` as it uses prefetching).");
7574

7675

77-
py::class_<Current<PyArray>, py_shared_ptr<Current<PyArray>>>(m, "Current", "Simple class that wraps around a single double representing a coil current.")
76+
py::class_<Current<PyArray>, shared_ptr<Current<PyArray>>>(m, "Current", "Simple class that wraps around a single double representing a coil current.")
7877
.def(py::init<double>())
7978
.def("set_dofs", &Current<PyArray>::set_dofs, "Set the current.")
8079
.def("get_dofs", &Current<PyArray>::get_dofs, "Get the current.")
8180
.def("set_value", &Current<PyArray>::set_value, "Set the current.")
8281
.def("get_value", &Current<PyArray>::get_value, "Get the current.");
8382

8483

85-
py::class_<Coil<PyArray>, py_shared_ptr<Coil<PyArray>>>(m, "Coil", "Optimizable that represents a coil, consisting of a curve and a current.")
84+
py::class_<Coil<PyArray>, shared_ptr<Coil<PyArray>>>(m, "Coil", "Optimizable that represents a coil, consisting of a curve and a current.")
8685
.def(py::init<shared_ptr<Curve<PyArray>>, shared_ptr<Current<PyArray>>>())
8786
.def_readonly("curve", &Coil<PyArray>::curve, "Get the underlying curve.")
8887
.def_readonly("current", &Coil<PyArray>::current, "Get the underlying current.");
8988

90-
auto mf = py::class_<PyMagneticField, PyMagneticFieldTrampoline<PyMagneticField>, py_shared_ptr<PyMagneticField>>(m, "MagneticField", "Abstract class representing magnetic fields.")
89+
auto mf = py::class_<PyMagneticField, PyMagneticFieldTrampoline<PyMagneticField>, shared_ptr<PyMagneticField>>(m, "MagneticField", "Abstract class representing magnetic fields.")
9190
.def(py::init<>());
9291
register_common_field_methods<PyMagneticField>(mf);
9392
//.def("B", py::overload_cast<>(&PyMagneticField::B));
9493

95-
auto bs = py::class_<PyBiotSavart, PyMagneticFieldTrampoline<PyBiotSavart>, py_shared_ptr<PyBiotSavart>, PyMagneticField>(m, "BiotSavart")
94+
auto bs = py::class_<PyBiotSavart, PyMagneticFieldTrampoline<PyBiotSavart>, shared_ptr<PyBiotSavart>, PyMagneticField>(m, "BiotSavart")
9695
.def(py::init<vector<shared_ptr<Coil<PyArray>>>>())
9796
.def("compute", &PyBiotSavart::compute)
9897
.def("fieldcache_get_or_create", &PyBiotSavart::fieldcache_get_or_create)
9998
.def("fieldcache_get_status", &PyBiotSavart::fieldcache_get_status);
10099
register_common_field_methods<PyBiotSavart>(bs);
101100

102-
auto ifield = py::class_<PyInterpolatedField, py_shared_ptr<PyInterpolatedField>, PyMagneticField>(m, "InterpolatedField")
101+
auto ifield = py::class_<PyInterpolatedField, shared_ptr<PyInterpolatedField>, PyMagneticField>(m, "InterpolatedField")
103102
.def(py::init<shared_ptr<PyMagneticField>, InterpolationRule, RangeTriplet, RangeTriplet, RangeTriplet, bool>())
104103
.def(py::init<shared_ptr<PyMagneticField>, int, RangeTriplet, RangeTriplet, RangeTriplet, bool>())
105104
.def("estimate_error_B", &PyInterpolatedField::estimate_error_B)

src/simsoptpp/python_surfaces.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#include "pybind11/pybind11.h"
22
#include "pybind11/stl.h"
3+
namespace py = pybind11;
34
#include "xtensor-python/pyarray.hpp" // Numpy bindings
45
typedef xt::pyarray<double> PyArray;
5-
#include "py_shared_ptr.h"
6-
PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr<T>);
76
using std::shared_ptr;
87
using std::vector;
98

src/simsoptpp/python_tracing.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
#include "pybind11/pybind11.h"
22
#include "pybind11/stl.h"
33
#include "pybind11/functional.h"
4+
namespace py = pybind11;
45
#include "xtensor-python/pyarray.hpp" // Numpy bindings
56
typedef xt::pyarray<double> PyArray;
67
#include "xtensor-python/pytensor.hpp" // Numpy bindings
78
typedef xt::pytensor<double, 2, xt::layout_type::row_major> PyTensor;
8-
#include "py_shared_ptr.h"
9-
PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr<T>);
109
using std::shared_ptr;
1110
using std::vector;
1211
#include "tracing.h"

0 commit comments

Comments
 (0)