Home > Software engineering >  Efficient way to erase and re create a (part of, if possible) subplot inside loop using matplotlib?
Efficient way to erase and re create a (part of, if possible) subplot inside loop using matplotlib?

Time:05-09

The code below creates a Scatter plot from X and based on values of w,b, creates lines over X.

I have tried a couple of combinations such as:

fig.canvas.draw()
fig.canvas.flush_events()

plt.clf
plt.cla

But they either seem to plot multiple lines over the plot or Delete the figure / axes.

Is it possible to plot the Scatter plot only once but the Lines keep changing based on w,b?.

Below is the code that I have used:

from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
import time
from IPython.display import display, clear_output

def get_hyperplane_value(x, w, b, offset):
    '''
    Generate Hyperplane for the plot
    '''
    return (-w[0] * x   b   offset) / w[1]


def plot_now(ax, W,b):
    '''
    Visualise the results
    '''
    x0_1 = np.amin(X[:, 0])
    x0_2 = np.amax(X[:, 0])

    x1_1 = get_hyperplane_value(x0_1, W, b, 0)
    x1_2 = get_hyperplane_value(x0_2, W, b, 0)

    x1_1_m = get_hyperplane_value(x0_1, W, b, -1)
    x1_2_m = get_hyperplane_value(x0_2, W, b, -1)

    x1_1_p = get_hyperplane_value(x0_1, W, b, 1)
    x1_2_p = get_hyperplane_value(x0_2, W, b, 1)

    ax.plot([x0_1, x0_2], [x1_1, x1_2], "y--")
    ax.plot([x0_1, x0_2], [x1_1_m, x1_2_m], "k")
    ax.plot([x0_1, x0_2], [x1_1_p, x1_2_p], "k")

    x1_min = np.amin(X[:, 1])
    x1_max = np.amax(X[:, 1])
    ax.set_ylim([x1_min - 3, x1_max   3])
    
    ax.scatter(X[:, 0], X[:, 1], marker="o", c = y)
    return ax



X, y = datasets.make_blobs(n_samples=50, n_features=2, centers=2, cluster_std=1.05, random_state=40)
y = np.where(y == 0, -1, 1)


fig = plt.figure(figsize = (7,7))
ax = fig.add_subplot(1, 1, 1)

    
for i in range(50):
    
    W = np.random.randn(2)
    b = np.random.randn()
    
    ax.cla()
    ax = plot_now(ax, W, b)
    
    display(fig)    
    clear_output(wait = True)
    plt.pause(0.25) 

CodePudding user response:

It appears to me that you are trying to animate a figure, so you should use FuncAnimation. The basic principle with animations is that you initialize your lines, and later update the values.

from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation

def get_hyperplane_value(x, w, b, offset):
    '''
    Generate Hyperplane for the plot
    '''
    return (-w[0] * x   b   offset) / w[1]

def get_weights_bias(i):
    W = np.random.randn(2)
    b = np.random.randn()
    return W, b

def plot_now(i):
    # retrieve weights and bias at iteration i
    W, b = get_weights_bias(i)
    
    x0_1 = np.amin(X[:, 0])
    x0_2 = np.amax(X[:, 0])

    x1_1 = get_hyperplane_value(x0_1, W, b, 0)
    x1_2 = get_hyperplane_value(x0_2, W, b, 0)

    x1_1_m = get_hyperplane_value(x0_1, W, b, -1)
    x1_2_m = get_hyperplane_value(x0_2, W, b, -1)

    x1_1_p = get_hyperplane_value(x0_1, W, b, 1)
    x1_2_p = get_hyperplane_value(x0_2, W, b, 1)

    line1.set_data([x0_1, x0_2], [x1_1, x1_2])
    line2.set_data([x0_1, x0_2], [x1_1_m, x1_2_m])
    line3.set_data([x0_1, x0_2], [x1_1_p, x1_2_p])

    x1_min = np.amin(X[:, 1])
    x1_max = np.amax(X[:, 1])
    ax.set_ylim([x1_min - 3, x1_max   3])

X, y = datasets.make_blobs(n_samples=50, n_features=2, centers=2, cluster_std=1.05, random_state=40)
y = np.where(y == 0, -1, 1)

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
plt.scatter(X[:, 0], X[:, 1], marker="o", c = y) # ax.scatter

# initialize empty lines
line1, = ax.plot([], [], "y--")
line2, = ax.plot([], [], "k")
line3, = ax.plot([], [], "k")

# create an animation with 10 frames
anim = FuncAnimation(fig, plot_now, frames=range(10), repeat=False)
plt.show()
  • Related