I am trying to understand, if there is a fast way to do a matrix transpose (64x64 bits) using ARM SIMD instructions.
I tried to explore the VTRN instruction of ARM SIMD but am not sure of its effective application in this scenario.
The input matrix is represented as uint64 mat[64], and the output is supposed to be a bitwise transpose.
For example if the input is:
0000....
1111....
0000....
1111....
The expected output:
0101....
0101....
0101....
0101....
CodePudding user response:
Use vld4.u64
to load 4 columns and transpose into rows. It will load 8 linear values (01234567) and deinterleave on load into registers as (04, 15, 26, 37). This is half-way to what you want (i.e. assuming this is an 8x8 matrix, you want 08 not 04), so not yet entirely the correct pairings.
Load the "other half" pairs you need then vswp
\ vzip.64
\ vuzp.64
to interleave the correct pairings from registers.
CodePudding user response:
The basic recursive scheme for a matrix transposition is to represent the matrix as a block matrix
AB
CD
which you transpose by first transposing each of A, B, C, and D and then swapping B and C. In practice this means applying a sequence of increasingly coarse swizzle steps, first using bitwise operations and later using permutation operations.
For example, for the first step you can swap the bits of two consecutive rows like this:
# assuming V0 and V1 hold the two rows
movi v27.16b, #0x55
ushr v4.16b, v0.16b, #1
shl v5.16b, v1.16b, #1
bif v0.16b, v5.16b, v27.16b
bit v1.16b, v4.16b, v27.16b
The next step requires a shift amount by 2 and a mask of #0x33
and swaps rows 1 with 4 and 2 with 3 and so on. Try to draw it on a piece of paper.
Repeat with increasing coarseness until you are done.
CodePudding user response:
The data size far exceeds the size of the register bank. You have a choice between:
- strided load and consecutive store
- consecutive load and strided store
And consecutive store is always much more preferrable.
#include <arm_neon.h>
void transposeBitwise64x64(uint64_t *pDst, uint64_t *pSrc)
{
uint8x8_t drow0, drow1, drow2, drow3, drow4, drow5, drow6, drow7;
uint8x8_t dtmp0, dtmp1, dtmp2, dtmp3, dtmp4, dtmp5, dtmp6, dtmp7;
uint8x16_t qrow0, qrow1, qrow2, qrow3, qrow4, qrow5, qrow6, qrow7;
uint8x16_t qtmp0, qtmp1, qtmp2, qtmp3, qtmp4, qtmp5, qtmp6, qtmp7;
const intptr_t sstride = 16;
uint8_t *pSrc1, *pSrc2, *pSrcBase;
uint32_t count = 8;
drow0 = vmov_n_u8(0);
drow1 = vmov_n_u8(0);
drow2 = vmov_n_u8(0);
drow3 = vmov_n_u8(0);
drow4 = vmov_n_u8(0);
drow5 = vmov_n_u8(0);
drow6 = vmov_n_u8(0);
drow7 = vmov_n_u8(0);
pSrcBase = (uint8_t *) pSrc;
do {
pSrc1 = pSrcBase;
pSrc2 = pSrcBase 8;
pSrcBase = 1;
drow0 = vld1_lane_u8(pSrc1, drow0, 0); pSrc1 = sstride;
drow1 = vld1_lane_u8(pSrc2, drow1, 0); pSrc2 = sstride;
drow2 = vld1_lane_u8(pSrc1, drow2, 0); pSrc1 = sstride;
drow3 = vld1_lane_u8(pSrc2, drow3, 0); pSrc2 = sstride;
drow4 = vld1_lane_u8(pSrc1, drow4, 0); pSrc1 = sstride;
drow5 = vld1_lane_u8(pSrc2, drow5, 0); pSrc2 = sstride;
drow6 = vld1_lane_u8(pSrc1, drow6, 0); pSrc1 = sstride;
drow7 = vld1_lane_u8(pSrc2, drow7, 0); pSrc2 = sstride;
drow0 = vld1_lane_u8(pSrc1, drow0, 1); pSrc1 = sstride;
drow1 = vld1_lane_u8(pSrc2, drow1, 1); pSrc2 = sstride;
drow2 = vld1_lane_u8(pSrc1, drow2, 1); pSrc1 = sstride;
drow3 = vld1_lane_u8(pSrc2, drow3, 1); pSrc2 = sstride;
drow4 = vld1_lane_u8(pSrc1, drow4, 1); pSrc1 = sstride;
drow5 = vld1_lane_u8(pSrc2, drow5, 1); pSrc2 = sstride;
drow6 = vld1_lane_u8(pSrc1, drow6, 1); pSrc1 = sstride;
drow7 = vld1_lane_u8(pSrc2, drow7, 1); pSrc2 = sstride;
drow0 = vld1_lane_u8(pSrc1, drow0, 2); pSrc1 = sstride;
drow1 = vld1_lane_u8(pSrc2, drow1, 2); pSrc2 = sstride;
drow2 = vld1_lane_u8(pSrc1, drow2, 2); pSrc1 = sstride;
drow3 = vld1_lane_u8(pSrc2, drow3, 2); pSrc2 = sstride;
drow4 = vld1_lane_u8(pSrc1, drow4, 2); pSrc1 = sstride;
drow5 = vld1_lane_u8(pSrc2, drow5, 2); pSrc2 = sstride;
drow6 = vld1_lane_u8(pSrc1, drow6, 2); pSrc1 = sstride;
drow7 = vld1_lane_u8(pSrc2, drow7, 2); pSrc2 = sstride;
drow0 = vld1_lane_u8(pSrc1, drow0, 3); pSrc1 = sstride;
drow1 = vld1_lane_u8(pSrc2, drow1, 3); pSrc2 = sstride;
drow2 = vld1_lane_u8(pSrc1, drow2, 3); pSrc1 = sstride;
drow3 = vld1_lane_u8(pSrc2, drow3, 3); pSrc2 = sstride;
drow4 = vld1_lane_u8(pSrc1, drow4, 3); pSrc1 = sstride;
drow5 = vld1_lane_u8(pSrc2, drow5, 3); pSrc2 = sstride;
drow6 = vld1_lane_u8(pSrc1, drow6, 3); pSrc1 = sstride;
drow7 = vld1_lane_u8(pSrc2, drow7, 3); pSrc2 = sstride;
drow0 = vld1_lane_u8(pSrc1, drow0, 4); pSrc1 = sstride;
drow1 = vld1_lane_u8(pSrc2, drow1, 4); pSrc2 = sstride;
drow2 = vld1_lane_u8(pSrc1, drow2, 4); pSrc1 = sstride;
drow3 = vld1_lane_u8(pSrc2, drow3, 4); pSrc2 = sstride;
drow4 = vld1_lane_u8(pSrc1, drow4, 4); pSrc1 = sstride;
drow5 = vld1_lane_u8(pSrc2, drow5, 4); pSrc2 = sstride;
drow6 = vld1_lane_u8(pSrc1, drow6, 4); pSrc1 = sstride;
drow7 = vld1_lane_u8(pSrc2, drow7, 4); pSrc2 = sstride;
drow0 = vld1_lane_u8(pSrc1, drow0, 5); pSrc1 = sstride;
drow1 = vld1_lane_u8(pSrc2, drow1, 5); pSrc2 = sstride;
drow2 = vld1_lane_u8(pSrc1, drow2, 5); pSrc1 = sstride;
drow3 = vld1_lane_u8(pSrc2, drow3, 5); pSrc2 = sstride;
drow4 = vld1_lane_u8(pSrc1, drow4, 5); pSrc1 = sstride;
drow5 = vld1_lane_u8(pSrc2, drow5, 5); pSrc2 = sstride;
drow6 = vld1_lane_u8(pSrc1, drow6, 5); pSrc1 = sstride;
drow7 = vld1_lane_u8(pSrc2, drow7, 5); pSrc2 = sstride;
drow0 = vld1_lane_u8(pSrc1, drow0, 6); pSrc1 = sstride;
drow1 = vld1_lane_u8(pSrc2, drow1, 6); pSrc2 = sstride;
drow2 = vld1_lane_u8(pSrc1, drow2, 6); pSrc1 = sstride;
drow3 = vld1_lane_u8(pSrc2, drow3, 6); pSrc2 = sstride;
drow4 = vld1_lane_u8(pSrc1, drow4, 6); pSrc1 = sstride;
drow5 = vld1_lane_u8(pSrc2, drow5, 6); pSrc2 = sstride;
drow6 = vld1_lane_u8(pSrc1, drow6, 6); pSrc1 = sstride;
drow7 = vld1_lane_u8(pSrc2, drow7, 6); pSrc2 = sstride;
drow0 = vld1_lane_u8(pSrc1, drow0, 7); pSrc1 = sstride;
drow1 = vld1_lane_u8(pSrc2, drow1, 7); pSrc2 = sstride;
drow2 = vld1_lane_u8(pSrc1, drow2, 7); pSrc1 = sstride;
drow3 = vld1_lane_u8(pSrc2, drow3, 7); pSrc2 = sstride;
drow4 = vld1_lane_u8(pSrc1, drow4, 7); pSrc1 = sstride;
drow5 = vld1_lane_u8(pSrc2, drow5, 7); pSrc2 = sstride;
drow6 = vld1_lane_u8(pSrc1, drow6, 7);
drow7 = vld1_lane_u8(pSrc2, drow7, 7);
dtmp0 = vshr_n_u8(drow0, 1);
dtmp1 = vshr_n_u8(drow1, 1);
dtmp2 = vshr_n_u8(drow2, 1);
dtmp3 = vshr_n_u8(drow3, 1);
dtmp4 = vshr_n_u8(drow4, 1);
dtmp5 = vshr_n_u8(drow5, 1);
dtmp6 = vshr_n_u8(drow6, 1);
dtmp7 = vshr_n_u8(drow7, 1);
qrow0 = vcombine_u8(drow0, dtmp0);
qrow1 = vcombine_u8(drow1, dtmp1);
qrow2 = vcombine_u8(drow2, dtmp2);
qrow3 = vcombine_u8(drow3, dtmp3);
qrow4 = vcombine_u8(drow4, dtmp4);
qrow5 = vcombine_u8(drow5, dtmp5);
qrow6 = vcombine_u8(drow6, dtmp6);
qrow7 = vcombine_u8(drow7, dtmp7);
//////////////////////////////////////
qtmp0 = qrow0;
qtmp1 = qrow1;
qtmp2 = qrow2;
qtmp3 = qrow3;
qtmp4 = qrow4;
qtmp5 = qrow5;
qtmp6 = qrow6;
qtmp7 = qrow7;
qtmp0 = vsliq_n_u8(qtmp0, qtmp1, 1);
qtmp2 = vsliq_n_u8(qtmp2, qtmp3, 1);
qtmp4 = vsliq_n_u8(qtmp4, qtmp5, 1);
qtmp6 = vsliq_n_u8(qtmp6, qtmp7, 1);
qtmp0 = vsliq_n_u8(qtmp0, qtmp2, 2);
qtmp4 = vsliq_n_u8(qtmp4, qtmp6, 2);
qtmp0 = vsliq_n_u8(qtmp0, qtmp4, 4);
vst1q_u8((uint8_t *)pDst, qtmp0); pDst = 2;
//////////////////////////////////////
qtmp0 = vshrq_n_u8(qrow0, 2);
qtmp1 = vshrq_n_u8(qrow1, 2);
qtmp2 = vshrq_n_u8(qrow2, 2);
qtmp3 = vshrq_n_u8(qrow3, 2);
qtmp4 = vshrq_n_u8(qrow4, 2);
qtmp5 = vshrq_n_u8(qrow5, 2);
qtmp6 = vshrq_n_u8(qrow6, 2);
qtmp7 = vshrq_n_u8(qrow7, 2);
qtmp0 = vsliq_n_u8(qtmp0, qtmp1, 1);
qtmp2 = vsliq_n_u8(qtmp2, qtmp3, 1);
qtmp4 = vsliq_n_u8(qtmp4, qtmp5, 1);
qtmp6 = vsliq_n_u8(qtmp6, qtmp7, 1);
qtmp0 = vsliq_n_u8(qtmp0, qtmp2, 2);
qtmp4 = vsliq_n_u8(qtmp4, qtmp6, 2);
qtmp0 = vsliq_n_u8(qtmp0, qtmp4, 4);
vst1q_u8((uint8_t *)pDst, qtmp0); pDst = 2;
//////////////////////////////////////
qtmp0 = vshrq_n_u8(qrow0, 4);
qtmp1 = vshrq_n_u8(qrow1, 4);
qtmp2 = vshrq_n_u8(qrow2, 4);
qtmp3 = vshrq_n_u8(qrow3, 4);
qtmp4 = vshrq_n_u8(qrow4, 4);
qtmp5 = vshrq_n_u8(qrow5, 4);
qtmp6 = vshrq_n_u8(qrow6, 4);
qtmp7 = vshrq_n_u8(qrow7, 4);
qtmp0 = vsliq_n_u8(qtmp0, qtmp1, 1);
qtmp2 = vsliq_n_u8(qtmp2, qtmp3, 1);
qtmp4 = vsliq_n_u8(qtmp4, qtmp5, 1);
qtmp6 = vsliq_n_u8(qtmp6, qtmp7, 1);
qtmp0 = vsliq_n_u8(qtmp0, qtmp2, 2);
qtmp4 = vsliq_n_u8(qtmp4, qtmp6, 2);
qtmp0 = vsliq_n_u8(qtmp0, qtmp4, 4);
vst1q_u8((uint8_t *)pDst, qtmp0); pDst = 2;
//////////////////////////////////////
qtmp0 = vshrq_n_u8(qrow0, 6);
qtmp1 = vshrq_n_u8(qrow1, 6);
qtmp2 = vshrq_n_u8(qrow2, 6);
qtmp3 = vshrq_n_u8(qrow3, 6);
qtmp4 = vshrq_n_u8(qrow4, 6);
qtmp5 = vshrq_n_u8(qrow5, 6);
qtmp6 = vshrq_n_u8(qrow6, 6);
qtmp7 = vshrq_n_u8(qrow7, 6);
qtmp0 = vsliq_n_u8(qtmp0, qtmp1, 1);
qtmp2 = vsliq_n_u8(qtmp2, qtmp3, 1);
qtmp4 = vsliq_n_u8(qtmp4, qtmp5, 1);
qtmp6 = vsliq_n_u8(qtmp6, qtmp7, 1);
qtmp0 = vsliq_n_u8(qtmp0, qtmp2, 2);
qtmp4 = vsliq_n_u8(qtmp4, qtmp6, 2);
qtmp0 = vsliq_n_u8(qtmp0, qtmp4, 4);
vst1q_u8((uint8_t *)pDst, qtmp0); pDst = 2;
} while (--count);
}
I tried my best to talk the compiler into generating optimized machine codes, but they simply won't listen: godbolt
Especially GCC sucks (as always).
I'll add an assembly version by tomorrow.