@@ -344,6 +344,55 @@ def __init__(
344344 format , config = quantization_config
345345 )
346346
347+ def get_missing_module_keys (self , model : Module ) -> List [str ]:
348+ """
349+ Identifies the expected missing weight keys in the compressed state_dict.
350+
351+ When a model undergoes sparsity or quantization compression, certain
352+ weight tensors may be absent from the checkpoint by virtue of compression.
353+ This function determines which weight keys are missing based on the
354+ applied compression techniques.
355+
356+ :param model: The PyTorch model to check for missing keys.
357+ :return: A list of missing keys expected in the compressed state_dict.
358+ """
359+ missing_keys = set ()
360+
361+ # Determine missing keys due to sparsity compression
362+ if (
363+ self .sparsity_compressor
364+ and self .sparsity_config .format != CompressionFormat .dense .value
365+ ):
366+ sparse_targets = match_named_modules (
367+ model = model ,
368+ targets = self .sparsity_config .targets ,
369+ ignore = self .sparsity_config .ignore ,
370+ )
371+
372+ missing_keys .update (
373+ merge_names (target_name , "weight" )
374+ for target_name , _module in sparse_targets
375+ )
376+
377+ # Determine missing keys due to pack quantization
378+ if (
379+ self .quantization_compressor
380+ and self .quantization_config .format
381+ == CompressionFormat .pack_quantized .value
382+ ):
383+ for scheme in self .quantization_config .config_groups .values ():
384+ quant_targets = match_named_modules (
385+ model = model ,
386+ targets = scheme .targets ,
387+ ignore = self .quantization_config .ignore ,
388+ )
389+ missing_keys .update (
390+ merge_names (target_name , "weight" )
391+ for target_name , _module in quant_targets
392+ )
393+
394+ return list (missing_keys )
395+
347396 def get_unexpected_file_keys (self , model : Module ) -> List [str ]:
348397 """
349398 Identifies extra keys introduced by the compression process in the
0 commit comments