DCMNet API#
Core DCMNet modules for training distributed charge multipole networks.
Data Preparation#
- mmml.dcmnet.dcmnet.data.prepare_datasets(key, num_train, num_valid, filename, clean=False, esp_mask=False, clip_esp=False, natoms=60)#
Prepare datasets for training and validation.
- Parameters:
key – Random key for dataset shuffling
num_train – Number of training samples
num_valid – Number of validation samples
filename – Filename(s) to load datasets from
clean – Whether to filter failed calculations
esp_mask – Whether to create ESP masks
clip_esp – Whether to clip ESP to first 1000 points
natoms – Maximum number of atoms per system
- Returns:
Tuple of (train_data, valid_data) dictionaries
- mmml.dcmnet.dcmnet.data.prepare_batches(key, data, batch_size, include_id=False, data_keys=None, num_atoms=60, dst_idx=None, src_idx=None)#
Prepare batches for training.
- Parameters:
key – Random key for shuffling
data – Dictionary containing the dataset
batch_size – Size of each batch
include_id – Whether to include ID in output
data_keys – List of keys to include
num_atoms – Number of atoms per system
dst_idx – Destination indices for message passing
src_idx – Source indices for message passing
- Returns:
List of batch dictionaries
Training Functions#
- mmml.dcmnet.dcmnet.training.train_model(key, model, train_data, valid_data, num_epochs, learning_rate, batch_size, writer, ndcm, esp_w=1.0, restart_params=None, ema_decay=0.999)#
Train DCMNet with ESP and monopole losses.
- Parameters:
key – Random key for training
model – MessagePassingModel instance
train_data – Training dataset
valid_data – Validation dataset
num_epochs – Number of training epochs
learning_rate – Learning rate
batch_size – Batch size
writer – TensorBoard writer
ndcm – Number of distributed multipoles
esp_w – ESP loss weight
restart_params – Parameters to restart from
ema_decay – Exponential moving average decay
- Returns:
Tuple of (final_params, final_valid_loss)
- mmml.dcmnet.dcmnet.training.train_model_dipo(key, model, train_data, valid_data, num_epochs, learning_rate, batch_size, writer, ndcm, esp_w=1.0, restart_params=None)#
Train DCMNet with dipole-augmented losses.
- Parameters:
key – Random key for training
model – MessagePassingModel instance
train_data – Training dataset (must include Dxyz, com, espMask)
valid_data – Validation dataset
num_epochs – Number of training epochs
learning_rate – Learning rate
batch_size – Batch size
writer – TensorBoard writer
ndcm – Number of distributed multipoles
esp_w – ESP loss weight
restart_params – Parameters to restart from
- Returns:
Tuple of (final_params, final_valid_loss)
Model Architecture#
- class mmml.dcmnet.dcmnet.modules.MessagePassingModel(features, max_degree, num_iterations, num_basis_functions, cutoff, n_dcm, include_pseudotensors=False)#
E(3)-equivariant message passing model for distributed multipoles.
- Parameters:
features – Number of features per atom
max_degree – Maximum spherical harmonic degree
num_iterations – Number of message passing iterations
num_basis_functions – Number of radial basis functions
cutoff – Distance cutoff for interactions
n_dcm – Number of distributed multipoles per atom
include_pseudotensors – Whether to include pseudotensors
Loss Functions#
- mmml.dcmnet.dcmnet.loss.esp_mono_loss(dipo_prediction, mono_prediction, vdw_surface, esp_target, mono, ngrid, n_atoms, batch_size, esp_w, n_dcm)#
Combined ESP and monopole loss function.
- Parameters:
dipo_prediction – Predicted distributed dipoles
mono_prediction – Predicted monopoles
vdw_surface – Surface grid points
esp_target – Target ESP values
mono – Reference monopoles
ngrid – Number of grid points per system
n_atoms – Number of atoms per system
batch_size – Batch size
esp_w – ESP loss weight
n_dcm – Number of distributed multipoles
- Returns:
Total loss value
- mmml.dcmnet.dcmnet.loss.dipo_esp_mono_loss(dipo_prediction, mono_prediction, vdw_surface, esp_target, mono, Dxyz, com, espMask, n_atoms, batch_size, esp_w, n_dcm)#
Dipole-augmented ESP and monopole loss function.
- Parameters:
dipo_prediction – Predicted distributed dipoles
mono_prediction – Predicted monopoles
vdw_surface – Surface grid points
esp_target – Target ESP values
mono – Reference monopoles
Dxyz – Reference dipole positions
com – Center of mass coordinates
espMask – ESP evaluation masks
n_atoms – Number of atoms per system
batch_size – Batch size
esp_w – ESP loss weight
n_dcm – Number of distributed multipoles
- Returns:
Tuple of (esp_loss, mono_loss, dipole_loss)