Home > Software design >  Is there a way to improve the performance of this fractal calculation algorithm?
Is there a way to improve the performance of this fractal calculation algorithm?

Time:10-20

Yesterday I came across the new 3Blue1Brown video about Newton's fractal and I was really mesmerized by his live representation of the fractal. (Here's the video link for anybody interested, it's at 13:40: https://www.youtube.com/watch?v=-RdOwhmqP5s)

I wanted to have a go at it myself and tried to code it in python (I think he uses python too).

I spent a few hours trying to improve my naive implementation and got to a point where I just don't know how could I make it faster.

The code looks like this:

import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from time import time


def print_fractal(state):
    fig = plt.figure(figsize=(8, 8))
    gs = GridSpec(1, 1)
    axs = [fig.add_subplot(gs[0, 0])]
    fig.tight_layout(pad=5)
    axs[0].matshow(state)
    axs[0].set_xticks([])
    axs[0].set_yticks([])
    plt.show()
    plt.close()


def get_function_value(z):
    return z**5   z**2 - z   1


def get_function_derivative_value(z):
    return 5*z**4   2*z - 1


def check_distance(state, roots):
    roots2 = np.zeros((roots.shape[0], state.shape[0], state.shape[1]), dtype=complex)
    for r in range(roots.shape[0]):
        roots2[r] = np.full((state.shape[0], state.shape[1]), roots[r])
    dist_2 = np.abs((roots2 - state))
    original_state = np.argmin(dist_2, axis=0)   1
    return original_state


def static():
    time_start = time()
    s = 4
    c = [0, 0]
    n = 800
    polynomial = [1, 0, 0, 1, -1, 1]
    roots = np.roots(polynomial)
    state = np.transpose((np.linspace(c[0] - s/2, c[0]   s/2, n)[:, None]   1j*np.linspace(c[1] - s/2, c[1]   s/2, n)))
    n_steps = 15
    time_setup = time()
    for _ in range(n_steps):
        state -= (get_function_value(state) / get_function_derivative_value(state))
    time_evolution = time()
    original_state = check_distance(state, roots)
    time_check = time()
    print_fractal(original_state)
    print("{0:<40}".format("Time to setup the initial configuration:"), "{:20.3f}".format(time_setup - time_start))
    print("{0:<40}".format("Time to evolve the state:"), "{:20.3f}".format(time_evolution - time_setup))
    print("{0:<40}".format("Time to check the closest roots:"), "{:20.3f}".format(time_check - time_evolution))

An average output looks like this:

Time to setup the initial configuration: 0.004

Time to evolve the state: 0.796

Time to check the closest roots: 0.094

It's clear that it's the evolution part that bottlenecks the process. It's not "slow", but I think it's not enough to render something live like in the video. I already did what I could by using numpy vectors and avoiding loops but I guess it's not enough. What other tricks could be applied here?

Note: I tried using numpy.polynomials.Polynomial class to evaluate the function, but it was slower than this version.

CodePudding user response:

for _ in range(n_steps):
    state -= (get_function_value(state) / get_function_derivative_value(state))

If you have enough memory, you can try to vectorize this loop and store each iteration steps with a matrix computation.

CodePudding user response:

  1. I got an improvement (~40% faster) by using single complex (np.complex64) precision.
(...)
state = np.transpose((np.linspace(c[0] - s/2, c[0]   s/2, n)[:, None] 
                        1j*np.linspace(c[1] - s/2, c[1]   s/2, n)))
state = state.astype(np.complex64)
(...)
  1. 3Blue1Brown added this link in the description: https://codepen.io/mherreshoff/pen/RwZPazd You can take a look how it was done there (sidenote: author of this pen used single precision as well)
  • Related