1515
1616logger = get_logger (__name__ )
1717
18- DTYPE_MAP = dict (bfloat16 = torch .bfloat16 , float16 = torch .float16 , float32 = torch .float32 , auto = "auto" )
19-
18+ DTYPE_MAP = dict (bfloat16 = torch .bfloat16 , float16 = torch .float16 , float32 = torch .float32 , int8 = torch .int8 , auto = "auto" )
2019
2120def main ():
2221 parser = argparse .ArgumentParser (description = "Load bloom layers and convert to 8-bit using torch quantization." )
2322
2423 parser .add_argument ("--model" , type = str , default = "bigscience/bloom-6b3" , help = "Model name for from_pretrained" )
2524 parser .add_argument ("--revision" , type = str , default = None , help = "Optional commit id from HF hub" )
26- parser .add_argument ("--torch_dtype" , type = str , default = "auto" , help = "Load initial model in this dtype" )
27- parser .add_argument ("--output_path" , type = str , default = "./converted_model" , help = "Track output repo to this folder" )
25+ parser .add_argument ("--torch_dtype" , type = str , choices = DTYPE_MAP .keys (), default = "auto" ,
26+ help = "Load initial model in this dtype" )
27+ parser .add_argument ("--output_path" , type = str , default = "./converted_model" ,
28+ help = "Track output repo to this folder" )
2829 parser .add_argument ("--output_repo" , type = str , default = "bigscience/test-bloomd" , help = "Push to this HF hub repo" )
2930 parser .add_argument ("--client_branch" , type = str , default = CLIENT_BRANCH , help = "Save client version to this branch" )
3031 parser .add_argument (
@@ -41,7 +42,6 @@ def main():
4142 if args .model == "bigscience/bloom" and free_ram_gb < 400 :
4243 logger .warning (f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have { free_ram_gb :.3f} free" )
4344
44- assert args .torch_dtype in DTYPE_MAP , f"torch_dtype must be one of { list (DTYPE_MAP .keys ())} "
4545 if os .path .exists (args .output_path ) and (
4646 len (os .listdir (args .output_path )) != 0 or not os .path .isdir (args .output_path )
4747 ):
@@ -54,8 +54,15 @@ def main():
5454 config .dht_prefix = args .output_repo
5555
5656 model = BloomModel .from_pretrained (
57- args .model , use_auth_token = args .use_auth_token , revision = args .revision , torch_dtype = DTYPE_MAP [args .torch_dtype ]
57+ args .model , use_auth_token = args .use_auth_token , revision = args .revision ,
58+ torch_dtype = DTYPE_MAP [args .torch_dtype ] if args .torch_dtype != "int8" else "float16" ,
59+ load_in_8bit = args .torch_dtype == "int8" ,
60+ device_map = {"word_embeddings" : "cuda" , "word_embeddings_layernorm" : "cuda" , "h" : "cuda" , "ln_f" : "cuda" }
5861 )
62+ if args .torch_dtype == "int8" :
63+ # trigger weight quantization
64+ model = model .cuda ()
65+
5966 if args .resize_token_embeddings :
6067 logger .info (f"Resizing token embeddings, new size = { args .resize_token_embeddings } " )
6168 model .resize_token_embeddings (args .resize_token_embeddings )
0 commit comments