I understand how to do general arithmetic operations in AVX2. However, there are conditional operations in scalar code I would like to translate to AVX2. How shall I do it? For example, I would like to vectorize
double arr[4] = {1.0,2.0,3.0,4.0};
double condition = 3.0;
for (int i = 0; i < 4; i ) {
if (arr[i] < condition) {
arr[i] *= 1.75;
}
else {
arr[i] *= 6.5;
}
}
for (auto i : arr) {
std::cout << i << '\t';
}
Expected output:
1.75 3.5 19.5 26
How can I perform conditional operations like above in AVX2?
CodePudding user response:
Use AVX2 conditional operations. Calculate both possible outputs on whole vectors. After that save those particular results that satisfy your conditions (mask). For your case:
double arr[4] = { 1.0,2.0,3.0,4.0 };
double condition = 3.0;
__m256d _arr = _mm256_loadu_pd(&arr[0]);
__m256d _multiplier1 = _mm256_set1_pd(1.75);
__m256d _multiplier2 = _mm256_set1_pd(6.5);
__m256d _firstResult = _mm256_mul_pd(_arr, _multiplier1); //if-branch
__m256d _secondResult = _mm256_mul_pd(_arr, _multiplier2); //else-branch
__m256d _mask = _mm256_set1_pd(condition);
_mask = _mm256_cmp_pd(_arr, _mask, _CMP_LT_OQ); //a < b ordered (non-signalling)
// Use mask to choose between _firstResult and _secondResult for each element
_firstResult = _mm256_blendv_pd(_secondResult, _firstResult, _mask);
double res[4];
_mm256_storeu_pd(&res[0], _firstResult);
for (auto i : res) {
std::cout << i << '\t';
}
Possible alternative approach instead of BLENDV is combination of AND, ANDNOT and OR. However BLENDV is much better both in simplicity and performance. Use BLENDV as long as you have as least SSE4.1 and don't have AVX512 yet.
For information about what _CMP_LT_OQ
mean and can be see Dave Dopson's table. You can do whatever comparisons you want changing this accordingly.
There are detailed notes by Peter Cordes about conditional operations in AVX2 and AVX512. There are more examples on conditional vectorization (with SSE and AVX512 examples) in Agner Fog's "Optimizing C " in chapter 12.4 on pages 121-124.
Maybe you aren't want to do some computations in else-branch or explicitly want to zero it. So that your expected output will look like
1.75 3.5 0.0 0.0
It that case you can you a bit more faster instruction sequence since you're not have to think about else-branch. There are at least 2 ways to achieve speedup:
- Removing second multiplication, but keeping blendv. Instead of _secondResult just use zeroed vector (it can be global const).
- Removing both second multiplication and blendv, and replacing blendv with AND mask. This variant uses zeroed vector as well.
Second way will be better. For example, according to uops table VBLENDVB on Skylake microarchitecture takes 2 uops, 2 clocks latency and can be done only once per clock. Meanwhile VANDPD have 1 uops, 1 clock latency and can be done 3 times in a single clock.
Worse way, just blending with zero
double arr[4] = { 1.0,2.0,3.0,4.0 };
double condition = 3.0;
__m256d _arr = _mm256_loadu_pd(&arr[0]);
__m256d _multiplier1 = _mm256_set1_pd(1.75);
__m256d _firstResult = _mm256_mul_pd(_arr, _multiplier1); //if-branch
__m256d _zeroes = _mm256_setzero_pd();
__m256d _mask = _mm256_set1_pd(condition);
_mask = _mm256_cmp_pd(_arr, _mask, _CMP_LT_OQ); //a < b ordered (non-signalling)
//Conditionally blenv _firstResult when IF statement satisfied, zeroes otherwise
_firstResult = _mm256_blendv_pd(_zeroes, _firstResult, _mask);
double res[4];
_mm256_storeu_pd(&res[0], _firstResult);
for (auto i : res) {
std::cout << i << '\t';
}
Better way, bitwise AND with a compare result is a cheaper way to conditionally zero.
double arr[4] = { 1.0,2.0,3.0,4.0 };
double condition = 3.0;
__m256d _arr = _mm256_loadu_pd(&arr[0]);
__m256d _multiplier1 = _mm256_set1_pd(1.75);
__m256d _firstResult = _mm256_mul_pd(_arr, _multiplier1); //if-branch
__m256d _zeroes = _mm256_setzero_pd();
__m256d _mask = _mm256_set1_pd(condition);
_mask = _mm256_cmp_pd(_arr, _mask, _CMP_LT_OQ); //a < b ordered (non-signalling)
// If result not satisfied condition, after bitwise AND it becomes zero
_firstResult = _mm256_and_pd(_firstResult, _mask);
double res[4] = {0.0,0.0,0.0,0.0};
_mm256_storeu_pd(&res[0], _firstResult);
for (auto i : res) {
std::cout << i << '\t';
}
This takes advantage of what a compare-result vector really is, and that the bit-pattern for IEEE 0.0
is all bits zeroed.