5252 # Testing only
5353 '0x0056' ,
5454 '0x0062' ,
55+ # TPU 7x
56+ '0x0076'
5557]
5658
5759
@@ -188,7 +190,10 @@ def version() -> int:
188190 except requests .HTTPError as e :
189191 raise EnvironmentError ('Failed to get TPU metadata' ) from e
190192
191- match = re .match (r'^v(\d)([A-Za-z]?){7}-(\d+)$' , env [xenv .ACCELERATOR_TYPE ])
193+ match = re .match (r'^(?:v|tpu)(\d)([A-Za-z]?){7}-(\d+)$' ,
194+ env [xenv .ACCELERATOR_TYPE ])
195+ if not match :
196+ raise EnvironmentError ('Failed to parse TPU version from metadata' )
192197 return int (match .groups ()[0 ])
193198
194199
@@ -254,7 +259,8 @@ def configure_topology(local_rank: int,
254259 tpu_env = get_tpu_env ()
255260
256261 accelerator_type = tpu_env [xenv .ACCELERATOR_TYPE ]
257- if version () >= 4 :
262+ tpu_version = version ()
263+ if tpu_version >= 4 :
258264 # Process bounds with 4 chips per process
259265 default_process_bounds = MeshShape .from_string (
260266 tpu_env [xenv .TPU_PROCESS_BOUNDS ])
@@ -270,8 +276,11 @@ def configure_topology(local_rank: int,
270276 process_bounds = default_process_bounds * chips_per_process
271277
272278 os .environ .setdefault (xenv .TPU_CHIPS_PER_PROCESS_BOUNDS , '1,1,1' )
273- os .environ .setdefault (xenv .TPU_PROCESS_BOUNDS ,
274- ',' .join (str (dim ) for dim in process_bounds ))
279+ process_bounds_str = ',' .join (str (dim ) for dim in process_bounds )
280+ if tpu_version == 7 :
281+ process_bounds_str += ',2'
282+
283+ os .environ .setdefault (xenv .TPU_PROCESS_BOUNDS , process_bounds_str )
275284
276285 # Assume each TPU has the same number of local processes with the same ports
277286 worker_id = int (tpu_env [xenv .WORKER_ID ])
0 commit comments