@@ -832,7 +832,9 @@ def get_params_dict(params: Union[str, List[str]]) -> dict:
832
832
"""
833
833
params_list = get_params_list (params ) if isinstance (params , str ) else params
834
834
return {
835
- split_result [0 ]: split_result [1 ] if len (split_result ) > 1 else UNKNOWN
835
+ split_result [0 ]: " " .join (split_result [1 :])
836
+ if len (split_result ) > 1
837
+ else UNKNOWN
836
838
for split_result in (x .split () for x in params_list )
837
839
}
838
840
@@ -881,7 +883,9 @@ def build_params_string(params: dict) -> str:
881
883
A params string.
882
884
"""
883
885
return (
884
- " " .join (f"{ name } { value } " for name , value in params .items ()).strip ()
886
+ " " .join (
887
+ f"{ name } { value } " if value else f"{ name } " for name , value in params .items ()
888
+ ).strip ()
885
889
if params
886
890
else UNKNOWN
887
891
)
@@ -1158,9 +1162,11 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
1158
1162
1159
1163
1160
1164
def build_pydantic_error_message (ex : ValidationError ):
1161
- """Added to handle error messages from pydantic model validator.
1165
+ """
1166
+ Added to handle error messages from pydantic model validator.
1162
1167
Combine both loc and msg for errors where loc (field) is present in error details, else only build error
1163
- message using msg field."""
1168
+ message using msg field.
1169
+ """
1164
1170
1165
1171
return {
1166
1172
"." .join (map (str , e ["loc" ])): e ["msg" ]
@@ -1185,67 +1191,71 @@ def is_pydantic_model(obj: object) -> bool:
1185
1191
1186
1192
@cached (cache = TTLCache (maxsize = 1 , ttl = timedelta (minutes = 5 ), timer = datetime .now ))
1187
1193
def load_gpu_shapes_index (
1188
- auth : Optional [Dict ] = None ,
1194
+ auth : Optional [Dict [ str , Any ] ] = None ,
1189
1195
) -> GPUShapesIndex :
1190
1196
"""
1191
- Loads the GPU shapes index from Object Storage or a local resource folder .
1197
+ Load the GPU shapes index, preferring the OS bucket copy over the local one .
1192
1198
1193
- The function first attempts to load the file from an Object Storage bucket using fsspec.
1194
- If the loading fails (due to connection issues, missing file, etc.), it falls back to
1195
- loading the index from a local file.
1199
+ Attempts to read `gpu_shapes_index.json` from OCI Object Storage first;
1200
+ if that succeeds, those entries will override the local defaults.
1196
1201
1197
1202
Parameters
1198
1203
----------
1199
- auth: (Dict, optional). Defaults to None.
1200
- The default authentication is set using `ads.set_auth` API. If you need to override the
1201
- default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
1202
- authentication signer and kwargs required to instantiate IdentityClient object.
1204
+ auth
1205
+ Optional auth dict (as returned by `ads.common.auth.default_signer()`)
1206
+ to pass through to `fsspec.open()`.
1203
1207
1204
1208
Returns
1205
1209
-------
1206
- GPUShapesIndex: The parsed GPU shapes index.
1210
+ GPUShapesIndex
1211
+ Merged index where any shape present remotely supersedes the local entry.
1207
1212
1208
1213
Raises
1209
1214
------
1210
- FileNotFoundError: If the GPU shapes index cannot be found in either Object Storage or locally.
1211
- json.JSONDecodeError: If the JSON is malformed.
1215
+ json.JSONDecodeError
1216
+ If any of the JSON is malformed.
1212
1217
"""
1213
1218
file_name = "gpu_shapes_index.json"
1214
- data : Dict [str , Any ] = {}
1215
1219
1216
- # Check if the CONDA_BUCKET_NS environment variable is set.
1220
+ # Try remote load
1221
+ remote_data : Dict [str , Any ] = {}
1217
1222
if CONDA_BUCKET_NS :
1218
1223
try :
1219
1224
auth = auth or authutil .default_signer ()
1220
- # Construct the object storage path. Adjust bucket name and path as needed.
1221
1225
storage_path = (
1222
1226
f"oci://{ CONDA_BUCKET_NAME } @{ CONDA_BUCKET_NS } /service_pack/{ file_name } "
1223
1227
)
1224
- logger .debug ("Loading GPU shapes index from Object Storage" )
1225
- with fsspec .open (storage_path , mode = "r" , ** auth ) as file_obj :
1226
- data = json .load (file_obj )
1227
- logger .debug ("Successfully loaded GPU shapes index." )
1228
- except Exception as ex :
1229
1228
logger .debug (
1230
- f"Failed to load GPU shapes index from Object Storage. Details: { ex } "
1231
- )
1232
-
1233
- # If loading from Object Storage failed, load from the local resource folder.
1234
- if not data :
1235
- try :
1236
- local_path = os .path .join (
1237
- os .path .dirname (__file__ ), "../resources" , file_name
1229
+ "Loading GPU shapes index from Object Storage: %s" , storage_path
1238
1230
)
1239
- logger .debug (f"Loading GPU shapes index from { local_path } ." )
1240
- with open (local_path ) as file_obj :
1241
- data = json .load (file_obj )
1242
- logger .debug ("Successfully loaded GPU shapes index." )
1243
- except Exception as e :
1231
+ with fsspec .open (storage_path , mode = "r" , ** auth ) as f :
1232
+ remote_data = json .load (f )
1244
1233
logger .debug (
1245
- f"Failed to load GPU shapes index from { local_path } . Details: { e } "
1234
+ "Loaded %d shapes from Object Storage" ,
1235
+ len (remote_data .get ("shapes" , {})),
1246
1236
)
1237
+ except Exception as ex :
1238
+ logger .debug ("Remote load failed (%s); falling back to local" , ex )
1239
+
1240
+ # Load local copy
1241
+ local_data : Dict [str , Any ] = {}
1242
+ local_path = os .path .join (os .path .dirname (__file__ ), "../resources" , file_name )
1243
+ try :
1244
+ logger .debug ("Loading GPU shapes index from local file: %s" , local_path )
1245
+ with open (local_path ) as f :
1246
+ local_data = json .load (f )
1247
+ logger .debug (
1248
+ "Loaded %d shapes from local file" , len (local_data .get ("shapes" , {}))
1249
+ )
1250
+ except Exception as ex :
1251
+ logger .debug ("Local load GPU shapes index failed (%s)" , ex )
1252
+
1253
+ # Merge: remote shapes override local
1254
+ local_shapes = local_data .get ("shapes" , {})
1255
+ remote_shapes = remote_data .get ("shapes" , {})
1256
+ merged_shapes = {** local_shapes , ** remote_shapes }
1247
1257
1248
- return GPUShapesIndex (** data )
1258
+ return GPUShapesIndex (shapes = merged_shapes )
1249
1259
1250
1260
1251
1261
def get_preferred_compatible_family (selected_families : set [str ]) -> str :
0 commit comments