From 4f36f9ab0d273446a349729044fb753451e8faca Mon Sep 17 00:00:00 2001 From: zhuwq Date: Tue, 22 Oct 2024 23:12:31 -0700 Subject: [PATCH] for quakeflow batch processing --- phasenet/data_reader.py | 64 +++++---------------- phasenet/predict.py | 119 +++++++++++++++++++--------------------- 2 files changed, 71 insertions(+), 112 deletions(-) diff --git a/phasenet/data_reader.py b/phasenet/data_reader.py index d5c55e0..f406b70 100755 --- a/phasenet/data_reader.py +++ b/phasenet/data_reader.py @@ -21,10 +21,10 @@ from tqdm import tqdm # token_json = f"{os.environ['HOME']}/.config/gcloud/application_default_credentials.json" -token_json = "application_default_credentials.json" -with open(token_json, "r") as fp: - token = json.load(fp) -fs_gs = fsspec.filesystem("gs", token=token) +# token_json = "application_default_credentials.json" +# with open(token_json, "r") as fp: +# token = json.load(fp) +# fs_gs = fsspec.filesystem("gs", token=token) # client = Client("SCEDC") client = Client("NCEDC") client_iris = Client("IRIS") ## HardCode: IRIS for response file @@ -196,7 +196,8 @@ def __init__( self.highpass_filter = highpass_filter # self.response_xml = response_xml if response_xml is not None: - self.response = obspy.read_inventory(response_xml) + # self.response = obspy.read_inventory(response_xml) + self.response = response_xml else: self.response = None self.sampling_rate = sampling_rate @@ -367,50 +368,14 @@ def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=10 stream = stream.merge(fill_value="latest") ## FIX: hard code for response file - ## NCEDC - station, network, channel = files[0].split("/")[-1].split(".")[:3] - response_xml = f"gs://quakeflow_catalog/NC/FDSNstationXML/{network}/{network}.{station}.xml" - # response_xml = ( - # f"gs://quakeflow_dataset/NC/FDSNstationXML/{network}.info/{network}.FDSN.xml/{network}.{station}.xml" - # ) - - ## SCEDC - # fname = files[0].split("/")[-1] - # network = fname[:2] - # station = fname[2:7].rstrip("_") - # instrument = fname[7:9] - # channel = fname[9] - # location = fname[10:12].rstrip("_") - # year = fname[13:17] - # jday = fname[17:20] - # response_xml = f"gs://quakeflow_catalog/SC/FDSNstationXML/{network}/{network}_{station}.xml" - - redownload = True - if fs_gs.exists(response_xml): - try: - with fs_gs.open(response_xml, "rb") as fp: - response = obspy.read_inventory(fp) - stream = stream.remove_sensitivity(response) - redownload = False - except Exception as e: - print(f"Error removing sensitivity: {e}") - else: - redownload = True - if redownload: - try: - response = client.get_stations(network=network, station=station, level="response") - except Exception as e: - print(f"Error downloading response: {e}") - print(f"Retry downloading response from IRIS...") - try: - response = client_iris.get_stations(network=network, station=station, level="response") - except Exception as e: - print(f"Error downloading response from IRIS: {e}") - raise - response.write(f"/tmp/{network}_{station}.xml", format="stationxml") - fs_gs.put(f"/tmp/{network}_{station}.xml", response_xml) - print(f"Update response file: {response_xml}") + station_id = files[0].split("/")[-1].replace(".mseed", "")[:-1] + response_xml = f"{response.rstrip('/')}/{station_id}.xml" + try: + with fsspec.open(response_xml, "rb") as fp: + response = obspy.read_inventory(fp) stream = stream.remove_sensitivity(response) + except Exception as e: + print(f"Error removing sensitivity: {e}") except Exception as e: print(f"Error reading {fname}:\n{e}") @@ -539,7 +504,7 @@ def read_mseed_3c(self, fname, response=None, highpass_filter=0.0, sampling_rate if len(station_ids) > 1: print(f"{station_ids = }") raise - assert (len(station_ids) == 1, f"Error: {fname} has multiple stations {station_ids}") + assert len(station_ids) == 1, f"Error: {fname} has multiple stations {station_ids}" begin_time = min([st.stats.starttime for st in traces]) end_time = max([st.stats.endtime for st in traces]) @@ -953,6 +918,7 @@ def __getitem__(self, i): # ) meta = self.read_mseed( base_name, + response=self.response, sampling_rate=self.sampling_rate, highpass_filter=self.highpass_filter, return_single_station=True, diff --git a/phasenet/predict.py b/phasenet/predict.py index d39e6e8..cf61d69 100755 --- a/phasenet/predict.py +++ b/phasenet/predict.py @@ -29,10 +29,10 @@ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) # token_json = f"{os.environ['HOME']}/.config/gcloud/application_default_credentials.json" -token_json = "application_default_credentials.json" -with open(token_json, "r") as fp: - token = json.load(fp) -fs_gs = fsspec.filesystem("gs", token=token) +# token_json = "application_default_credentials.json" +# with open(token_json, "r") as fp: +# token = json.load(fp) +# fs_gs = fsspec.filesystem("gs", token=token) def read_args(): @@ -151,28 +151,29 @@ def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None): dt=1.0 / args.sampling_rate, ) - picks.extend(picks_) + # picks.extend(picks_) # ## save pick per file if len(fname_batch) == 1: # ### FIX: Hard code for NCEDC and SCEDC - tmp = fname_batch[0].decode().split(",")[0].lstrip("s3://").split("/") - parant_dir = "/".join(tmp[2:-1]) # remove s3://ncedc-pds/continuous and mseed file name + tmp = fname_batch[0].decode().split(",")[0].split("/") + subdir = "/".join(tmp[-1-3:-1]) fname = tmp[-1].rstrip("\n").rstrip(".mseed").rstrip(".ms") + ".csv" - csv_name = f"quakeflow_catalog/NC/phasenet/{parant_dir}/{fname}" - # csv_name = f"quakeflow_catalog/SC/phasenet/{parant_dir}/{fname}" - if not os.path.exists(os.path.join(args.result_dir, "picks", parant_dir)): - os.makedirs(os.path.join(args.result_dir, "picks", parant_dir), exist_ok=True) + # csv_name = f"quakeflow_catalog/NC/phasenet/{subdir}/{fname}" + # csv_name = f"quakeflow_catalog/SC/phasenet/{subdir}/{fname}" + if not os.path.exists(os.path.join(args.result_dir, "picks", subdir)): + os.makedirs(os.path.join(args.result_dir, "picks", subdir), exist_ok=True) + csv_file = os.path.join(args.result_dir, "picks", subdir, fname) if len(picks_) == 0: - with fs_gs.open(csv_name, "w") as fp: + with open(csv_file, "w") as fp: fp.write("") else: df = pd.DataFrame(picks_) df = df[df["phase_index"] > 10] if len(df) == 0: - with fs_gs.open(csv_name, "w") as fp: + with open(csv_file, "w") as fp: fp.write("") else: df["phase_amplitude"] = df["phase_amplitude"].apply(lambda x: f"{x:.3e}") @@ -189,24 +190,16 @@ def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None): ] ] df.sort_values(by=["phase_time"], inplace=True) - df.to_csv( - os.path.join( - args.result_dir, - "picks", - parant_dir, - fname, - ), - index=False, - ) - fs_gs.put( - os.path.join( - args.result_dir, - "picks", - parant_dir, - fname, - ), - csv_name, - ) + df.to_csv(csv_file, index=False) + # fs_gs.put( + # os.path.join( + # args.result_dir, + # "picks", + # subdir, + # fname, + # ), + # csv_name, + # ) if args.plot_figure: if not (isinstance(fname_batch, np.ndarray) or isinstance(fname_batch, list)): @@ -230,38 +223,38 @@ def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None): fname_batch = [x.decode() for x in fname_batch] save_prob_h5(pred_batch, fname_batch, prob_h5) - if len(picks) > 0: - # save_picks(picks, args.result_dir, amps=amps, fname=args.result_fname+".csv") - # save_picks_json(picks, args.result_dir, dt=data_reader.dt, amps=amps, fname=args.result_fname+".json") - df = pd.DataFrame(picks) - # df["fname"] = df["file_name"] - # df["id"] = df["station_id"] - # df["timestamp"] = df["phase_time"] - # df["prob"] = df["phase_prob"] - # df["type"] = df["phase_type"] - - base_columns = [ - "station_id", - "begin_time", - "phase_index", - "phase_time", - "phase_score", - "phase_type", - "file_name", - ] - if args.amplitude: - base_columns.append("phase_amplitude") - base_columns.append("phase_amp") - df["phase_amp"] = df["phase_amplitude"] - - df = df[base_columns] - df.to_csv(os.path.join(args.result_dir, args.result_fname + ".csv"), index=False) - - print( - f"Done with {len(df[df['phase_type'] == 'P'])} P-picks and {len(df[df['phase_type'] == 'S'])} S-picks" - ) - else: - print(f"Done with 0 P-picks and 0 S-picks") + # if len(picks) > 0: + # # save_picks(picks, args.result_dir, amps=amps, fname=args.result_fname+".csv") + # # save_picks_json(picks, args.result_dir, dt=data_reader.dt, amps=amps, fname=args.result_fname+".json") + # df = pd.DataFrame(picks) + # # df["fname"] = df["file_name"] + # # df["id"] = df["station_id"] + # # df["timestamp"] = df["phase_time"] + # # df["prob"] = df["phase_prob"] + # # df["type"] = df["phase_type"] + + # base_columns = [ + # "station_id", + # "begin_time", + # "phase_index", + # "phase_time", + # "phase_score", + # "phase_type", + # "file_name", + # ] + # if args.amplitude: + # base_columns.append("phase_amplitude") + # base_columns.append("phase_amp") + # df["phase_amp"] = df["phase_amplitude"] + + # df = df[base_columns] + # df.to_csv(os.path.join(args.result_dir, args.result_fname + ".csv"), index=False) + + # print( + # f"Done with {len(df[df['phase_type'] == 'P'])} P-picks and {len(df[df['phase_type'] == 'S'])} S-picks" + # ) + # else: + # print(f"Done with 0 P-picks and 0 S-picks") return 0