Home > database >  SIMD code for transforming one-letter amino acid code into integer between 0 and 22
SIMD code for transforming one-letter amino acid code into integer between 0 and 22

Time:01-20

Trying to code the following transformation in SIMD C , any ideas ?

Code is from https://github.com/soedinglab/hh-suite/blob/master/src/hhutil-inl.h#L45-L83

const int ANY=20;       //number representing an X (any amino acid) internally
const int GAP=21;

/////////////////////////////////////////////////////////////////////////////////////
// Transforms the one-letter amino acid code into an integer between 0 and 22
/////////////////////////////////////////////////////////////////////////////////////
inline char aa2i(char c)
{
  //A  R  N  D  C  Q  E  G  H  I  L  K  M  F  P  S  T  W  Y  V
  if (c>='a' && c<='z') c ='A'-'a';
  switch (c)
    {
    case 'A': return 0;
    case 'R': return 1;
    case 'N': return 2;
    case 'D': return 3;
    case 'C': return 4;
    case 'Q': return 5;
    case 'E': return 6;
    case 'G': return 7;
    case 'H': return 8;
    case 'I': return 9;
    case 'L': return 10;
    case 'K': return 11;
    case 'M': return 12;
    case 'F': return 13;
    case 'P': return 14;
    case 'S': return 15;
    case 'T': return 16;
    case 'W': return 17;
    case 'Y': return 18;
    case 'V': return 19;
    case 'X': return ANY;
    case 'J': return ANY;
    case 'O': return ANY;
    case 'U': return 4;  //Selenocystein -> Cystein
    case 'B': return 3;  //D (or N)
    case 'Z': return 6;  //E (or Q)
    case '-': return GAP;
    case '.': return GAP;
    case '_': return GAP;
    }
  if (c>=0 && c<=32) return -1; // white space and control characters
  return -2;
}

CodePudding user response:

Here is one idea using AVX2, the "main" idea is using VPSHUFB in "lookup table"-mode (using the data as the shuffle mask, with the LUT as the first source operand). The rest of the code basically deals with all the special cases.

For range checks, I shift the range up (by adding some constant) to make the top of the range = 127, then compare whether the resulting value is signed-greater-than some lower bound (the old lower bound, plus however much the range was shifted by).

I rely on VPSHUFB's property that it puts 0 in a byte if the top bit of its shuffle index is set. By first shifting the range down so that the start of the range lands on zero, and then up with an unsigned-saturating addition so that everything above the end of the range (thanks to the subtraction, anything outside the range is now above the end of the range) has its upper bit set, so the result of the lookup-by-shuffle will be zero there. That means I can OR the results together instead of having to blend them.

Perhaps some of this can be simplified, especially the handling of the special cases.

BTW I tested this somewhat (not rigorously), and it seemed to work for non-negative characters at least. For negative characters, -1 isn't handled the same way, which would be fixable but I expected that no one cares about (char)-1.

Handling any "leftovers" (if N is not divisible by 32) is done by having a special extra iteration that overlaps partially with the second-to-last iteration, unless N is so small that that would access memory before the start of the data.

in and out do not require any special alignment, and can be equal, but the input and output shouldn't overlap in some odd way.

void aa2i(char *in, char *out, size_t N)
{
    size_t i = 0;

    __m256i LUT1 = _mm256_setr_epi8(
    //   @  A  B  C  D  E   F  G  H  I    J   K   L   M  N    O
        -2, 0, 3, 4, 3, 6, 13, 7, 8, 9, ANY, 11, 10, 12, 2, ANY,
        -2, 0, 3, 4, 3, 6, 13, 7, 8, 9, ANY, 11, 10, 12, 2, ANY);
    __m256i LUT2 = _mm256_setr_epi8(
    //   P  Q  R   S   T  U   V   W    X   Y  Z   [   \   ]   ^    _
        14, 5, 1, 15, 16, 4, 19, 17, ANY, 18, 6, -2, -2, -2, -2, GAP,
        14, 5, 1, 15, 16, 4, 19, 17, ANY, 18, 6, -2, -2, -2, -2, GAP);
    do {
        for (; i   31 < N; i  = 32)
        {
            __m256i data = _mm256_loadu_si256((__m256i*) & in[i]);
            // is above ws but not letter
            __m256i is_above_ws = _mm256_cmpgt_epi8(_mm256_add_epi8(data, _mm256_set1_epi8(95)), _mm256_set1_epi8(94));
            // is control
            __m256i is_control = _mm256_cmpgt_epi8(_mm256_add_epi8(data, _mm256_set1_epi8(64)), _mm256_set1_epi8(96));
            // is '-' or '.'
            __m256i is_dash_or_dot = _mm256_cmpgt_epi8(_mm256_add_epi8(data, _mm256_set1_epi8(81)), _mm256_set1_epi8(125));
            // convert '`'-'z' to '@'-'Z'
            __m256i is_not_a_to_z_lower =
                _mm256_cmpgt_epi8(_mm256_sub_epi8(data, _mm256_set1_epi8('`' - 128)), _mm256_set1_epi8(26));
            __m256i lowercase = _mm256_xor_si256(data, _mm256_andnot_si256(is_not_a_to_z_lower, _mm256_set1_epi8(0x20)));
            // range @ .. O
            __m256i rangeA = _mm256_sub_epi8(lowercase, _mm256_set1_epi8('@'));
            __m256i partA = _mm256_shuffle_epi8(LUT1, _mm256_adds_epu8(rangeA, _mm256_set1_epi8(0x70)));
            // range P .. _
            __m256i rangeB = _mm256_sub_epi8(lowercase, _mm256_set1_epi8('P'));
            __m256i partB = _mm256_shuffle_epi8(LUT2, _mm256_adds_epu8(rangeB, _mm256_set1_epi8(0x70)));
            // assemble parts
            __m256i res = _mm256_or_si256(partA, partB);
            res = _mm256_blendv_epi8(res, _mm256_set1_epi8(-2), _mm256_add_epi8(data, _mm256_set1_epi8(1)));
            res = _mm256_or_si256(res, _mm256_or_si256(is_above_ws, _mm256_and_si256(is_control, _mm256_set1_epi8(-2))));
            res = _mm256_blendv_epi8(res, _mm256_set1_epi8(GAP), is_dash_or_dot);
            _mm256_storeu_si256((__m256i*) & out[i], res);
        }
        if (i < N && N >= 32)
        {
            // if there is a leftover but the array is big enough,
            // do one last iteration for the leftover, partly overlapping with the previous iteration
            i = N - 32;
        }
        else 
            break;
    } while (1);

    // scalar fallback, only used for tiny arrays
    for (; i < N; i  )
        out[i] = aa2i(in[i]);
}

CodePudding user response:

I don't see a lot of room for SIMD (as such) to help here. I'd start with something like this though:

int vals[] = {
    0, 3, 4, 3, 6, 13, 7, 8, 9, ANY, 11, 10, 12,
    2, ANY, 14, 5, 1, 15, 16, 4, 19, 17, ANY, 18, 6

};
if (isupper(c)) return vals[c-'A'];
else return GAP;

[This does depend on using a character set where lower-case letters are contiguous, so it's not portable to IBM mainframes, but I doubt you care.]

  • Related