torchref.refinement.weighting package
Weighting schemes for loss component aggregation.
This module provides weighting schemes that adjust the relative importance of different loss components during refinement.
All weighting schemes inherit from BaseWeighting and receive data through LossState rather than direct refinement references.
- class torchref.refinement.weighting.BaseWeighting(device=None, **kwargs)[source]
Bases:
DeviceMixin,Module,ABCAbstract base class for weighting schemes using LossState.
Weighting schemes: - Are set up without state (only device and hyperparameters) - Receive state only when computing weights via forward() - Return weights dict (do not modify state directly)
All tunable parameters should be registered as buffers using register_buffer() so they can be accessed/modified via state_dict notation.
- Parameters:
device (torch.device, optional) – Computation device. Defaults to the configured device.current.
- device
Computation device.
- Type:
- torchref.refinement.weighting.WeightingScheme
alias of
BaseWeighting
- class torchref.refinement.weighting.ResolutionWeighting(device=None, base_w_geometry=1.0, base_w_adp=1.0, d_ref=2.0, alpha=0.0)[source]
Bases:
BaseWeightingBase prior strength with optional resolution-dependent correction.
With proper NLL sums and perfectly calibrated sigmas, w=1 would be the pure Bayesian answer. However empirical sweep on 50 structures (1.15-3.0 A) showed the monomer library sigmas lead to too-loose geometry at w=1. The optimal base is:
w_geometry ~= 10 (from median Rfree minimum) w_adp ~= 1 (lower is better for most resolutions)
Geometry and ADP have different base weights because the monomer library geometry sigmas are tighter than the ADP restraint sigmas are loose, so they need different compensation.
Optional resolution dependence:
w_geometry = base_w_geometry * (d_min / d_ref) ^ alpha w_adp = base_w_adp * (d_min / d_ref) ^ alpha
alpha=0 disables resolution correction. The sweep found the resolution dependence was weak (<20% effect), so alpha=0 is the default.
- Parameters:
device (torch.device, optional) – Computation device.
base_w_geometry (float, optional) – Base geometry weight. Default 10.0 (from empirical sweep).
base_w_adp (float, optional) – Base ADP weight. Default 1.0 (from empirical sweep).
d_ref (float, optional) – Reference resolution for the optional power-law correction. Default 2.0 A.
alpha (float, optional) – Resolution sensitivity. Default 0.0 (disabled).
- class torchref.refinement.weighting.OverfittingWeighting(device=None, target_gap=0.05, min_weight=1.0, sharpness=30.0, geom_share=1.0, smoothing=0.8)[source]
Bases:
BaseWeightingDynamic overfitting correction based on Rfree - Rwork gap.
Uses R-factors (scale-invariant, per-reflection-normalized) rather than NLL values, which are incomparable between work/test sets after switching to summed NLLs.
When the Rfree-Rwork gap exceeds target_gap, exponentially increases regularization. The correction is applied primarily to ADP weights and only weakly to geometry weights: in crystallographic refinement, overfitting is typically driven by B-factors (which have one parameter per atom and relatively weak restraints) rather than coordinates (which are held tightly by the geometry prior).
- Effective correction:
factor = min_weight + exp(sharpness * (gap - target_gap)) w_adp *= factor w_geometry *= 1 + geom_share * (factor - 1)
With geom_share = 0.2, only 20% of the overfitting correction is applied to geometry, keeping most of the effect on ADP.
Tunable parameters (as buffers): - target_gap: R-factor gap threshold. Default 0.05 (5%). - min_weight: base correction factor. Default 1.0. - sharpness: exponential response steepness. Default 30.0. - geom_share: fraction of correction applied to geometry. Default 1.0. - smoothing: EMA smoothing factor (0-1). Default 0.8.
- class torchref.refinement.weighting.ManualWeighting(weights, device=None)[source]
Bases:
BaseWeightingApply fixed manual weights.
This scheme doesn’t need any state data - just returns the present weights.
- class torchref.refinement.weighting.ComponentWeighting(device=None, weights=None, component_weights=None, schemes=None, initial_xray_loss=None)[source]
Bases:
DeviceMixin,ModuleCombines multiple weighting schemes using nn.ModuleDict.
Holds weighting schemes but does NOT hold a refinement reference. Weighting is computed via forward(state) which receives a LossState.
Default schemes: - ‘resolution’: ResolutionWeighting - resolution-dependent prior strength - ‘overfitting’: OverfittingWeighting - prevents overfitting via Rfree gap
- Parameters:
device (torch.device, optional) – Computation device.
weights (dict, optional) – Manual weight overrides.
component_weights (dict, optional) – Manual weight overrides for specific components.
schemes (list of BaseWeighting, optional) – Additional custom weighting schemes.
- schemes
Dictionary of weighting schemes.
- Type:
nn.ModuleDict
- __init__(device=None, weights=None, component_weights=None, schemes=None, initial_xray_loss=None)[source]
Submodules
- torchref.refinement.weighting.base_weighting module
- torchref.refinement.weighting.component_weighting module
BaseWeightingWeightingSchemeResolutionWeightingOverfittingWeightingManualWeightingComponentWeightingComponentWeighting.schemesComponentWeighting.__init__()ComponentWeighting.__getitem__()ComponentWeighting.__contains__()ComponentWeighting.keys()ComponentWeighting.values()ComponentWeighting.items()ComponentWeighting.add_scheme()ComponentWeighting.forward()ComponentWeighting.total_loss_from_state()ComponentWeighting.stats()
- torchref.refinement.weighting.es_policy_training module
ESConfigESResultGenerationResultESPolicyTrainerESPolicyTrainer.__init__()ESPolicyTrainer.historyESPolicyTrainer.get_flat_params()ESPolicyTrainer.set_flat_params()ESPolicyTrainer.sample_perturbations()ESPolicyTrainer.create_perturbed_policy()ESPolicyTrainer.compute_advantages()ESPolicyTrainer.compute_gradient()ESPolicyTrainer.update()ESPolicyTrainer.evaluate_generation()ESPolicyTrainer.save_checkpoint()ESPolicyTrainer.load_checkpoint()
create_policy_network()initialize_policy_sensibly()
- torchref.refinement.weighting.policy_weighting module
PolicyComponentWeightingPolicyComponentWeighting.policyPolicyComponentWeighting.samplePolicyComponentWeighting.temperaturePolicyComponentWeighting.current_stepPolicyComponentWeighting.last_log_weightsPolicyComponentWeighting.last_weightsPolicyComponentWeighting.trajectoryPolicyComponentWeighting.namePolicyComponentWeighting.__init__()PolicyComponentWeighting.set_training_mode()PolicyComponentWeighting.set_eval_mode()PolicyComponentWeighting.reset_trajectory()PolicyComponentWeighting.start_recording()PolicyComponentWeighting.stop_recording()PolicyComponentWeighting.increment_step()PolicyComponentWeighting.extract_state_features()PolicyComponentWeighting.forward()PolicyComponentWeighting.compute_weights()PolicyComponentWeighting.apply_to_state()PolicyComponentWeighting.stats()
StepStateStepState.stepStepState.progressStepState.improvement_rateStepState.rworkStepState.rfreeStepState.rfree_gapStepState.delta_rfreeStepState.xray_lossStepState.xray_loss_testStepState.xray_work_test_ratioStepState.geometry_lossStepState.adp_lossStepState.bond_rmsdStepState.angle_rmsdStepState.mean_adp_normalizedStepState.adp_std_normalizedStepState.component_lossesStepState.component_weightsStepState.__init__()
StepRecordTrajectoryDataTrajectoryData.pdb_idTrajectoryData.structure_pathTrajectoryData.sf_pathTrajectoryData.stepsTrajectoryData.initial_rfreeTrajectoryData.final_rfreeTrajectoryData.initial_rworkTrajectoryData.final_rworkTrajectoryData.total_timeTrajectoryData.random_seedTrajectoryData.policy_versionTrajectoryData.successTrajectoryData.error_messageTrajectoryData.__init__()
trajectory_to_dict()
- torchref.refinement.weighting.random_weighting module
- Weight Sampling Strategy
RandomWeightingSchemeRandomWeightingScheme.base_log_weightsRandomWeightingScheme.current_log_weightsRandomWeightingScheme.current_weightsRandomWeightingScheme.sample_countRandomWeightingScheme.nameRandomWeightingScheme.__init__()RandomWeightingScheme.apply_step_perturbation()RandomWeightingScheme.resample_trajectory()RandomWeightingScheme.resample_weights()RandomWeightingScheme.forward()RandomWeightingScheme.sample_countRandomWeightingScheme.get_log_weights()RandomWeightingScheme.get_base_log_weights()RandomWeightingScheme.get_weights()RandomWeightingScheme.stats()
RandomComponentWeightingRandomComponentWeighting.__init__()RandomComponentWeighting.random_schemeRandomComponentWeighting.resample_trajectory()RandomComponentWeighting.resample_weights()RandomComponentWeighting.compute_weights()RandomComponentWeighting.get_sampled_log_weights()RandomComponentWeighting.get_base_log_weights()RandomComponentWeighting.get_sampled_weights()RandomComponentWeighting.stats()