Home > database >  Python - time difference (JAX library)
Python - time difference (JAX library)

Time:11-30

I'm trying to compare execution times between functions:

  • simpleExponentialSmoothing - which is my implementation of SES in JAX library
  • simpleExponentialSmoothingJax - as above, but boosted with JIT from JAX library
  • SimpleExpSmoothing - implementation from Statsmodels library

I have tried using %timeit, time and writing my own function to measure time using datetime, however I'm quite confused. My function to measure time and %timeit are returning the same exec time, however %time is showing much, much different exec time. I have found that %time checks only single run and is less accurate than %timeit, but how does it apply to asynchronous functions like those in JAX? Although I've blocked them until finishing calculations, I'm not sure if that is enough.

I need advice about this measure, which should I take as actual execution time?

%time

%time timeSeriesSes = simpleExponentialSmoothing(params, timeSeries, initState).block_until_ready()
%time timeSeriesSesJit = simpleExponentialSmoothingJit(params, timeSeries, initState).block_until_ready()
%time timeSeriesSesSm = SimpleExpSmoothing(timeSeries).fit()
CPU times: user 82.4 ms, sys: 4.03 ms, total: 86.4 ms
Wall time: 97.6 ms
CPU times: user 199 µs, sys: 0 ns, total: 199 µs
Wall time: 214 µs
CPU times: user 6.12 ms, sys: 0 ns, total: 6.12 ms
Wall time: 6.2 ms

%timeit

%timeit timeSeriesSes = simpleExponentialSmoothing(params, timeSeries, initState).block_until_ready()
%timeit timeSeriesSesJit = simpleExponentialSmoothingJit(params, timeSeries, initState).block_until_ready()
%timeit timeSeriesSesSm = SimpleExpSmoothing(timeSeries).fit()
48.8 ms ± 904 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
15.5 µs ± 222 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
3.4 ms ± 62.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

CodePudding user response:

For your JAX-specific question: using block_until_ready() should be enough to account for JAX's asynchronous execution.

Be careful also about JIT compilation: the first time you call a JIT-compiled function with arguments of a particular shape, the compilation time will affect the speed of execution. After that, the cached compilation will be used.

As to your more general question: the difference between %timeit and %time is covered in the IPython docs:

By default, timeit() temporarily turns off garbage collection during the timing. The advantage of this approach is that it makes independent timings more comparable. The disadvantage is that GC may be an important component of the performance of the function being measured.

(From https://docs.python.org/3/library/timeit.html#timeit.Timer.timeit)

So if you want to measure performance with Python garbage collection, and with only a single execution, use %time. If you want to measure performance without Python garbage collection, and with multiple executions for more statistical rigor, use %timeit.

  • Related