DCMNet API#

Core DCMNet modules for training distributed charge multipole networks.

Data Preparation#

mmml.dcmnet.dcmnet.data.prepare_datasets(key, num_train, num_valid, filename, clean=False, esp_mask=False, clip_esp=False, natoms=60)#

Prepare datasets for training and validation.

Parameters:
  • key – Random key for dataset shuffling

  • num_train – Number of training samples

  • num_valid – Number of validation samples

  • filename – Filename(s) to load datasets from

  • clean – Whether to filter failed calculations

  • esp_mask – Whether to create ESP masks

  • clip_esp – Whether to clip ESP to first 1000 points

  • natoms – Maximum number of atoms per system

Returns:

Tuple of (train_data, valid_data) dictionaries

mmml.dcmnet.dcmnet.data.prepare_batches(key, data, batch_size, include_id=False, data_keys=None, num_atoms=60, dst_idx=None, src_idx=None)#

Prepare batches for training.

Parameters:
  • key – Random key for shuffling

  • data – Dictionary containing the dataset

  • batch_size – Size of each batch

  • include_id – Whether to include ID in output

  • data_keys – List of keys to include

  • num_atoms – Number of atoms per system

  • dst_idx – Destination indices for message passing

  • src_idx – Source indices for message passing

Returns:

List of batch dictionaries

Training Functions#

mmml.dcmnet.dcmnet.training.train_model(key, model, train_data, valid_data, num_epochs, learning_rate, batch_size, writer, ndcm, esp_w=1.0, restart_params=None, ema_decay=0.999)#

Train DCMNet with ESP and monopole losses.

Parameters:
  • key – Random key for training

  • model – MessagePassingModel instance

  • train_data – Training dataset

  • valid_data – Validation dataset

  • num_epochs – Number of training epochs

  • learning_rate – Learning rate

  • batch_size – Batch size

  • writer – TensorBoard writer

  • ndcm – Number of distributed multipoles

  • esp_w – ESP loss weight

  • restart_params – Parameters to restart from

  • ema_decay – Exponential moving average decay

Returns:

Tuple of (final_params, final_valid_loss)

mmml.dcmnet.dcmnet.training.train_model_dipo(key, model, train_data, valid_data, num_epochs, learning_rate, batch_size, writer, ndcm, esp_w=1.0, restart_params=None)#

Train DCMNet with dipole-augmented losses.

Parameters:
  • key – Random key for training

  • model – MessagePassingModel instance

  • train_data – Training dataset (must include Dxyz, com, espMask)

  • valid_data – Validation dataset

  • num_epochs – Number of training epochs

  • learning_rate – Learning rate

  • batch_size – Batch size

  • writer – TensorBoard writer

  • ndcm – Number of distributed multipoles

  • esp_w – ESP loss weight

  • restart_params – Parameters to restart from

Returns:

Tuple of (final_params, final_valid_loss)

Model Architecture#

class mmml.dcmnet.dcmnet.modules.MessagePassingModel(features, max_degree, num_iterations, num_basis_functions, cutoff, n_dcm, include_pseudotensors=False)#

E(3)-equivariant message passing model for distributed multipoles.

Parameters:
  • features – Number of features per atom

  • max_degree – Maximum spherical harmonic degree

  • num_iterations – Number of message passing iterations

  • num_basis_functions – Number of radial basis functions

  • cutoff – Distance cutoff for interactions

  • n_dcm – Number of distributed multipoles per atom

  • include_pseudotensors – Whether to include pseudotensors

Loss Functions#

mmml.dcmnet.dcmnet.loss.esp_mono_loss(dipo_prediction, mono_prediction, vdw_surface, esp_target, mono, ngrid, n_atoms, batch_size, esp_w, n_dcm)#

Combined ESP and monopole loss function.

Parameters:
  • dipo_prediction – Predicted distributed dipoles

  • mono_prediction – Predicted monopoles

  • vdw_surface – Surface grid points

  • esp_target – Target ESP values

  • mono – Reference monopoles

  • ngrid – Number of grid points per system

  • n_atoms – Number of atoms per system

  • batch_size – Batch size

  • esp_w – ESP loss weight

  • n_dcm – Number of distributed multipoles

Returns:

Total loss value

mmml.dcmnet.dcmnet.loss.dipo_esp_mono_loss(dipo_prediction, mono_prediction, vdw_surface, esp_target, mono, Dxyz, com, espMask, n_atoms, batch_size, esp_w, n_dcm)#

Dipole-augmented ESP and monopole loss function.

Parameters:
  • dipo_prediction – Predicted distributed dipoles

  • mono_prediction – Predicted monopoles

  • vdw_surface – Surface grid points

  • esp_target – Target ESP values

  • mono – Reference monopoles

  • Dxyz – Reference dipole positions

  • com – Center of mass coordinates

  • espMask – ESP evaluation masks

  • n_atoms – Number of atoms per system

  • batch_size – Batch size

  • esp_w – ESP loss weight

  • n_dcm – Number of distributed multipoles

Returns:

Tuple of (esp_loss, mono_loss, dipole_loss)