11#! /usr/bin/env python
22"""Basic Model Interface implementation for the 2D heat model."""
33
4+ from typing import Any
5+
46import numpy as np
57from bmipy import Bmi
8+ from numpy .typing import NDArray
69
710from .heat import Heat
811
@@ -14,20 +17,21 @@ class BmiHeat(Bmi):
1417 _input_var_names = ("plate_surface__temperature" ,)
1518 _output_var_names = ("plate_surface__temperature" ,)
1619
17- def __init__ (self ):
20+ def __init__ (self ) -> None :
1821 """Create a BmiHeat model that is ready for initialization."""
19- self ._model = None
20- self ._values = {}
21- self ._var_units = {}
22- self ._var_loc = {}
23- self ._grids = {}
24- self ._grid_type = {}
22+ # self._model: Heat | None = None
23+ self ._model : Heat
24+ self ._values : dict [str , NDArray [Any ]] = {}
25+ self ._var_units : dict [str , str ] = {}
26+ self ._var_loc : dict [str , str ] = {}
27+ self ._grids : dict [int , list [str ]] = {}
28+ self ._grid_type : dict [int , str ] = {}
2529
2630 self ._start_time = 0.0
27- self ._end_time = np .finfo ("d" ).max
31+ self ._end_time = float ( np .finfo ("d" ).max )
2832 self ._time_units = "s"
2933
30- def initialize (self , filename = None ):
34+ def initialize (self , filename : str | None = None ) -> None :
3135 """Initialize the Heat model.
3236
3337 Parameters
@@ -39,7 +43,7 @@ def initialize(self, filename=None):
3943 self ._model = Heat ()
4044 elif isinstance (filename , str ):
4145 with open (filename ) as file_obj :
42- self ._model = Heat .from_file_like (file_obj . read () )
46+ self ._model = Heat .from_file_like (file_obj )
4347 else :
4448 self ._model = Heat .from_file_like (filename )
4549
@@ -49,11 +53,11 @@ def initialize(self, filename=None):
4953 self ._grids = {0 : ["plate_surface__temperature" ]}
5054 self ._grid_type = {0 : "uniform_rectilinear" }
5155
52- def update (self ):
56+ def update (self ) -> None :
5357 """Advance model by one time step."""
5458 self ._model .advance_in_time ()
5559
56- def update_frac (self , time_frac ) :
60+ def update_frac (self , time_frac : float ) -> None :
5761 """Update model by a fraction of a time step.
5862
5963 Parameters
@@ -66,7 +70,7 @@ def update_frac(self, time_frac):
6670 self .update ()
6771 self ._model .time_step = time_step
6872
69- def update_until (self , then ) :
73+ def update_until (self , then : float ) -> None :
7074 """Update model until a particular time.
7175
7276 Parameters
@@ -80,11 +84,12 @@ def update_until(self, then):
8084 self .update ()
8185 self .update_frac (n_steps - int (n_steps ))
8286
83- def finalize (self ):
87+ def finalize (self ) -> None :
8488 """Finalize model."""
85- self ._model = None
89+ del self ._model
90+ # self._model = None
8691
87- def get_var_type (self , var_name ) :
92+ def get_var_type (self , var_name : str ) -> str :
8893 """Data type of variable.
8994
9095 Parameters
@@ -99,7 +104,7 @@ def get_var_type(self, var_name):
99104 """
100105 return str (self .get_value_ptr (var_name ).dtype )
101106
102- def get_var_units (self , var_name ) :
107+ def get_var_units (self , var_name : str ) -> str :
103108 """Get units of variable.
104109
105110 Parameters
@@ -114,7 +119,7 @@ def get_var_units(self, var_name):
114119 """
115120 return self ._var_units [var_name ]
116121
117- def get_var_nbytes (self , var_name ) :
122+ def get_var_nbytes (self , var_name : str ) -> int :
118123 """Get units of variable.
119124
120125 Parameters
@@ -129,13 +134,13 @@ def get_var_nbytes(self, var_name):
129134 """
130135 return self .get_value_ptr (var_name ).nbytes
131136
132- def get_var_itemsize (self , name ) :
137+ def get_var_itemsize (self , name : str ) -> int :
133138 return np .dtype (self .get_var_type (name )).itemsize
134139
135- def get_var_location (self , name ) :
140+ def get_var_location (self , name : str ) -> str :
136141 return self ._var_loc [name ]
137142
138- def get_var_grid (self , var_name ) :
143+ def get_var_grid (self , var_name : str ) -> int | None :
139144 """Grid id for a variable.
140145
141146 Parameters
@@ -151,8 +156,9 @@ def get_var_grid(self, var_name):
151156 for grid_id , var_name_list in self ._grids .items ():
152157 if var_name in var_name_list :
153158 return grid_id
159+ return None
154160
155- def get_grid_rank (self , grid_id ) :
161+ def get_grid_rank (self , grid_id : int ) -> int :
156162 """Rank of grid.
157163
158164 Parameters
@@ -167,7 +173,7 @@ def get_grid_rank(self, grid_id):
167173 """
168174 return len (self ._model .shape )
169175
170- def get_grid_size (self , grid_id ) :
176+ def get_grid_size (self , grid_id : int ) -> int :
171177 """Size of grid.
172178
173179 Parameters
@@ -182,7 +188,7 @@ def get_grid_size(self, grid_id):
182188 """
183189 return int (np .prod (self ._model .shape ))
184190
185- def get_value_ptr (self , var_name ) :
191+ def get_value_ptr (self , var_name : str ) -> NDArray [ Any ] :
186192 """Reference to values.
187193
188194 Parameters
@@ -197,7 +203,7 @@ def get_value_ptr(self, var_name):
197203 """
198204 return self ._values [var_name ]
199205
200- def get_value (self , var_name , dest ) :
206+ def get_value (self , var_name : str , dest : NDArray [ Any ]) -> NDArray [ Any ] :
201207 """Copy of values.
202208
203209 Parameters
@@ -215,7 +221,9 @@ def get_value(self, var_name, dest):
215221 dest [:] = self .get_value_ptr (var_name ).flatten ()
216222 return dest
217223
218- def get_value_at_indices (self , var_name , dest , indices ):
224+ def get_value_at_indices (
225+ self , var_name : str , dest : NDArray [Any ], indices : NDArray [np .int_ ]
226+ ) -> NDArray [Any ]:
219227 """Get values at particular indices.
220228
221229 Parameters
@@ -235,7 +243,7 @@ def get_value_at_indices(self, var_name, dest, indices):
235243 dest [:] = self .get_value_ptr (var_name ).take (indices )
236244 return dest
237245
238- def set_value (self , var_name , src ) :
246+ def set_value (self , var_name : str , src : NDArray [ Any ]) -> None :
239247 """Set model values.
240248
241249 Parameters
@@ -248,7 +256,9 @@ def set_value(self, var_name, src):
248256 val = self .get_value_ptr (var_name )
249257 val [:] = src .reshape (val .shape )
250258
251- def set_value_at_indices (self , name , inds , src ):
259+ def set_value_at_indices (
260+ self , name : str , inds : NDArray [np .int_ ], src : NDArray [Any ]
261+ ) -> None :
252262 """Set model values at particular indices.
253263
254264 Parameters
@@ -263,76 +273,80 @@ def set_value_at_indices(self, name, inds, src):
263273 val = self .get_value_ptr (name )
264274 val .flat [inds ] = src
265275
266- def get_component_name (self ):
276+ def get_component_name (self ) -> str :
267277 """Name of the component."""
268278 return self ._name
269279
270- def get_input_item_count (self ):
280+ def get_input_item_count (self ) -> int :
271281 """Get names of input variables."""
272282 return len (self ._input_var_names )
273283
274- def get_output_item_count (self ):
284+ def get_output_item_count (self ) -> int :
275285 """Get names of output variables."""
276286 return len (self ._output_var_names )
277287
278- def get_input_var_names (self ):
288+ def get_input_var_names (self ) -> tuple [ str , ...] :
279289 """Get names of input variables."""
280290 return self ._input_var_names
281291
282- def get_output_var_names (self ):
292+ def get_output_var_names (self ) -> tuple [ str , ...] :
283293 """Get names of output variables."""
284294 return self ._output_var_names
285295
286- def get_grid_shape (self , grid_id , shape ) :
296+ def get_grid_shape (self , grid_id : int , shape : NDArray [ np . int_ ]) -> NDArray [ np . int_ ] :
287297 """Number of rows and columns of uniform rectilinear grid."""
288298 var_name = self ._grids [grid_id ][0 ]
289299 shape [:] = self .get_value_ptr (var_name ).shape
290300 return shape
291301
292- def get_grid_spacing (self , grid_id , spacing ):
302+ def get_grid_spacing (
303+ self , grid_id : int , spacing : NDArray [np .float64 ]
304+ ) -> NDArray [np .float64 ]:
293305 """Spacing of rows and columns of uniform rectilinear grid."""
294306 spacing [:] = self ._model .spacing
295307 return spacing
296308
297- def get_grid_origin (self , grid_id , origin ):
309+ def get_grid_origin (
310+ self , grid_id : int , origin : NDArray [np .float64 ]
311+ ) -> NDArray [np .float64 ]:
298312 """Origin of uniform rectilinear grid."""
299313 origin [:] = self ._model .origin
300314 return origin
301315
302- def get_grid_type (self , grid_id ) :
316+ def get_grid_type (self , grid_id : int ) -> str :
303317 """Type of grid."""
304318 return self ._grid_type [grid_id ]
305319
306- def get_start_time (self ):
320+ def get_start_time (self ) -> float :
307321 """Start time of model."""
308322 return self ._start_time
309323
310- def get_end_time (self ):
324+ def get_end_time (self ) -> float :
311325 """End time of model."""
312326 return self ._end_time
313327
314- def get_current_time (self ):
328+ def get_current_time (self ) -> float :
315329 return self ._model .time
316330
317- def get_time_step (self ):
331+ def get_time_step (self ) -> float :
318332 return self ._model .time_step
319333
320- def get_time_units (self ):
334+ def get_time_units (self ) -> str :
321335 return self ._time_units
322336
323- def get_grid_edge_count (self , grid ) :
337+ def get_grid_edge_count (self , grid : int ) -> int :
324338 raise NotImplementedError ("get_grid_edge_count" )
325339
326- def get_grid_edge_nodes (self , grid , edge_nodes ) :
340+ def get_grid_edge_nodes (self , grid : int , edge_nodes : NDArray [ np . int_ ]) -> None :
327341 raise NotImplementedError ("get_grid_edge_nodes" )
328342
329- def get_grid_face_count (self , grid ) :
343+ def get_grid_face_count (self , grid : int ) -> None :
330344 raise NotImplementedError ("get_grid_face_count" )
331345
332- def get_grid_face_nodes (self , grid , face_nodes ) :
346+ def get_grid_face_nodes (self , grid : int , face_nodes : NDArray [ np . int_ ]) -> None :
333347 raise NotImplementedError ("get_grid_face_nodes" )
334348
335- def get_grid_node_count (self , grid ) :
349+ def get_grid_node_count (self , grid : int ) -> int :
336350 """Number of grid nodes.
337351
338352 Parameters
@@ -347,17 +361,19 @@ def get_grid_node_count(self, grid):
347361 """
348362 return self .get_grid_size (grid )
349363
350- def get_grid_nodes_per_face (self , grid , nodes_per_face ):
364+ def get_grid_nodes_per_face (
365+ self , grid : int , nodes_per_face : NDArray [np .int_ ]
366+ ) -> None :
351367 raise NotImplementedError ("get_grid_nodes_per_face" )
352368
353- def get_grid_face_edges (self , grid , face_edges ) :
369+ def get_grid_face_edges (self , grid : int , face_edges : NDArray [ np . int_ ]) -> None :
354370 raise NotImplementedError ("get_grid_face_edges" )
355371
356- def get_grid_x (self , grid , x ) :
372+ def get_grid_x (self , grid : int , x : NDArray [ np . float64 ]) -> None :
357373 raise NotImplementedError ("get_grid_x" )
358374
359- def get_grid_y (self , grid , y ) :
375+ def get_grid_y (self , grid : int , y : NDArray [ np . float64 ]) -> None :
360376 raise NotImplementedError ("get_grid_y" )
361377
362- def get_grid_z (self , grid , z ) :
378+ def get_grid_z (self , grid : int , z : NDArray [ np . float64 ]) -> None :
363379 raise NotImplementedError ("get_grid_z" )
0 commit comments