I have an abstract base class called BaseStrategy
. It contains one pure virtual function calculateEfficiency()
. There are two classes ConvolutionStrategy
and MaxPoolStrategy
which derive from this base class and implement their own specific version of calculateEfficiency()
.
Here is some code:
class BaseStrategy {
public:
explicit BaseStrategy();
virtual ~BaseStrategy() = default;
private:
virtual double calculateEfficiency(mlir::Operation* op) = 0;
};
class ConvolutionStrategy : public BaseStrategy {
private:
double calculateEfficiency(mlir::Operation* op)
{
//some formula for convolution
return 1;
}
};
class MaxPoolStrategy : public BaseStrategy {
private:
double calculateEfficiency(mlir::Operation* op)
{
//some formula for MaxPool
return 1;
}
};
Now I have another class called StrategyAssigner
. It has method calculateAllLayerEfficiencies()
whose purpose is to iterate over all layers in a network. Depending on the type of layer there is a switch statement and should call the correct calculateEfficiency()
depending on the layer type.
class StrategyAssigner final {
public:
explicit StrategyAssigner(){};
public:
void calculateAllLayerEfficiencies() {
// Logic to iterate over all layers in
// a network
switch (layerType) {
case Convolution:
// Call calculateEfficiency() for Convolution
break;
case MaxPool:
// Call calculateEfficiency() for MaxPool
break;
}
};
}
int main ()
{
StrategyAssigner assigner;
assigner.calculateAllLayerEfficiencies();
}
My question is, should I store references of objects Convolution and MaxPool in the class StrategyAssigner
so that I can call the respective calculateEfficiency()
.
Or could you suggest a better way to call calculateEfficiency()
. I don't really know how to create the objects (stupid as that sounds).
I can't make calculateEfficiency()
static as I need them to be virtual so that each derived class can implemented its own formula.
CodePudding user response:
If you included complete code I could give a more detailed answer, but you need to store BaseStrategy pointers that are initialized with derived class instances. Here's an example made from some of your code:
std::vector<std::unique_ptr<BaseStrategy>> strategies;
strategies.emplace_back(new ConvolutionStrategy);
strategies.emplace_back(new MaxPoolStrategy);
for (int i = 0; i < strategies.size(); i) {
std::unique_ptr<BaseStrategy>& pStrat = strategies[i];
pStrat->calculateEfficiency(...);
}
Note that this won't compile because I don't have enough details from the code you posted to make it so, but this shows how to exploit polymorphism in the way that you need.
Also, I used smart pointers for memory management; use these at your discretion.
CodePudding user response:
You can indeed use runtime polymorphism here:
- Declare
~BaseStrategy
virtual (you are already doing it ;-) - If you are never going to instantiate a
BaseStrategy
, declare one of its methods as virtual pure, e.g.calculateEfficiency
(you are already doing it as well!). I would make that methodconst
, since it doesn't look it's going to modify the instance. And it will need to be public, because it will need to be accessed fromStrategyAnalyser
. - Declare
calculateEfficiency
asvirtual
andoverride
in each of the subclasses. It could also befinal
if you don't want subclasses to override it. - I'd keep a
std::vector
of smart pointers toBaseStrategy
atStrategyAssigner
. You can useunique_ptr
s if you think this class is not going to be sharing those pointers. - The key point now is that you create heap instances of the subclasses and assign them to a pointer of the base class.
class StrategyAssigner final {
public:
void addStrategy(std::unique_ptr<BaseStrategy> s) {
strategies_.push_back(std::move(s));
}
private:
std::vector<std::unique_ptr<BaseStrategy>> strategies_{};
};
int main()
{
StrategyAssigner assigner;
assigner.addStrategy(std::make_unique<ConvolutionStrategy>());
}
- Then, when you call
calculateEfficiency
using any of those pointers toBaseStrategy
, the runtime polymorphism will kick in and it will be the method for the subclass the one that will be actually called.
class ConvolutionStrategy : public BaseStrategy {
private:
virtual double calculateEfficiency() const override {
std::cout << "ConvolutionStrategy::calculateEfficiency()\n";
return 10;
}
};
class MaxPoolStrategy : public BaseStrategy {
private:
virtual double calculateEfficiency() const override {
std::cout << "MaxPoolStrategy::calculateEfficiency()\n";
return 20;
}
};
class StrategyAssigner final {
public:
void calculateAllLayerEfficiencies() {
auto sum = std::accumulate(std::cbegin(strategies_), std::cend(strategies_), 0,
[](auto total, const auto& strategy_up) {
return total strategy_up->calculateEfficiency(); });
std::cout << "Sum of all efficiencies: " << sum << "\n";
};
};
int main()
{
StrategyAssigner assigner;
assigner.addStrategy(std::make_unique<ConvolutionStrategy>());
assigner.addStrategy(std::make_unique<MaxPoolStrategy>());
assigner.calculateAllLayerEfficiencies();
}
// Outputs:
//
// ConvolutionStrategy::calculateEfficiency()
// MaxPoolStrategy::calculateEfficiency()
// Sum of all efficiencies: 30