looks like a nice overview. i’ve implemented neural ODEs in Jax for low dimensional problems and it works well, but I keep looking for a good, fast, CPU-first implementation that is good for models that fit in cache and don’t require a GPU or big Torch/TF machinery.
Did you use https://github.com/patrick-kidger/diffrax ?
JAX Talk: Diffrax https://www.youtube.com/watch?v=Jy5Jw8hNiAQ
Anecdotally, I used diffrax (and equinox) throughout last year after jumping between a few differential equation solvers in Python, for a project based on Dynamic Field Theory [1]. I only scratched the surface, but so far, it's been a pleasure to use, and it's quite fast. It also introduced me to equinox [2], by the same author, which I'm using to get the JAX-friendly equivalent of dataclasses.
`vmap`-able differential equation solving is really cool.
[1]: https://dynamicfieldtheory.org/ [2]: https://github.com/patrick-kidger/equinox
Thanks, that looks neat.
Kidger's thesis is wonderful https://arxiv.org/abs/2202.02435
no, wrote it by hand for use with my own Heun implementation, since it’s for use within stochastic delayed systems.
jax is fun but as effective as i’d like for CPU
Not as effective as I'd like?
ha, yeah, thanks.
How would you describe what a neural ODE is in the simplest possible terms? Let's say I know what an NN and a DE are :).
classic NN takes a vector of data through layers to make a prediction. Backprop adjusts network weights till predictions are right. These network weights form a vector, and training changes this vector till it hits values that mean "trained network".
Neural ODE reframes this: instead of focusing on the weights, focus on how they change. It sees training as finding a path from untrained to trained state. At each step, it uses ODE solvers to compute the next state, continuing for N steps till it reaches values matching training data. This gives you the solution for the trained network.
Pretty cool approach, looking more into it, thank you!