The minimum code to reproduce the issue is as follows:
aaa.hpp
#include <string>
#include <vector>
#include <iostream>
#include <stdexcept>
#include <cmath>
#include <cassert>
#include <utility>
template <typename D>
class BaseClass{
/* The base class. */
protected:
bool skip_nan = true;
};
template <typename D>
class DerivedClass : public BaseClass<D>{
public:
explicit DerivedClass(bool skip_nan_){ this->skip_nan = skip_nan_; }
};
aaa_py.cpp
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include "aaa.hpp"
namespace py = pybind11;
template <typename D>
void fff(BaseClass<D>& rs){
}
PYBIND11_MODULE(aaa_py, m) {
m.def("fff_float", &fff<float>);
py::class_<DerivedClass<float>>(m, "DerivedClass_float").def(py::init<bool>(), py::arg("skip_nan")=true);
}
test.py
import aaa_py as rsp
ds = rsp.DerivedClass_float()
rsp.fff_float(ds)
Error:
Traceback (most recent call last):
File "usage_example_python.py", line 8, in <module>
rsp.fff_float(ds)
TypeError: fff_float(): incompatible function arguments. The following argument types are supported:
1. (arg0: BaseClass<float>) -> None
Invoked with: <aaa_py.DerivedClass_float object at 0x000002C9EAF57308>
Basically the error is saying that function fff_float
expects somebody from BaseClass<float>
, but receives DerivedClass<float>
. If I change the function to accept DerivedClass<D>&
, then there would be no errors. Or, if I create an instantiation of DerivedClass<float>
directly in c and pass it to the float_fff
, there would also be no issues. So it seems that the issue is that for an object created in Python, somehow the information about its base class is lost. How can I deal with this?
CodePudding user response:
You need to declare BaseClass<float>
in Python and tell pybind11 that DerivedClass_float
extends it.
PYBIND11_MODULE(MyModule, m)
{
m.def("fff_float", &fff<float>);
// declare base class - this simply expose to Python, it's impossible to
// construct a BaseClass_float in Python since no constructor is provided
py::class_<BaseClass<float>>(m, "BaseClass_float");
// tell pybind11 that DerivedClass<float> extends BaseClass<float>
py::class_<DerivedClass<float>, BaseClass<float>>(m, "DerivedClass_float")
// ^^^^^^^^^^^^^^^^
.def(py::init<bool>(), py::arg("skip_nan") = true);
}