Home > Software engineering >  How do I properly derive from a nested struct?
How do I properly derive from a nested struct?

Time:06-19

I have an abstract (templated) class that I want to have its own return type InferenceData.

template <typename StateType>
class Model {
public:
    struct InferenceData;
    virtual InferenceData inference () = 0;
};

Now below is an attempt to derive from it

template <typename StateType>
class MonteCarlo : public Model<StateType> {
public:
    
    // struct InferenceData {};
    
    typename MonteCarlo::InferenceData inference () {
        typename MonteCarlo::InferenceData x;
        return x;
    }
};

This works, but only because the definition of MonteCarlo::InferenceData is commented out. If it is not commented, I get invalid covariant return type error. I want each ModelDerivation<StateType>::InferenceData to be its own type and have its own implementation as a struct. How do I achieve this?

CodePudding user response:

You cannot change the return type of a derived virtual method. This is why your compilation failed when you try to return your derived InferenceData from MonteCarlo::inference().

In order to achieve what you need, you need to use a polymorphic return type, which requires pointer/reference semantics. For this your derived InferenceData will have to inherit the base InferenceData, and inference() should return a pointer/reference to the base InferenceData.

One way to do it is with a smart pointer - e.g. a std::unique_ptr - see the code below:

#include <memory>

template <typename StateType>
class Model {
public:
    struct InferenceData {};
    virtual std::unique_ptr<InferenceData> inference() = 0;
};


template <typename StateType>
class MonteCarlo : public Model<StateType> {
public:
    struct InferenceDataSpecific : public Model<StateType>::InferenceData {};

    virtual std::unique_ptr<Model::InferenceData> inference() {
        return std::make_unique<InferenceDataSpecific>();
    }
};

int main()
{
    MonteCarlo<int> m;
    auto d = m.inference();
    return 0;
}

Note: if you need to share the data, you can use a std::shared_ptr.

CodePudding user response:

You have to make the return type part of the template arguments:

template <typename StateType, typename InferenceData>
class Model {
public:
    virtual InferenceData inference () = 0;
};

Then you can set the return type when you derive from it.

CodePudding user response:

You can actually have your MonteCarlo::inference return a pointer (or reference) to a MonteCarlo::InferenceData, as long as you do things correctly otherwise. A simple version looks like this:

#include <memory>
#include <iostream>

template <typename StateType>
class Model {
public:
    // base return type:
    struct InferenceData { };

    virtual InferenceData *inference() = 0;
};

template <typename StateType>
class MonteCarlo : public Model<StateType> {
public:
    // derived return type must be derived from base return type:
    struct InferenceData : public ::Model<StateType>::InferenceData { };

    InferenceData *inference() { return new InferenceData; }
};

int main() {
    MonteCarlo<int> mci;
    auto v = mci.inference();
}

This a a covariant return type (as the compiler alluded to in its error message). There are a couple of points to keep in mind here though:

  1. The return type really does have to be covariant. That is, the base class function has to be declared to return a pointer/reference to some base class, and the derived function has to return a pointer/reference to a type derived from that that the base function returns.
  2. A unique_ptr<Derived> allows implicit conversion to unique_ptr<Base>, assuming Derived is derived from Base, but a unique_ptr<Derived> still isn't actually derived from unique_ptr<Base>, so you can't use (most typical) smart pointers for covariant returns.

For the moment, I've used new to create the returned object. That's pretty common when dealing with derivation and such, but it can be avoided if necessary. Doing that can get non-trivial in itself, depending on your needs. In a really simple case, define a static object of the correct type, and return a pointer to it (but that leads to problems if you do multi-threading).

  • Related