Home > Blockchain >  How to use recursion in Risc-V ? To translate c into Risc-V
How to use recursion in Risc-V ? To translate c into Risc-V

Time:10-16

We have an assignment to translate the following C code to assembly:

#include <stdio.h>
#include <stdlib.h>

int gcd(int a, int b)
{
    int ret;
    while (a != b){
        if (a > b){
             ret = gcd(a - b, b);
             return ret;
        }
        else{
            ret = gcd(a , b - a);
             return ret;
        }
    }
    return a;
}

void main(void)
{
    int a;
    int b;
    int c;
    
    printf("a & b? ");
    scanf(" %d %d",&a,&b );
    
    c = gcd(a, b);
    printf("result = %d",c);
} 

Which finds the greatest common divisor of an integer. We must translate it to assembly

.data
str1:    .ascii "result = \0"
str2:    .ascii "\n\0"

.text

.global main

main:
    addi sp,sp,-32 
    sw  ra,28(sp)
    sw  s0,24(sp)
    addi s0,sp,32
    call read_int
    mv s0,a0    
    call read_int
    mv s1,a0    
    mv a0,s0    #a0 = a
    mv a1,s1    #a1 = b
    call gcd
    mv s1,a0
    la a0,str1
    call print_string
    mv a0,s1
    call print_int
    la a0,str2
    call print_string
    lw  ra,28(sp)
    lw  s0,24(sp)
    addi    sp,sp,32
    call show_pc
    call exit
    ret

gcd:
    addi sp,sp,-8   # Increase stack w/ space for 1 int
    sw s3,4(sp)     # push S0 to stack
    sw s4,4(sp)
L1: beq a0,a1,L2    # if (a0==a1) go to L2
    slt s4,a0,a1    # if (a<b) s1=1, else s4=0
    beq s4,zero,L3  # if s4==0 go to L3
    sub s3,a0,a1    # varRet(s3) = a-b
    call gcd        # recursion
L3: sub s3,a1,a0    # varRet(s3) = b-a
    call gcd        # recursion
    beq zero,zero,L1 # jump to L1
L2: lw s3,4(sp)     # restore old s0
    lw s4,4(sp)
    addi sp,sp,4    # decrease stack
    jr ra           # jump to ra

But I am getting a data saving error in my code in my gcd function. Probably because how I am using the variables or ending the function.

Any help on understanding how to do that would be appreciated.

CodePudding user response:

Recursion is a bit of a red herring.  As long as the runtime architecture supports a (recursive) call stack — RISC V environments do — then supporting or doing recursion comes down to the same thing as a function, A, calling another function, B.  (In the recursive case, A and B are the same function.)  The point being that all you have to do is follow the rules for function calling and recursion will simply work.

(There are opportunities for optimizing function calling by deviating from the standard calling rules, but that's generally not what instructors are going for here.)

I would address the following:

  • Using registers s3 & s4, which are unnecessary for this function.  (If you do use them, need to fix prologue and epilog as both s3 and s4 are being saved to the same exact memory location, so last to be saved wins the memory location.)

  • Should use call-clobbered register for ret, i.e. a0 is a good choice there.

  • Consider using maybe t0 for your temps involving the relational test, but that is a MIPS style of programming: unlike MIPS, RISC V has the full complement of relations in conditional branches, so an slt is not required.

  • Must properly pass (a-b,b) as parameters to recursive call, first arg in a0, second in a1.  Same for passing (a,b-a) in the second call.

  • Should allocate stack space and preserve ra as it will automatically be repurposed for subsequent calling and you'll need the original ra value to get back to the proper caller.

  • The control flow of the assembly doesn't match the C code.  Though the C code is written using a while loop, it actually contains no looping.  Every path inside the loop body has a return statement, thus, the "loop" runs either never or just once — would be better written as an if.  The assembly code attempts to make an actual loop, missing the point of return statements on all the code paths in the loop.  The return construct needs to execute function epilog (and return to the caller) rather than falling through to some random part of code that happens to be next.

CodePudding user response:

Erik's answer explains what is wrong with the code. There are a number of errors. I'll list a few:

  1. register save/restore code is wrong. Both store/load from 4(sp) so one register is clobbered
  2. Wrong addi in restore code--it does not match the save code
  3. No save/restore of the return address (e.g. ra). Actually, this is the only register that needs save/restore.
  4. a0/a1 (i.e. a and b) are never changed for the recursive call
  5. Usually, the return value is in v0. That's just the convention so not so big a deal.
  6. We only need one extra register to take the slt result. By convention, this can be a t* register [which doesn't need save/restore].

Note that in the code below, I've tested the modified C functions but not the asm.


When I want to write assembler, I rewrite the C code to use just if and goto to more closely mimic the [proposed] assembler.

The if can only be simple (e.g.) if (a > b) and not if ((a > b) && (c < d)). The latter must be split into simple if (using more goto)

int
gcd3(int a, int b)
{
    int ret = a;

    if (a == b)
        goto done;

    if (a > b)
        goto a_gt_b;

a_lt_b:
        ret = gcd(a, b - a);
        goto done;

a_gt_b:
        ret = gcd(a - b, b);
        goto done;

done:
    return ret;
}

Here is the refactored asm code:

gcd:
    addi    sp,sp,-4                # Increase stack w/ space for 1 int
    sw      ra,0(sp)                # save return address

    add     v0,a0,zero              # ret = a

    beq     a0,a1,done              # if (a == b) we are done

    slt     t0,a0,a1                # is a > b?
    beq     t0,zero,a_gt_b          # yes, fly

    # a < b
a_lt_b:
    sub     a1,a1,a0                # b = b - a
    call    gcd                     # recursion
    b       done

    # a > b
a_gt_b:
    sub     a0,a0,a1                # a = a - b
    call    gcd                     # recursion
    b       done

done:
    lw      ra,0(sp)                # restore return address
    addi    sp,sp,4                 # decrease stack
    jr      ra                      # return

Your original C code combined elements of recursion and looping. The function doesn't really need to be recursive at all.

Here's a looping version:

int
gcd4(int a, int b)
{

loop:
    if (a == b)
        goto done;

    if (a > b)
        goto a_gt_b;

a_lt_b:
    b = b - a;
    goto loop;

a_gt_b:
    a = a - b;
    goto loop;

done:
    return a;
}

Here is the assembler code:

gcd:
    beq     a0,a1,done              # if (a == b) we are done

    slt     t0,a0,a1                # is a > b?
    beq     t0,zero,a_gt_b          # yes, fly

    # a < b
a_lt_b:
    sub     a1,a1,a0                # b = b - a
    b       gcd

    # a > b
a_gt_b:
    sub     a0,a0,a1                # a = a - b
    b       gcd

done:
    add     v0,a0,zero              # ret = a
    jr      ra                      # return

Here's the full test C program that I used:

#include <stdio.h>
#include <stdlib.h>

int
gcd(int a, int b)
{
    int ret;

    while (a != b) {
        if (a > b) {
            ret = gcd(a - b, b);
            return ret;
        }
        else {
            ret = gcd(a, b - a);
            return ret;
        }
    }

    return a;
}

int
gcd2(int a, int b)
{
    int ret = a;

    while (a != b) {
        if (a > b) {
            ret = gcd(a - b, b);
            break;
        }
        else {
            ret = gcd(a, b - a);
            break;
        }
    }

    return ret;
}

int
gcd3(int a, int b)
{
    int ret = a;

    if (a == b)
        goto done;

    if (a > b)
        goto a_gt_b;

a_lt_b:
        ret = gcd(a, b - a);
        goto done;

a_gt_b:
        ret = gcd(a - b, b);
        goto done;

done:
    return ret;
}

int
gcd4(int a, int b)
{

loop:
    if (a == b)
        goto done;

    if (a > b)
        goto a_gt_b;

a_lt_b:
    b = b - a;
    goto loop;

a_gt_b:
    a = a - b;
    goto loop;

done:
    return a;
}

int
main(void)
{
    int a;
    int b;
    int code = 0;
    char buf[100];

    while (1) {
        printf("a & b? ");
        fflush(stdout);

        if (fgets(buf,sizeof(buf),stdin) == NULL)
            break;
        if (buf[0] == '\n')
            break;

        sscanf(buf," %d %d", &a, &b);
        printf("%d %d\n", a, b);

        int c1 = gcd(a, b);
        printf("gcd1 = %d\n", c1);

        int c3 = gcd3(a, b);
        printf("gcd3 = %d\n", c3);
        if (c3 != c1) {
            printf("MISMATCH\n");
            code = 1;
        }

        int c4 = gcd4(a, b);
        printf("gcd4 = %d\n", c4);
        if (c4 != c1) {
            printf("MISMATCH\n");
            code = 1;
        }

        if (code)
            break;
    }

    return code;
}
  • Related