Home > Software design >  adding a vertical line to a time series plot in python
adding a vertical line to a time series plot in python

Time:12-13

i am plotting time series data, which will be split to a training and test data set. Now, i would like to draw a verticcal line in the plot, that indicated where the training/test data split happens.

split_point indicates where the data should be plotted. 

df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv', parse_dates=['date'], index_col='date')

df
data_size=len(df)

split_point = data_size - data_size // 3
split_point

# Draw Plot
def plot_df(df, x, y, title="", xlabel='Date', ylabel='Value', dpi=100):
    plt.figure(figsize=(16,5), dpi=dpi)
    plt.plot(x, y, color='tab:red')
    plt.gca().set(title=title, xlabel=xlabel, ylabel=ylabel)
    plt.show()

plot_df(df, x=df.index, y=df.value, title='Monthly anti-diabetic drug sales in Australia from 1992 to 2008.') 

How can this be added to the plot? I tried using plt.axvline, but don't know how to go from the split point to the date. Any ideas?

plt.axvline(split_point)

CodePudding user response:

You almost there simply extract elemnts in split point which will be like this.

split =(df.iloc[[split_point]] )

Gives#

               value
date                
2002-11-01  13.28764

so date is the index. Extract index as follows.

split =(df.index[split_point] )

Gives #

2002-11-01 00:00:00

Plot uisng plt.axvline()

Complete code

import pandas as pd
import matplotlib.pyplot as plt
import datetime as dt

df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv', parse_dates=['date'], index_col='date')

df
data_size=len(df)

split_point = data_size - data_size // 3
print(split_point)
split =(df.index[split_point] )
print(split)

# Draw Plot
def plot_df(df, x, y, title="", xlabel='Date', ylabel='Value', dpi=100):
    plt.figure(figsize=(16,5), dpi=dpi)
    plt.plot(x, y, color='tab:red')
    plt.gca().set(title=title, xlabel=xlabel, ylabel=ylabel)
    

    plt.axvline((split))
    plt.show()

plot_df(df, x=df.index, y=df.value, title='Monthly anti-diabetic drug sales in Australia from 1992 to 2008.') 

Gives # enter image description here

  • Related