Home > Mobile >  How to scatter plot with 1 million points
How to scatter plot with 1 million points

Time:10-29

I'm trying to make a program that draws a graph using given points from a csv file which contains 4 strings for each row (Number of the point, x pos, y pos, color), but the time it takes is ridiculously high, so i'm looking for ideas to make it faster.

from matplotlib import pyplot as plt    
from matplotlib import style   
import csv

style.use('ggplot')

s = 0.5
with open('total.csv') as f:
  f_reader = csv.reader(f, delimiter=',')
  for row in f_reader:
    plt.scatter(str(row[1]), str(row[2]), color=str(row[3]), s=s)
plt.savefig("graph.png", dpi=1000)

CodePudding user response:

The first step would be to call scatter once instead of for every points, without adding a dependency on numpy and pandas it could look like:

from matplotlib import pyplot as plt
from matplotlib import style
import csv

style.use("ggplot")

s = 0.5
x = []
y = []
c = []
with open("total.csv") as f:
    f_reader = csv.reader(f, delimiter=",")
    for row in f_reader:
        x.append(row[1])
        y.append(row[2])
        c.append(row[3])
plt.scatter(x, y, color=c, s=s)
plt.savefig("graph.png", dpi=1000)

Then maybe try pandas.read_csv which would give you an pandas dataframe allowing you to access the columns of your CSV without a for loop, which would probably be faster.

Each time you try a variation, measure the time it take (possibly on a smaller file) to know what help and what don't, in other words, don't try to enhance perfs blindly.

Using pandas it would look like:

from matplotlib import pyplot as plt
from matplotlib import style
import pandas as pd

style.use("ggplot")

total = pd.read_csv("total.csv")
plt.scatter(total.x, total.y, color=total.color, s=0.5)
plt.savefig("graph.png", dpi=1000)

If you want to learn more on pandas good practices for performance, I like the No more sad pandas talk, take a look at it.

  • Related