Home > Mobile >  C polymorphism: how to create derived class objects
C polymorphism: how to create derived class objects

Time:02-22

I have an abstract base class called BaseStrategy. It contains one pure virtual function calculateEfficiency(). There are two classes ConvolutionStrategy and MaxPoolStrategy which derive from this base class and implement their own specific version of calculateEfficiency().

Here is some code:

class BaseStrategy {
public:
    explicit BaseStrategy();
    virtual ~BaseStrategy() = default;

private:
    virtual double calculateEfficiency(mlir::Operation* op) = 0;
};

class ConvolutionStrategy : public BaseStrategy {
private:
    double calculateEfficiency(mlir::Operation* op)
    {
        //some formula for convolution 
        return 1;
    }
};

class MaxPoolStrategy : public BaseStrategy {
private:
    double calculateEfficiency(mlir::Operation* op)
    {
        //some formula for MaxPool
        return 1;
    }
};

Now I have another class called StrategyAssigner. It has method calculateAllLayerEfficiencies() whose purpose is to iterate over all layers in a network. Depending on the type of layer there is a switch statement and should call the correct calculateEfficiency() depending on the layer type.

class StrategyAssigner final {
public:
    explicit StrategyAssigner(){};

public:
    void calculateAllLayerEfficiencies() {
        // Logic to iterate over all layers in
        // a network
        switch (layerType) {
        case Convolution:
            // Call calculateEfficiency() for Convolution
            break;
        case MaxPool:
            // Call calculateEfficiency() for MaxPool
            break;
        }
    };
}

int main ()
{
    StrategyAssigner assigner;
    assigner.calculateAllLayerEfficiencies();
}

My question is, should I store references of objects Convolution and MaxPool in the class StrategyAssigner so that I can call the respective calculateEfficiency().

Or could you suggest a better way to call calculateEfficiency(). I don't really know how to create the objects (stupid as that sounds).

I can't make calculateEfficiency() static as I need them to be virtual so that each derived class can implemented its own formula.

CodePudding user response:

If you included complete code I could give a more detailed answer, but you need to store BaseStrategy pointers that are initialized with derived class instances. Here's an example made from some of your code:

std::vector<std::unique_ptr<BaseStrategy>> strategies;
strategies.emplace_back(new ConvolutionStrategy);
strategies.emplace_back(new MaxPoolStrategy);

for (int i = 0; i < strategies.size();   i) {
    std::unique_ptr<BaseStrategy>& pStrat = strategies[i];
    pStrat->calculateEfficiency(...);
}

Note that this won't compile because I don't have enough details from the code you posted to make it so, but this shows how to exploit polymorphism in the way that you need.

Also, I used smart pointers for memory management; use these at your discretion.

CodePudding user response:

You can indeed use runtime polymorphism here:

  • Declare ~BaseStrategy virtual (you are already doing it ;-)
  • If you are never going to instantiate a BaseStrategy, declare one of its methods as virtual pure, e.g. calculateEfficiency (you are already doing it as well!). I would make that method const, since it doesn't look it's going to modify the instance. And it will need to be public, because it will need to be accessed from StrategyAnalyser.
  • Declare calculateEfficiency as virtual and override in each of the subclasses. It could also be final if you don't want subclasses to override it.
  • I'd keep a std::vector of smart pointers to BaseStrategy at StrategyAssigner. You can use unique_ptrs if you think this class is not going to be sharing those pointers.
  • The key point now is that you create heap instances of the subclasses and assign them to a pointer of the base class.
class StrategyAssigner final {
public:
    void addStrategy(std::unique_ptr<BaseStrategy> s) {
        strategies_.push_back(std::move(s));
    }
private:
    std::vector<std::unique_ptr<BaseStrategy>> strategies_{};
};

int main()
{
    StrategyAssigner assigner;
    assigner.addStrategy(std::make_unique<ConvolutionStrategy>());
}
  • Then, when you call calculateEfficiency using any of those pointers to BaseStrategy, the runtime polymorphism will kick in and it will be the method for the subclass the one that will be actually called.
class ConvolutionStrategy : public BaseStrategy {
private:
    virtual double calculateEfficiency() const override {
        std::cout << "ConvolutionStrategy::calculateEfficiency()\n";
        return 10;
    }
};

class MaxPoolStrategy : public BaseStrategy {
private:
    virtual double calculateEfficiency() const override {
        std::cout << "MaxPoolStrategy::calculateEfficiency()\n";
        return 20;
    }
};

class StrategyAssigner final {
public:
    void calculateAllLayerEfficiencies() {        
        auto sum = std::accumulate(std::cbegin(strategies_), std::cend(strategies_), 0,
            [](auto total, const auto& strategy_up) {
                return total   strategy_up->calculateEfficiency(); });
        std::cout << "Sum of all efficiencies: " << sum << "\n";
    };
};

int main()
{
    StrategyAssigner assigner;
    assigner.addStrategy(std::make_unique<ConvolutionStrategy>());
    assigner.addStrategy(std::make_unique<MaxPoolStrategy>());
    assigner.calculateAllLayerEfficiencies(); 
}

// Outputs:
//
//   ConvolutionStrategy::calculateEfficiency()
//   MaxPoolStrategy::calculateEfficiency()
//   Sum of all efficiencies: 30

[Demo]

  • Related