torchref.model.parameter_wrappers module
A file that contains wrapper classes for handling crystallographic parameters. (Occ, xyz, B, etc.)
- class torchref.model.parameter_wrappers.MixedTensor(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None)[source]
Bases:
DeviceMixin,CachedForwardMixin,ModuleA wrapper class for tensors with mixed fixed and refinable elements.
Stores a mask indicating which elements can be refined and maintains both fixed and refinable components separately. The full tensor is reconstructed on-the-fly when accessed.
- Parameters:
initial_values (torch.Tensor, optional) – Initial tensor values for all elements. Optional for empty init.
refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined. If None, all elements are refinable.
requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for the tensor. Default is same as initial_values.
device (torch.device, optional) – Device for the tensor. Default is same as initial_values.
name (str, optional) – Optional name for this parameter (useful for debugging/logging).
- refinable_mask
Boolean mask indicating refinable elements.
- Type:
- fixed_mask
Boolean mask indicating fixed elements (inverse of refinable_mask).
- Type:
- fixed_values
Buffer containing fixed values.
- Type:
- refinable_params
Parameter containing refinable values.
- Type:
nn.Parameter
Examples
Empty initialization for state_dict loading:
mixed = MixedTensor() mixed.load_state_dict(torch.load('mixed.pt'))
Full initialization with values:
mask = torch.zeros(100, dtype=torch.bool) mask[20:30] = True initial_values = torch.randn(100) mixed = MixedTensor(initial_values, refinable_mask=mask, requires_grad=True) optimizer = torch.optim.Adam([mixed.refinable_params], lr=0.01)
- __init__(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None)[source]
Initialize a MixedTensor.
If initial_values is provided, fully initializes the tensor. If not provided (empty init), creates a shell ready for load_state_dict().
- Parameters:
initial_values (torch.Tensor, optional) – Initial tensor values for all elements. Optional for empty init.
refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined. If None, all elements are refinable.
requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for the tensor. Default is same as initial_values.
device (torch.device, optional) – Device for the tensor. Default is same as initial_values.
name (str, optional) – Optional name for this parameter (useful for debugging/logging).
- forward()[source]
Reconstruct and return the full tensor.
Three fast paths, in priority order:
All atoms refinable — return
refinable_paramsdirectly (no clone, no scatter). Common in standard refinement where no atoms are frozen. Saves a full-tensor clone + anindex_put_per call and replaces theindex_putbackward (sort + atomic scatter) with a no-op identity.No refinable atoms — return
fixed_values.clone()(the clone preserves the caller-must-not-mutate contract even though the result is detached from autograd).Mixed — go through
_AssembleMixedTensor, whose backward is a singleindex_select(gather) instead of PyTorch’s defaultindex_put_backward (radix-sort + scatter).
- __getitem__(key)[source]
Get values at specified indices/mask from the full tensor.
- Parameters:
key (int, slice, torch.Tensor, or tuple) – Index specification. Can be: - int: Single element - slice: Range of elements (e.g., 5:10, :, ::2) - torch.Tensor: Boolean mask or integer indices - tuple: Multi-dimensional indexing
- Returns:
Selected values from the full tensor.
- Return type:
Examples
model.b[5] # Get B-factor for atom 5 model.b[5:10] # Get B-factors for atoms 5-9 model.b[mask] # Get B-factors where mask is True model.xyz[:, 0] # Get all x-coordinates
Notes
Subclasses may override _get_values() to customize value retrieval.
- __setitem__(key, value)[source]
Set values at specified indices/mask.
This method updates both fixed_values and refinable_params at the specified positions. Supports various indexing styles including slices, boolean masks, and integer indices.
- Parameters:
key (int, slice, torch.Tensor, or tuple) – Index specification. Can be: - int: Single element - slice: Range of elements (e.g., 5:10, :, ::2) - torch.Tensor: Boolean mask or integer indices - tuple: Multi-dimensional indexing
value (torch.Tensor, float, or int) – Values to assign. Can be: - Scalar: Broadcast to all selected positions - Tensor: Must match the shape of selected region
Examples
model.b[:] = 30.0 # Set all B-factors to 30 model.b[5:10] = 25.0 # Set B-factors 5-9 to 25 model.b[mask] = new_values # Set B-factors where mask is True model.xyz[mask] = new_coords # Set coordinates for masked atoms model.xyz[:, 0] += 1.0 # Shift all x-coordinates (read-modify-write)
Notes
This method modifies the tensor in-place. The refinable_params parameter is replaced with a new Parameter containing the updated values, which may affect optimizer state.
Subclasses may override _set_values() to customize value handling (e.g., PositiveMixedTensor converts to log-space).
- set(values, mask)[source]
Set values at positions specified by a boolean mask.
Updates both fixed_values and refinable_params at the positions specified by the mask. This is useful for applying coordinate shifts, B-factor corrections, or any other updates to specific atoms.
- Parameters:
values (torch.Tensor) –
New values to assign. Shape must match: - For 1D tensors: (n_selected,) where n_selected = mask.sum() - For 2D tensors (e.g., xyz): (n_selected, d) where d is the
second dimension size (e.g., 3 for coordinates)
mask (torch.Tensor) – Boolean mask of shape (n_atoms,) indicating which elements to update. True positions will receive the new values.
- Raises:
ValueError – If mask shape doesn’t match tensor’s first dimension, or if values shape doesn’t match the number of selected elements.
Examples
# Update coordinates for selected atoms mask = model.get_selection_mask("chain A") new_coords = original_coords[mask] + shift model.xyz.set(new_coords, mask) # Update B-factors for specific residues mask = model.get_selection_mask("resseq 10:20") new_b = torch.ones(mask.sum()) * 30.0 model.b.set(new_b, mask)
Notes
This method modifies the tensor in-place. The refinable_params parameter is replaced with a new Parameter containing the updated values, which may affect optimizer state.
- property shape
Return the shape of the full tensor.
- property dtype
Return the dtype of the tensor.
- property device
Return the device of the tensor.
- update_fixed_values(new_values)[source]
Update the fixed values (does not affect refinable parameters).
- Parameters:
new_values (torch.Tensor) – New tensor values. Only fixed positions will be updated.
- Raises:
ValueError – If new_values shape doesn’t match tensor shape.
- update_refinable_mask(new_mask, reset_refinable=False)[source]
Update which elements are refinable.
This is an advanced operation that modifies the refinable/fixed split.
- Parameters:
new_mask (torch.Tensor) – New boolean mask indicating refinable elements.
reset_refinable (bool, optional) – If True, reset refinable parameters to current fixed values. If False, keep existing refinable parameter values where possible. Default is False.
- copy()[source]
Create a deep copy of this MixedTensor.
Creates a complete independent copy with all buffers and parameters. Alias for clone().
- Returns:
New MixedTensor instance with copied data.
- Return type:
- clip(min_value=None, max_value=None)[source]
Clip the full tensor values between min_value and max_value.
- refine(selection, reset_values=False)[source]
Make a selection of the tensor refinable.
- Parameters:
selection (slice, torch.Tensor, or tuple) – Selection indicating which elements should become refinable. Can be: - Boolean tensor of same shape as the full tensor - Slice object (e.g., slice(10, 20)) - Tuple of indices for multidimensional tensors - Integer indices
reset_values (bool, optional) – If True, reset the selected elements to their current fixed values before making them refinable. Default is False.
Examples
mixed.refine(slice(10, 20)) # Make elements 10-19 refinable mixed.refine(mask) # Make elements where mask is True refinable
- fix(selection, freeze_at_current=True)[source]
Make a selection of the tensor fixed (non-refinable).
- Parameters:
selection (slice, torch.Tensor, or tuple) – Selection indicating which elements should become fixed. Can be: - Boolean tensor of same shape as the full tensor - Slice object (e.g., slice(10, 20)) - Tuple of indices for multidimensional tensors - Integer indices
freeze_at_current (bool, optional) – If True (default), freeze the selected elements at their current values. If False, they revert to the original fixed values.
Examples
mixed.fix(slice(10, 20)) # Fix elements 10-19 mixed.fix(mask) # Fix elements where mask is True
- class torchref.model.parameter_wrappers.PositiveMixedTensor(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, epsilon=0.1)[source]
Bases:
MixedTensorA MixedTensor subclass ensuring all values are positive via log-space parametrization.
Useful for parameters that must be strictly positive (e.g., B-factors, scale factors, sigma values). Values are stored as logarithms internally and converted to normal space via exp() when accessed.
Reparametrization:
internal_value = log(desired_value) output_value = exp(internal_value)
This ensures output_value > 0 always, with smooth gradient flow.
- Parameters:
initial_values (torch.Tensor, optional) – Initial tensor values in NORMAL space. Optional for empty init.
refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined.
requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for the tensor.
device (torch.device, optional) – Device for the tensor.
name (str, optional) – Optional name for this parameter.
epsilon (float, optional) – Small value to add before taking log to avoid log(0). Default is 1e-1.
Examples
Empty initialization for state_dict loading:
b = PositiveMixedTensor() b.load_state_dict(torch.load('b_factors.pt'))
Full initialization with values:
initial_b = torch.tensor([20.0, 30.0, 15.0]) b = PositiveMixedTensor(initial_b) output = b() # Returns exp(log_b) = positive values assert (b() > 0).all()
- __init__(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, epsilon=0.1)[source]
Initialize a PositiveMixedTensor.
If initial_values is provided, fully initializes the tensor. If not provided (empty init), creates a shell ready for load_state_dict().
- Parameters:
initial_values (torch.Tensor, optional) – Initial tensor values in NORMAL space. Optional for empty init.
refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined.
requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for the tensor.
device (torch.device, optional) – Device for the tensor.
name (str, optional) – Optional name for this parameter.
epsilon (float, optional) – Small value to add before taking log to avoid log(0). Default is 1e-1.
- Raises:
ValueError – If any initial values are not positive.
- forward()[source]
Return the full tensor in NORMAL space.
Applies exponential transformation to the log-space values.
- Returns:
Tensor with positive values.
- Return type:
- fix(mask, freeze_at_current=True)[source]
Fix (freeze) specific elements.
Converts current normal-space values to log space for storage.
- Parameters:
mask (torch.Tensor) – Boolean mask indicating which elements to fix.
freeze_at_current (bool, optional) – If True, freeze at current values. Default is True.
- refine(mask)[source]
Make specific elements refinable.
Preserves current log-space values.
- Parameters:
mask (torch.Tensor) – Boolean mask indicating which elements to make refinable.
- set(values, mask)[source]
Set values at positions specified by a boolean mask.
Values are provided in NORMAL space (e.g., actual B-factors) and automatically converted to log-space for internal storage.
- Parameters:
values (torch.Tensor) – New values to assign in NORMAL space (positive values). Shape must be (n_selected,) where n_selected = mask.sum().
mask (torch.Tensor) – Boolean mask of shape (n_atoms,) indicating which elements to update. True positions will receive the new values.
- Raises:
ValueError – If mask shape doesn’t match tensor’s first dimension, if values shape doesn’t match the number of selected elements, or if any values are not positive.
Examples
# Update B-factors for selected atoms mask = model.get_selection_mask("name CA") new_b = torch.ones(mask.sum()) * 30.0 # Set CA B-factors to 30 model.b.set(new_b, mask)
Notes
This method modifies the tensor in-place. Values are automatically converted to log-space internally to maintain the positivity constraint.
- get_log_values()[source]
Return the internal log-space representation.
Useful for debugging or when direct access to the parametrization space is needed.
- Returns:
Tensor with log-space values.
- Return type:
- update_refinable_mask(new_mask, reset_refinable=False)[source]
Update which elements are refinable.
Properly handles log-space conversion.
- Parameters:
new_mask (torch.Tensor) – New boolean mask indicating refinable elements.
reset_refinable (bool, optional) – If True, reset refinable parameters to current fixed values. If False, keep existing refinable parameter values where possible. Default is False.
- class torchref.model.parameter_wrappers.CholeskyMixedTensor(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, epsilon=0.001)[source]
Bases:
MixedTensorA MixedTensor for anisotropic ADPs (U tensors) kept positive-definite.
The six U components (u11, u22, u33, u12, u13, u23) are stored internally as the six free parameters of a lower-triangular Cholesky factor
L, and the public tensor is reconstructed asU = L Lᵀ. With the diagonal ofLmapped throughexp(x) + epsilon(strictly positive),Uis positive- definite by construction for any value of the free parameters – so unconstrained optimisation (e.g. LBFGS line search) can never driveUindefinite. An indefiniteUotherwise makes the per-atom anisotropic B-matrix singular, so its inverse and the Gaussian exponent blow up and the structure-factor FFT returns NaN. This is the anisotropic analogue ofPositiveMixedTensor, which keeps isotropic B positive the same way.Rows that are entirely non-finite (isotropic atoms carry
U = NaN) are passed through unchanged in both directions, preserving the iso/aniso split.Notes
The eigen-decomposition / Cholesky needed to map
U -> Lruns only at construction and on freeze/unfreeze; the forward (hot) path is justexpand a handful of products, so gradients flow cleanly torefinable_paramswith no matrix factorisation in the autograd graph.- __init__(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, epsilon=0.001)[source]
Initialize a MixedTensor.
If initial_values is provided, fully initializes the tensor. If not provided (empty init), creates a shell ready for load_state_dict().
- Parameters:
initial_values (torch.Tensor, optional) – Initial tensor values for all elements. Optional for empty init.
refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined. If None, all elements are refinable.
requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for the tensor. Default is same as initial_values.
device (torch.device, optional) – Device for the tensor. Default is same as initial_values.
name (str, optional) – Optional name for this parameter (useful for debugging/logging).
- fix(mask, freeze_at_current=True)[source]
Freeze rows, storing their current value in Cholesky space.
- update_refinable_mask(new_mask, reset_refinable=False)[source]
Repartition refinable/fixed elements, preserving values in U space.
The base implementation re-stores
forward()output directly, which would double-transform here (U written back into Cholesky-parameter storage); convert to Cholesky parameters first, mirroringPositiveMixedTensor.update_refinable_mask().
- copy()[source]
Deep-copy, preserving the Cholesky parametrization.
Rebuilds from the U-space values (
__init__reconverts to Cholesky parameters), so the copy stays positive-definite rather than degrading to a plain unconstrainedMixedTensor.
- class torchref.model.parameter_wrappers.OccupancyTensor(initial_values=None, sharing_groups=None, altloc_groups=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, use_sigmoid=True)[source]
Bases:
MixedTensorA specialized MixedTensor for handling occupancy parameters in crystallographic refinement.
Handles specific constraints and requirements for occupancy including value bounds [0, 1] via sigmoid reparameterization, atom sharing, and alternative conformation sum-to-1 constraints.
Features: - Values bounded between 0 and 1 using sigmoid reparameterization - Atoms can share occupancies (e.g., all atoms in a residue) - Alternative conformations automatically sum to 1.0 via normalization - Memory-efficient collapsed storage (one parameter per sharing group) - Fully vectorized collapse/expand operations
- Parameters:
initial_values (torch.Tensor, optional) – Initial occupancy values for ALL atoms (should be in [0, 1]). Optional for empty init.
sharing_groups (torch.Tensor, optional) – Tensor of shape (n_atoms,) where each value is the collapsed index for that atom. If None, each atom has independent occupancy.
altloc_groups (list of tuple, optional) – List of tuples of atom index lists representing alternative conformations. Each tuple contains the atom indices for each conformation.
refinable_mask (torch.Tensor, optional) – Boolean mask for which ATOMS can be refined (in full tensor space).
requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for the tensor.
device (torch.device, optional) – Device for the tensor.
name (str, optional) – Optional name for this parameter.
use_sigmoid (bool, optional) – If True, use sigmoid parameterization to bound values to [0,1]. Default is True.
- expansion_mask
Maps atoms to collapsed indices.
- Type:
- collapse_counts
Count of atoms per collapsed index.
- Type:
Examples
sharing_groups = torch.tensor([0, 0, 1, 1, 2, 2]) occ = OccupancyTensor( initial_values=torch.tensor([1.0, 1.0, 0.7, 0.7, 0.3, 0.3]), sharing_groups=sharing_groups, altloc_groups=[([2, 3], [4, 5])], ) result = occ() # Atoms 2-3 and 4-5 will sum to 1.0
- __init__(initial_values=None, sharing_groups=None, altloc_groups=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, use_sigmoid=True)[source]
Initialize an OccupancyTensor with collapsed storage and altloc support.
If initial_values is provided, fully initializes the tensor. If not provided (empty init), creates a shell ready for load_state_dict().
- Parameters:
initial_values (torch.Tensor, optional) – Initial occupancy values for ALL atoms (should be in [0, 1]). Optional for empty init.
sharing_groups (torch.Tensor, optional) – Tensor of shape (n_atoms,) where each value is the collapsed index for that atom. If None, each atom has independent occupancy. Example: tensor([0, 0, 0, 1, 1, 2]) means atoms 0,1,2 share one occupancy, atoms 3,4 share another, and atom 5 is independent.
altloc_groups (list of tuple, optional) – List of tuples of atom index lists representing alternative conformations. Example: [([10,11], [12,13])] means atoms 10,11 (conf A) and 12,13 (conf B) are altlocs that sum to 1.0.
refinable_mask (torch.Tensor, optional) – Boolean mask for which ATOMS can be refined (in full tensor space). If any atom in a group is refinable, the entire group becomes refinable.
requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for the tensor.
device (torch.device, optional) – Device for the tensor.
name (str, optional) – Optional name for this parameter.
use_sigmoid (bool, optional) – If True, use sigmoid parameterization to bound values to [0,1]. Default is True.
- forward()[source]
Reconstruct full occupancy tensor with sigmoid and altloc constraints.
For alternative conformations, applies sigmoid then normalizes within each group to enforce sum-to-1 constraint.
- Returns:
Full occupancy tensor with values in [0, 1] and shape (n_atoms,).
- Return type:
- property shape
Return the shape of the FULL tensor (not collapsed).
- property collapsed_shape
Return the shape of the collapsed internal storage.
- clamp(min_value=0.0, max_value=1.0)[source]
Clamp occupancy values to specified range and return a new OccupancyTensor.
- Parameters:
- Returns:
New OccupancyTensor with clamped values.
- Return type:
- set_group_occupancy(group_idx, value)[source]
Set the occupancy for all atoms in a specific collapsed group.
- Parameters:
- Raises:
ValueError – If group_idx is out of range or value is not in [0, 1].
- get_group_occupancy(group_idx)[source]
Get the current occupancy value for a collapsed group.
- Parameters:
group_idx (int) – Collapsed index of the group.
- Returns:
Current occupancy value for the group.
- Return type:
- Raises:
ValueError – If group_idx is out of range.
- freeze(mask=None)[source]
Freeze occupancy parameters, making them non-refinable.
The mask is supplied in UNCOMPRESSED (full atom) form but freezing operates on the COMPRESSED data structure. This method handles the conversion.
- Parameters:
mask (torch.Tensor, optional) – Boolean mask in FULL (uncompressed) atom space indicating which atoms to freeze. If None, freeze all parameters. Shape must be (n_atoms,).
Notes
If ANY atom in a sharing group is frozen, the ENTIRE group is frozen because all atoms in a group share the same compressed parameter.
Examples
# Freeze atoms 0-10 (in full atom space) freeze_mask = torch.zeros(n_atoms, dtype=torch.bool) freeze_mask[0:11] = True occ.freeze(freeze_mask) # Freeze all atoms occ.freeze()
- unfreeze(mask=None)[source]
Unfreeze occupancy parameters, making them refinable.
The mask is supplied in UNCOMPRESSED (full atom) form but unfreezing operates on the COMPRESSED data structure. This method handles the conversion.
- Parameters:
mask (torch.Tensor, optional) – Boolean mask in FULL (uncompressed) atom space indicating which atoms to unfreeze. If None, unfreeze all parameters. Shape must be (n_atoms,).
Notes
If ANY atom in a sharing group is unfrozen, the ENTIRE group becomes refinable because all atoms in a group share the same compressed parameter.
Examples
# Unfreeze atoms 100-200 (in full atom space) unfreeze_mask = torch.zeros(n_atoms, dtype=torch.bool) unfreeze_mask[100:201] = True occ.unfreeze(unfreeze_mask) # Unfreeze all atoms occ.unfreeze()
- freeze_all()[source]
Freeze all occupancy parameters.
Convenience method equivalent to freeze(None).
- unfreeze_all()[source]
Unfreeze all occupancy parameters.
Convenience method equivalent to unfreeze(None).
- get_refinable_atoms()[source]
Get a boolean mask in FULL atom space indicating refinable atoms.
- Returns:
Boolean tensor of shape (n_atoms,) where True indicates the atom’s occupancy is refinable (though it shares with others in its group).
- Return type:
- get_frozen_atoms()[source]
Get a boolean mask in FULL atom space indicating frozen atoms.
- Returns:
Boolean tensor of shape (n_atoms,) where True indicates the atom’s occupancy is frozen.
- Return type:
- get_refinable_count()[source]
Get the number of refinable parameters in COMPRESSED space.
This is the number of refinable groups, not atoms. Use get_refinable_atoms().sum() to get the number of refinable atoms.
- Returns:
Number of refinable compressed parameters.
- Return type:
- get_fixed_count()[source]
Get the number of fixed parameters in COMPRESSED space.
This is the number of fixed groups, not atoms. Use get_frozen_atoms().sum() to get the number of frozen atoms.
- Returns:
Number of fixed compressed parameters.
- Return type:
- update_refinable_mask(new_mask, in_compressed_space=False)[source]
Directly update the refinable mask with a new mask.
Allows more direct control over which parameters are refinable, compared to freeze/unfreeze which modify the existing state.
- Parameters:
new_mask (torch.Tensor) – Boolean tensor indicating which parameters should be refinable. If in_compressed_space=False: shape (n_atoms,) in full atom space. If in_compressed_space=True: shape (n_groups,) in compressed space.
in_compressed_space (bool, optional) – If True, new_mask is in compressed space. If False (default), new_mask is in full atom space and will be collapsed.
Examples
Full atom space:
atom_mask = torch.zeros(n_atoms, dtype=torch.bool) atom_mask[:100] = True occ.update_refinable_mask(atom_mask, in_compressed_space=False)
Compressed space:
group_mask = torch.zeros(n_groups, dtype=torch.bool) group_mask[::2] = True occ.update_refinable_mask(group_mask, in_compressed_space=True)
- static from_residue_groups(initial_values, pdb_dataframe, refinable_mask=None, **kwargs)[source]
Create an OccupancyTensor where all atoms in each residue share occupancy.
Common use case where all atoms in a residue should have the same occupancy.
- Parameters:
initial_values (torch.Tensor) – Initial occupancy values for all atoms.
pdb_dataframe (pandas.DataFrame) – DataFrame with PDB data (must have ‘resname’, ‘resseq’, ‘chainid’).
refinable_mask (torch.Tensor, optional) – Mask for refinable atoms.
**kwargs – Additional arguments passed to OccupancyTensor constructor.
- Returns:
OccupancyTensor with residue-based sharing groups.
- Return type:
- class torchref.model.parameter_wrappers.PassThroughTensor(initial_values, requires_grad=True, dtype=None, device=None, name=None)[source]
Bases:
DeviceMixin,ModuleA simple parameter wrapper that passes the parameter through unchanged.
Useful as a placeholder or for parameters that do not require any special handling.
- Parameters:
initial_values (torch.Tensor) – Initial tensor values.
requires_grad (bool, optional) – Whether the parameter requires gradients. Default is True.
dtype (torch.dtype, optional) – Data type of the tensor.
device (torch.device, optional) – Device to place the tensor on.
name (str, optional) – Optional name for the parameter.
- __init__(initial_values, requires_grad=True, dtype=None, device=None, name=None)[source]
Initialize the PassThroughTensor.
- Parameters:
initial_values (torch.Tensor) – Initial tensor values.
requires_grad (bool, optional) – Whether the parameter requires gradients. Default is True.
dtype (torch.dtype, optional) – Data type of the tensor.
device (torch.device, optional) – Device to place the tensor on.
name (str, optional) – Optional name for the parameter.