TensorFlow Probability on JAX
https://www.tensorflow.org/probability/examples/TensorFlow_Probability_on_JAX
Run in Google Colab
View source on GitHub
Download notebook
TensorFlow Probability (TFP) is a library for probabilistic reasoning and statistical analysis that now works on JAX! For those not familiar, JAX is a library for accelerated numerical computing based on composable function transformations.
We have ported a lot of TFP's most useful functionality to JAX while preserving the abstractions and APIs that many TFP users are now comfortable with.