"""
Adaptive Mesh Refinement (AMR) via Newest Vertex Bisection (NVB).
Provides element-level refinement indicators, conforming mesh bisection,
and field interpolation routines for phase-field fracture simulations.
Algorithm overview
------------------
Newest Vertex Bisection is the standard AMR strategy for triangular meshes.
Each triangle is bisected by inserting a node at the midpoint of its
*refinement edge* (chosen as the longest edge). Conforming closure ensures
that every newly bisected edge is shared by two refined triangles, so
the resulting mesh has no hanging nodes.
Typical usage::
from phast.adaptive import (
compute_refinement_indicator, refine_mesh,
interpolate_field, interpolate_elem_field,
)
marked = compute_refinement_indicator(mesh, d)
if marked.any():
new_mesh, parent_map, child_map = refine_mesh(mesh, marked)
d = interpolate_field(old_mesh, new_mesh, d, parent_map)
H = interpolate_elem_field(old_mesh, new_mesh, H, child_map)
mesh = new_mesh
References
----------
- Mitchell, W.F. (1989). A comparison of adaptive refinement techniques
for elliptic problems. *ACM TOMS* 15(4), 326-347.
- Bartels, S. (2015). *Numerical Methods for Nonlinear PDEs*. Springer.
Chapter 4: Adaptive finite element methods.
"""
import torch
import numpy as np
from typing import Tuple, Dict, Iterable, List, Optional, Sequence
from ..core.mesh import FEMMesh
# -------------------------------------------------------------------- #
# 1. Composable refinement criteria (Galvis-style)
# -------------------------------------------------------------------- #
#
# These criterion helpers (originally landed in ``mesh_adaptivity.py``,
# #111 scaffold) return *which* elements to refine as element-ID lists.
# They are integrated here to live alongside the newest-vertex-bisection
# pipeline that consumes a boolean mask.
#
# Reference
# ---------
# Galvis et al. 2026 (Eng Fract Mech): FreeFEM++ ``adaptmesh()`` on the
# damage-gradient indicator with relative tol 0.04 and max-vertex cap.
[docs]
def damage_gradient_criterion(
d_field: torch.Tensor,
mesh: 'FEMMesh',
threshold: float = 0.1,
) -> List[int]:
"""Flag elements where the damage-gradient magnitude exceeds ``threshold``.
Element-level gradient is the standard P1 reconstruction
``grad_d_e = sum_j grad_phi_j * d_j`` over the three vertices.
Parameters
----------
d_field : torch.Tensor, shape (N,)
Nodal damage in [0, 1].
mesh : FEMMesh
Must expose ``elements`` (E, 3) and ``grad_phi`` (E, 3, 2).
threshold : float
Gradient-magnitude threshold (units: 1/length).
Returns
-------
list of int
Element indices satisfying ``||grad d|| > threshold``.
"""
elems = mesh.elements
d_elem = d_field[elems] # (E, 3)
gx = (mesh.grad_phi[:, :, 0] * d_elem).sum(dim=1) # (E,)
gy = (mesh.grad_phi[:, :, 1] * d_elem).sum(dim=1) # (E,)
grad_mag = torch.sqrt(gx * gx + gy * gy)
flagged = torch.where(grad_mag > threshold)[0]
return flagged.cpu().tolist()
[docs]
def crack_tip_neighborhood_criterion(
d_field: torch.Tensor,
mesh: 'FEMMesh',
radius: float = 3.0,
d_tip_threshold: float = 0.5,
) -> List[int]:
"""Flag elements within ``radius * h_min`` of any cracked element.
A "cracked" element is one whose maximum nodal damage exceeds
``d_tip_threshold``. Every element whose centroid lies within
``radius * mesh.h_min`` of any cracked-element centroid is flagged
(cracked elements themselves included).
Parameters
----------
d_field : torch.Tensor, shape (N,)
Nodal damage in [0, 1].
mesh : FEMMesh
Must expose ``nodes``, ``elements``, ``h_min``.
radius : float
Neighbourhood radius in units of ``h_min``.
d_tip_threshold : float
Damage value above which an element is treated as cracked.
Returns
-------
list of int
Element indices within the crack-tip neighbourhood.
"""
elems = mesh.elements
nodes = mesh.nodes
centroids = nodes[elems].mean(dim=1) # (E, 2)
d_elem = d_field[elems] # (E, 3)
is_cracked = d_elem.max(dim=1).values > d_tip_threshold
if not bool(is_cracked.any()):
return []
tip_centroids = centroids[is_cracked] # (T, 2)
# Pairwise distances (E, T); fine for scaffold-scale tests, the
# production driver should use a kd-tree or chunked computation.
diff = centroids.unsqueeze(1) - tip_centroids.unsqueeze(0)
dist = torch.linalg.norm(diff, dim=2) # (E, T)
min_dist = dist.min(dim=1).values # (E,)
cutoff = radius * float(mesh.h_min)
flagged = torch.where(min_dist <= cutoff)[0]
return flagged.cpu().tolist()
[docs]
def union_refine_set(criteria_results: Iterable[Sequence[int]]) -> List[int]:
"""Union of element-ID lists from multiple criteria.
Parameters
----------
criteria_results : iterable of int sequences
Output of one or more criterion functions.
Returns
-------
list of int
Sorted unique element IDs.
"""
merged: set = set()
for r in criteria_results:
merged.update(int(i) for i in r)
return sorted(merged)
# -------------------------------------------------------------------- #
# 2. Refinement indicator (boolean mask consumed by refine_mesh)
# -------------------------------------------------------------------- #
[docs]
def compute_refinement_indicator(
mesh: FEMMesh,
d: torch.Tensor,
grad_d_threshold: float = 0.3,
d_threshold: float = 0.5,
*,
use_damage: bool = True,
use_gradient: bool = True,
use_neighborhood: bool = False,
neighborhood_radius: float = 3.0,
) -> torch.Tensor:
"""Compose refinement criteria into a boolean mask over elements.
By default an element is marked for refinement if *either*:
- The maximum nodal damage among its three vertices exceeds
``d_threshold`` (element lies within the crack region), or
- The magnitude of the element-level damage gradient exceeds
``grad_d_threshold`` (element lies at the crack front where
the damage field transitions rapidly).
The optional ``use_neighborhood`` flag adds a third criterion that
flags elements within ``neighborhood_radius * mesh.h_min`` of any
cracked element (see ``crack_tip_neighborhood_criterion``).
Parameters
----------
mesh : FEMMesh
Current mesh (must have ``elements``, ``grad_phi``).
d : torch.Tensor, shape (N,)
Nodal damage field, values in [0, 1].
grad_d_threshold : float
Gradient magnitude threshold for crack-front marking.
d_threshold : float
Maximum nodal damage threshold for crack-body marking.
use_damage, use_gradient, use_neighborhood : bool
Toggles for each criterion. The defaults reproduce the legacy
behaviour (damage OR gradient).
neighborhood_radius : float
Radius (in units of ``h_min``) for the optional crack-tip
neighborhood criterion.
Returns
-------
marked : torch.Tensor, shape (E,), dtype=torch.bool
True for each element that should be refined.
"""
E = mesh.elements.shape[0]
elems = mesh.elements # (E, 3)
d_elem = d[elems] # (E, 3)
marked = torch.zeros(E, dtype=torch.bool, device=d.device)
# Criterion 1: max nodal damage within element exceeds threshold
if use_damage:
max_d_elem = d_elem.max(dim=1).values # (E,)
marked = marked | (max_d_elem > d_threshold)
# Criterion 2: gradient magnitude exceeds threshold
# Element-level gradient: grad_d_e = sum_j grad_phi_j * d_j
if use_gradient:
grad_d_x = (mesh.grad_phi[:, :, 0] * d_elem).sum(dim=1) # (E,)
grad_d_y = (mesh.grad_phi[:, :, 1] * d_elem).sum(dim=1) # (E,)
grad_d_mag = torch.sqrt(grad_d_x ** 2 + grad_d_y ** 2) # (E,)
marked = marked | (grad_d_mag > grad_d_threshold)
# Criterion 3 (opt-in): crack-tip neighborhood
if use_neighborhood:
flagged = crack_tip_neighborhood_criterion(
d, mesh, radius=neighborhood_radius,
d_tip_threshold=d_threshold,
)
if flagged:
idx = torch.tensor(flagged, dtype=torch.long, device=d.device)
marked[idx] = True
return marked
# -------------------------------------------------------------------- #
# 2. Edge-to-element adjacency
# -------------------------------------------------------------------- #
def _build_edge_to_elem(elements: torch.Tensor) -> Dict[Tuple[int, int], list]:
"""Build a mapping from sorted edge tuples to element indices.
Vectorized with NumPy: builds all edges, sorts, and groups by unique edge.
Parameters
----------
elements : (E, 3) long tensor
Returns
-------
edge_to_elem : dict mapping (ni, nj) -> [elem_idx, ...]
Edge nodes are sorted so (min, max).
"""
import numpy as np
elems_np = elements.cpu().numpy()
E = elems_np.shape[0]
# Build all 3E edges: (node_min, node_max, elem_idx)
e0, e1, e2 = elems_np[:, 0], elems_np[:, 1], elems_np[:, 2]
edges_a = np.stack([e0, e1], axis=1) # edge 0-1
edges_b = np.stack([e1, e2], axis=1) # edge 1-2
edges_c = np.stack([e2, e0], axis=1) # edge 2-0
all_edges = np.concatenate([edges_a, edges_b, edges_c], axis=0) # (3E, 2)
all_edges.sort(axis=1) # sort each edge so (min, max)
elem_ids = np.tile(np.arange(E), 3) # (3E,)
# Group by unique edge using structured array
edge_keys = all_edges[:, 0].astype(np.int64) * (elems_np.max() + 1) + all_edges[:, 1]
sort_idx = np.argsort(edge_keys)
sorted_keys = edge_keys[sort_idx]
sorted_edges = all_edges[sort_idx]
sorted_elems = elem_ids[sort_idx]
# Find boundaries between unique edges
breaks = np.concatenate([[0], np.where(np.diff(sorted_keys) != 0)[0] + 1, [len(sorted_keys)]])
edge_to_elem: Dict[Tuple[int, int], list] = {}
for i in range(len(breaks) - 1):
s, e = breaks[i], breaks[i + 1]
key = (int(sorted_edges[s, 0]), int(sorted_edges[s, 1]))
edge_to_elem[key] = sorted_elems[s:e].tolist()
return edge_to_elem
def _longest_edge(nodes: torch.Tensor, tri: torch.Tensor) -> int:
"""Return the local index (0, 1, 2) of the vertex opposite the longest edge.
The longest edge of triangle (v0, v1, v2) is determined, and the
local index of the vertex *opposite* that edge is returned. This
vertex becomes the "newest vertex" in the NVB convention, and the
longest edge is the refinement (bisection) edge.
Local edge convention:
- Edge 0: v1-v2 (opposite v0)
- Edge 1: v0-v2 (opposite v1)
- Edge 2: v0-v1 (opposite v2)
Parameters
----------
nodes : (N, 2) node coordinates
tri : (3,) long tensor of node indices
Returns
-------
int : local vertex index opposite the longest edge (0, 1, or 2)
"""
p = nodes[tri] # (3, 2)
# Edge lengths squared (avoid sqrt for comparison)
e0_sq = ((p[1] - p[2]) ** 2).sum() # opposite v0
e1_sq = ((p[0] - p[2]) ** 2).sum() # opposite v1
e2_sq = ((p[0] - p[1]) ** 2).sum() # opposite v2
lengths_sq = torch.stack([e0_sq, e1_sq, e2_sq])
return int(lengths_sq.argmax().item())
# -------------------------------------------------------------------- #
# 3. Core refinement
# -------------------------------------------------------------------- #
[docs]
def refine_mesh(
mesh: FEMMesh,
marked_elements: torch.Tensor,
) -> Tuple['FEMMesh', Dict[int, Tuple[int, int]], Dict[int, int]]:
"""Refine the mesh by newest-vertex bisection with conforming closure.
Each marked triangle is bisected along its longest edge. Neighbors
sharing the bisected edge are also refined to maintain conformity
(no hanging nodes). The process iterates until no new forced
refinements are needed.
Parameters
----------
mesh : FEMMesh
Current mesh.
marked_elements : torch.Tensor, shape (E,), dtype=torch.bool
Elements to refine (from ``compute_refinement_indicator``).
Returns
-------
new_mesh : FEMMesh
Refined mesh with all precomputed quantities.
parent_map : dict
``{new_node_idx: (old_node_a, old_node_b)}`` for midpoint nodes.
Existing nodes are NOT in this dict.
child_map : dict
``{new_elem_idx: old_elem_idx}`` for every new element.
Unrefined elements map to themselves.
"""
nodes_np = mesh.nodes.cpu().numpy().copy()
elems_np = mesh.elements.cpu().numpy().copy()
n_old_nodes = nodes_np.shape[0]
n_old_elems = elems_np.shape[0]
# Collect which boundary node-sets each node belongs to, so we can
# tag midpoint nodes that sit on a boundary edge.
node_to_sets: Dict[int, set] = {}
for set_name, idx_tensor in mesh.node_sets.items():
for ni in idx_tensor.cpu().numpy().tolist():
node_to_sets.setdefault(ni, set()).add(set_name)
# --- Phase 1: propagate marks for conforming closure ---
# Build edge -> elem adjacency
edge_to_elem = _build_edge_to_elem(mesh.elements)
# Determine the refinement edge for every element (vectorized)
nodes_np = mesh.nodes.cpu().numpy()
p = nodes_np[elems_np] # (E, 3, 2)
e0_sq = ((p[:, 1] - p[:, 2]) ** 2).sum(axis=1) # opposite v0
e1_sq = ((p[:, 0] - p[:, 2]) ** 2).sum(axis=1) # opposite v1
e2_sq = ((p[:, 0] - p[:, 1]) ** 2).sum(axis=1) # opposite v2
ref_vertex = np.stack([e0_sq, e1_sq, e2_sq], axis=1).argmax(axis=1).astype(np.int32)
# Refinement edge node-pair per element
# If ref_vertex[ei] == k, the bisection edge connects the other two vertices.
_other = {0: (1, 2), 1: (0, 2), 2: (0, 1)}
def _ref_edge(ei):
a_loc, b_loc = _other[ref_vertex[ei]]
na, nb = int(elems_np[ei, a_loc]), int(elems_np[ei, b_loc])
return (min(na, nb), max(na, nb))
to_refine = set(int(i) for i in torch.where(marked_elements)[0].cpu().numpy())
# Propagate: if an element's refinement edge is shared and the neighbor
# also needs that edge bisected, mark the neighbor.
#
# Conforming closure correctness: when neighbor nj is marked because it
# shares the refinement edge of some marked element ei, nj's own
# refinement edge (longest edge) may differ from the shared edge. In
# the next while-loop iteration, nj is included in to_refine, so
# _ref_edge(nj) is evaluated and *its* refinement-edge neighbors are
# checked and marked if needed. This transitive propagation continues
# until no new marks are added, ensuring every marked element's
# refinement edge is shared only with another marked element (or is on
# the boundary). NVB then bisects each marked element along its own
# refinement edge, and the shared midpoint guarantees no hanging nodes.
max_closure_iters = n_old_elems # cannot mark more elements than exist
changed = True
for _closure_iter in range(max_closure_iters):
if not changed:
break
changed = False
new_marks = set()
for ei in to_refine:
edge = _ref_edge(ei)
for nj in edge_to_elem.get(edge, []):
if nj != ei and nj not in to_refine:
# Neighbor shares the bisected edge — mark it for
# refinement. Even if this edge is not nj's refinement
# edge, nj must be bisected to avoid a hanging node.
# nj's own refinement edge neighbors will be checked
# in subsequent while-loop iterations.
new_marks.add(nj)
if new_marks:
to_refine |= new_marks
changed = True
# --- Phase 2: bisect each marked element ---
# Bookkeeping for midpoint nodes (avoid duplicates across shared edges)
edge_to_midnode: Dict[Tuple[int, int], int] = {}
new_nodes_list = list(nodes_np) # grow by appending midpoints
parent_map: Dict[int, Tuple[int, int]] = {}
def _get_or_create_midnode(na: int, nb: int) -> int:
edge = (min(na, nb), max(na, nb))
if edge in edge_to_midnode:
return edge_to_midnode[edge]
mid = 0.5 * (new_nodes_list[na] + new_nodes_list[nb])
mid_idx = len(new_nodes_list)
new_nodes_list.append(mid)
edge_to_midnode[edge] = mid_idx
parent_map[mid_idx] = (edge[0], edge[1])
return mid_idx
new_elems_list = []
child_map: Dict[int, int] = {}
new_elem_idx = 0
for ei in range(n_old_elems):
tri = elems_np[ei] # (3,) original vertices
if ei not in to_refine:
# Keep element as-is
new_elems_list.append(tri.copy())
child_map[new_elem_idx] = ei
new_elem_idx += 1
else:
# Bisect along longest edge
rv = ref_vertex[ei]
a_loc, b_loc = _other[rv]
v_apex = int(tri[rv])
v_a = int(tri[a_loc])
v_b = int(tri[b_loc])
v_mid = _get_or_create_midnode(v_a, v_b)
# Two children:
# child 0: (apex, v_a, mid)
# child 1: (apex, mid, v_b)
new_elems_list.append(np.array([v_apex, v_a, v_mid], dtype=np.int64))
child_map[new_elem_idx] = ei
new_elem_idx += 1
new_elems_list.append(np.array([v_apex, v_mid, v_b], dtype=np.int64))
child_map[new_elem_idx] = ei
new_elem_idx += 1
# --- Phase 3: assign boundary node-sets to midpoint nodes ---
new_node_sets: Dict[str, list] = {
name: idx_tensor.cpu().numpy().tolist()
for name, idx_tensor in mesh.node_sets.items()
}
for mid_idx, (na, nb) in parent_map.items():
# A midpoint inherits any node-set that BOTH endpoints share
sets_a = node_to_sets.get(na, set())
sets_b = node_to_sets.get(nb, set())
common_sets = sets_a & sets_b
for s in common_sets:
new_node_sets[s].append(mid_idx)
# Convert node-set lists to tensors
node_sets_tensors = {}
for name, idx_list in new_node_sets.items():
if idx_list:
node_sets_tensors[name] = torch.tensor(
sorted(idx_list), dtype=torch.long, device=mesh.device)
# --- Phase 4: build new FEMMesh from tensors ---
new_nodes_tensor = torch.tensor(
np.array(new_nodes_list), dtype=mesh.dtype, device=mesh.device)
new_elems_tensor = torch.tensor(
np.stack(new_elems_list), dtype=torch.long, device=mesh.device)
new_mesh = FEMMesh.from_tensors(
nodes=new_nodes_tensor,
elements=new_elems_tensor,
node_sets=node_sets_tensors,
device=mesh.device,
dtype=mesh.dtype,
)
return new_mesh, parent_map, child_map
# -------------------------------------------------------------------- #
# 4. Field interpolation (nodal)
# -------------------------------------------------------------------- #
[docs]
def interpolate_field(
old_mesh: FEMMesh,
new_mesh: FEMMesh,
field: torch.Tensor,
parent_map: Dict[int, Tuple[int, int]],
) -> torch.Tensor:
"""Interpolate a nodal field from the old mesh to the refined mesh.
Existing nodes retain their original values. New midpoint nodes
receive the average of their two parent node values.
Parameters
----------
old_mesh : FEMMesh
Mesh before refinement.
new_mesh : FEMMesh
Mesh after refinement.
field : torch.Tensor
Nodal field on ``old_mesh``. Shape ``(N_old,)`` for scalar,
``(N_old, D)`` for vector fields.
parent_map : dict
``{new_node_idx: (old_node_a, old_node_b)}`` from ``refine_mesh``.
Returns
-------
new_field : torch.Tensor
Field on ``new_mesh`` with same trailing dimensions as ``field``.
"""
is_vector = field.dim() > 1
n_new = new_mesh.n_nodes
n_old = old_mesh.n_nodes
if is_vector:
new_field = torch.zeros(
n_new, field.shape[1], dtype=field.dtype, device=field.device)
else:
new_field = torch.zeros(n_new, dtype=field.dtype, device=field.device)
# Copy values for existing (old) nodes
new_field[:n_old] = field
# Interpolate midpoint nodes
for mid_idx, (na, nb) in parent_map.items():
new_field[mid_idx] = 0.5 * (field[na] + field[nb])
return new_field
# -------------------------------------------------------------------- #
# 5. Field interpolation (element)
# -------------------------------------------------------------------- #
[docs]
def interpolate_elem_field(
old_mesh: FEMMesh,
new_mesh: FEMMesh,
field: torch.Tensor,
child_map: Dict[int, int],
) -> torch.Tensor:
"""Interpolate an element field from the old mesh to the refined mesh.
Children inherit their parent element's value. This is appropriate
for piecewise-constant element data such as the history variable H.
Parameters
----------
old_mesh : FEMMesh
Mesh before refinement.
new_mesh : FEMMesh
Mesh after refinement.
field : torch.Tensor
Element field on ``old_mesh``, shape ``(E_old,)`` or ``(E_old, D)``.
child_map : dict
``{new_elem_idx: old_elem_idx}`` from ``refine_mesh``.
Returns
-------
new_field : torch.Tensor
Field on ``new_mesh`` with same trailing dimensions as ``field``.
"""
is_vector = field.dim() > 1
n_new = new_mesh.n_elems
if is_vector:
new_field = torch.zeros(
n_new, field.shape[1], dtype=field.dtype, device=field.device)
else:
new_field = torch.zeros(n_new, dtype=field.dtype, device=field.device)
for new_ei, old_ei in child_map.items():
new_field[new_ei] = field[old_ei]
return new_field