Home > Software design >  Difference between template specialization and SFINAE with std::enable_if?
Difference between template specialization and SFINAE with std::enable_if?

Time:09-14

If I have a template function in C , and want it to behave in a different manner in presence of a specific template parameter, I will use a template specialization:

#include <iostream>
#include <type_traits>

template<typename T> void myFunction(T&& input) {
    std::cout << "The parameter is " << input << '\n';
}

template<> void myFunction<int>(int&& input) {
    std::cout << "Double of the parameter is " << 2 * input << '\n';
}

int main() {
    myFunction(24.2);
    myFunction(24);
    myFunction("Hello");
}

Use of std::enable_if<bool B> from <type_traits> enables to activate a similar behaviour for some specific template types. Any other type would cause a substitution error.

#include <iostream>
#include <type_traits>

template<typename T, std::enable_if_t<std::is_integral_v<T>>* = nullptr > 
void myFunction(T&& input) {
    std::cout << "Double of the parameter is " << 2 * input << '\n';
}

template<typename T, std::enable_if_t<std::is_floating_point_v<T>>* = nullptr > 
void myFunction(T&& input) {
    std::cout << "The parameter is " << input << '\n';
}

int main() {
    myFunction(24.2);
    myFunction(24);
    // myFunction("Hello"); // compile-time substitution error
}

My question is: what is really the gain of using the fancier std::enable_if specialization, in place of the standard template specialization? Are they really alternatives, or std::enable_if is exclusively for template metaprogramming?

CodePudding user response:

These 2 approaches are executed on different stages. First is overload resolution - on this stage template specialization is not used. Only generic template function prototype is in use.

On the other hand SFINAE approach drops the overloads that failed to be substituted silently on this stage leaving only one candidate based on enable_if conditions.

Second step is template instantiation - here the body of your function is created - and it will use body provided by most specialized template.

CodePudding user response:

With std::enable_if<>, you can write arbitrary conditions. An example is specializing based on the size of the object; other examples include whether a field / member function exists, whether the result of a static constexpr function is what's expected. Specialization is limited to the type matching.

CodePudding user response:

tl;dr

SFINAE also helps misuse of template entities result in more readable errors.

A bit longer

Not using SFINAE but just relying on specializations means, among other things, that you are leaving an unrestrained primary template available for overload resolution, which can result in very lengthly, and not easily readable errors, if it gets fed with types which the implementation doesn't make sense.

Indeed, the template will always be instantiated thanks to template type deduction, but then, if the type doesn't happen to be suitable for the implementation of that function. That will be discovered maybe deep in several nested function calls.

As noted in another answer, SFINAE allows one to discard some overloads (SFINAE them out, somebody says) before overload resolution actually takes place, hopefully leaving much less overloads standing, for the benefit of programmer who will read shorter errors.


Example

This is a not-too-far from real-life example of how template errors can be long, but made shorter thanks to SFINAE.

In the toy example below, I've defined a template function for inserting an element in a collection. A few details to better understand why I wrote it the way I did:

  • The insertions of an element in a std::unordered_set<T, whatever> requires that operator== is defined on two arguments of type T; as you can see from the commented friend function, I've intentionally not defined it for the type A, so as to make the code fail at compile time (if you uncomment that line the code compiles and runs smoothly).

    • You can disregard the silly hash function; it's there just to make the code compile in case you uncomment the friend definition.
  • I've defined an Equatable concept to check whether operator== is defined on two arguments of a given type T, i.e. if the type is suitable for the element type of an unordered_set; this was to have a predicate to apply to a type as the condition for an enable_if.

  • myInsert is the function object of the question, a template function which is unrestricted.

    • You can also imagine you've written a specialization for another class B on which operator== is defined, but the point is that among all other types you can throw at it, there's A, on which operator== is not defined, and this is causing the error.
#include <cmath>
#include <type_traits>
#include <unordered_set>

// concept to detect whether operator== is defined argument of a type T
template<typename T, typename = void>
struct Equatable : std::false_type {};
template<typename T>
struct Equatable<T, std::void_t<decltype(T{} == T{})>> : std::true_type {};
template<typename T>
constexpr bool equatable_v = Equatable<T>::value;

// The problematic class which has no operator== defined on it
struct A {
    void empty(int) {
    }
    //friend bool operator==(A const&, A const&) { return true; }
};

// Our template function not using SFINAE
template<typename Coll, typename Elem>
void myInsert(Coll& coll, Elem&& elem) {
    coll.insert(std::move(elem));
}

int main() {
    auto hash = [](A const&){ return 1; };
    std::unordered_set<A, decltype(hash)> s;
    myInsert(s, A{});
}

Try to compile the code above. You'll most likely get an error which is long with respect to compile-time errors of non-templated code (but believe me, it's short! I've actually included <cmath> for the purpose of triggering more instantiation attempts to make the error a bit longer). Here's what I get:

In file included from /usr/include/c  /12.2.0/unordered_set:44,
                 from example.cpp:3:
/usr/include/c  /12.2.0/bits/stl_function.h: In instantiation of ‘constexpr bool std::equal_to<_Tp>::operator()(const _Tp&, const _Tp&) const [with _Tp = A]’:
/usr/include/c  /12.2.0/bits/hashtable_policy.h:1701:18:   required from ‘bool std::__detail::_Hashtable_base<_Key, _Value, _ExtractKey, _Equal, _Hash, _RangeHash, _Unused, _Traits>::_M_key_equals_tr(const _Kt&, const std::__detail::_Hash_node_value<_Value, typename _Traits::__hash_cached::value>&) const [with _Kt = A; _Key = A; _Value = A; _ExtractKey = std::__detail::_Identity; _Equal = std::equal_to<A>; _Hash = main()::<lambda(const A&)>; _RangeHash = std::__detail::_Mod_range_hashing; _Unused = std::__detail::_Default_ranged_hash; _Traits = std::__detail::_Hashtable_traits<true, true, true>; typename _Traits::__hash_cached = std::__detail::_Hashtable_traits<true, true, true>::__hash_cached]’
/usr/include/c  /12.2.0/bits/hashtable.h:2237:32:   required from ‘std::pair<typename std::__detail::_Insert<_Key, _Value, _Alloc, _ExtractKey, _Equal, _Hash, _RangeHash, _Unused, _RehashPolicy, _Traits>::iterator, bool> std::_Hashtable<_Key, _Value, _Alloc, _ExtractKey, _Equal, _Hash, _RangeHash, _Unused, _RehashPolicy, _Traits>::_M_insert_unique(_Kt&&, _Arg&&, const _NodeGenerator&) [with _Kt = A; _Arg = A; _NodeGenerator = std::__detail::_AllocNode<std::allocator<std::__detail::_Hash_node<A, true> > >; _Key = A; _Value = A; _Alloc = std::allocator<A>; _ExtractKey = std::__detail::_Identity; _Equal = std::equal_to<A>; _Hash = main()::<lambda(const A&)>; _RangeHash = std::__detail::_Mod_range_hashing; _Unused = std::__detail::_Default_ranged_hash; _RehashPolicy = std::__detail::_Prime_rehash_policy; _Traits = std::__detail::_Hashtable_traits<true, true, true>; typename std::__detail::_Insert<_Key, _Value, _Alloc, _ExtractKey, _Equal, _Hash, _RangeHash, _Unused, _RehashPolicy, _Traits>::iterator = std::__detail::_Insert_base<A, A, std::allocator<A>, std::__detail::_Identity, std::equal_to<A>, main()::<lambda(const A&)>, std::__detail::_Mod_range_hashing, std::__detail::_Default_ranged_hash, std::__detail::_Prime_rehash_policy, std::__detail::_Hashtable_traits<true, true, true> >::iterator; typename _Traits::__constant_iterators = std::__detail::_Hashtable_traits<true, true, true>::__constant_iterators]’
/usr/include/c  /12.2.0/bits/hashtable.h:906:27:   required from ‘std::pair<typename std::__detail::_Insert<_Key, _Value, _Alloc, _ExtractKey, _Equal, _Hash, _RangeHash, _Unused, _RehashPolicy, _Traits>::iterator, bool> std::_Hashtable<_Key, _Value, _Alloc, _ExtractKey, _Equal, _Hash, _RangeHash, _Unused, _RehashPolicy, _Traits>::_M_insert(_Arg&&, const _NodeGenerator&, std::true_type) [with _Arg = A; _NodeGenerator = std::__detail::_AllocNode<std::allocator<std::__detail::_Hash_node<A, true> > >; _Key = A; _Value = A; _Alloc = std::allocator<A>; _ExtractKey = std::__detail::_Identity; _Equal = std::equal_to<A>; _Hash = main()::<lambda(const A&)>; _RangeHash = std::__detail::_Mod_range_hashing; _Unused = std::__detail::_Default_ranged_hash; _RehashPolicy = std::__detail::_Prime_rehash_policy; _Traits = std::__detail::_Hashtable_traits<true, true, true>; typename std::__detail::_Insert<_Key, _Value, _Alloc, _ExtractKey, _Equal, _Hash, _RangeHash, _Unused, _RehashPolicy, _Traits>::iterator = std::__detail::_Insert_base<A, A, std::allocator<A>, std::__detail::_Identity, std::equal_to<A>, main()::<lambda(const A&)>, std::__detail::_Mod_range_hashing, std::__detail::_Default_ranged_hash, std::__detail::_Prime_rehash_policy, std::__detail::_Hashtable_traits<true, true, true> >::iterator; typename _Traits::__constant_iterators = std::__detail::_Hashtable_traits<true, true, true>::__constant_iterators; std::true_type = std::integral_constant<bool, true>]’
/usr/include/c  /12.2.0/bits/hashtable_policy.h:1035:22:   required from ‘std::__detail::_Insert<_Key, _Value, _Alloc, _ExtractKey, _Equal, _Hash, _RangeHash, _Unused, _RehashPolicy, _Traits, true>::__ireturn_type std::__detail::_Insert<_Key, _Value, _Alloc, _ExtractKey, _Equal, _Hash, _RangeHash, _Unused, _RehashPolicy, _Traits, true>::insert(value_type&&) [with _Key = A; _Value = A; _Alloc = std::allocator<A>; _ExtractKey = std::__detail::_Identity; _Equal = std::equal_to<A>; _Hash = main()::<lambda(const A&)>; _RangeHash = std::__detail::_Mod_range_hashing; _Unused = std::__detail::_Default_ranged_hash; _RehashPolicy = std::__detail::_Prime_rehash_policy; _Traits = std::__detail::_Hashtable_traits<true, true, true>; __ireturn_type = std::__detail::_Insert<A, A, std::allocator<A>, std::__detail::_Identity, std::equal_to<A>, main()::<lambda(const A&)>, std::__detail::_Mod_range_hashing, std::__detail::_Default_ranged_hash, std::__detail::_Prime_rehash_policy, std::__detail::_Hashtable_traits<true, true, true>, true>::__ireturn_type; value_type = A]’
/usr/include/c  /12.2.0/bits/unordered_set.h:426:27:   required from ‘std::pair<typename std::_Hashtable<_Value, _Value, _Alloc, std::__detail::_Identity, _Pred, _Hash, std::__detail::_Mod_range_hashing, std::__detail::_Default_ranged_hash, std::__detail::_Prime_rehash_policy, std::__detail::_Hashtable_traits<std::__not_<std::__and_<std::__is_fast_hash<_Hash>, std::__is_nothrow_invocable<const _Hash&, const _Tp&> > >::value, true, true> >::iterator, bool> std::unordered_set<_Value, _Hash, _Pred, _Alloc>::insert(value_type&&) [with _Value = A; _Hash = main()::<lambda(const A&)>; _Pred = std::equal_to<A>; _Alloc = std::allocator<A>; typename std::_Hashtable<_Value, _Value, _Alloc, std::__detail::_Identity, _Pred, _Hash, std::__detail::_Mod_range_hashing, std::__detail::_Default_ranged_hash, std::__detail::_Prime_rehash_policy, std::__detail::_Hashtable_traits<std::__not_<std::__and_<std::__is_fast_hash<_Hash>, std::__is_nothrow_invocable<const _Hash&, const _Tp&> > >::value, true, true> >::iterator = std::__detail::_Insert_base<A, A, std::allocator<A>, std::__detail::_Identity, std::equal_to<A>, main()::<lambda(const A&)>, std::__detail::_Mod_range_hashing, std::__detail::_Default_ranged_hash, std::__detail::_Prime_rehash_policy, std::__detail::_Hashtable_traits<true, true, true> >::iterator; value_type = A]’
example.cpp:28:16:   required from ‘void myInsert(Coll&, Elem&&) [with Coll = std::unordered_set<A, main()::<lambda(const A&)> >; Elem = A]’
example.cpp:34:13:   required from here
/usr/include/c  /12.2.0/bits/stl_function.h:378:20: error: no match for ‘operator==’ (operand types are ‘const A’ and ‘const A’)
  378 |       { return __x == __y; }
      |                ~~~~^~~~~~
In file included from /usr/include/c  /12.2.0/bits/stl_algobase.h:67,
                 from /usr/include/c  /12.2.0/bits/specfun.h:45,
                 from /usr/include/c  /12.2.0/cmath:1935,
                 from example.cpp:1:
/usr/include/c  /12.2.0/bits/stl_iterator.h:530:5: note: candidate: ‘template<class _IteratorL, class _IteratorR> constexpr bool std::operator==(const reverse_iterator<_IteratorL>&, const reverse_iterator<_IteratorR>&) requires requires{{std::operator==::__x->base() == std::operator==::__y->base()} -> decltype(auto) [requires std::convertible_to<<placeholder>, bool>];}’ (reversed)
  530 |     operator==(const reverse_iterator<_IteratorL>& __x,
      |     ^~~~~~~~
/usr/include/c  /12.2.0/bits/stl_iterator.h:530:5: note:   template argument deduction/substitution failed:
/usr/include/c  /12.2.0/bits/stl_function.h:378:20: note:   ‘const A’ is not derived from ‘const std::reverse_iterator<_IteratorL>’
  378 |       { return __x == __y; }
      |                ~~~~^~~~~~
/usr/include/c  /12.2.0/bits/stl_iterator.h:1656:5: note: candidate: ‘template<class _IteratorL, class _IteratorR> constexpr bool std::operator==(const move_iterator<_IteratorL>&, const move_iterator<_IteratorR>&) requires requires{{std::operator==::__x->base() == std::operator==::__y->base()} -> decltype(auto) [requires std::convertible_to<<placeholder>, bool>];}’ (reversed)
 1656 |     operator==(const move_iterator<_IteratorL>& __x,
      |     ^~~~~~~~
/usr/include/c  /12.2.0/bits/stl_iterator.h:1656:5: note:   template argument deduction/substitution failed:
/usr/include/c  /12.2.0/bits/stl_function.h:378:20: note:   ‘const A’ is not derived from ‘const std::move_iterator<_IteratorL>’
  378 |       { return __x == __y; }
      |                ~~~~^~~~~~
In file included from /usr/include/c  /12.2.0/unordered_set:40:
/usr/include/c  /12.2.0/bits/allocator.h:219:5: note: candidate: ‘template<class _T1, class _T2> constexpr bool std::operator==(const allocator<_Up>&, const allocator<_T2>&)’ (reversed)
  219 |     operator==(const allocator<_T1>&, const allocator<_T2>&)
      |     ^~~~~~~~
/usr/include/c  /12.2.0/bits/allocator.h:219:5: note:   template argument deduction/substitution failed:
/usr/include/c  /12.2.0/bits/stl_function.h:378:20: note:   ‘const A’ is not derived from ‘const std::allocator<_Up>’
  378 |       { return __x == __y; }
      |                ~~~~^~~~~~
In file included from /usr/include/c  /12.2.0/bits/stl_algobase.h:64:
/usr/include/c  /12.2.0/bits/stl_pair.h:640:5: note: candidate: ‘template<class _T1, class _T2> constexpr bool std::operator==(const pair<_T1, _T2>&, const pair<_T1, _T2>&)’
  640 |     operator==(const pair<_T1, _T2>& __x, const pair<_T1, _T2>& __y)
      |     ^~~~~~~~
/usr/include/c  /12.2.0/bits/stl_pair.h:640:5: note:   template argument deduction/substitution failed:
/usr/include/c  /12.2.0/bits/stl_function.h:378:20: note:   ‘const A’ is not derived from ‘const std::pair<_T1, _T2>’
  378 |       { return __x == __y; }
      |                ~~~~^~~~~~
/usr/include/c  /12.2.0/bits/stl_iterator.h:589:5: note: candidate: ‘template<class _Iterator> constexpr bool std::operator==(const reverse_iterator<_IteratorL>&, const reverse_iterator<_IteratorL>&) requires requires{{std::operator==::__x->base() == std::operator==::__y->base()} -> decltype(auto) [requires std::convertible_to<<placeholder>, bool>];}’
  589 |     operator==(const reverse_iterator<_Iterator>& __x,
      |     ^~~~~~~~
/usr/include/c  /12.2.0/bits/stl_iterator.h:589:5: note:   template argument deduction/substitution failed:
/usr/include/c  /12.2.0/bits/stl_function.h:378:20: note:   ‘const A’ is not derived from ‘const std::reverse_iterator<_IteratorL>’
  378 |       { return __x == __y; }
      |                ~~~~^~~~~~
/usr/include/c  /12.2.0/bits/stl_iterator.h:1726:5: note: candidate: ‘template<class _Iterator> constexpr bool std::operator==(const move_iterator<_IteratorL>&, const move_iterator<_IteratorL>&)’
 1726 |     operator==(const move_iterator<_Iterator>& __x,
      |     ^~~~~~~~
/usr/include/c  /12.2.0/bits/stl_iterator.h:1726:5: note:   template argument deduction/substitution failed:
/usr/include/c  /12.2.0/bits/stl_function.h:378:20: note:   ‘const A’ is not derived from ‘const std::move_iterator<_IteratorL>’
  378 |       { return __x == __y; }
      |                ~~~~^~~~~~

I challenge you to read it and tell me it's easy to understand what's telling us.

What's the alternative? Using std::enable_if to enforce the condition that the types fed to myInsert have to satisfy. A simplified and incomplete attempt consists in changing the line

template<typename Coll, typename Elem>

to

template<typename Coll, typename Elem, std::enable_if_t<equatable_v<Elem>> = 0>

which results in the error to become much shorter:

example.cpp: In function ‘int main()’:
example.cpp:34:13: error: no matching function for call to ‘myInsert(std::unordered_set<A, main()::<lambda(const A&)> >&, A)’
   34 |     myInsert(s, A{});
      |     ~~~~~~~~^~~~~~~~
example.cpp:27:6: note: candidate: ‘template<class Coll, class Elem, typename std::enable_if<equatable_v<Elem>, int>::type <anonymous> > void myInsert(Coll&, Elem&&)’
   27 | void myInsert(Coll& coll, Elem&& elem) {
      |      ^~~~~~~~
example.cpp:27:6: note:   template argument deduction/substitution failed:
example.cpp:23:46: error: no type named ‘type’ in ‘struct std::enable_if<false, int>’
   23 | , std::enable_if_t<equatable_v<Elem>, int> = 0>
      |                                              ^
  • Related