Home > Software engineering >  How should I implement a generic FMA/FMAF instruction in software?
How should I implement a generic FMA/FMAF instruction in software?

Time:09-27

FMA is a fused multiply-add instruction. The fmaf (float x, float y, float z) function in glibc calls the vfmadd213ss instruction. I want to know how this instruction is implemented. According to my understanding:

  1. add the exponents of x and y .
  2. multiplied the mantissas of x and y.
  3. normalized the result of x * y , but not rounded.
  4. compare the exponent of z and move the mantissa of the smaller exponent
  5. add the mantissa and the result is normalized again
  6. rounded(rn).

CodePudding user response:

The current x86-64 architecture implements so-called FMA3 variants: Since a fused-multiply add operation requires three source operands, if the instruction implementing it has only three operands in total, one must specify which of the source operands is also the destination: vfmadd123ss, vfmadd213ss, vfmadd231ss.

In terms of mathmetical functionality these instructions are all equivalent, and compute a*b c with a single rounding, where a, b, and c are IEEE-754 binary32 (single-precision) operands, which in programming languages in the C and C families are typically mapped to float.

The top-level algorithmic outline provided in the question is correct. The code below demonstrates how one can implement all necessary details under the restrictions that IEEE-754 floating-point exceptions are turned off (i.e. the code provides the masked response prescribed by the IEEE-754 standard), and that subnormal support is turned on (many platforms, including x86-64, also support non-standard "flush-to-zero" and "denormals-are-zero" modes). When there is more than one NaN source operand provided to FMA, any one of them, or a canonical NaN, forms the basis of the result. In the code below I simply matched the behavior of the Xeon W2133 CPU in my workstation; adjustments may be needed for other processors.

The code below is a garden-variety FMA emulation, coded for reasonable performance and reasonable clarity. If a platform offers a CLZ (count leading zeros) or related instruction, it would be helpful to interface that through intrinsics. Correct rounding is very much dependent on the proper tracking of round and sticky bits. In hardware, these are typically two actual bits, but for a software emulation it is often useful to use an entire unsigned integer (rndstk in the code below), the most significant of which represents the round bit, and all remaining lower-order bits collective (ie. ORed together) represent the sticky bit.

For practical (faster) fmaf() emulation one typically relies on performing intermediate computation in IEEE-754 binary64 (double precision). This gives rise to tricky double-rounding issues, and not all implementations in common open-source libraries work correctly. The best way known from the literature is to use a special rounding mode, rounding to odd, in intermediate computation. See:

Sylvie Boldo and Guillaume Melquiond, "Emulation of a FMA and correctly rounded sums: Proved algorithms using rounding to odd," IEEE Transactions on Computers, Vol. 57, No. 4, February 2008, pp. 462-471.

Rigorously testing an FMA implementation, whether in hardware or software, is a hard problem. Because the search space is huge, simply using massive amounts (say, tens of billions) of random test vectors is going to provide only a "smoke" test, useful for demonstrating that the implementation is not hopelessly broken. Below I am adding pattern-based tests that are capable of exercising a number of corner cases. Nonetheless the code below should be considered only lightly tested.

If you plan to use the FMA emulation in any kind of professional capacity, I highly recommend investing serious time into ensuring functional correctness; there are frankly too many broken FMA emulations out there already. For industrial-strength implementations, hardware vendors employ mechanically-checked mathematical proofs of correct operation. This book by a practitioner with extensive experience provides a good overview of how that works in practice:

David M. Russinoff, Formal Verification of Floating-Point Design: A Mathematical Approach, Springer 2019.

There are also professional test suites for exercising numerous corner cases, for example the FPgen floating-point test generator from the IBM Research Labs in Haifa. They used to make a collection of single-precision test vectors freely available on their website, but I cannot seem to find them anymore.

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <limits.h>
#include <math.h>

#define PURELY_RANDOM  (0)
#define PATTERN_BASED  (1)
#define TEST_MODE      (PURELY_RANDOM)
#define ROUND_MODE     (roundNearest) 

/* vvvvvvvvvvvvvvvvvvvvvvvvvvv x86-64 specific vvvvvvvvvvvvvvvvvvvvvvvvvvvv */
#include "immintrin.h"

#define roundMinInf  (_MM_ROUND_DOWN)
#define roundPosInf  (_MM_ROUND_UP)
#define roundZero    (_MM_ROUND_TOWARD_ZERO)
#define roundNearest (_MM_ROUND_NEAREST)

#define ftzOff       (_MM_FLUSH_ZERO_OFF)
#define dazOff       (_MM_DENORMALS_ZERO_OFF)

void set_subnormal_support (uint32_t ftz, uint32_t daz)
{
    _MM_SET_DENORMALS_ZERO_MODE (ftz);
    _MM_SET_FLUSH_ZERO_MODE (daz);
}

float ref_fmaf (float a, float b, float c, uint32_t rnd)
{
    __m128 r, s, t, u;
    float res;
    uint32_t old_mxcsr;
    old_mxcsr = _mm_getcsr();
    _MM_SET_ROUNDING_MODE (rnd);
    s = _mm_set_ss (a);
    t = _mm_set_ss (b);
    u = _mm_set_ss (c);
    r = _mm_fmadd_ss (s, t, u);
    _mm_store_ss (&res, r);
    _mm_setcsr (old_mxcsr);
    return res;
}
/* ^^^^^^^^^^^^^^^^^^^^^^^^^^^ x86-64 specific ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ */

/* re-interpret bits of IEEE-754 'binary32' as unsigned 32-bit integer */
uint32_t float_as_uint32 (float a)
{
    uint32_t r;
    memcpy (&r, &a, sizeof r);
    return r;
}
/* re-interpret bits of unsigned 32-bit integer as IEEE-754 'binary32' */
float uint32_as_float (uint32_t a)
{
    float r;
    memcpy (&r, &a, sizeof r);
    return r;
}

/* 32-bit leading zero count. Use platform-specific intrinsic if available */
int clz32 (uint32_t a)
{
    int n = 32;
    if (a >= 0x00010000u) { a >>= 16;  n -= 16; }
    if (a >= 0x00000100u) { a >>=  8;  n -=  8; }
    if (a >= 0x00000010u) { a >>=  4;  n -=  4; }
    if (a >= 0x00000004u) { a >>=  2;  n -=  2; }
    n -= a & ~(a >> 1);
    return n;
}

/* 64-bit leading zero count. Use platform-specific intrinsic if available */
int clz64 (uint64_t a)
{
    uint32_t hi = (uint32_t)(a >> 32);
    uint32_t lo = (uint32_t)(a & 0xffffffff);
    return hi ? clz32 (hi) : (32   clz32 (lo));
}

/* full product of two 32-bit unsigned integers. May use platform intrinsic */
uint64_t mul_u32_wide (uint32_t a, uint32_t b)
{
    return (uint64_t)a * b;
}

uint32_t fmaf_kernel (uint32_t x, uint32_t y, uint32_t z, int mode)     
 {     
     const uint32_t FP32_SIGN_BIT = 0x80000000;
     const uint32_t FP32_POS_ZERO = 0x00000000;
     const uint32_t FP32_NEG_ZERO = 0x80000000;
     const uint32_t FP32_QNAN_BIT = 0x00400000;
     const uint32_t FP32_INT_BIT  = 0x00800000;
     const uint32_t FP32_EXPO_MASK = 0x7f800000;
     const uint32_t FP32_POS_INFINITY = 0x7f800000;
     const uint32_t FP32_NEG_INFINITY = 0xff800000;
     const uint32_t FP32_POS_MAX_NORMAL = 0x7f7fffff;
     const uint32_t FP32_NEG_MAX_NORMAL = 0xff7fffff;
     const uint32_t FP32_QNAN_INDEFINITE = 0xffc00000;
     const uint32_t FP32_EXPO_BIAS = 127;
     const uint32_t FP32_STORED_MANT_BITS = 23;
     const uint32_t FP32_EXPO_BITS = 8;
     const uint32_t FP32_MAX_NORM_EXPO_M1 = 254 - 1;
     uint64_t mant_p, templl;
     uint32_t mant_x, mant_y, mant_z, mant_r;
     uint32_t expo_x, expo_y, expo_z, expo_r, expo_p;
     uint32_t sign_z, sign_p, sign_r;
     uint32_t r, shift, lz, rndstk, z_zer, temp;
      
     expo_x = ((x & FP32_EXPO_MASK) >> FP32_STORED_MANT_BITS) - 1;
     expo_y = ((y & FP32_EXPO_MASK) >> FP32_STORED_MANT_BITS) - 1;
     expo_z = ((z & FP32_EXPO_MASK) >> FP32_STORED_MANT_BITS) - 1;
     z_zer = (z << 1) == 0x00000000;
      
     if (!((expo_x <= FP32_MAX_NORM_EXPO_M1) &&
           (expo_y <= FP32_MAX_NORM_EXPO_M1) &&
           (expo_z <= FP32_MAX_NORM_EXPO_M1))) {
         uint32_t x_nan = (x << 1) >  0xff000000;
         uint32_t y_nan = (y << 1) >  0xff000000;
         uint32_t z_nan = (z << 1) >  0xff000000;
         uint32_t x_inf = (x << 1) == 0xff000000;
         uint32_t y_inf = (y << 1) == 0xff000000;
         uint32_t z_inf = (z << 1) == 0xff000000;
         uint32_t x_zer = (x << 1) == 0x00000000;
         uint32_t y_zer = (y << 1) == 0x00000000;
         
         /* pass-through quietened NaN arguments */
         if (y_nan) {
             return y | FP32_QNAN_BIT;
         }
         if (x_nan) {
             return x | FP32_QNAN_BIT;
         }
         if (z_nan) {
             return z | FP32_QNAN_BIT;
         }
         /* invalid operations, bsed on zeros and infinities */
         if (((x_zer && y_inf) || (y_zer && x_inf)) ||
             (z_inf && (x_inf || y_inf) && ((int32_t)(x ^ y ^ z) < 0))) {
             return FP32_QNAN_INDEFINITE;
         }
         /* infinity results */
         if (x_inf) {
             return x ^ (y & FP32_SIGN_BIT);
         }
         if (y_inf) {
             return y ^ (x & FP32_SIGN_BIT);
         }
         if (z_inf) {
             return z;
         }
         /* results of negative zero */
         if ((z == FP32_NEG_ZERO) &&
             (x_zer || y_zer) && ((int32_t)(x ^ y) < 0)) {
             return z;
         }
         /* zero results */
         if (z_zer && (x_zer || y_zer)) {
             return ((mode == roundMinInf) ?
                     ((x ^ y ^ z) & FP32_SIGN_BIT) : (z & ~FP32_SIGN_BIT));
         }
         /* product x*y is zero: pass-through z */
         if (x_zer || y_zer) {
             return z;
         }
         /* normalize x if subnormal */
         if (expo_x == (uint32_t)-1) {    
             temp = x << FP32_EXPO_BITS;
             lz = clz32 (temp);
             temp = temp << lz;
             expo_x = expo_x - lz   1;
             x = (temp >> FP32_EXPO_BITS) | (x & FP32_SIGN_BIT);
         }
         /* normalize y if subnormal */
         if (expo_y == (uint32_t)-1) {
             temp = y << FP32_EXPO_BITS;
             lz = clz32 (temp);
             temp = temp << lz;
             expo_y = expo_y - lz   1;
             y = (temp >> FP32_EXPO_BITS) | (y & FP32_SIGN_BIT);
         }
         /* normalize z if subnormal */
         if ((expo_z == (uint32_t)-1) && (!z_zer)) {
             temp = z << FP32_EXPO_BITS;
             lz = clz32 (temp);
             temp = temp << lz;
             expo_z = expo_z - lz   1;
             z = (temp >> FP32_EXPO_BITS) | (z & FP32_SIGN_BIT);
         }
     }

     /* multiply x * y */
     expo_p = expo_x   expo_y - FP32_EXPO_BIAS   2;
     sign_p = (x ^ y) & FP32_SIGN_BIT;
     mant_x = (x & 0x00ffffff) | FP32_INT_BIT;
     mant_y = (y << 8) | (FP32_INT_BIT << 8);
     mant_p = mul_u32_wide (mant_x, mant_y);
     
     /* normalize product x*y */
     if (!(mant_p & ((uint64_t)FP32_INT_BIT << 32))) {
         mant_p = mant_p << 1;
         expo_p--;
     }

     /* add z to produxt x*y */
     if (z_zer) {
         expo_r = expo_p;
         sign_r = sign_p;
         mant_r = (uint32_t)(mant_p >> 32);
         rndstk = (uint32_t)(mant_p);
     } else {
         sign_z = z & FP32_SIGN_BIT;
         mant_z = (z & 0x00ffffff) | FP32_INT_BIT;
         uint64_t large, small, mant_z_ext = (uint64_t)mant_z << 32;
         /* sort summands by magnitude of significands */
         if (((int)expo_p > (int)expo_z) ||
             ((expo_p == expo_z) && (mant_p > mant_z_ext))) {
             expo_r = expo_p;
             sign_r = sign_p;
             large = mant_p;
             small = mant_z_ext;
             shift = expo_p - expo_z;
         } else {
             expo_r = expo_z;
             sign_r = sign_z;
             large = mant_z_ext;
             small = mant_p;
             shift = expo_z - expo_p;
         }
         /* denormalize small */
         if (shift == 0) {
             rndstk = 0;
         } else if (shift > 63) {
             rndstk = 1; // only sticky
             small = 0;
         } else {
             templl = small << (64 - shift);
             rndstk = (uint32_t)(templl >> 32) | (((uint32_t)templl) ? 1 : 0);
             small = small >> shift;
         }
         /* add or subtract significants */
         if (sign_p != sign_z) {
             large = large - small - (rndstk ? 1 : 0);
             /* complete cancelation: return 0 */
             if (large == 0) {
                 return (mode == roundMinInf) ? FP32_NEG_ZERO : FP32_POS_ZERO;
             }
             /* normalize mantissa if necessary */
             if (!(large & ((uint64_t)FP32_INT_BIT << 32))) {
                 lz = clz64 (large);
                 shift = lz - 8;
                 large = large << shift;
                 expo_r = expo_r - shift;
             }
         } else {
             large = large   small;
             /* normalize mantissa if necessary */
             if (large & 0x0100000000000000ULL) {
                 templl = large << 63;
                 rndstk = (uint32_t)(templl >> 32) | (rndstk ? 1 : 0);
                 large = large >> 1;
                 expo_r  ;
             }
         }
         mant_r = (uint32_t)(large >> 32);
         rndstk = (uint32_t)(large) | (rndstk ? 1 : 0);
     }

     /* round result */
     if (expo_r <= FP32_MAX_NORM_EXPO_M1) { // normal
         if (mode == roundNearest) {
             mant_r  = (rndstk == 0x80000000) ? (mant_r & 1) : (rndstk >> 31);
         } else if (mode == roundPosInf) {
             mant_r  = rndstk && !sign_r;
         } else if (mode == roundMinInf) {
             mant_r  = rndstk && sign_r;
         } else { // mode == roundZero
         }
         r = sign_r   mant_r   (expo_r << 23);
         return r;
     } else if ((int32_t)expo_r >= 0) { // overflow: largest normal or infinity
         if (mode == roundNearest) {
             r = sign_r | FP32_POS_INFINITY;
         } else if (mode == roundZero) {
             r = sign_r | FP32_POS_MAX_NORMAL;
         } else if (mode == roundPosInf) {
             r = sign_r ? FP32_NEG_MAX_NORMAL : FP32_POS_INFINITY;
         } else { // (mode == roundMinInf)
             r = sign_r ? FP32_NEG_INFINITY : FP32_POS_MAX_NORMAL;
         }
         return r;
     } else { /* underflow: smallest normal, subnormal, or zero */
         shift = 0 - expo_r;
         rndstk = (shift > 25) ? 1 : ((mant_r << (32 - shift)) | (rndstk ? 1 : 0));
         mant_r = (shift > 25) ? 0 : (mant_r >> shift);
         if (mode == roundNearest) {
             mant_r  = ((rndstk == 0x80000000) ? (mant_r & 1) : (rndstk >> 31));
         } else if (mode == roundPosInf) {
             mant_r  = rndstk && !sign_r;
         } else if (mode == roundMinInf) {
             mant_r  = rndstk && sign_r;
         } else { // mode == roundZero
         }
         r = sign_r   mant_r;
     }

     return r;
}     

float my_fmaf (float a, float b, float c, uint32_t rnd)
{
    return uint32_as_float (fmaf_kernel (float_as_uint32 (a),
                                         float_as_uint32 (b),
                                         float_as_uint32 (c),
                                         rnd));
}   

uint32_t v[8192];
// George Marsaglia's KISS PRNG, period 2**123. Newsgroup sci.math, 21 Jan 1999
// Bug fix: Greg Rose, "KISS: A Bit Too Simple" http://eprint.iacr.org/2011/007
static uint32_t kiss_z=362436069, kiss_w=521288629;
static uint32_t kiss_jsr=123456789, kiss_jcong=380116160;
#define znew (kiss_z=36969*(kiss_z&65535) (kiss_z>>16))
#define wnew (kiss_w=18000*(kiss_w&65535) (kiss_w>>16))
#define MWC  ((znew<<16) wnew )
#define SHR3 (kiss_jsr^=(kiss_jsr<<13),kiss_jsr^=(kiss_jsr>>17), \
              kiss_jsr^=(kiss_jsr<<5))
#define CONG (kiss_jcong=69069*kiss_jcong 1234567)
#define KISS ((MWC^CONG) SHR3)

int main (void)
{
    const uint32_t rnd = ROUND_MODE;
    unsigned long long count = 0;
    float a, b, c, res, ref;
    const uint32_t nbrBits = sizeof (uint32_t) * CHAR_BIT;
    uint32_t i, j, patterns, idx = 0;
    uint32_t ai, bi, ci, resi, refi;

    /* pattern class 1: 2**i */
    for (i = 0; i < nbrBits; i  ) {
        v [idx] = ((uint32_t)1 << i);
        idx  ;
    }
    /* pattern class 2: 2**i-1 */
    for (i = 0; i < nbrBits; i  ) {
        v [idx] = (((uint32_t)1 << i) - 1);
        idx  ;
    }
    /* pattern class 3: 2**i 1 */
    for (i = 0; i < nbrBits; i  ) {
        v [idx] = (((uint32_t)1 << i)   1);
        idx  ;
    }
    /* pattern class 4: 2**i   2**j */
    for (i = 0; i < nbrBits; i  ) {
        for (j = 0; j < nbrBits; j  ) {
            v [idx] = (((uint32_t)1 << i)   ((uint32_t)1 << j));
            idx  ;
        }
    }
    /* pattern class 5: 2**i - 2**j */
    for (i = 0; i < nbrBits; i  ) {
        for (j = 0; j < nbrBits; j  ) {
            v [idx] = (((uint32_t)1 << i) - ((uint32_t)1 << j));
            idx  ;
        }
    }
    /* pattern class 6: MAX_UINT/(2**i 1) rep. blocks of i zeros an i ones */
    for (i = 0; i < nbrBits; i  ) {
        v [idx] = ((~(uint32_t)0) / (((uint32_t)1 << i)   1));
        idx  ;
    }
    patterns = idx;
    /* pattern class 6: one's complement of pattern classes 1 through 5 */
    for (i = 0; i < patterns; i  ) {
        v [idx] = ~v [i];
        idx  ;
    }
    /* pattern class 7: two's complement of pattern classes 1 through 5 */
    for (i = 0; i < patterns; i  ) {
        v [idx] = ~v [i]   1;
        idx  ;
    }
    patterns = idx;

    printf ("testing single-precision FMA\n");
    printf ("rounding mode: ");
    if (rnd == roundZero) {
        printf ("toward zero (truncate)\n");
    } else if (rnd == roundNearest) {
        printf ("round to nearest, ties to even\n");
    } else if (rnd == roundPosInf) {
        printf ("round up (toward positive infinity)\n");
    } else if (rnd == roundMinInf) {
        printf ("round down (toward negative infinity)\n");
    } else {
        printf ("unsupported\n");
        return EXIT_FAILURE;
    }

#if TEST_MODE == PURELY_RANDOM
    printf ("using purely random test vectors\n");
#elif TEST_MODE == PATTERN_BASED
    printf ("using pattern-based test vectors\n");
    printf ("#patterns = %u\n", patterns);
#endif // TEST_MODE

    /* make sure subnormal support is turned on */
    set_subnormal_support (ftzOff, dazOff);

    do {
#if TEST_MODE == PURELY_RANDOM
        ai = KISS;
        bi = KISS;
        ci = KISS;
#elif TEST_MODE == PATTERN_BASED
        ai = KISS;
        bi = KISS;
        ci = KISS;
        ai = ((v[ai%patterns] & 0x7fffff) | (KISS & ~0x7fffff));
        bi = ((v[bi%patterns] & 0x7fffff) | (KISS & ~0x7fffff));
        ci = ((v[ci%patterns] & 0x7fffff) | (KISS & ~0x7fffff));
#endif // TEST_MODE
        a = uint32_as_float (ai);
        b = uint32_as_float (bi);
        c = uint32_as_float (ci);
        res = my_fmaf (a, b, c, rnd);
        ref = ref_fmaf (a, b, c, rnd);
        resi = float_as_uint32 (res);
        refi = float_as_uint32 (ref);
        if (!(resi == refi)) {
            printf ("!!!! error @ a=x (% 15.8e)  b=x (% 15.8e)  c=x (% 15.8e)  res = x (% 15.8e)  ref = x (% 15.8e)\n",
                    ai, a, bi, b, ci, c, resi, res, refi, ref);
            return EXIT_FAILURE;
        }
        count  ;
        if (!(count & 0xffffff)) printf ("\r%llu", count);
    } while (1);
    return EXIT_SUCCESS;
}
  • Related