@@ -135,35 +135,49 @@ def __init__(
135135 lon .size == kwargs [kwvar ].size
136136 ), f"{ kwvar } and positions (lon, lat, depth) don't have the same lengths."
137137
138- self ._data = {
139- "lon" : lon .astype (lonlatdepth_dtype ),
140- "lat" : lat .astype (lonlatdepth_dtype ),
141- "depth" : depth .astype (lonlatdepth_dtype ),
142- "time" : time ,
143- "dt" : np .timedelta64 (1 , "ns" ) * np .ones (len (trajectory_ids )),
144- # "ei": (["trajectory", "ngrid"], np.zeros((len(trajectory_ids), len(fieldset.gridset)), dtype=np.int32)),
145- "state" : np .zeros ((len (trajectory_ids )), dtype = np .int32 ),
146- "lon_nextloop" : lon .astype (lonlatdepth_dtype ),
147- "lat_nextloop" : lat .astype (lonlatdepth_dtype ),
148- "depth_nextloop" : depth .astype (lonlatdepth_dtype ),
149- "time_nextloop" : time ,
150- "trajectory" : trajectory_ids ,
151- }
152- self ._ptype = pclass .getPType ()
138+ self ._ds = xr .Dataset (
139+ {
140+ "lon" : (["trajectory" ], lon .astype (lonlatdepth_dtype )),
141+ "lat" : (["trajectory" ], lat .astype (lonlatdepth_dtype )),
142+ "depth" : (["trajectory" ], depth .astype (lonlatdepth_dtype )),
143+ "time" : (["trajectory" ], time ),
144+ "dt" : (["trajectory" ], np .timedelta64 (1 , "ns" ) * np .ones (len (trajectory_ids ))),
145+ "ei" : (["trajectory" , "ngrid" ], np .zeros ((len (trajectory_ids ), len (fieldset .gridset )), dtype = np .int32 )),
146+ "state" : (["trajectory" ], np .zeros ((len (trajectory_ids )), dtype = np .int32 )),
147+ "lon_nextloop" : (["trajectory" ], lon .astype (lonlatdepth_dtype )),
148+ "lat_nextloop" : (["trajectory" ], lat .astype (lonlatdepth_dtype )),
149+ "depth_nextloop" : (["trajectory" ], depth .astype (lonlatdepth_dtype )),
150+ "time_nextloop" : (["trajectory" ], time ),
151+ },
152+ coords = {
153+ "trajectory" : ("trajectory" , trajectory_ids ),
154+ },
155+ attrs = {
156+ "ngrid" : len (fieldset .gridset ),
157+ "ptype" : pclass .getPType (),
158+ },
159+ )
153160 # add extra fields from the custom Particle class
154161 for v in pclass .__dict__ .values ():
155162 if isinstance (v , Variable ):
156163 if isinstance (v .initial , attrgetter ):
157- initial = v .initial (self )
164+ initial = v .initial (self ). values
158165 else :
159166 initial = v .initial * np .ones (len (trajectory_ids ), dtype = v .dtype )
160- self ._data [v .name ] = initial
167+ self ._ds [v .name ] = ([ "trajectory" ], initial )
161168
162169 # update initial values provided on ParticleSet creation
163170 for kwvar , kwval in kwargs .items ():
164171 if not hasattr (pclass , kwvar ):
165172 raise RuntimeError (f"Particle class does not have Variable { kwvar } " )
166- self ._data [kwvar ][:] = kwval
173+ self ._ds [kwvar ][:] = kwval
174+
175+ # also keep a struct of numpy arrays for faster access (see parcels-benchmarks/pull/1)
176+ self ._data = {}
177+ for v in self ._ds .keys ():
178+ self ._data [v ] = self ._ds [v ].data
179+ self ._data ["trajectory" ] = self ._ds ["trajectory" ].data
180+ self ._ptype = self ._ds .attrs ["ptype" ]
167181
168182 self ._kernel = None
169183
0 commit comments