"""JSON serialization helpers for torch tensors and numpy arrays."""
import torch
[docs]
def convert_to_serializable(obj):
"""Convert tensors and numpy arrays to JSON-serializable types.
Recursively walks dicts, lists, and tuples, converting ``torch.Tensor``,
``numpy.ndarray``, and numpy scalar types to plain Python objects that
``json.dump`` can handle.
Parameters
----------
obj : object
Arbitrary Python object (tensor, array, dict, list, scalar, ...).
Returns
-------
object
A JSON-serializable equivalent.
"""
if isinstance(obj, torch.Tensor):
return obj.tolist() if obj.numel() > 1 else obj.item()
try:
import numpy as np
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.integer):
return int(obj)
except ImportError:
pass
if isinstance(obj, dict):
return {k: convert_to_serializable(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return type(obj)(convert_to_serializable(v) for v in obj)
return obj