Charge-Spin Conditioned PhysNet#

Multi-state energy and force predictions with charge and spin conditioning.

Overview#

The charge-spin conditioned PhysNet extends the standard PhysNet model to accept:

  • Total molecular charge (Q): e.g., 0, +1, -1

  • Total spin multiplicity (S): e.g., 1 (singlet), 2 (doublet), 3 (triplet)

This enables predictions across different electronic states from a single model.

Quick Start#

from mmml.physnetjax.physnetjax.models.model_charge_spin import EF_ChargeSpinConditioned
import jax.numpy as jnp
import e3x

# Create model
model = EF_ChargeSpinConditioned(
    features=128,
    num_iterations=3,
    natoms=60,
    charge_range=(-2, 2),    # Support charges -2 to +2
    spin_range=(1, 5),       # Support singlet to quintet
)

# Initialize
dst_idx, src_idx = e3x.ops.sparse_pairwise_indices(num_atoms)
params = model.init(
    key, Z, R, dst_idx, src_idx,
    total_charges=jnp.array([0.0]),
    total_spins=jnp.array([1.0]),
)

# Predict
outputs = model.apply(
    params, Z, R, dst_idx, src_idx,
    total_charges=jnp.array([0, 1, -1]),  # Neutral, cation, anion
    total_spins=jnp.array([1, 2, 2]),     # Singlet, doublet, doublet
)

# outputs["energy"]: (3,) - energies for each state
# outputs["forces"]: (num_atoms, 3) - forces

Use Cases#

Ionization Energies#

# Neutral vs cation
Q = jnp.array([0, 1])
S = jnp.array([1, 2])

E = model.apply(params, Z, R, ..., Q, S)["energy"]
IE = E[1] - E[0]  # Ionization energy

Singlet-Triplet Gaps#

# Compare spin states
Q = jnp.array([0, 0])
S = jnp.array([1, 3])  # Singlet vs triplet

E = model.apply(params, Z, R, ..., Q, S)["energy"]
gap = E[1] - E[0]  # S-T gap

Prediction Options#

Control what gets computed:

# Energy only (faster)
outputs = model.apply(
    params, Z, R, ..., Q, S,
    predict_energy=True,
    predict_forces=False,
)

# Forces only
outputs = model.apply(
    params, Z, R, ..., Q, S,
    predict_energy=False,
    predict_forces=True,
)

# Both (default)
outputs = model.apply(
    params, Z, R, ..., Q, S,
    predict_energy=True,
    predict_forces=True,
)

Model Parameters#

  • charge_range: Tuple of (min_charge, max_charge) to support

  • spin_range: Tuple of (min_spin, max_spin) multiplicities

  • charge_embed_dim: Dimension of charge embedding (default: 16)

  • spin_embed_dim: Dimension of spin embedding (default: 16)

All other parameters same as standard PhysNet (features, num_iterations, etc.)

Examples#

See examples/train_charge_spin_simple.py for a complete working example.

See examples/predict_options_demo.py for prediction mode examples.

API Reference#