Charge-Spin Conditioned PhysNet#
Multi-state energy and force predictions with charge and spin conditioning.
Overview#
The charge-spin conditioned PhysNet extends the standard PhysNet model to accept:
Total molecular charge (Q): e.g., 0, +1, -1
Total spin multiplicity (S): e.g., 1 (singlet), 2 (doublet), 3 (triplet)
This enables predictions across different electronic states from a single model.
Quick Start#
from mmml.physnetjax.physnetjax.models.model_charge_spin import EF_ChargeSpinConditioned
import jax.numpy as jnp
import e3x
# Create model
model = EF_ChargeSpinConditioned(
features=128,
num_iterations=3,
natoms=60,
charge_range=(-2, 2), # Support charges -2 to +2
spin_range=(1, 5), # Support singlet to quintet
)
# Initialize
dst_idx, src_idx = e3x.ops.sparse_pairwise_indices(num_atoms)
params = model.init(
key, Z, R, dst_idx, src_idx,
total_charges=jnp.array([0.0]),
total_spins=jnp.array([1.0]),
)
# Predict
outputs = model.apply(
params, Z, R, dst_idx, src_idx,
total_charges=jnp.array([0, 1, -1]), # Neutral, cation, anion
total_spins=jnp.array([1, 2, 2]), # Singlet, doublet, doublet
)
# outputs["energy"]: (3,) - energies for each state
# outputs["forces"]: (num_atoms, 3) - forces
Use Cases#
Ionization Energies#
# Neutral vs cation
Q = jnp.array([0, 1])
S = jnp.array([1, 2])
E = model.apply(params, Z, R, ..., Q, S)["energy"]
IE = E[1] - E[0] # Ionization energy
Singlet-Triplet Gaps#
# Compare spin states
Q = jnp.array([0, 0])
S = jnp.array([1, 3]) # Singlet vs triplet
E = model.apply(params, Z, R, ..., Q, S)["energy"]
gap = E[1] - E[0] # S-T gap
Prediction Options#
Control what gets computed:
# Energy only (faster)
outputs = model.apply(
params, Z, R, ..., Q, S,
predict_energy=True,
predict_forces=False,
)
# Forces only
outputs = model.apply(
params, Z, R, ..., Q, S,
predict_energy=False,
predict_forces=True,
)
# Both (default)
outputs = model.apply(
params, Z, R, ..., Q, S,
predict_energy=True,
predict_forces=True,
)
Model Parameters#
charge_range: Tuple of (min_charge, max_charge) to supportspin_range: Tuple of (min_spin, max_spin) multiplicitiescharge_embed_dim: Dimension of charge embedding (default: 16)spin_embed_dim: Dimension of spin embedding (default: 16)
All other parameters same as standard PhysNet (features, num_iterations, etc.)
Examples#
See examples/train_charge_spin_simple.py for a complete working example.
See examples/predict_options_demo.py for prediction mode examples.