Efficient GPU-computing simulation platform JAX-PF for differentiable phase field model
Abstract
We present JAX-PF, an open-source, GPU-accelerated, and differentiable Phase Field (PF) software package, supporting both explicit and implicit time stepping schemes. Leveraging the modern computing architecture JAX, JAX-PF achieves high performance through array programming and GPU acceleration, delivering ~5x speedup over PRISMS-PF with MPI (24 CPU cores) for systems with ~4.19 million degrees of freedom using explicit schemes, and scaling efficiently with implicit schemes for large-size problems. Furthermore, a key feature of JAX-PF is automatic differentiation (AD), eliminating manual derivations of free-energy functionals and Jacobians. Beyond forward simulations, JAX-PF demonstrates its potential in inverse design by providing sensitivities for gradient-based optimization. We demonstrate, for the first time, the calibration of PF material parameters using AD-based sensitivities, highlighting its capability for high-dimensional inverse problems. By combining efficiency, flexibility, and full differentiability, JAX-PF offers a fast, practical, and integrated tool for forward simulation and inverse design, advancing co-designing of material and manufacturing processes and supporting the goals of the Materials Genome Initiative.