I want to create objects with polymorphic types based on a string I read from some config text file. A naive simple solution to this is to assign a string to each possible type and then compare the config string in an else-if chain to all the defined types. Something like:
class Base
{
virtual std::string GetStringType() = 0;
};
class Derived1 : public Base
{
std::string GetStringType() override { return "Derived1"; }
};
class Derived2 : public Base
{
std::string GetStringType() override { return "Derived2"; }
};
// etc ...
void main(int argc, char *argv[])
{
std::unique_ptr<Base> ptr;
auto derived1 = std::make_unique<Derived1>();
auto derived2 = std::make_unique<Derived2>();
// etc ...
std::string stringType(argv[1]);
if (stringType == derived1->GetStringType())
ptr = std::make_unique<decltype(derived1)>();
else if (stringType == derived2->GetStringType())
ptr = std::make_unique<decltype(derived2)>();
// etc ...
}
However, with this approach, each time a new derived class is added, a new else-if branch needs to be manually added, and I am trying to avoid that. Is there a better, more automatic approach to this?
Also, in an ideal scenario, when a new derived class is defined somewhere (just defined, not instantiated), I would like to check against it automatically also. Is this somehow possible? I'd be happy for any solution that works, macros included.
CodePudding user response:
A simple map-based factory can do the trick:
#include <map>
#include <functional>
#include <string>
#include <memory>
class Base
{
public:
virtual ~Base() = default;
static std::unique_ptr<Base> create(const std::string& name) {
return factories_.at(name)();
}
template<typename T>
static void registerDerived() {
static_assert(std::is_base_of_v<Base, T>);
factories_[T::GetStringType()] = std::make_unique<T>;
}
private:
static std::map<std::string, std::function<std::unique_ptr<Base>()>> factories_;
};
std::map<std::string, std::function<std::unique_ptr<Base>()>> Base::factories_;
class Derived1 : public Base
{
public:
static std::string GetStringType() { return "Derived1"; }
};
class Derived2 : public Base
{
public:
static std::string GetStringType() { return "Derived2"; }
};
int main(int argc, char *argv[]) {
Base::registerDerived<Derived1>();
Base::registerDerived<Derived2>();
// etc...
std::unique_ptr<Base> ptr = Base::create(argv[1]);
// ...
}
CodePudding user response:
This is an example based on unordered_map, with an explicit class factory. Decoupling the creation of the instances completely from the base class. (Separation of concerns)
#include <type_traits>
#include <string>
#include <functional>
#include <iostream>
#include <memory>
#include <unordered_map>
#include <stdexcept>
#include <sstream>
//-------------------------------------------------------------------------------------------------
class Base
{
public:
virtual void Hello() = 0;
};
class Class1 : public Base
{
public:
void Hello() override
{
std::cout << "Class1\n";
}
};
class Class2 : public Base
{
public:
void Hello() override
{
std::cout << "Class2\n";
}
};
//-------------------------------------------------------------------------------------------------
class ClassFactory final
{
public:
// For each class registered at a function to the map
// that will create a unique_ptr to a new instance of that class
template<typename class_t>
void Register(const std::string& string)
{
static_assert(std::is_base_of_v<Base, class_t>, "You can only register classes derived from Base");
// create_function is a lambda
auto create_function = [] { return std::make_unique<class_t>(); };
// add the function to the map, with the string as key
create_function_map.insert({ string,create_function });
}
std::unique_ptr<Base> Create(const std::string& string)
{
// if nothing is found for the string then throw an exception
if (create_function_map.find(string) == create_function_map.end())
{
std::ostringstream os;
os << "No class registered for string : " << string;
throw std::invalid_argument(os.str());
}
// otherwise call the function that will make an instance of the
// derived class
auto create_function = create_function_map.at(string);
return create_function();
}
static ClassFactory& Instance()
{
static ClassFactory instance;
return instance;
}
private:
ClassFactory() = default;
~ClassFactory() = default;
// use unordered_map it has lookup complexity of O(1)
// std::map has a lookup complexity of O(n log(n))
std::unordered_map<std::string, std::function<std::unique_ptr<Base>()>> create_function_map;
};
//-------------------------------------------------------------------------------------------------
int main()
{
// For each new class register a name, this is the only code
// you need to expand to add new classes
// factory is a singleton so you can reuse it throughout your code
// without having to pass it around.
auto& factory = ClassFactory::Instance();
factory.Register<Class1>("Class1");
factory.Register<Class2>("Class2");
// create an instance of an object of type Class1.
auto object1 = factory.Create("Class1");
object1->Hello();
// create an instance of an object of type Class2.
auto object2 = factory.Create("Class2");
object2->Hello();
return 0;
}