"""How to extend a pseudoatomic basis set with additional orbitals."""
import enum
from abc import ABC, abstractmethod
from pydantic import Field
from kapaow.basis import (
AngularMomentum,
AtomicBasis,
PseudoatomicBasis,
ordered_subshells,
)
from kapaow.pydantic import BaseModel
__all__: list[str] = [
"BasisExtension",
"BasisExtensionType",
"BasisExtensionViaAddition",
"BasisExtensionViaChannel",
"BasisExtensionViaPolarization",
"parse_extension",
]
[docs]
class BasisExtensionType(enum.Enum):
"""Type of basis extension."""
SUBSHELL = "subshell"
POLARIZATION = "polarization"
S = "s"
P = "p"
D = "d"
F = "f"
G = "g"
@property
def angular_momentum(self) -> AngularMomentum | None:
"""Return the angular momentum if this is a channel-specific extension."""
_map = {
"s": AngularMomentum.S,
"p": AngularMomentum.P,
"d": AngularMomentum.D,
"f": AngularMomentum.F,
"g": AngularMomentum.G,
}
return _map.get(self.value)
[docs]
class BasisExtension(BaseModel, ABC):
"""An extension to a basis set."""
[docs]
@abstractmethod
def extend(self, basis: AtomicBasis | PseudoatomicBasis) -> PseudoatomicBasis:
"""Extend the given basis set and return the new basis set."""
[docs]
class BasisExtensionViaAddition(BasisExtension):
"""Add the next subshell to an atomic basis set."""
increment: int = Field(default=1, description="number of subshells to add")
[docs]
def extend_atomic(self, basis: AtomicBasis) -> AtomicBasis:
"""Extend the provided basis by adding the next subshell(s), returning an AtomicBasis.
First checks for gaps in Madelung order between basis subshells
(e.g. 5s missing between 4p and 4d for Pd), filtering out core
subshells (n < min n of basis). If no valid gaps, adds the next
subshell after the outermost.
"""
indices = sorted(ordered_subshells.index(s) for s in basis.subshells)
i_innermost = indices[0]
i_outermost = indices[-1]
min_n = min(s.n for s in basis.subshells)
# For each l channel, record the max n present in the basis
max_n_per_l: dict[AngularMomentum, int] = {}
for s in basis.subshells:
if s.l not in max_n_per_l or s.n > max_n_per_l[s.l]:
max_n_per_l[s.l] = s.n
# Look for valid gaps between innermost and outermost basis entries
# Skip core subshells: n < min_n, or same l channel already has higher n
gaps = []
for subshell in ordered_subshells[i_innermost:i_outermost]:
if subshell in basis:
continue
if subshell.n < min_n:
continue
if subshell.l in max_n_per_l and subshell.n < max_n_per_l[subshell.l]:
continue
gaps.append(subshell)
# Use gaps first, then continue past outermost
candidates = gaps + ordered_subshells[i_outermost + 1 :]
to_add = candidates[: self.increment]
if len(to_add) < self.increment:
raise ValueError(f"Cannot add {self.increment} subshell(s) beyond the current basis.")
return basis.extend(to_add)
[docs]
def extend(self, basis: AtomicBasis | PseudoatomicBasis) -> PseudoatomicBasis:
"""Extend the provided basis by adding the next subshell(s)."""
if isinstance(basis, PseudoatomicBasis):
raise TypeError(
"Cannot extend pseudoatomic bases by addition because we can't know"
" what l channel to add to"
)
return self.extend_atomic(basis).to_pseudoatomic_basis()
[docs]
class BasisExtensionViaChannel(BasisExtension):
"""Add orbitals in a specific angular momentum channel."""
channel: AngularMomentum = Field(description="angular momentum channel to extend")
increment: int = Field(default=1, description="number of radial functions to add")
[docs]
def extend(self, basis: AtomicBasis | PseudoatomicBasis) -> PseudoatomicBasis:
"""Extend the provided basis by adding orbitals in the specified channel."""
if isinstance(basis, AtomicBasis):
basis = basis.to_pseudoatomic_basis()
return basis.extend(**{self.channel.name.lower(): self.increment})
[docs]
class BasisExtensionViaPolarization(BasisExtension):
"""Add polarization orbitals to a pseudoatomic basis set."""
increment: int = Field(
default=1, description="number of polarization orbitals to add per angular momentum channel"
)
[docs]
def extend(self, basis: AtomicBasis | PseudoatomicBasis) -> PseudoatomicBasis:
"""Extend the provided basis by adding polarization orbitals."""
if isinstance(basis, AtomicBasis):
basis = basis.to_pseudoatomic_basis()
new_basis = basis
increment = self.increment
while increment > 0:
channels_to_increment = [l for l in AngularMomentum if l <= new_basis.l_max + 1]
for l in channels_to_increment:
new_basis = new_basis.extend(**{l.name.lower(): 1})
increment -= 1
return new_basis
[docs]
def parse_extension(add: tuple[str, ...]) -> BasisExtension | None:
"""Build the basis extension implied by a tuple of ``--add`` flags.
The CLI exposes a repeatable ``--add`` option whose values are the
string forms of :class:`BasisExtensionType`. Each repetition counts
as one increment; mixing different kinds in a single call is not
supported.
Parameters
----------
add
Tuple of flag strings (e.g. ``("subshell", "subshell")`` or
``("p",)``). An empty tuple returns ``None``.
Returns
-------
BasisExtension | None
The corresponding extension, or ``None`` if *add* is empty.
Raises
------
ValueError
If multiple distinct kinds are mixed.
"""
if not add:
return None
kinds = set(add)
if len(kinds) > 1:
raise ValueError("Cannot mix different --add types in the same call.")
kind = BasisExtensionType(kinds.pop())
count = len(add)
if kind == BasisExtensionType.SUBSHELL:
return BasisExtensionViaAddition(increment=count)
if kind == BasisExtensionType.POLARIZATION:
return BasisExtensionViaPolarization(increment=count)
channel = kind.angular_momentum
if channel is not None:
return BasisExtensionViaChannel(channel=channel, increment=count)
return None