Home > Back-end >  How to sort a large list of dictionaries without loading into memory in Python
How to sort a large list of dictionaries without loading into memory in Python

Time:10-29

I have a CSV file with about 50 million rows and I'm trying to manipulate the data and write to a new CSV file. Here's the code below:

import csv
import itertools

def main():
    with open("input.csv", "r") as csvfile:
        rows = csv.DictReader(csvfile)
        sorted_rows = sorted(rows, key=lambda row: row["name"])
        grouping = groupby(sorted_rows, lambda row: row["name"])

        with open("output.csv", "w") as final_csvfile:
            fieldnames = ["name", "number"]
            writer = csv.DictWriter(final_csvfile, fieldnames=fieldnames)

            for group, items in grouping:
                total = sum(int(item["number"]) for item in items)
                writer.writerow(
                    {
                        "name": group,
                        "number": str(total),
                    }
                )


if __name__ == "__main__":
    main()

This works well on a not too large number of rows, but when I run the actual CSV with 50 million rows, it becomes very slow and the program gets killed eventually.

Now the line: sorted_rows = sorted(rows, key=lambda row: row["name"]) is the main problem because it loads the 50 million rows into memory (a list) so it can be sorted. I have come to understand that the first thing sorted() does is to convert any generator given to it into a list, so how do I go about this please? Any pointers?

CodePudding user response:

Can you try this and check if there is any improvements? I have done away with the sorting and groupby, by using a dictionary, and the dictionary only saves the name and number and not the other columns, so that is less memory used.

import csv
from collections import defaultdict

def main():
    sums = defaultdict(int)
    with open("input.csv", "r") as csvfile:
        rows = csv.DictReader(csvfile)

        for row in rows:
            sums[row["name"]]  = int(row["number"])

        with open("output.csv", "w") as final_csvfile:
            fieldnames = ["name", "number"]
            writer = csv.DictWriter(final_csvfile, fieldnames=fieldnames)
            writer.writerows(
                {"name": name, "number": number} for name, number in sums.items()
            )


if __name__ == "__main__":
    main()

In case anyone is wondering about timings, for a five million row csv with 4 columns (name, number, col_3, col_4), this took 10 seconds to create the output file, and OP's code took 16 seconds.

  • Related