I want to plot points on the interval x in [0, 4]. My function performs something interesting for very small values, so I would like to create a non-linear scale that would use more space for smaller values of x. Logarithmic scale would be a great solution, but the problem is that my x-axis must include 0, which is not part of logarithmic axis.
So I considered using a power scale. After some googling I came across the following solution.
def stratify(ax, power=2):
f = lambda x: (x 1)**(1 / power)
f_inv = lambda y: y**power - 1
ax.set_xscale('function', functions=(f, f_inv))
x = np.linspace(0, 4, 100)
y = np.sqrt(x)
fig, ax = plt.subplots()
ax.plot(x, y)
stratify(ax, 2)
plt.show()
The function stratify
changes the x-scale of the plot to the square root function. This looks kind of correct. Below is a minimal example plot corresponding to the above code (not actual data).
I would like to have control over the nonlinearity in the x-scale, that is why I have introduced the power
parameter. However, when I change the power parameter to value different from 2, the plot does not change at all. This is very surprising for me. I would appreciate if somebody could advise me how I can control the extent of non-linearity in x-axis. If possible, I would like it even more non-linear, making the interval [0, 0.5] take half of the plot.
EDIT While the current solution by @Thomas works as intended, the plotting routine throws a lot of errors when one attempts to do anything with it. For example, I would like to insert extra ticks, as the original only has integer ticks for whatever reason. Inserting extra ticks via ax.set_xticks(ax.get_xticks() [0.5])
results in an error posx and posy should be finite values
. What is this error, and how can it be bypassed?
CodePudding user response:
For me, there's a change when switching from power=2
to power=10
. Just be careful to edit it at the right position, i.e. when calling stratify=X
.
Here's power=2
:
Here's power=10
:
However, here's another suggestion that is slightly more aggressive:
import numpy as np
import matplotlib.pyplot as plt
def stratify(ax, scale=1):
f = lambda x: np.log(x / scale 1)
f_inv = lambda y: scale * (np.exp(y) - 1)
ax.set_xscale('function', functions=(f, f_inv))
x = np.linspace(0, 4, 100)
y = np.sqrt(x)
fig, axs = plt.subplots(1, 3)
for i, scale in enumerate([10,1,0.1]):
ax = axs[i]
ax.set_title(f'Scale={scale}')
ax.plot(x, y)
stratify(ax, scale=scale)
plt.show()