Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Particles splits script #69

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import os
import re
import warnings
os.environ["OPENBLAS_NUM_THREADS"] = "1"

import h5py
Expand Down Expand Up @@ -68,7 +69,7 @@ def load_hbt_config(config_path):

return config

def load_overflow_data(path_to_split_log_files):
def load_overflow_data(path_to_split_log_files, tree_size):
'''
Loads particle splitting information for particles which
split enough times to overflow the SplitTrees dataset.
Expand All @@ -90,11 +91,17 @@ def load_overflow_data(path_to_split_log_files):

overflow_data = {}
for filename in os.listdir(path_to_split_log_files):
if not re.match(r'^splits_\d{4}\.hdf5', filename):
if not re.match(r'^splits_\d{4}\.txt', filename):
continue
file_data = np.loadtxt(f'{path_to_split_log_files}/{filename}', dtype=np.int64)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
file_data = np.loadtxt(f'{path_to_split_log_files}/{filename}', dtype=np.int64)
for row in file_data:
_, new_prog_id, old_prog_id, count, tree = row
# Handle negative numbers
tree = int(tree)
if tree < 0:
tree = (1 << tree_size) + tree
overflow_data[(count, new_prog_id)] = {
'progenitor_id': old_prog_id,
'tree': tree,
Expand All @@ -104,7 +111,7 @@ def load_overflow_data(path_to_split_log_files):

def generate_path_to_snapshot(config, snapshot_index):
'''
Returns the path to the virtual file of a snapshot to analyse.
Returns the path of the snapshots files to analyse.

Parameters
----------
Expand All @@ -118,7 +125,7 @@ def generate_path_to_snapshot(config, snapshot_index):
Returns
-------
str
Path to the virtual file of the snapshot.
Path to the snapshot files.
'''
if 'SnapshotDirBase' in config:
subdirectory = f"{config['SnapshotDirBase']}_{config['SnapshotIdList'][snapshot_index]:04d}"
Expand Down Expand Up @@ -245,14 +252,15 @@ def get_corrected_split_trees(split_data, overflow_data):
# Create copy of arrays to be updated
progenitor_ids = split_data["progenitor_ids"].copy()
# Set datatype as object so we can have arbitrarily large values
tree_size = split_data["trees"].itemsize * 8
trees = split_data["trees"].astype('object')
trees[trees < 0] = (1 << tree_size) + trees[trees < 0]

# Return if the rank has no data
if (split_data["counts"].shape[0] == 0) or (overflow_data is None):
return progenitor_ids, trees

# Calculate the maximum number of times a particle has overflowed
tree_size = split_data["trees"].itemsize * 8
n_overflow = np.max((split_data["counts"] - 1) // tree_size)

while n_overflow > 0:
Expand Down Expand Up @@ -300,7 +308,9 @@ def update_overflow_split_trees(split_data, overflow_data=None):
split_data['trees'] = split_data['trees'].astype('object')
return split_data
print(f'{np.sum(invalid_trees)} particles with invalid trees skipped on rank {comm_rank}')
split_data['trees'] = split_data['trees'][~invalid_trees].astype('object')
trees = split_data['trees'][~invalid_trees].astype('object')
trees[trees < 0] = (1 << tree_size) + trees[trees < 0]
split_data['trees'] = trees
split_data['counts'] = split_data['counts'][~invalid_trees]
split_data['progenitor_ids'] = split_data['progenitor_ids'][~invalid_trees]
split_data['particle_ids'] = split_data['particle_ids'][~invalid_trees]
Expand All @@ -309,6 +319,8 @@ def update_overflow_split_trees(split_data, overflow_data=None):
split_data['trees'] = trees
split_data['progenitor_ids'] = progenitor_ids

assert np.min(split_data['trees']) >= 0

return split_data

def get_splits_of_existing_tree(progenitor_particle_ids, progenitor_split_trees, progenitor_split_counts, descendant_particle_ids, descendant_split_trees):
Expand Down Expand Up @@ -403,26 +415,10 @@ def get_descendant_particle_ids(old_snapshot_data, new_snapshot_data):

# If we have a new tree, all new particle IDs have as their progenitor the
# particle ID that originated this unique tree.
# NOTE: Disabled because SWIFT runs had incorrect ParticleProgenitorIDs
# progenitor_id_old = tree_progenitor_ID
progenitor_id = tree_progenitor_ID

# Remove the progenitor particle from the list of IDs to prevent infinite loop
new_ids = new_snapshot_data['particle_ids'][tree_index]

# We get the ID that has all 0s in its split tree (it retained the ID of
# the original particle)
progenitor_id = new_ids[new_snapshot_data["trees"][tree_index] == 0]

# We should have either 1 or 0 progenitor ids.
assert(len(progenitor_id) < 2)

# Print out a warning if no progenitor ID was found... We cannot do much more
# than that.
if len(progenitor_id) == 0:
local_no_progenitors_found += 1
continue

progenitor_id = progenitor_id[0]

new_ids = new_ids[new_ids != progenitor_id]

# We could encounter cases where a particle has split and its descendants
Expand All @@ -441,10 +437,6 @@ def get_descendant_particle_ids(old_snapshot_data, new_snapshot_data):
old_snapshot_data['counts'][tree_index_old],
new_snapshot_data['particle_ids'][tree_index],
new_snapshot_data['trees'][tree_index]))
global_no_progenitors_found = comm.allreduce(local_no_progenitors_found)
if global_no_progenitors_found > 0:
if comm_rank == 0:
print(f"We could not find progenitors for {global_no_progenitors_found} new split trees.")

return new_splits

Expand Down Expand Up @@ -644,7 +636,10 @@ def generate_split_file(path_to_config, snapshot_index, path_to_split_log_files)
#==========================================================================
overflow_data = None
if comm_rank == 0:
overflow_data = load_overflow_data(path_to_split_log_files)
ref_snapshot_path = generate_path_to_snapshot(config, snapshot_index)
with h5py.File(ref_snapshot_path.format(file_nr=0), 'r') as file:
tree_size = file['PartType0/SplitTrees'].dtype.itemsize * 8
overflow_data = load_overflow_data(path_to_split_log_files, tree_size)
overflow_data = comm.bcast(overflow_data, root=0)

#==========================================================================
Expand Down Expand Up @@ -725,6 +720,9 @@ def generate_split_file(path_to_config, snapshot_index, path_to_split_log_files)

if __name__ == "__main__":

if comm_rank == 0:
print(f"Running generate_particle_splitting_information.py with {comm_size} ranks")

from virgo.mpi.util import MPIArgumentParser

parser = MPIArgumentParser(comm, description="Generate HDF5 files that contain information about which particles split between consecutively outputs analysed by HBT+.")
Expand Down