Source code for kapaow.basis

"""How to extend a pseudoatomic basis set with additional orbitals."""

from __future__ import annotations

from abc import ABC, abstractmethod
from enum import IntEnum
from pathlib import Path
from typing import Any, Self

from pydantic import field_validator, model_validator

from kapaow.pydantic import BaseModel

__all__: list[str] = [
    "AngularMomentum",
    "AtomicBasis",
    "PseudoatomicBasis",
    "Subshell",
    "ordered_subshells",
]


[docs] class AngularMomentum(IntEnum): """Angular momentum quantum numbers.""" S = 0 P = 1 D = 2 F = 3 G = 4 def __add__(self, other: int | AngularMomentum) -> AngularMomentum: if isinstance(other, AngularMomentum): other = other.value return AngularMomentum(self.value + other)
class Basis(BaseModel, ABC): """A basis set.""" @property @abstractmethod def l_max(self) -> AngularMomentum: """The maximum angular momentum quantum number in the basis set.""" @property @abstractmethod def n_max(self) -> int: """The maximum principal quantum number in the basis set."""
[docs] class PseudoatomicBasis(Basis): """A pseudoatomic basis set. Only need to keep track of the number of orbitals per angular momentum channel. """ number_of_orbitals: dict[AngularMomentum, int]
[docs] @field_validator("number_of_orbitals", mode="before") @classmethod def ensure_all_channels_present( cls, v: dict[AngularMomentum, int] ) -> dict[AngularMomentum, int]: """Ensure all angular momentum channels are present in the dictionary.""" for ang_mtm in AngularMomentum: if ang_mtm not in v: v[ang_mtm] = 0 return v
[docs] @field_validator("number_of_orbitals", mode="before") @classmethod def coerce_keys(cls, v: Any) -> Any: """Coerce integer keys to AngularMomentum.""" if isinstance(v, dict): coerced_v: dict[AngularMomentum, int] = { AngularMomentum(k) if isinstance(k, int) else k: val for k, val in v.items() } return coerced_v return v
@property def l_max(self) -> AngularMomentum: """The maximum angular momentum quantum number in the basis set. i.e. return the largest l for which number_of_orbitals[l] > 0 """ return max( (l for l, count in self.number_of_orbitals.items() if count > 0), key=lambda l: l.value ) @property def n_max(self) -> int: """The maximum principal quantum number in the basis set. Note that this is for the pseudo-wavefunction, so n_max is 1 if the basis has only 2s and 2p orbitals. """ return max(self.number_of_orbitals.values()) if self.number_of_orbitals else 0 def __len__(self) -> int: return sum(self.number_of_orbitals.values())
[docs] def extend( self, s: int = 0, p: int = 0, d: int = 0, f: int = 0, g: int = 0 ) -> PseudoatomicBasis: """Return a new PseudoatomicBasis with added orbitals.""" new_number_of_orbitals = self.number_of_orbitals.copy() new_number_of_orbitals[AngularMomentum.S] += s new_number_of_orbitals[AngularMomentum.P] += p new_number_of_orbitals[AngularMomentum.D] += d new_number_of_orbitals[AngularMomentum.F] += f new_number_of_orbitals[AngularMomentum.G] += g return PseudoatomicBasis(number_of_orbitals=new_number_of_orbitals)
@property def total_number_of_orbitals(self) -> int: """Total number of orbitals per atom, accounting for m-degeneracy.""" return sum(count * (2 * l.value + 1) for l, count in self.number_of_orbitals.items()) @property def l_values(self) -> list[int]: """List of l values in the basis set, repeated according to the number of orbitals.""" l_vals: list[int] = [] for l in AngularMomentum: l_vals += [l.value] * self.number_of_orbitals[l] return l_vals
[docs] class Subshell(BaseModel): """A subshell in a pseudoatomic basis set.""" n: int l: AngularMomentum
[docs] @field_validator("l", mode="before") @classmethod def coerce_l(cls, v: Any) -> Any: """Coerce integer l to AngMtm.""" if isinstance(v, int): return AngularMomentum(v) return v
[docs] @model_validator(mode="after") def validate_n_l(self) -> Self: """Validate that n and l are consistent.""" if self.n <= self.l.value: raise ValueError( f"Invalid subshell with n={self.n} and l={self.l}: n must be greater than l" ) return self
# Order of subshells to add following the Madelung rule ordered_subshells = [ Subshell(n=1, l=AngularMomentum.S), # 1s Subshell(n=2, l=AngularMomentum.S), # 2s Subshell(n=2, l=AngularMomentum.P), # 2p Subshell(n=3, l=AngularMomentum.S), # 3s Subshell(n=3, l=AngularMomentum.P), # 3p Subshell(n=4, l=AngularMomentum.S), # 4s Subshell(n=3, l=AngularMomentum.D), # 3d Subshell(n=4, l=AngularMomentum.P), # 4p Subshell(n=5, l=AngularMomentum.S), # 5s Subshell(n=4, l=AngularMomentum.D), # 4d Subshell(n=5, l=AngularMomentum.P), # 5p Subshell(n=6, l=AngularMomentum.S), # 6s Subshell(n=4, l=AngularMomentum.F), # 4f Subshell(n=5, l=AngularMomentum.D), # 5d Subshell(n=6, l=AngularMomentum.P), # 6p Subshell(n=7, l=AngularMomentum.S), # 7s Subshell(n=5, l=AngularMomentum.F), # 5f Subshell(n=6, l=AngularMomentum.D), # 6d Subshell(n=7, l=AngularMomentum.P), # 7p Subshell(n=8, l=AngularMomentum.S), # 8s ] def _valence_subshells_from_z_valence(*, element: str, z_valence: float) -> list[Subshell]: """Reconstruct the valence subshells of a neutral-atom Madelung filling. Used when a UPF advertises ``number_of_wfc=0`` and therefore has no ``<PP_CHI>`` blocks to read the baseline basis from. We fill :data:`ordered_subshells` with the neutral atom's Z electrons, then peel subshells off the outermost end until their electron count sums to ``z_valence`` -- those are the subshells the pseudo treats as valence. """ from ase.data import atomic_numbers z = atomic_numbers[element.strip()] z_val_int = round(z_valence) if abs(z_valence - z_val_int) > 1e-3: raise ValueError(f"z_valence={z_valence} is not (close to) an integer") # Fill subshells in Madelung order until we have placed z electrons. filled: list[tuple[Subshell, int]] = [] remaining = z for subshell in ordered_subshells: if remaining <= 0: break capacity = 2 * (2 * subshell.l + 1) electrons = min(capacity, remaining) filled.append((subshell, electrons)) remaining -= electrons if remaining > 0: raise ValueError( f"ordered_subshells does not extend far enough to fill {element} " f"(Z={z}); please add more entries." ) # Peel the outermost subshells off until they sum to z_valence. valence: list[Subshell] = [] used = 0 for subshell, electrons in reversed(filled): if used >= z_val_int: break valence.append(subshell) used += electrons if used != z_val_int: raise ValueError( f"Cannot partition {element} (Z={z}) Madelung filling into a " f"valence shell of {z_val_int} electrons (got {used}); the " "pseudo's valence partition does not align with Madelung order." ) return list(reversed(valence))
[docs] class AtomicBasis(Basis): """An atomic basis set. Need to keep track of the (n, l) values of each subshell. """ subshells: list[Subshell] @property def l_max(self) -> AngularMomentum: """The maximum angular momentum quantum number in the basis set.""" ang_mtms = [s.l for s in self.subshells] return max(ang_mtms, key=lambda l: l.value) @property def n_max(self) -> int: """The maximum principal quantum number in the basis set.""" n_values = [s.n for s in self.subshells] return max(n_values) if n_values else 0
[docs] def to_pseudoatomic_basis(self) -> PseudoatomicBasis: """Convert to a PseudoatomicBasis.""" number_of_orbitals: dict[AngularMomentum, int] = {} for subshell in self.subshells: if subshell.l not in number_of_orbitals: number_of_orbitals[subshell.l] = 0 number_of_orbitals[subshell.l] += 1 return PseudoatomicBasis(number_of_orbitals=number_of_orbitals)
[docs] @classmethod def from_upf(cls, upf_path: Path) -> AtomicBasis: """Construct an AtomicBasis from a UPF pseudopotential file. If the UPF ships pseudoatomic wavefunctions (``<PP_CHI>`` blocks), the basis is read directly. Otherwise (e.g. SG15 ONCV, where ``PP_HEADER`` advertises ``number_of_wfc=0``), the baseline is reconstructed from the neutral-atom Madelung filling, keeping only the outermost subshells whose electron count sums to ``z_valence``. """ from upf_tools import UPFDict upf_dict = UPFDict.from_upf(upf_path) header = upf_dict["header"] chi_blocks = upf_dict.get("pswfc", {}).get("chi") if "pswfc" in upf_dict else None if chi_blocks: return cls(subshells=[Subshell(n=chi["n"], l=chi["l"]) for chi in chi_blocks]) return cls( subshells=_valence_subshells_from_z_valence( element=str(header["element"]).strip(), z_valence=float(header["z_valence"]), ) )
[docs] def extend(self, subshells: list[Subshell]) -> AtomicBasis: """Return a new AtomicBasis with an added subshell.""" return AtomicBasis(subshells=self.subshells + subshells)
def __contains__(self, subshell: Subshell) -> bool: return subshell in self.subshells def __len__(self) -> int: return len(self.subshells)