Home > Enterprise >  Template specialization with macros
Template specialization with macros

Time:05-24

I'm looking at the following function, reproduced below

#define FBGEMM_SPECIALIZED_REQUANTIZE(T)                            \
  template <>                                                       \
  FBGEMM_API void Requantize<T>(                                    \
      const int32_t* src,                                           \
      T* dst,                                                       \
      const int64_t len,                                            \
      const RequantizationParams& params,                           \
      int thread_id,                                                \
      int num_threads) {                                            \
    int64_t i_begin, i_end;                                         \
    fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \
    for (int64_t i = i_begin; i < i_end;   i) {                     \
      dst[i] = Requantize<T>(src[i], params);                       \
    }                                                               \
  }
FBGEMM_SPECIALIZED_REQUANTIZE(uint16_t)
FBGEMM_SPECIALIZED_REQUANTIZE(int32_t)
#undef FBGEMM_SPECIALIZED_REQUANTIZE

It appears to be using a macro to specialize the functions.

I'm wondering what is the difference between doing that vs. no macros and just specializing everything like usual in C ?

CodePudding user response:

As mentioned in comments, macros are merely about text replacement (more precisely: tokens). Macros cannot do something that more typing cannot do as well. Instead of

FBGEMM_SPECIALIZED_REQUANTIZE(uint16_t)
FBGEMM_SPECIALIZED_REQUANTIZE(int32_t)

The author could have spelled out the two specializations without using any macro. However, this would lead to code duplication. This on the other hand can be avoided, as mentioned by fabian via:

template <typename T>                                                    
void Helper(const int32_t* src,                                           
  T* dst,                                                       
  const int64_t len,                                            
  const RequantizationParams& params,                           
  int thread_id,                                                
  int num_threads) {
   // code here
}

and then

template <>                                                       
FBGEMM_API void Requantize<uint16_t>(const int32_t* src,                                           
  uint16_t* dst,                                                       
  const int64_t len,                                            
  const RequantizationParams& params,                           
  int thread_id,                                                
  int num_threads) { Helper<uint16_t>(src,dst,len,params,thread_id,num_threads); }

And same specialization for int32_t. Note how already the argument list of the function leads to lots of repetition. Macros are usually avoided because they lead to obfuscation, code duplication is usually avoided because it leads to hard to maintain code. Its a trade off to be made.

The other alternative to specialize for two different types at once is to use sfinae, but that requires to modify the primary template which may not be desirable. Or concepts but they are only available since C 20.


Anyhow...

I'm wondering what is the difference between doing that vs. no macros and just specializing everything like usual in C ?

The amount of typing.

CodePudding user response:

First we should start from original template:

template <typename T>
FBGEMM_API void Requantize(
    const std::int32_t* src,
    T* dst,
    std::int64_t len,
    const RequantizationParams& params,
    int thread_id = 0,
    int num_threads = 1);

Note this template is only declared. I can't find implementation of it. Note it is in a header file.

There is also alternative specialization for uint8_t.

Now author planed to make it work only for two types: uint8_t, uint16_t and int32_t.

You can see this, since macro FBGEMM_SPECIALIZED_REQUANTIZE is defined then after those two usages it is immediately undefined.

So question how this behaves in current state?

  • if user of library uses this template with pointer to type which is not supported (is not one of types: uint8_t, uint16_t and int32_t), it will get linker error: "undefined reference to function tempalte .....".
  • if user of library uses this temple with a pointer of uint8_t one implementation will be used and other implementation will be used for uint16_t and int32_t.

Is this good? IMO not it is better to get compilation error instead linking error. Note that during development of new code, until this code is not in use (for example test written) build can pass, then when you add test or use new functionality then you got linker error.

Can this be done better? YES!

But it depends on requirements. One way to solve this is this way: In header file have overloaded old fashioned functions:

FBGEMM_API void Requantize(
    const std::int32_t* src,
    uint8_t* dst,
    std::int64_t len,
    const RequantizationParams& params,
    int thread_id = 0,
    int num_threads = 1);

FBGEMM_API void Requantize(
    const std::int32_t* src,
    uint16_t* dst,
    std::int64_t len,
    const RequantizationParams& params,
    int thread_id = 0,
    int num_threads = 1);

FBGEMM_API void Requantize(
    const std::int32_t* src,
    int32_t* dst,
    std::int64_t len,
    const RequantizationParams& params,
    int thread_id = 0,
    int num_threads = 1);

Then in cpp file define template and just use it inside of this functions:

// this one has own version
FBGEMM_API void Requantize(
    const std::int32_t* src,
    uint8_t* dst,
    std::int64_t len,
    const RequantizationParams& params,
    int thread_id = 0,
    int num_threads = 1) {
  int64_t i_begin, i_end;
  fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);
  if (params.target_qparams.precision == 8 && cpuinfo_initialize() &&
      fbgemmHasAvx2Support()) {
    RequantizeAvx2(&src[i_begin], &dst[i_begin], i_end - i_begin, params);
  } else {
    for (int64_t i = i_begin; i < i_end;   i) {
      dst[i] = Requantize<uint8_t>(src[i], params);
    }
  }
}

namespace detail {
  template <typename T>
  FBGEMM_API void Requantize<T>(
      const int32_t* src,
      T* dst,
      const int64_t len,
      const RequantizationParams& params,
      int thread_id,
      int num_threads) {
    int64_t i_begin, i_end;
    fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);
    for (int64_t i = i_begin; i < i_end;   i) {
      dst[i] = Requantize<T>(src[i], params);
    }
  }
} // namespace detail

FBGEMM_API void Requantize(
    const std::int32_t* src,
    uint16_t* dst,
    std::int64_t len,
    const RequantizationParams& params,
    int thread_id = 0,
    int num_threads = 1) {
    detail::Requantize(src, dst, len, params, thread_id, num_threads);
}

FBGEMM_API void Requantize(
    const std::int32_t* src,
    int32_t* dst,
    std::int64_t len,
    const RequantizationParams& params,
    int thread_id = 0,
    int num_threads = 1) {
    detail::Requantize(src, dst, len, params, thread_id, num_threads);
}

Now linker errors are replaced with compilation errors and macros are not needed.


Other way to do it is keep template declaration in header file and define it in cpp file providing specialization for uint8_t. So header file unchanged and in cpp:

  template <typename T>
  FBGEMM_API void Requantize<T>(
      const int32_t* src,
      T* dst,
      const int64_t len,
      const RequantizationParams& params,
      int thread_id,
      int num_threads) {
    int64_t i_begin, i_end;
    fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);
    for (int64_t i = i_begin; i < i_end;   i) {
      dst[i] = Requantize<T>(src[i], params);
    }
  }

template<>
FBGEMM_API void Requantize<uint8_t>(
    const std::int32_t* src,
    uint8_t* dst,
    std::int64_t len,
    const RequantizationParams& params,
    int thread_id = 0,
    int num_threads = 1) {
  int64_t i_begin, i_end;
  fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);
  if (params.target_qparams.precision == 8 && cpuinfo_initialize() &&
      fbgemmHasAvx2Support()) {
    RequantizeAvx2(&src[i_begin], &dst[i_begin], i_end - i_begin, params);
  } else {
    for (int64_t i = i_begin; i < i_end;   i) {
      dst[i] = Requantize<uint8_t>(src[i], params);
    }
  }
}

tempalte Requantize<uint16_t>;
tempalte Requantize<int32_t>;

but I do not know how this interacts with other overloads of Requantize.

  • Related