Mixed-precision CG

Mixed-precision conjugate-gradient solver.

Provides cg_mixed_precision with three modes:

  • ‘float64’ : reference path (default; preserves existing behaviour).

  • ‘float32’ : full single precision (memory-bandwidth + tensor-core friendly).

  • ‘mixed’float32 matvec inside CG; float64 residual / convergence test

    with iterative refinement to recover most of float64 accuracy.

The mixed mode follows the three-precision iterative refinement pattern of Carson & Higham 2018, “Accelerating the solution of linear systems by iterative refinement in three precisions”, SIAM J. Sci. Comput. 40(2).

phast.solvers.mixed_precision_cg.cg_mixed_precision(matvec, rhs, x0=None, tol=1e-08, max_iter=1000, precision='float64', max_refine=5)[source]

Conjugate-gradient solve A x = rhs with optional reduced-precision matvec.

Parameters:
matvec : callable v -> A @ v (must accept dtype of v and return same dtype).

rhs : torch.Tensor; its dtype defines the "high" working precision.

x0 : optional initial guess (same dtype as rhs).

tol : relative residual tolerance.

precision : 'float64' | 'float32' | 'mixed'.

max_refine : iterative-refinement steps for 'mixed' mode.

Return type:

(x, iters, converged)