torchref.utils.utils module

class torchref.utils.utils.ModuleReference(module)[source]

Bases: object

A wrapper class to hold references to PyTorch modules without registering them.

When you assign a nn.Module to an attribute of another nn.Module, PyTorch automatically registers it as a submodule, which adds its parameters to the parent’s parameter tree. This wrapper prevents tlog_normal_stdhat automatic registration.

This is useful when you want to:

  • Hold references to modules without including their parameters

  • Avoid circular dependencies in the module tree

  • Reference external modules that should be managed separately

_wrapped_module

The wrapped PyTorch module.

Type:

torch.nn.Module

Examples

model = MyModel()
scaler = Scaler()
scaler._model = ModuleReference(model)  # Won't register as submodule
# Access the module via .module property
output = scaler._model.module(input_data)
__init__(module)[source]

Wrap a module to prevent automatic registration.

Parameters:

module (torch.nn.Module) – The PyTorch module to wrap.

property module

Access the wrapped module.

__getattr__(name)[source]

Forward attribute access to the wrapped module.

__call__(*args, **kwargs)[source]

Forward calls to the wrapped module.

class torchref.utils.utils.CIFReader(filepath=None)[source]

Bases: object

A dictionary-like reader for CIF/mmCIF files.

Loops are stored as pandas DataFrames. Other data is stored in a hierarchical dictionary structure.

data

Dictionary storing parsed CIF data.

Type:

dict

filepath

Path to the loaded CIF file.

Type:

pathlib.Path or None

__init__(filepath=None)[source]

Initialize CIF reader.

Parameters:

filepath (str, optional) – Path to CIF file to load immediately.

load(filepath)[source]

Load and parse a CIF file.

Parameters:

filepath (str) – Path to CIF file.

write(filepath)[source]

Write the CIF data back to a file.

Parameters:

filepath (str) – Output file path.

__getitem__(key)[source]

Get item by key.

__setitem__(key, value)[source]

Set item by key.

__contains__(key)[source]

Check if key exists.

__len__()[source]

Return number of top-level categories.

keys()[source]

Return dictionary keys.

values()[source]

Return dictionary values.

items()[source]

Return dictionary items.

get(key, default=None)[source]

Get item with default value.

__repr__()[source]

String representation.

summary()[source]

Print a summary of the CIF contents.

torchref.utils.utils.save_map(array, cell, filename)[source]

Save a 3D map to a CCP4 file.

Parameters:
Returns:

True if save was successful.

Return type:

bool

class torchref.utils.utils.TensorDict(initial_dict=None)[source]

Bases: Module

A dictionary-like container for PyTorch tensors that: - Supports standard dict syntax - Automatically moves with the module - Registers tensors as buffers so they are included in state_dict

__init__(initial_dict=None)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

keys()[source]
values()[source]
items()[source]
class torchref.utils.utils.TensorMasks(data=None, device=None)[source]

Bases: DeviceMixin, dict

A dictionary for managing boolean mask tensors with device support.

This is a lightweight dict subclass that: - Ensures all tensors are boolean dtype - Supports device movement via to(), cuda(), cpu() - Provides combined mask via __call__()

Parameters:
  • data (dict, optional) – Initial mask data.

  • device (str or torch.device, optional) – Device for tensors. Defaults to the configured device.current.

Examples

masks = TensorMasks(device='cuda')
masks['valid'] = torch.ones(100, dtype=torch.bool)
masks['rfree'] = rfree_flags > 0
combined = masks()  # Get combined mask (AND of all)
masks.cpu()  # Move all to CPU
__init__(data=None, device=None)[source]
__setitem__(key, tensor)[source]

Set mask tensor, ensuring boolean dtype and correct device.

reset_cache()[source]

Invalidate the cached combined mask.

__call__()[source]

Return combined mask (AND of all masks).

Returns:

Combined boolean mask, or None if no masks.

Return type:

torch.Tensor

torchref.utils.utils.sanitize_pdb_dataframe(pdb, verbose=0)[source]

Sanitize a PDB DataFrame to ensure unique atom identifiers.

This function fixes common issues in PDB/CIF files:

  1. HETATM records (especially waters) with duplicate resseq values (e.g., all 0)

  2. Residue names longer than 3 characters (truncates to 3)

  3. Ensures unique (chainid, resseq, name, altloc) combinations

Parameters:
  • pdb (pandas.DataFrame) – DataFrame with PDB data (must have columns: ATOM, chainid, resseq, name, altloc, resname, serial).

  • verbose (int, default 0) – Verbosity level (0=silent, 1=info, 2=debug).

Returns:

Sanitized DataFrame with unique atom identifiers.

Return type:

pandas.DataFrame

Examples

from torchref.model import Model
from torchref.utils import sanitize_pdb_dataframe
model = Model()
model.load_cif('structure.cif')
model.pdb = sanitize_pdb_dataframe(model.pdb, verbose=1)
torchref.utils.utils.parse_phenix_selection(selection_string, pdb_df)[source]

Parse Phenix-style atom selection syntax and return a boolean mask.

Supports common Phenix selection keywords:

  • chain <id>: Select atoms by chain ID (e.g., “chain A”)

  • resseq <num>: Select atoms by residue sequence number (e.g., “resseq 10”)

  • resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”)

  • resname <name>: Select atoms by residue name (e.g., “resname ALA”)

  • name <atom>: Select atoms by atom name (e.g., “name CA”)

  • element <elem>: Select atoms by element (e.g., “element C”)

  • altloc <id>: Select atoms by alternate location (e.g., “altloc A”)

  • all: Select all atoms

  • not <selection>: Negate selection

  • <sel1> and <sel2>: Intersection of selections

  • <sel1> or <sel2>: Union of selections

  • Parentheses for grouping: (selection)

Parameters:
  • selection_string (str) – Phenix-style selection string.

  • pdb_df (pandas.DataFrame) – DataFrame containing atomic data with columns: ‘chainid’, ‘resseq’, ‘resname’, ‘name’, ‘element’, ‘altloc’.

Returns:

Boolean tensor of shape (n_atoms,) where True indicates selected atoms.

Return type:

torch.Tensor

Raises:

ValueError – If selection syntax is invalid.

Examples

# Select chain A
mask = parse_phenix_selection("chain A", pdb_df)

# Select residues 10-20 in chain A
mask = parse_phenix_selection("chain A and resseq 10:20", pdb_df)

# Select all CA atoms
mask = parse_phenix_selection("name CA", pdb_df)

# Select backbone atoms
mask = parse_phenix_selection("name CA or name C or name N or name O", pdb_df)

# Select everything except water
mask = parse_phenix_selection("not resname HOH", pdb_df)

# Use parentheses for grouping
mask = parse_phenix_selection("chain A and (name CA or name CB)", pdb_df)
torchref.utils.utils.create_selection_mask(selection_string, pdb_df, current_mask=None, mode='set')[source]

Create or modify a refinable mask based on a Phenix-style selection.

This function allows you to update refinable masks by selecting specific atoms using Phenix-style syntax. You can either replace the current mask, add to it, or remove from it.

Parameters:
  • selection_string (str) – Phenix-style selection string.

  • pdb_df (pandas.DataFrame) – DataFrame containing atomic data.

  • current_mask (torch.Tensor, optional) – Current refinable mask. If None, starts with all False.

  • mode (str, default 'set') –

    How to combine with current mask:

    • ’set’: Replace mask with selection (default)

    • ’add’: Add selection to current mask (OR operation)

    • ’remove’: Remove selection from current mask (AND NOT operation)

Returns:

Updated boolean mask of shape (n_atoms,).

Return type:

torch.Tensor

Raises:

ValueError – If mode is not one of ‘set’, ‘add’, ‘remove’.

Examples

# Create new mask selecting chain A
mask = create_selection_mask("chain A", pdb_df, mode='set')

# Add residues 10-20 to existing mask
mask = create_selection_mask("resseq 10:20", pdb_df, current_mask=mask, mode='add')

# Remove water from mask
mask = create_selection_mask("resname HOH", pdb_df, current_mask=mask, mode='remove')
torchref.utils.utils.state_dict_to_json_serializable(sd)[source]

Convert a state_dict with tensors to a JSON-serializable format.

Parameters:

sd (Dict[str, torch.Tensor]) – State dict with tensor values.

Returns:

JSON-serializable dictionary.

Return type:

Dict[str, Any]

torchref.utils.utils.dict_to_state_dict(sd_raw)[source]

Convert a dict with serialized tensor info to a PyTorch state_dict.

Parameters:

sd_raw (dict) – Dictionary where values are dicts with ‘data’, ‘dtype’, ‘shape’ keys.

Returns:

State dict with torch.Tensor values.

Return type:

dict

torchref.utils.utils.json_to_state_dicts_separate(json_path)[source]

Parse hyperparameter JSON and return state_dicts for component_weighting, geometry_target, and adp_target.

Parameters:

json_path (str) – Path to the JSON file containing hyperparameters.

Returns:

Three state_dicts and a list of unassigned keys: (component_weighting_state, geometry_target_state, adp_target_state, unassigned_keys)

Return type:

Tuple[Dict, Dict, Dict, list]

torchref.utils.utils.disable_grad_outside_optimizer(optimized_params, all_params)[source]

Set requires_grad=False on parameters not being optimized.

Call this once after creating the optimizer. Non-optimized parameters will no longer contribute to the autograd graph, which means ModelFT.get_structure_factor will produce structure factors without grad_fn for frozen models — enabling indefinite caching until parameters change.

Parameters:
  • optimized_params (iterable of torch.Tensor) – Parameters passed to the optimizer.

  • all_params (iterable of torch.Tensor) – All model parameters (e.g. model.parameters()).