Mixed-precision CG¶
Mixed-precision conjugate-gradient solver.
Provides cg_mixed_precision with three modes:
‘float64’ : reference path (default; preserves existing behaviour).
‘float32’ : full single precision (memory-bandwidth + tensor-core friendly).
- ‘mixed’float32 matvec inside CG; float64 residual / convergence test
with iterative refinement to recover most of float64 accuracy.
The mixed mode follows the three-precision iterative refinement pattern of Carson & Higham 2018, “Accelerating the solution of linear systems by iterative refinement in three precisions”, SIAM J. Sci. Comput. 40(2).
-
phast.solvers.mixed_precision_cg.cg_mixed_precision(matvec, rhs, x0=
None, tol=1e-08, max_iter=1000, precision='float64', max_refine=5)[source]¶ Conjugate-gradient solve A x = rhs with optional reduced-precision matvec.
- Parameters:¶
- matvec : callable v -> A @ v (must accept dtype of v and return same dtype).¶
- rhs : torch.Tensor; its dtype defines the "high" working precision.¶
- x0 : optional initial guess (same dtype as rhs).¶
- tol : relative residual tolerance.¶
- precision : 'float64' | 'float32' | 'mixed'.¶
- max_refine : iterative-refinement steps for 'mixed' mode.¶
- Return type:¶
(x, iters, converged)