Home > Net >  How do I avoid copying the array during every merge in a bottom-up mergesort?
How do I avoid copying the array during every merge in a bottom-up mergesort?

Time:07-06

I have the following code, which implements a bottom-up mergesort:

#include "sort.h"
#include <stdint.h>
#include <stdlib.h>
#include <string.h>

#define CUTOFF 8
#define MIN(a, b) ((a) < (b) ? (a) : (b))

void merge(uint8_t *, uint8_t *, size_t, size_t, size_t, cmpfn);

// sort LEN elements of A, of SZ bytes each, using CMP to perform the
// comparisons
void merge_sort(void *a, size_t len, size_t sz, cmpfn cmp) {
  uint8_t *buf = malloc(len * sz);

  for (size_t i = 0; i < len; i  = CUTOFF) {
    insertion_sort((uint8_t *) a   i * sz, MIN(CUTOFF, len - i), sz, cmp);
  }
  for (size_t w = CUTOFF; w < len; w *= 2) {
    for (size_t i = 0; i   w < len; i  = w * 2) {
      merge((uint8_t *) a   i * sz, buf, MIN(w * 2, len - i), w, sz, cmp);
    }
  }
  free(buf);
}

void merge(uint8_t *a, uint8_t *buf, size_t len, size_t mid, size_t sz,
          cmpfn cmp) {
  if (cmp(a   mid * sz, a   (mid - 1) * sz) >= 0) {
    return;
  }

  size_t i = 0, j = mid, k = 0;

  memcpy(buf, a, len * sz);
  while (i < mid && j < len) {
    if (cmp(buf   j * sz, buf   i * sz) < 0) {
      memcpy(a   k * sz, buf   j   * sz, sz);
    } else {
      memcpy(a   k * sz, buf   i   * sz, sz);
    }
      k;
  }
  if (j == len) {
    memcpy(a   k * sz, buf   i * sz, (mid - i) * sz);
  }
}

Currently, during every call of merge, the code copies the contents of the array into buf. I know that in a top-down recursive approach, you can switch the order of the two arrays when calling merge to avoid the copy. How can I achieve this with an iterative approach?

EDIT: I got it working now, and I've posted the answer. If anyone has any suggestions or notices any drawbacks, feedback would be highly appreciated!

CodePudding user response:

I think I got it working now.

#include "sort.h"
#include <stdint.h>
#include <stdlib.h>
#include <string.h>

#define CUTOFF 8
#define SWAP(a, b)                                                             \
  do {                                                                         \
    typeof(a) temp = a;                                                        \
    a = b;                                                                     \
    b = temp;                                                                  \
  } while (0);
#define MIN(a, b) ((a) < (b) ? (a) : (b))

void merge(uint8_t *, uint8_t *, uint8_t *, size_t, size_t, size_t, cmpfn);

// sort LEN elements of A, of SZ bytes each, using CMP to perform the
// comparisons
void merge_sort(void *a, size_t len, size_t sz, cmpfn cmp) {
  void *p = a;
  uint8_t *buf = malloc(len * sz);

  for (size_t i = 0; i < len; i  = CUTOFF) {
    insertion_sort((uint8_t *) a   i * sz, MIN(CUTOFF, len - i), sz, cmp);
  }
  for (size_t w = CUTOFF; w < len; w *= 2) {
    for (size_t i = 0; i   w <= len; i  = w * 2) {
      merge(buf   i * sz, (uint8_t *) a   i * sz, (uint8_t *) a   (i   w) * sz,
            w, MIN(w, len - (i   w)), sz, cmp);
    }
    SWAP(a, buf);
  }
  if (a != p) {
    memcpy(p, a, len * sz);
    buf = a;
  }
  free(buf);
}

void merge(uint8_t *res, uint8_t *a, uint8_t *b, size_t na, size_t nb,
           size_t sz, cmpfn cmp) {
  if (na > 0 && nb > 0 && cmp(b, a   (na - 1) * sz) >= 0) {
    memcpy(res, a, na * sz);
    memcpy(res   na * sz, b, nb * sz);
    return;
  }

  size_t i = 0, j = 0, k = 0;

  while (i < na && j < nb) {
    if (cmp(b   j * sz, a   i * sz) < 0) {
      memcpy(res   k * sz, b   j   * sz, sz);
    } else {
      memcpy(res   k * sz, a   i   * sz, sz);
    }
      k;
  }
  if (i == na) {
    memcpy(res   k * sz, b   j * sz, (nb - j) * sz);
  } else if (j == nb) {
    memcpy(res   k * sz, a   i * sz, (na - i) * sz);
  }
}

CodePudding user response:

This is an old example I have. Insertion sort run size is set to 16 or 32 so there are an even number of merge passes. A second array is allocated one time, and the direction of merge changes with each pass by swapping pointers.

size_t GetPassCount(size_t n)               /* return # passes */
{
    size_t i = 0;
    for(size_t s = 1; s < n; s <<= 1)
        i  = 1;
    return(i);
}

void MergeSort(int a[], size_t n)
{
int *p0r;                               /* ptr to current element run 0 */
int *p0e;                               /* ptr to end             run 0 */
int *p1r;                               /* ptr to current element run 1 */
int *p1e;                               /* ptr to end             run 1 */
int *pax;                               /* ptr to a[] or b[] */
int *pbx;                               /* ptr to b[] or a[] */
size_t rsz;                             /* run size  */
int * b;

    if(n < 2)                           /* if size < 2 return */
        return;
    b = malloc(n * sizeof(int));        /* allocate second array */
    /* set run size so merge sort is even number of passes */
    rsz = ((GetPassCount(n) & 1) != 0) ? 32 : 16;
    {                                   /* insertion sort */
        size_t l, r;
        size_t i, j;
        int t;
        for (l = 0; l < n; l = r) {
            r = l   rsz;
            if (r > n)r = n;
            l--;
            for (j = l   2; j < r; j  ) {
                t = a[j];
                i = j-1;
                while(i != l && a[i] > t){
                    a[i 1] = a[i];
                    i--;
                }
                a[i 1] = t;
            }
        }
    }

    while(rsz < n){                     /* merge sort */
        pbx = &b[0];
        pax = &a[0];
        while(pax < &a[n]){
            p0e = rsz   (p0r = pax);
            if(p0e >= &a[n]){
                p0e = &a[n];
                goto cpy10;}
            p1e = rsz   (p1r = p0e);
            if(p1e >= &a[n])
                p1e = &a[n];
            /* 2 way merge */
            while(1){
                if(*p0r <= *p1r){
                    *pbx   = *p0r  ;    /* run 0 smallest */
                    if(p0r < p0e)       /* if not end run continue */
                        continue;
                    goto cpy11;
                } else {
                    *pbx   = *p1r  ;    /* run 1 smallest */
                    if(p1r < p1e)       /* if not end run continue */
                        continue;
                    goto cpy10;
                }
            }
cpy11:      p0r = p1r;
            p0e = p1e;
            /* 1 way copy */
cpy10:      while (1) {
                *pbx   = *p0r  ;        /* copy element */
                if (p0r < p0e)          /* if not end of run continue */
                    continue;
                break;
            }
            pax  = rsz << 1;            /* setup for next set of runs */
        }
        pax = a;                        /* swap ptrs */
        a = b;
        b = pax;
        rsz <<= 1;                      /* double run size */
    }
    free(b);                            /* free second array */
}
  • Related