torchref.model.parameter_wrappers module

A file that contains wrapper classes for handling crystallographic parameters. (Occ, xyz, B, etc.)

class torchref.model.parameter_wrappers.MixedTensor(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None)[source]

Bases: DeviceMixin, CachedForwardMixin, Module

A wrapper class for tensors with mixed fixed and refinable elements.

Stores a mask indicating which elements can be refined and maintains both fixed and refinable components separately. The full tensor is reconstructed on-the-fly when accessed.

Parameters:
  • initial_values (torch.Tensor, optional) – Initial tensor values for all elements. Optional for empty init.

  • refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined. If None, all elements are refinable.

  • requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for the tensor. Default is same as initial_values.

  • device (torch.device, optional) – Device for the tensor. Default is same as initial_values.

  • name (str, optional) – Optional name for this parameter (useful for debugging/logging).

refinable_mask

Boolean mask indicating refinable elements.

Type:

torch.Tensor

fixed_mask

Boolean mask indicating fixed elements (inverse of refinable_mask).

Type:

torch.Tensor

fixed_values

Buffer containing fixed values.

Type:

torch.Tensor

refinable_params

Parameter containing refinable values.

Type:

nn.Parameter

Examples

Empty initialization for state_dict loading:

mixed = MixedTensor()
mixed.load_state_dict(torch.load('mixed.pt'))

Full initialization with values:

mask = torch.zeros(100, dtype=torch.bool)
mask[20:30] = True
initial_values = torch.randn(100)
mixed = MixedTensor(initial_values, refinable_mask=mask, requires_grad=True)
optimizer = torch.optim.Adam([mixed.refinable_params], lr=0.01)
__init__(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None)[source]

Initialize a MixedTensor.

If initial_values is provided, fully initializes the tensor. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • initial_values (torch.Tensor, optional) – Initial tensor values for all elements. Optional for empty init.

  • refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined. If None, all elements are refinable.

  • requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for the tensor. Default is same as initial_values.

  • device (torch.device, optional) – Device for the tensor. Default is same as initial_values.

  • name (str, optional) – Optional name for this parameter (useful for debugging/logging).

forward()[source]

Reconstruct and return the full tensor.

Three fast paths, in priority order:

  1. All atoms refinable — return refinable_params directly (no clone, no scatter). Common in standard refinement where no atoms are frozen. Saves a full-tensor clone + an index_put_ per call and replaces the index_put backward (sort + atomic scatter) with a no-op identity.

  2. No refinable atoms — return fixed_values.clone() (the clone preserves the caller-must-not-mutate contract even though the result is detached from autograd).

  3. Mixed — go through _AssembleMixedTensor, whose backward is a single index_select (gather) instead of PyTorch’s default index_put_ backward (radix-sort + scatter).

__getitem__(key)[source]

Get values at specified indices/mask from the full tensor.

Parameters:

key (int, slice, torch.Tensor, or tuple) – Index specification. Can be: - int: Single element - slice: Range of elements (e.g., 5:10, :, ::2) - torch.Tensor: Boolean mask or integer indices - tuple: Multi-dimensional indexing

Returns:

Selected values from the full tensor.

Return type:

torch.Tensor

Examples

model.b[5]           # Get B-factor for atom 5
model.b[5:10]        # Get B-factors for atoms 5-9
model.b[mask]        # Get B-factors where mask is True
model.xyz[:, 0]      # Get all x-coordinates

Notes

Subclasses may override _get_values() to customize value retrieval.

__setitem__(key, value)[source]

Set values at specified indices/mask.

This method updates both fixed_values and refinable_params at the specified positions. Supports various indexing styles including slices, boolean masks, and integer indices.

Parameters:
  • key (int, slice, torch.Tensor, or tuple) – Index specification. Can be: - int: Single element - slice: Range of elements (e.g., 5:10, :, ::2) - torch.Tensor: Boolean mask or integer indices - tuple: Multi-dimensional indexing

  • value (torch.Tensor, float, or int) – Values to assign. Can be: - Scalar: Broadcast to all selected positions - Tensor: Must match the shape of selected region

Examples

model.b[:] = 30.0           # Set all B-factors to 30
model.b[5:10] = 25.0        # Set B-factors 5-9 to 25
model.b[mask] = new_values  # Set B-factors where mask is True
model.xyz[mask] = new_coords  # Set coordinates for masked atoms
model.xyz[:, 0] += 1.0      # Shift all x-coordinates (read-modify-write)

Notes

This method modifies the tensor in-place. The refinable_params parameter is replaced with a new Parameter containing the updated values, which may affect optimizer state.

Subclasses may override _set_values() to customize value handling (e.g., PositiveMixedTensor converts to log-space).

set(values, mask)[source]

Set values at positions specified by a boolean mask.

Updates both fixed_values and refinable_params at the positions specified by the mask. This is useful for applying coordinate shifts, B-factor corrections, or any other updates to specific atoms.

Parameters:
  • values (torch.Tensor) –

    New values to assign. Shape must match: - For 1D tensors: (n_selected,) where n_selected = mask.sum() - For 2D tensors (e.g., xyz): (n_selected, d) where d is the

    second dimension size (e.g., 3 for coordinates)

  • mask (torch.Tensor) – Boolean mask of shape (n_atoms,) indicating which elements to update. True positions will receive the new values.

Raises:

ValueError – If mask shape doesn’t match tensor’s first dimension, or if values shape doesn’t match the number of selected elements.

Examples

# Update coordinates for selected atoms
mask = model.get_selection_mask("chain A")
new_coords = original_coords[mask] + shift
model.xyz.set(new_coords, mask)

# Update B-factors for specific residues
mask = model.get_selection_mask("resseq 10:20")
new_b = torch.ones(mask.sum()) * 30.0
model.b.set(new_b, mask)

Notes

This method modifies the tensor in-place. The refinable_params parameter is replaced with a new Parameter containing the updated values, which may affect optimizer state.

property shape

Return the shape of the full tensor.

property dtype

Return the dtype of the tensor.

property device

Return the device of the tensor.

get_refinable_count()[source]

Return the number of refinable parameters.

get_fixed_count()[source]

Return the number of fixed parameters.

update_fixed_values(new_values)[source]

Update the fixed values (does not affect refinable parameters).

Parameters:

new_values (torch.Tensor) – New tensor values. Only fixed positions will be updated.

Raises:

ValueError – If new_values shape doesn’t match tensor shape.

update_refinable_mask(new_mask, reset_refinable=False)[source]

Update which elements are refinable.

This is an advanced operation that modifies the refinable/fixed split.

Parameters:
  • new_mask (torch.Tensor) – New boolean mask indicating refinable elements.

  • reset_refinable (bool, optional) – If True, reset refinable parameters to current fixed values. If False, keep existing refinable parameter values where possible. Default is False.

detach()[source]

Return a detached copy of the full tensor.

clone()[source]

Create a deep copy of this MixedTensor.

copy()[source]

Create a deep copy of this MixedTensor.

Creates a complete independent copy with all buffers and parameters. Alias for clone().

Returns:

New MixedTensor instance with copied data.

Return type:

MixedTensor

clip(min_value=None, max_value=None)[source]

Clip the full tensor values between min_value and max_value.

to(*args, **kwargs)[source]

Move via DeviceMixin and rebuild the index cache.

refine(selection, reset_values=False)[source]

Make a selection of the tensor refinable.

Parameters:
  • selection (slice, torch.Tensor, or tuple) – Selection indicating which elements should become refinable. Can be: - Boolean tensor of same shape as the full tensor - Slice object (e.g., slice(10, 20)) - Tuple of indices for multidimensional tensors - Integer indices

  • reset_values (bool, optional) – If True, reset the selected elements to their current fixed values before making them refinable. Default is False.

Examples

mixed.refine(slice(10, 20))  # Make elements 10-19 refinable
mixed.refine(mask)  # Make elements where mask is True refinable
fix(selection, freeze_at_current=True)[source]

Make a selection of the tensor fixed (non-refinable).

Parameters:
  • selection (slice, torch.Tensor, or tuple) – Selection indicating which elements should become fixed. Can be: - Boolean tensor of same shape as the full tensor - Slice object (e.g., slice(10, 20)) - Tuple of indices for multidimensional tensors - Integer indices

  • freeze_at_current (bool, optional) – If True (default), freeze the selected elements at their current values. If False, they revert to the original fixed values.

Examples

mixed.fix(slice(10, 20))  # Fix elements 10-19
mixed.fix(mask)  # Fix elements where mask is True
refine_all()[source]

Make all elements refinable.

fix_all(freeze_at_current=True)[source]

Make all elements fixed.

property name: str | None

Return the name of this parameter.

__str__()[source]

More detailed string representation.

parameters()[source]

Return refinable parameters for optimizer.

class torchref.model.parameter_wrappers.PositiveMixedTensor(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, epsilon=0.1)[source]

Bases: MixedTensor

A MixedTensor subclass ensuring all values are positive via log-space parametrization.

Useful for parameters that must be strictly positive (e.g., B-factors, scale factors, sigma values). Values are stored as logarithms internally and converted to normal space via exp() when accessed.

Reparametrization:

internal_value = log(desired_value)
output_value = exp(internal_value)

This ensures output_value > 0 always, with smooth gradient flow.

Parameters:
  • initial_values (torch.Tensor, optional) – Initial tensor values in NORMAL space. Optional for empty init.

  • refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined.

  • requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for the tensor.

  • device (torch.device, optional) – Device for the tensor.

  • name (str, optional) – Optional name for this parameter.

  • epsilon (float, optional) – Small value to add before taking log to avoid log(0). Default is 1e-1.

Examples

Empty initialization for state_dict loading:

b = PositiveMixedTensor()
b.load_state_dict(torch.load('b_factors.pt'))

Full initialization with values:

initial_b = torch.tensor([20.0, 30.0, 15.0])
b = PositiveMixedTensor(initial_b)
output = b()  # Returns exp(log_b) = positive values
assert (b() > 0).all()
__init__(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, epsilon=0.1)[source]

Initialize a PositiveMixedTensor.

If initial_values is provided, fully initializes the tensor. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • initial_values (torch.Tensor, optional) – Initial tensor values in NORMAL space. Optional for empty init.

  • refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined.

  • requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for the tensor.

  • device (torch.device, optional) – Device for the tensor.

  • name (str, optional) – Optional name for this parameter.

  • epsilon (float, optional) – Small value to add before taking log to avoid log(0). Default is 1e-1.

Raises:

ValueError – If any initial values are not positive.

forward()[source]

Return the full tensor in NORMAL space.

Applies exponential transformation to the log-space values.

Returns:

Tensor with positive values.

Return type:

torch.Tensor

fix(mask, freeze_at_current=True)[source]

Fix (freeze) specific elements.

Converts current normal-space values to log space for storage.

Parameters:
  • mask (torch.Tensor) – Boolean mask indicating which elements to fix.

  • freeze_at_current (bool, optional) – If True, freeze at current values. Default is True.

refine(mask)[source]

Make specific elements refinable.

Preserves current log-space values.

Parameters:

mask (torch.Tensor) – Boolean mask indicating which elements to make refinable.

set(values, mask)[source]

Set values at positions specified by a boolean mask.

Values are provided in NORMAL space (e.g., actual B-factors) and automatically converted to log-space for internal storage.

Parameters:
  • values (torch.Tensor) – New values to assign in NORMAL space (positive values). Shape must be (n_selected,) where n_selected = mask.sum().

  • mask (torch.Tensor) – Boolean mask of shape (n_atoms,) indicating which elements to update. True positions will receive the new values.

Raises:

ValueError – If mask shape doesn’t match tensor’s first dimension, if values shape doesn’t match the number of selected elements, or if any values are not positive.

Examples

# Update B-factors for selected atoms
mask = model.get_selection_mask("name CA")
new_b = torch.ones(mask.sum()) * 30.0  # Set CA B-factors to 30
model.b.set(new_b, mask)

Notes

This method modifies the tensor in-place. Values are automatically converted to log-space internally to maintain the positivity constraint.

get_log_values()[source]

Return the internal log-space representation.

Useful for debugging or when direct access to the parametrization space is needed.

Returns:

Tensor with log-space values.

Return type:

torch.Tensor

update_refinable_mask(new_mask, reset_refinable=False)[source]

Update which elements are refinable.

Properly handles log-space conversion.

Parameters:
  • new_mask (torch.Tensor) – New boolean mask indicating refinable elements.

  • reset_refinable (bool, optional) – If True, reset refinable parameters to current fixed values. If False, keep existing refinable parameter values where possible. Default is False.

copy()[source]

Create a deep copy of this PositiveMixedTensor.

Properly handles the log-space reparametrization.

Returns:

New PositiveMixedTensor instance with copied data.

Return type:

PositiveMixedTensor

__str__()[source]

More detailed string representation.

class torchref.model.parameter_wrappers.CholeskyMixedTensor(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, epsilon=0.001)[source]

Bases: MixedTensor

A MixedTensor for anisotropic ADPs (U tensors) kept positive-definite.

The six U components (u11, u22, u33, u12, u13, u23) are stored internally as the six free parameters of a lower-triangular Cholesky factor L, and the public tensor is reconstructed as U = L Lᵀ. With the diagonal of L mapped through exp(x) + epsilon (strictly positive), U is positive- definite by construction for any value of the free parameters – so unconstrained optimisation (e.g. LBFGS line search) can never drive U indefinite. An indefinite U otherwise makes the per-atom anisotropic B-matrix singular, so its inverse and the Gaussian exponent blow up and the structure-factor FFT returns NaN. This is the anisotropic analogue of PositiveMixedTensor, which keeps isotropic B positive the same way.

Rows that are entirely non-finite (isotropic atoms carry U = NaN) are passed through unchanged in both directions, preserving the iso/aniso split.

Notes

The eigen-decomposition / Cholesky needed to map U -> L runs only at construction and on freeze/unfreeze; the forward (hot) path is just exp and a handful of products, so gradients flow cleanly to refinable_params with no matrix factorisation in the autograd graph.

__init__(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, epsilon=0.001)[source]

Initialize a MixedTensor.

If initial_values is provided, fully initializes the tensor. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • initial_values (torch.Tensor, optional) – Initial tensor values for all elements. Optional for empty init.

  • refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined. If None, all elements are refinable.

  • requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for the tensor. Default is same as initial_values.

  • device (torch.device, optional) – Device for the tensor. Default is same as initial_values.

  • name (str, optional) – Optional name for this parameter (useful for debugging/logging).

forward()[source]

Return the full U tensor (positive-definite per finite row).

fix(mask, freeze_at_current=True)[source]

Freeze rows, storing their current value in Cholesky space.

refine(mask)[source]

Make rows refinable, preserving their current value in Cholesky space.

set(values, mask)[source]

Set U-space values for masked rows (converted to Cholesky internally).

update_refinable_mask(new_mask, reset_refinable=False)[source]

Repartition refinable/fixed elements, preserving values in U space.

The base implementation re-stores forward() output directly, which would double-transform here (U written back into Cholesky-parameter storage); convert to Cholesky parameters first, mirroring PositiveMixedTensor.update_refinable_mask().

copy()[source]

Deep-copy, preserving the Cholesky parametrization.

Rebuilds from the U-space values (__init__ reconverts to Cholesky parameters), so the copy stays positive-definite rather than degrading to a plain unconstrained MixedTensor.

class torchref.model.parameter_wrappers.OccupancyTensor(initial_values=None, sharing_groups=None, altloc_groups=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, use_sigmoid=True)[source]

Bases: MixedTensor

A specialized MixedTensor for handling occupancy parameters in crystallographic refinement.

Handles specific constraints and requirements for occupancy including value bounds [0, 1] via sigmoid reparameterization, atom sharing, and alternative conformation sum-to-1 constraints.

Features: - Values bounded between 0 and 1 using sigmoid reparameterization - Atoms can share occupancies (e.g., all atoms in a residue) - Alternative conformations automatically sum to 1.0 via normalization - Memory-efficient collapsed storage (one parameter per sharing group) - Fully vectorized collapse/expand operations

Parameters:
  • initial_values (torch.Tensor, optional) – Initial occupancy values for ALL atoms (should be in [0, 1]). Optional for empty init.

  • sharing_groups (torch.Tensor, optional) – Tensor of shape (n_atoms,) where each value is the collapsed index for that atom. If None, each atom has independent occupancy.

  • altloc_groups (list of tuple, optional) – List of tuples of atom index lists representing alternative conformations. Each tuple contains the atom indices for each conformation.

  • refinable_mask (torch.Tensor, optional) – Boolean mask for which ATOMS can be refined (in full tensor space).

  • requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for the tensor.

  • device (torch.device, optional) – Device for the tensor.

  • name (str, optional) – Optional name for this parameter.

  • use_sigmoid (bool, optional) – If True, use sigmoid parameterization to bound values to [0,1]. Default is True.

expansion_mask

Maps atoms to collapsed indices.

Type:

torch.Tensor

linked_occ_sizes

List of altloc group sizes present.

Type:

list

collapse_counts

Count of atoms per collapsed index.

Type:

torch.Tensor

Examples

sharing_groups = torch.tensor([0, 0, 1, 1, 2, 2])
occ = OccupancyTensor(
    initial_values=torch.tensor([1.0, 1.0, 0.7, 0.7, 0.3, 0.3]),
    sharing_groups=sharing_groups,
    altloc_groups=[([2, 3], [4, 5])],
)
result = occ()  # Atoms 2-3 and 4-5 will sum to 1.0
__init__(initial_values=None, sharing_groups=None, altloc_groups=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, use_sigmoid=True)[source]

Initialize an OccupancyTensor with collapsed storage and altloc support.

If initial_values is provided, fully initializes the tensor. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • initial_values (torch.Tensor, optional) – Initial occupancy values for ALL atoms (should be in [0, 1]). Optional for empty init.

  • sharing_groups (torch.Tensor, optional) – Tensor of shape (n_atoms,) where each value is the collapsed index for that atom. If None, each atom has independent occupancy. Example: tensor([0, 0, 0, 1, 1, 2]) means atoms 0,1,2 share one occupancy, atoms 3,4 share another, and atom 5 is independent.

  • altloc_groups (list of tuple, optional) – List of tuples of atom index lists representing alternative conformations. Example: [([10,11], [12,13])] means atoms 10,11 (conf A) and 12,13 (conf B) are altlocs that sum to 1.0.

  • refinable_mask (torch.Tensor, optional) – Boolean mask for which ATOMS can be refined (in full tensor space). If any atom in a group is refinable, the entire group becomes refinable.

  • requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for the tensor.

  • device (torch.device, optional) – Device for the tensor.

  • name (str, optional) – Optional name for this parameter.

  • use_sigmoid (bool, optional) – If True, use sigmoid parameterization to bound values to [0,1]. Default is True.

forward()[source]

Reconstruct full occupancy tensor with sigmoid and altloc constraints.

For alternative conformations, applies sigmoid then normalizes within each group to enforce sum-to-1 constraint.

Returns:

Full occupancy tensor with values in [0, 1] and shape (n_atoms,).

Return type:

torch.Tensor

property shape

Return the shape of the FULL tensor (not collapsed).

property collapsed_shape

Return the shape of the collapsed internal storage.

clamp(min_value=0.0, max_value=1.0)[source]

Clamp occupancy values to specified range and return a new OccupancyTensor.

Parameters:
  • min_value (float, optional) – Minimum occupancy value. Default is 0.0.

  • max_value (float, optional) – Maximum occupancy value. Default is 1.0.

Returns:

New OccupancyTensor with clamped values.

Return type:

OccupancyTensor

set_group_occupancy(group_idx, value)[source]

Set the occupancy for all atoms in a specific collapsed group.

Parameters:
  • group_idx (int) – Collapsed index of the group.

  • value (float) – Occupancy value to set (must be in [0, 1]).

Raises:

ValueError – If group_idx is out of range or value is not in [0, 1].

get_group_occupancy(group_idx)[source]

Get the current occupancy value for a collapsed group.

Parameters:

group_idx (int) – Collapsed index of the group.

Returns:

Current occupancy value for the group.

Return type:

float

Raises:

ValueError – If group_idx is out of range.

freeze(mask=None)[source]

Freeze occupancy parameters, making them non-refinable.

The mask is supplied in UNCOMPRESSED (full atom) form but freezing operates on the COMPRESSED data structure. This method handles the conversion.

Parameters:

mask (torch.Tensor, optional) – Boolean mask in FULL (uncompressed) atom space indicating which atoms to freeze. If None, freeze all parameters. Shape must be (n_atoms,).

Notes

If ANY atom in a sharing group is frozen, the ENTIRE group is frozen because all atoms in a group share the same compressed parameter.

Examples

# Freeze atoms 0-10 (in full atom space)
freeze_mask = torch.zeros(n_atoms, dtype=torch.bool)
freeze_mask[0:11] = True
occ.freeze(freeze_mask)
# Freeze all atoms
occ.freeze()
unfreeze(mask=None)[source]

Unfreeze occupancy parameters, making them refinable.

The mask is supplied in UNCOMPRESSED (full atom) form but unfreezing operates on the COMPRESSED data structure. This method handles the conversion.

Parameters:

mask (torch.Tensor, optional) – Boolean mask in FULL (uncompressed) atom space indicating which atoms to unfreeze. If None, unfreeze all parameters. Shape must be (n_atoms,).

Notes

If ANY atom in a sharing group is unfrozen, the ENTIRE group becomes refinable because all atoms in a group share the same compressed parameter.

Examples

# Unfreeze atoms 100-200 (in full atom space)
unfreeze_mask = torch.zeros(n_atoms, dtype=torch.bool)
unfreeze_mask[100:201] = True
occ.unfreeze(unfreeze_mask)
# Unfreeze all atoms
occ.unfreeze()
freeze_all()[source]

Freeze all occupancy parameters.

Convenience method equivalent to freeze(None).

unfreeze_all()[source]

Unfreeze all occupancy parameters.

Convenience method equivalent to unfreeze(None).

get_refinable_atoms()[source]

Get a boolean mask in FULL atom space indicating refinable atoms.

Returns:

Boolean tensor of shape (n_atoms,) where True indicates the atom’s occupancy is refinable (though it shares with others in its group).

Return type:

torch.Tensor

get_frozen_atoms()[source]

Get a boolean mask in FULL atom space indicating frozen atoms.

Returns:

Boolean tensor of shape (n_atoms,) where True indicates the atom’s occupancy is frozen.

Return type:

torch.Tensor

get_refinable_count()[source]

Get the number of refinable parameters in COMPRESSED space.

This is the number of refinable groups, not atoms. Use get_refinable_atoms().sum() to get the number of refinable atoms.

Returns:

Number of refinable compressed parameters.

Return type:

int

get_fixed_count()[source]

Get the number of fixed parameters in COMPRESSED space.

This is the number of fixed groups, not atoms. Use get_frozen_atoms().sum() to get the number of frozen atoms.

Returns:

Number of fixed compressed parameters.

Return type:

int

update_refinable_mask(new_mask, in_compressed_space=False)[source]

Directly update the refinable mask with a new mask.

Allows more direct control over which parameters are refinable, compared to freeze/unfreeze which modify the existing state.

Parameters:
  • new_mask (torch.Tensor) – Boolean tensor indicating which parameters should be refinable. If in_compressed_space=False: shape (n_atoms,) in full atom space. If in_compressed_space=True: shape (n_groups,) in compressed space.

  • in_compressed_space (bool, optional) – If True, new_mask is in compressed space. If False (default), new_mask is in full atom space and will be collapsed.

Examples

Full atom space:

atom_mask = torch.zeros(n_atoms, dtype=torch.bool)
atom_mask[:100] = True
occ.update_refinable_mask(atom_mask, in_compressed_space=False)

Compressed space:

group_mask = torch.zeros(n_groups, dtype=torch.bool)
group_mask[::2] = True
occ.update_refinable_mask(group_mask, in_compressed_space=True)
static from_residue_groups(initial_values, pdb_dataframe, refinable_mask=None, **kwargs)[source]

Create an OccupancyTensor where all atoms in each residue share occupancy.

Common use case where all atoms in a residue should have the same occupancy.

Parameters:
  • initial_values (torch.Tensor) – Initial occupancy values for all atoms.

  • pdb_dataframe (pandas.DataFrame) – DataFrame with PDB data (must have ‘resname’, ‘resseq’, ‘chainid’).

  • refinable_mask (torch.Tensor, optional) – Mask for refinable atoms.

  • **kwargs – Additional arguments passed to OccupancyTensor constructor.

Returns:

OccupancyTensor with residue-based sharing groups.

Return type:

OccupancyTensor

copy()[source]

Create a deep copy of this OccupancyTensor.

Creates a complete independent copy with all buffers and parameters, including sharing groups, altloc groups, and collapsed storage structures.

Returns:

New OccupancyTensor instance with copied data.

Return type:

OccupancyTensor

class torchref.model.parameter_wrappers.PassThroughTensor(initial_values, requires_grad=True, dtype=None, device=None, name=None)[source]

Bases: DeviceMixin, Module

A simple parameter wrapper that passes the parameter through unchanged.

Useful as a placeholder or for parameters that do not require any special handling.

Parameters:
  • initial_values (torch.Tensor) – Initial tensor values.

  • requires_grad (bool, optional) – Whether the parameter requires gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type of the tensor.

  • device (torch.device, optional) – Device to place the tensor on.

  • name (str, optional) – Optional name for the parameter.

__init__(initial_values, requires_grad=True, dtype=None, device=None, name=None)[source]

Initialize the PassThroughTensor.

Parameters:
  • initial_values (torch.Tensor) – Initial tensor values.

  • requires_grad (bool, optional) – Whether the parameter requires gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type of the tensor.

  • device (torch.device, optional) – Device to place the tensor on.

  • name (str, optional) – Optional name for the parameter.

forward()[source]

Return the parameter value unchanged.

Returns:

The parameter tensor.

Return type:

torch.Tensor