torchref.utils.utils module
- class torchref.utils.utils.ModuleReference(module)[source]
Bases:
objectA wrapper class to hold references to PyTorch modules without registering them.
When you assign a nn.Module to an attribute of another nn.Module, PyTorch automatically registers it as a submodule, which adds its parameters to the parent’s parameter tree. This wrapper prevents tlog_normal_stdhat automatic registration.
This is useful when you want to:
Hold references to modules without including their parameters
Avoid circular dependencies in the module tree
Reference external modules that should be managed separately
- _wrapped_module
The wrapped PyTorch module.
- Type:
Examples
model = MyModel() scaler = Scaler() scaler._model = ModuleReference(model) # Won't register as submodule # Access the module via .module property output = scaler._model.module(input_data)
- __init__(module)[source]
Wrap a module to prevent automatic registration.
- Parameters:
module (torch.nn.Module) – The PyTorch module to wrap.
- property module
Access the wrapped module.
- class torchref.utils.utils.CIFReader(filepath=None)[source]
Bases:
objectA dictionary-like reader for CIF/mmCIF files.
Loops are stored as pandas DataFrames. Other data is stored in a hierarchical dictionary structure.
- filepath
Path to the loaded CIF file.
- Type:
pathlib.Path or None
- __init__(filepath=None)[source]
Initialize CIF reader.
- Parameters:
filepath (str, optional) – Path to CIF file to load immediately.
- torchref.utils.utils.save_map(array, cell, filename)[source]
Save a 3D map to a CCP4 file.
- Parameters:
array (numpy.ndarray or torch.Tensor) – 3D array representing the map.
cell (list, tuple, numpy.ndarray, torch.Tensor, or gemmi.UnitCell) – Unit cell parameters [a, b, c, alpha, beta, gamma].
filename (str) – Output CCP4 file name.
- Returns:
True if save was successful.
- Return type:
- class torchref.utils.utils.TensorDict(initial_dict=None)[source]
Bases:
ModuleA dictionary-like container for PyTorch tensors that: - Supports standard dict syntax - Automatically moves with the module - Registers tensors as buffers so they are included in state_dict
- class torchref.utils.utils.TensorMasks(data=None, device=None)[source]
Bases:
DeviceMixin,dictA dictionary for managing boolean mask tensors with device support.
This is a lightweight dict subclass that: - Ensures all tensors are boolean dtype - Supports device movement via to(), cuda(), cpu() - Provides combined mask via __call__()
- Parameters:
data (dict, optional) – Initial mask data.
device (str or torch.device, optional) – Device for tensors. Defaults to the configured device.current.
Examples
masks = TensorMasks(device='cuda') masks['valid'] = torch.ones(100, dtype=torch.bool) masks['rfree'] = rfree_flags > 0 combined = masks() # Get combined mask (AND of all) masks.cpu() # Move all to CPU
- torchref.utils.utils.sanitize_pdb_dataframe(pdb, verbose=0)[source]
Sanitize a PDB DataFrame to ensure unique atom identifiers.
This function fixes common issues in PDB/CIF files:
HETATM records (especially waters) with duplicate resseq values (e.g., all 0)
Residue names longer than 3 characters (truncates to 3)
Ensures unique (chainid, resseq, name, altloc) combinations
- Parameters:
pdb (pandas.DataFrame) – DataFrame with PDB data (must have columns: ATOM, chainid, resseq, name, altloc, resname, serial).
verbose (int, default 0) – Verbosity level (0=silent, 1=info, 2=debug).
- Returns:
Sanitized DataFrame with unique atom identifiers.
- Return type:
Examples
from torchref.model import Model from torchref.utils import sanitize_pdb_dataframe model = Model() model.load_cif('structure.cif') model.pdb = sanitize_pdb_dataframe(model.pdb, verbose=1)
- torchref.utils.utils.parse_phenix_selection(selection_string, pdb_df)[source]
Parse Phenix-style atom selection syntax and return a boolean mask.
Supports common Phenix selection keywords:
chain <id>: Select atoms by chain ID (e.g., “chain A”)
resseq <num>: Select atoms by residue sequence number (e.g., “resseq 10”)
resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”)
resname <name>: Select atoms by residue name (e.g., “resname ALA”)
name <atom>: Select atoms by atom name (e.g., “name CA”)
element <elem>: Select atoms by element (e.g., “element C”)
altloc <id>: Select atoms by alternate location (e.g., “altloc A”)
all: Select all atoms
not <selection>: Negate selection
<sel1> and <sel2>: Intersection of selections
<sel1> or <sel2>: Union of selections
Parentheses for grouping: (selection)
- Parameters:
selection_string (str) – Phenix-style selection string.
pdb_df (pandas.DataFrame) – DataFrame containing atomic data with columns: ‘chainid’, ‘resseq’, ‘resname’, ‘name’, ‘element’, ‘altloc’.
- Returns:
Boolean tensor of shape (n_atoms,) where True indicates selected atoms.
- Return type:
- Raises:
ValueError – If selection syntax is invalid.
Examples
# Select chain A mask = parse_phenix_selection("chain A", pdb_df) # Select residues 10-20 in chain A mask = parse_phenix_selection("chain A and resseq 10:20", pdb_df) # Select all CA atoms mask = parse_phenix_selection("name CA", pdb_df) # Select backbone atoms mask = parse_phenix_selection("name CA or name C or name N or name O", pdb_df) # Select everything except water mask = parse_phenix_selection("not resname HOH", pdb_df) # Use parentheses for grouping mask = parse_phenix_selection("chain A and (name CA or name CB)", pdb_df)
- torchref.utils.utils.create_selection_mask(selection_string, pdb_df, current_mask=None, mode='set')[source]
Create or modify a refinable mask based on a Phenix-style selection.
This function allows you to update refinable masks by selecting specific atoms using Phenix-style syntax. You can either replace the current mask, add to it, or remove from it.
- Parameters:
selection_string (str) – Phenix-style selection string.
pdb_df (pandas.DataFrame) – DataFrame containing atomic data.
current_mask (torch.Tensor, optional) – Current refinable mask. If None, starts with all False.
mode (str, default 'set') –
How to combine with current mask:
’set’: Replace mask with selection (default)
’add’: Add selection to current mask (OR operation)
’remove’: Remove selection from current mask (AND NOT operation)
- Returns:
Updated boolean mask of shape (n_atoms,).
- Return type:
- Raises:
ValueError – If mode is not one of ‘set’, ‘add’, ‘remove’.
Examples
# Create new mask selecting chain A mask = create_selection_mask("chain A", pdb_df, mode='set') # Add residues 10-20 to existing mask mask = create_selection_mask("resseq 10:20", pdb_df, current_mask=mask, mode='add') # Remove water from mask mask = create_selection_mask("resname HOH", pdb_df, current_mask=mask, mode='remove')
- torchref.utils.utils.state_dict_to_json_serializable(sd)[source]
Convert a state_dict with tensors to a JSON-serializable format.
- Parameters:
sd (Dict[str, torch.Tensor]) – State dict with tensor values.
- Returns:
JSON-serializable dictionary.
- Return type:
Dict[str, Any]
- torchref.utils.utils.dict_to_state_dict(sd_raw)[source]
Convert a dict with serialized tensor info to a PyTorch state_dict.
- torchref.utils.utils.json_to_state_dicts_separate(json_path)[source]
Parse hyperparameter JSON and return state_dicts for component_weighting, geometry_target, and adp_target.
- torchref.utils.utils.disable_grad_outside_optimizer(optimized_params, all_params)[source]
Set
requires_grad=Falseon parameters not being optimized.Call this once after creating the optimizer. Non-optimized parameters will no longer contribute to the autograd graph, which means
ModelFT.get_structure_factorwill produce structure factors withoutgrad_fnfor frozen models — enabling indefinite caching until parameters change.- Parameters:
optimized_params (iterable of torch.Tensor) – Parameters passed to the optimizer.
all_params (iterable of torch.Tensor) – All model parameters (e.g.
model.parameters()).