"""How to extend a pseudoatomic basis set with additional orbitals."""
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import IntEnum
from pathlib import Path
from typing import Any, Self
from pydantic import field_validator, model_validator
from kapaow.pydantic import BaseModel
__all__: list[str] = [
"AngularMomentum",
"AtomicBasis",
"PseudoatomicBasis",
"Subshell",
"ordered_subshells",
]
[docs]
class AngularMomentum(IntEnum):
"""Angular momentum quantum numbers."""
S = 0
P = 1
D = 2
F = 3
G = 4
def __add__(self, other: int | AngularMomentum) -> AngularMomentum:
if isinstance(other, AngularMomentum):
other = other.value
return AngularMomentum(self.value + other)
class Basis(BaseModel, ABC):
"""A basis set."""
@property
@abstractmethod
def l_max(self) -> AngularMomentum:
"""The maximum angular momentum quantum number in the basis set."""
@property
@abstractmethod
def n_max(self) -> int:
"""The maximum principal quantum number in the basis set."""
[docs]
class PseudoatomicBasis(Basis):
"""A pseudoatomic basis set.
Only need to keep track of the number of orbitals per angular momentum channel.
"""
number_of_orbitals: dict[AngularMomentum, int]
[docs]
@field_validator("number_of_orbitals", mode="before")
@classmethod
def ensure_all_channels_present(
cls, v: dict[AngularMomentum, int]
) -> dict[AngularMomentum, int]:
"""Ensure all angular momentum channels are present in the dictionary."""
for ang_mtm in AngularMomentum:
if ang_mtm not in v:
v[ang_mtm] = 0
return v
[docs]
@field_validator("number_of_orbitals", mode="before")
@classmethod
def coerce_keys(cls, v: Any) -> Any:
"""Coerce integer keys to AngularMomentum."""
if isinstance(v, dict):
coerced_v: dict[AngularMomentum, int] = {
AngularMomentum(k) if isinstance(k, int) else k: val for k, val in v.items()
}
return coerced_v
return v
@property
def l_max(self) -> AngularMomentum:
"""The maximum angular momentum quantum number in the basis set.
i.e. return the largest l for which number_of_orbitals[l] > 0
"""
return max(
(l for l, count in self.number_of_orbitals.items() if count > 0), key=lambda l: l.value
)
@property
def n_max(self) -> int:
"""The maximum principal quantum number in the basis set.
Note that this is for the pseudo-wavefunction, so n_max is 1 if the basis has only 2s and
2p orbitals.
"""
return max(self.number_of_orbitals.values()) if self.number_of_orbitals else 0
def __len__(self) -> int:
return sum(self.number_of_orbitals.values())
[docs]
def extend(
self, s: int = 0, p: int = 0, d: int = 0, f: int = 0, g: int = 0
) -> PseudoatomicBasis:
"""Return a new PseudoatomicBasis with added orbitals."""
new_number_of_orbitals = self.number_of_orbitals.copy()
new_number_of_orbitals[AngularMomentum.S] += s
new_number_of_orbitals[AngularMomentum.P] += p
new_number_of_orbitals[AngularMomentum.D] += d
new_number_of_orbitals[AngularMomentum.F] += f
new_number_of_orbitals[AngularMomentum.G] += g
return PseudoatomicBasis(number_of_orbitals=new_number_of_orbitals)
@property
def total_number_of_orbitals(self) -> int:
"""Total number of orbitals per atom, accounting for m-degeneracy."""
return sum(count * (2 * l.value + 1) for l, count in self.number_of_orbitals.items())
@property
def l_values(self) -> list[int]:
"""List of l values in the basis set, repeated according to the number of orbitals."""
l_vals: list[int] = []
for l in AngularMomentum:
l_vals += [l.value] * self.number_of_orbitals[l]
return l_vals
[docs]
class Subshell(BaseModel):
"""A subshell in a pseudoatomic basis set."""
n: int
l: AngularMomentum
[docs]
@field_validator("l", mode="before")
@classmethod
def coerce_l(cls, v: Any) -> Any:
"""Coerce integer l to AngMtm."""
if isinstance(v, int):
return AngularMomentum(v)
return v
[docs]
@model_validator(mode="after")
def validate_n_l(self) -> Self:
"""Validate that n and l are consistent."""
if self.n <= self.l.value:
raise ValueError(
f"Invalid subshell with n={self.n} and l={self.l}: n must be greater than l"
)
return self
# Order of subshells to add following the Madelung rule
ordered_subshells = [
Subshell(n=1, l=AngularMomentum.S), # 1s
Subshell(n=2, l=AngularMomentum.S), # 2s
Subshell(n=2, l=AngularMomentum.P), # 2p
Subshell(n=3, l=AngularMomentum.S), # 3s
Subshell(n=3, l=AngularMomentum.P), # 3p
Subshell(n=4, l=AngularMomentum.S), # 4s
Subshell(n=3, l=AngularMomentum.D), # 3d
Subshell(n=4, l=AngularMomentum.P), # 4p
Subshell(n=5, l=AngularMomentum.S), # 5s
Subshell(n=4, l=AngularMomentum.D), # 4d
Subshell(n=5, l=AngularMomentum.P), # 5p
Subshell(n=6, l=AngularMomentum.S), # 6s
Subshell(n=4, l=AngularMomentum.F), # 4f
Subshell(n=5, l=AngularMomentum.D), # 5d
Subshell(n=6, l=AngularMomentum.P), # 6p
Subshell(n=7, l=AngularMomentum.S), # 7s
Subshell(n=5, l=AngularMomentum.F), # 5f
Subshell(n=6, l=AngularMomentum.D), # 6d
Subshell(n=7, l=AngularMomentum.P), # 7p
Subshell(n=8, l=AngularMomentum.S), # 8s
]
def _valence_subshells_from_z_valence(*, element: str, z_valence: float) -> list[Subshell]:
"""Reconstruct the valence subshells of a neutral-atom Madelung filling.
Used when a UPF advertises ``number_of_wfc=0`` and therefore has no
``<PP_CHI>`` blocks to read the baseline basis from. We fill
:data:`ordered_subshells` with the neutral atom's Z electrons, then
peel subshells off the outermost end until their electron count
sums to ``z_valence`` -- those are the subshells the pseudo treats
as valence.
"""
from ase.data import atomic_numbers
z = atomic_numbers[element.strip()]
z_val_int = round(z_valence)
if abs(z_valence - z_val_int) > 1e-3:
raise ValueError(f"z_valence={z_valence} is not (close to) an integer")
# Fill subshells in Madelung order until we have placed z electrons.
filled: list[tuple[Subshell, int]] = []
remaining = z
for subshell in ordered_subshells:
if remaining <= 0:
break
capacity = 2 * (2 * subshell.l + 1)
electrons = min(capacity, remaining)
filled.append((subshell, electrons))
remaining -= electrons
if remaining > 0:
raise ValueError(
f"ordered_subshells does not extend far enough to fill {element} "
f"(Z={z}); please add more entries."
)
# Peel the outermost subshells off until they sum to z_valence.
valence: list[Subshell] = []
used = 0
for subshell, electrons in reversed(filled):
if used >= z_val_int:
break
valence.append(subshell)
used += electrons
if used != z_val_int:
raise ValueError(
f"Cannot partition {element} (Z={z}) Madelung filling into a "
f"valence shell of {z_val_int} electrons (got {used}); the "
"pseudo's valence partition does not align with Madelung order."
)
return list(reversed(valence))
[docs]
class AtomicBasis(Basis):
"""An atomic basis set.
Need to keep track of the (n, l) values of each subshell.
"""
subshells: list[Subshell]
@property
def l_max(self) -> AngularMomentum:
"""The maximum angular momentum quantum number in the basis set."""
ang_mtms = [s.l for s in self.subshells]
return max(ang_mtms, key=lambda l: l.value)
@property
def n_max(self) -> int:
"""The maximum principal quantum number in the basis set."""
n_values = [s.n for s in self.subshells]
return max(n_values) if n_values else 0
[docs]
def to_pseudoatomic_basis(self) -> PseudoatomicBasis:
"""Convert to a PseudoatomicBasis."""
number_of_orbitals: dict[AngularMomentum, int] = {}
for subshell in self.subshells:
if subshell.l not in number_of_orbitals:
number_of_orbitals[subshell.l] = 0
number_of_orbitals[subshell.l] += 1
return PseudoatomicBasis(number_of_orbitals=number_of_orbitals)
[docs]
@classmethod
def from_upf(cls, upf_path: Path) -> AtomicBasis:
"""Construct an AtomicBasis from a UPF pseudopotential file.
If the UPF ships pseudoatomic wavefunctions (``<PP_CHI>``
blocks), the basis is read directly. Otherwise (e.g. SG15
ONCV, where ``PP_HEADER`` advertises ``number_of_wfc=0``),
the baseline is reconstructed from the neutral-atom Madelung
filling, keeping only the outermost subshells whose electron
count sums to ``z_valence``.
"""
from upf_tools import UPFDict
upf_dict = UPFDict.from_upf(upf_path)
header = upf_dict["header"]
chi_blocks = upf_dict.get("pswfc", {}).get("chi") if "pswfc" in upf_dict else None
if chi_blocks:
return cls(subshells=[Subshell(n=chi["n"], l=chi["l"]) for chi in chi_blocks])
return cls(
subshells=_valence_subshells_from_z_valence(
element=str(header["element"]).strip(),
z_valence=float(header["z_valence"]),
)
)
[docs]
def extend(self, subshells: list[Subshell]) -> AtomicBasis:
"""Return a new AtomicBasis with an added subshell."""
return AtomicBasis(subshells=self.subshells + subshells)
def __contains__(self, subshell: Subshell) -> bool:
return subshell in self.subshells
def __len__(self) -> int:
return len(self.subshells)