Home > Back-end >  Using type trait to ensure a type cannot be derived from itself
Using type trait to ensure a type cannot be derived from itself

Time:05-02

I would like to statically check if a class is derived from a base class but not from itself. Below is an example of what I try to achieve. The code compiles unfortunately (never thought I would say this).

I was hoping for the second static assertion to kick in and see that I tried to derive a class from itself. I would appreciate if you guys could help me in understanding better what I am doing wrong. I tried to search online without success.

#include <type_traits>

struct Base {};

template <typename T>
struct Derived : T {
  static_assert(std::is_base_of<Base, T>::value, "Type must derive from Base");
  static_assert(!(std::is_base_of<Derived, T>::value),
                "Type must not derive from Derived");
};

int main(int argc, char** argv) {
  Derived<Base> d__base; // should be OK
  Derived<Derived<Base>> d_d_base; // should be KO

  return 0;
}

CodePudding user response:

Type must not derive from Derived

Derived is not a type in itself, it's a template which in std::is_base_of<Derived, T>::value gets resolved to the current specialization in the context it's in and it can never be T. If you have Derived<Derived<Base>> then T is Derived<Base> and the Derived without specified template parameters is Derived<Derived<Base>>, so, not the same as T.

You could add a type trait to check if T is Derived<something>:

template <template<class...> class F, class T>
struct is_from_template {
    static std::false_type test(...);

    template <class... U>
    static std::true_type test(const F<U...>&);

    static constexpr bool value = decltype(test(std::declval<T>()))::value;
};

Now, using that would prevent the type Derived<Derived<Base>>:

struct Base {};

template <typename T>
struct Derived : T {
    static_assert(std::is_base_of<Base, T>::value,
                  "Type must derive from Base");
    static_assert(!is_from_template<Derived, T>::value,
                  "Type must not derive from Derived<>");
};

int main() {
    Derived<Base> d_base;                // OK
    // Derived<Derived<Base>> d_d_base;  // error
}
  • Related