Home > Blockchain >  Can I define a virtual function (or similar) to return of vector of derived class members?
Can I define a virtual function (or similar) to return of vector of derived class members?

Time:11-08

I have a base class B with derived classes X, Y and Z (in fact, more than 20 derived classes). Each class has a tag() function that identifies which (derived) class it is. My program stores instances of the derived classes as pointers in a vector defined as vector<B*>. Each derived class may appear in this vector 0..n times.

I would like to have a function that looks through the vector for instances of a derived type and returns a new vector with the type of the derived class, eg

#include <vector>
using namespace std;

class B {
  public:
  // ...
  virtual int tag() {return 0xFF;};
};

class X : public B {
  // ...
  int tag() {return 1;};
  vector<X*> find_derived(vector<B*> base_vec) {
    vector<X*> derived_vec;
      for (auto p : base_vec) {
        if (p->tag() == tag()) {
          derived_vec.push_back((X*) p);
        }
      }
    return derived_vec;
  }
};

Obviously I don't want to have to define find_derived in each derived class but I don't see how to do this as a virtual function. Currently I am doing it using a macro but, since I am learning C , I woudl prefer a method that used language constructs rather than those in the pre-processor. Is there another way?

CodePudding user response:

One possibility:

template <typename D>
class FindDerivedMixin {
public:
  vector<D*> find_derived(const vector<B*>& base_vec) {
    int my_tag = static_cast<D*>(this)->tag();
    vector<D*> derived_vec;
    for (auto p : base_vec) {
      if (p->tag() == my_tag) derived_vec.push_back(static_cast<D*>(p));
    }
    return derived_vec;
  }
};

class X : public B, public FindDerivedMixin<X> {};

CodePudding user response:

Like the previous answer, what you need is some template programming. This is an example without mixin though:

#include <vector>
#include <iostream>
#include <type_traits>
#include <string>

//-----------------------------------------------------------------------------
// Base class

class Base
{
public:
    virtual ~Base() = default;

    // pure virtual method to be implemented by derived classes
    virtual void Hello() const = 0;

protected:
    // example of a constuctor with parameters
    // it is protected since no instances of Base
    // should be made by accident.
    explicit Base(const std::string& message) :
        m_message(message)
    {
    }

    // getter for private member variable
    const std::string& message() const
    {
        return m_message;
    }

private:
    std::string m_message;
};

//-----------------------------------------------------------------------------
// Class which contains a collection of derived classes of base

class Collection
{
public:
    Collection() = default;
    virtual ~Collection() = default;

    // Add derived classes to the collection.
    // Forward any arguments to the constructor of the derived class
    template<typename type_t, typename... args_t>
    void Add(args_t&&... args)
    {
        // compile time check if user adds a class that's derived from base.
        static_assert(std::is_base_of_v<Base, type_t>,"You must add a class derived from Base");

        // for polymorphism to work (casting) we need pointers to derived classes.
        // use unique pointers to ensure it is the collection that will be the owner of the
        // instances
        m_collection.push_back(std::make_unique<type_t>(std::forward<args_t>(args)...));
    }

    // Getter function to get derived objects of type_t
    template<typename type_t>
    std::vector<type_t*> get_objects() 
    {
        static_assert(std::is_base_of_v<Base, type_t>, "You must add a class derived from Base");
        
        // return non-owning pointers to the derived classes
        std::vector<type_t*> retval;
        
        // loop over all objects in the collection of type std::unique_ptr<Base>
        for (auto& ptr : m_collection)
        {
            // try to cast to a pointer to derived class of type_t
            type_t* derived_ptr = dynamic_cast<type_t*>(ptr.get());

            // if cast was succesful we have a pointer to the derived type
            if (derived_ptr != nullptr)
            {
                // add the non-owning pointer to the vector that's going to be returned
                retval.push_back(derived_ptr);
            }
        }
        return retval;
    }

private:
    std::vector<std::unique_ptr<Base>> m_collection;
};

//-----------------------------------------------------------------------------
// some derived classes for testing.

class Derived1 :
    public Base
{
public:
    explicit Derived1(const std::string& message) :
        Base(message)
    {
    }
    
    virtual ~Derived1() = default;


    void Hello() const override
    {
        std::cout << "Derived1 : " << message() << "\n";
    }
};

//-----------------------------------------------------------------------------

class Derived2 :
    public Base
{
public:
    explicit Derived2(const std::string& message) :
        Base(message)
    {
    }

    virtual ~Derived2() = default;

    void Hello() const override
    {
        std::cout << "Derived2 : " << message() << "\n";
    }
};

//-----------------------------------------------------------------------------

int main()
{
    Collection collection;
    collection.Add<Derived1>("Instance 1");
    collection.Add<Derived1>("Instance 2");
    collection.Add<Derived2>("Instance 1");
    collection.Add<Derived2>("Instance 2");
    collection.Add<Derived1>("Instance 3");
    
    // This is where template programming really helps 
    // the lines above where just to get the collection filled
    auto objects = collection.get_objects<Derived1>();

    for (auto& derived : objects)
    {
        derived->Hello();
    }

    return 0;
}
  • Related