Source code for scri.SpEC.file_io.corotating_paired_xor

# Copyright (c) 2020, Michael Boyle
# See LICENSE file for details: <https://github.com/moble/spherical_functions/blob/master/LICENSE>

import warnings
import os
import numpy as np
import quaternion
import spherical_functions as sf
from ... import WaveformModes, DataType, DataNames

sxs_formats = [
    "corotating_paired_xor",
]


[docs] def save(w, file_name=None, L2norm_fractional_tolerance=1e-10, log_frame=None, compress=True): import tempfile import contextlib import pathlib import json import numpy import scipy import h5py import sxs import scri from ...utilities import xor_timeseries, fletcher32 # Make sure that we can understand the file_name and create the directory if file_name is None: # We'll just be creating a temp directory below, to check warnings.warn( "Input `file_name` is None. Running in temporary directory.\n" "Note that this option is mostly for debugging purposes." ) else: h5_path = pathlib.Path(file_name).expanduser().resolve().with_suffix(".h5") if not h5_path.parent.exists(): h5_path.parent.mkdir(parents=True) compression_options = ( {"compression": "gzip", "compression_opts": 9, "shuffle": True, "fletcher32": True,} if compress else {} ) if L2norm_fractional_tolerance == 0.0: log_frame = quaternion.as_float_array(np.log(w.frame))[:, 1:] else: # We need this storage anyway, so let's just make a copy and work in-place w = w.copy() if log_frame is not None: log_frame = log_frame.copy() # Ensure waveform is in corotating frame if w.frameType == scri.Inertial: try: initial_time = w.t[0] relaxation_time = w.metadata.relaxation_time max_norm_time = w.max_norm_time() z_alignment_region = ((relaxation_time - initial_time) / (max_norm_time - initial_time), 0.95) except: z_alignment_region = (0.1, 0.95) w, log_frame = w.to_corotating_frame( tolerance=1e-10, z_alignment_region=z_alignment_region, truncate_log_frame=True ) log_frame = log_frame[:, 1:] if w.frameType != scri.Corotating: raise ValueError( "Frame type of input waveform must be 'Corotating' or 'Inertial'; " f"it is {w.frame_type_string}" ) # Convert mode structure to conjugate pairs w.convert_to_conjugate_pairs() # Set bits below the desired significance level to 0 w.truncate(tol=L2norm_fractional_tolerance) # Compute log(frame) if log_frame is None: log_frame = quaternion.as_float_array(np.log(w.frame))[:, 1:] power_of_2 = 2 ** (-np.floor(np.log2(L2norm_fractional_tolerance / 10))).astype("int") log_frame = np.round(log_frame * power_of_2) / power_of_2 # Change -0.0 to 0.0 (~.5% compression for non-precessing systems) w.t += 0.0 w.data += 0.0 log_frame += 0.0 # XOR successive instants in time xor_timeseries(w.t) xor_timeseries(w.data) xor_timeseries(log_frame) # Make sure we have a place to keep all this with contextlib.ExitStack() as context: if file_name is None: temp_dir = context.enter_context(tempfile.TemporaryDirectory()) h5_path = pathlib.Path(f"{temp_dir}") / "test.h5" else: print(f'Saving H5 to "{h5_path}"') # Write the H5 file with h5py.File(h5_path, "w") as f: f.attrs["sxs_format"] = "corotating_paired_xor" warnings.warn('sxs_format is being set to "corotating_paired_xor"') f.create_dataset("time", data=w.t.view(np.uint64), chunks=(w.n_times,), **compression_options) f.create_dataset("modes", data=w.data.view(np.uint64), chunks=(w.n_times, 1), **compression_options) f["modes"].attrs["ell_min"] = w.ell_min f["modes"].attrs["ell_max"] = w.ell_max if log_frame.size > 1: f.create_dataset( "log_frame", data=log_frame.view(np.uint64), chunks=(w.n_times, 1), **compression_options ) # Get some numbers for the JSON file h5_size = os.stat(h5_path).st_size if file_name is None: print(f"Output H5 file size: {h5_size:_} B") fletcher32_dict = {} with h5py.File(h5_path, "r") as f: fletcher32_dict["time"] = fletcher32(f["time"][:]) fletcher32_dict["modes"] = fletcher32(f["modes"][:]) if "log_frame" in f: fletcher32_dict["log_frame"] = fletcher32(f["log_frame"][:]) # Write the corresponding JSON file json_path = h5_path.with_suffix(".json") json_data = { "sxs_format": "corotating_paired_xor", "data_info": { "data_type": w.data_type_string, "spin_weight": int(w.spin_weight), "ell_min": int(w.ell_min), "ell_max": int(w.ell_max), }, "transformations": { "truncation": L2norm_fractional_tolerance, # see below for 'boost_velocity' # see below for 'space_translation' }, "version_info": { "numpy": numpy.__version__, "scipy": scipy.__version__, "h5py": h5py.__version__, "quaternion": quaternion.__version__, "spherical_functions": sf.__version__, "scri": scri.__version__, "sxs": sxs.__version__, # see below 'spec_version_hist' }, "validation": {"h5_file_size": h5_size, "n_times": w.n_times, "fletcher32": fletcher32_dict,}, } if hasattr(w, "boost_velocity"): json_data["transformations"]["boost_velocity"] = w.boost_velocity.tolist() if hasattr(w, "space_translation"): json_data["transformations"]["space_translation"] = w.space_translation.tolist() if hasattr(w, "version_hist"): json_data["version_info"]["spec_version_history"] = w.version_hist if file_name is not None: print(f'Saving JSON to "{json_path}"') with json_path.open("w") as f: json.dump(json_data, f, indent=2, separators=(",", ": "), ensure_ascii=True) return w
[docs] def load(file_name, ignore_validation=False): import pathlib import json import h5py import scri from ...utilities import xor_timeseries_reverse, fletcher32 def invalid(message): if ignore_validation: warnings.warn(message) else: raise ValueError(message) h5_path = pathlib.Path(file_name).expanduser().resolve().with_suffix(".h5") json_path = h5_path.with_suffix(".json") # This will be used for validation h5_size = os.stat(h5_path).st_size if not json_path.exists(): invalid(f'JSON file "{json_path}" cannot be found, but is expected for this data format.') json_data = {} else: with open(json_path) as f: json_data = json.load(f) dataType = json_data.get("data_info", {}).get("data_type", "UnknownDataType") dataType = scri.DataType[scri.DataNames.index(dataType)] # Make sure this is our format sxs_format = json_data.get("sxs_format", "") if sxs_format not in sxs_formats: invalid( f'The `sxs_format` found in JSON file is "{sxs_format}"; it should be one of\n' f" {sxs_formats}." ) json_h5_file_size = json_data.get("validation", {}).get("h5_file_size", 0) if json_h5_file_size != h5_size: invalid( f"Mismatch between `validation/h5_file_size` key in JSON file ({json_h5_file_size}) " f'and observed file size ({h5_size}) of "{h5_path}".' ) with h5py.File(h5_path, "r") as f: # Make sure this is our format sxs_format = f.attrs["sxs_format"] if sxs_format not in sxs_formats: raise ValueError( f'The `sxs_format` found in H5 file is "{sxs_format}"; it should be one of\n' f" {sxs_formats}." ) # Ensure that the 'validation' keys from the JSON file are the same as in this file json_n_times = json_data.get("validation", {}).get("n_times", 0) n_times = f["time"][:].view(np.float64).size if json_n_times != n_times: invalid( f"Number of time steps in H5 file ({n_times}) " f"does not match expected value from JSON ({json_n_times})." ) for dataset, checksum in json_data.get("validation", {}).get("fletcher32", {}).items(): observed_checksum = fletcher32(f[dataset][:]) if checksum != observed_checksum: invalid(f'Checksum of "{dataset}" dataset does not match expected value from JSON.') # Read the data time = f["time"][:].view(np.float64) modes = f["modes"][:].view(complex) ell_min = f["modes"].attrs["ell_min"] ell_max = f["modes"].attrs["ell_max"] if "log_frame" in f: log_frame = f["log_frame"][:].view(np.float64) else: log_frame = np.empty((0, 3), dtype=np.float64) xor_timeseries_reverse(time) xor_timeseries_reverse(modes) xor_timeseries_reverse(log_frame) frame = np.exp(quaternion.as_quat_array(np.insert(log_frame, 0, 0.0, axis=1))) w = WaveformModes( t=time, frame=frame, data=modes, frameType=scri.Corotating, dataType=dataType, m_is_scaled_out=True, r_is_scaled_out=True, ell_min=ell_min, ell_max=ell_max, ) w.convert_from_conjugate_pairs() w.json_data = json_data return w