• marmaduke 11 hours ago |
    looks like a nice overview. i’ve implemented neural ODEs in Jax for low dimensional problems and it works well, but I keep looking for a good, fast, CPU-first implementation that is good for models that fit in cache and don’t require a GPU or big Torch/TF machinery.
    • sitkack 11 hours ago |
      • marmaduke 8 hours ago |
        no, wrote it by hand for use with my own Heun implementation, since it’s for use within stochastic delayed systems.

        jax is fun but as effective as i’d like for CPU

        • Iwan-Zotow an hour ago |
          Not as effective as I'd like?
      • yberreby 6 hours ago |
        Anecdotally, I used diffrax (and equinox) throughout last year after jumping between a few differential equation solvers in Python, for a project based on Dynamic Field Theory [1]. I only scratched the surface, but so far, it's been a pleasure to use, and it's quite fast. It also introduced me to equinox [2], by the same author, which I'm using to get the JAX-friendly equivalent of dataclasses.

        `vmap`-able differential equation solving is really cool.

        [1]: https://dynamicfieldtheory.org/ [2]: https://github.com/patrick-kidger/equinox