1
+ import os
2
+ import PIL .Image
3
+ import mindspore
4
+ import mindspore as ms
5
+ import numpy as np
6
+ from mindnlp .core import ops
7
+ from mindnlp .transformers import AutoModelForCausalLM
8
+ from janus .models import MultiModalityCausalLM , VLChatProcessor
9
+ import mindspore .context as context
10
+
11
+ from mindnlp .configs import use_pyboost , set_pyboost
12
+ set_pyboost (False )
13
+ print ('use_pyboost:' , use_pyboost ())
14
+ mindspore .set_context (
15
+ mode = mindspore .PYNATIVE_MODE ,
16
+ # max_device_memory="15GB",
17
+ pynative_synchronize = True ,
18
+ device_target = "Ascend" ,
19
+ # mode=mindspore.GRAPH_MODE,
20
+ # jit_config={"jit_level":"O2"},
21
+ ascend_config = {"precision_mode" :"allow_mix_precision" })
22
+ print (mindspore .get_context ("mode" ))
23
+ # specify the path to the model
24
+ model_path = "/home/HwHiAiUser/Janus-Pro-1B"
25
+ vl_chat_processor : VLChatProcessor = VLChatProcessor .from_pretrained (model_path )
26
+ tokenizer = vl_chat_processor .tokenizer
27
+
28
+ vl_gpt : MultiModalityCausalLM = AutoModelForCausalLM .from_pretrained (
29
+ model_path , trust_remote_code = True , ms_dtype = mindspore .float16
30
+ )
31
+ print ('loaded processor and ckpt ' )
32
+
33
+
34
+ conversation = [
35
+ {
36
+ "role" : "<|User|>" ,
37
+ "content" : "A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair" ,
38
+ # "content": "sun under blue sky",
39
+ },
40
+ {"role" : "<|Assistant|>" , "content" : "" },
41
+ ]
42
+
43
+ sft_format = vl_chat_processor .apply_sft_template_for_multi_turn_prompts (
44
+ conversations = conversation ,
45
+ sft_format = vl_chat_processor .sft_format ,
46
+ system_prompt = "" ,
47
+ )
48
+ prompt = sft_format + vl_chat_processor .image_start_tag
49
+ from mindnlp .core import no_grad
50
+
51
+ # @torch.inference_mode()
52
+ with no_grad ():
53
+ def generate (
54
+ mmgpt : MultiModalityCausalLM ,
55
+ vl_chat_processor : VLChatProcessor ,
56
+ prompt : str ,
57
+ temperature : float = 1 ,
58
+ parallel_size : int = 1 , #16,
59
+ cfg_weight : float = 5 ,
60
+ # image_token_num_per_image: int = 8,#576,
61
+ image_token_num_per_image : int = 576 ,#576,
62
+ img_size : int = 384 ,
63
+ patch_size : int = 16 ,
64
+ ):
65
+ input_ids = vl_chat_processor .tokenizer .encode (prompt )
66
+ input_ids = ms .Tensor (input_ids , dtype = ms .int64 )
67
+
68
+ tokens = ops .zeros (parallel_size * 2 , len (input_ids ), dtype = ms .int32 )
69
+ for i in range (parallel_size * 2 ):
70
+ tokens [i , :] = input_ids
71
+ if i % 2 != 0 :
72
+ tokens [i , 1 :- 1 ] = vl_chat_processor .pad_id
73
+
74
+ inputs_embeds = mmgpt .language_model .get_input_embeddings ()(tokens ) #(parallel_size*2, len(input_ids) )
75
+
76
+ generated_tokens = ops .zeros (parallel_size , image_token_num_per_image , dtype = ms .int32 )
77
+
78
+ for i in range (image_token_num_per_image ):
79
+ print (str (i )+ '=' * 60 )
80
+ outputs = mmgpt .language_model .model (inputs_embeds = inputs_embeds , use_cache = True , past_key_values = outputs .past_key_values if i != 0 else None )
81
+ hidden_states = outputs .last_hidden_state # (parallel_size*2, len(input_ids), 2048)
82
+
83
+ logits = mmgpt .gen_head (hidden_states [:, - 1 , :]) #取最后一个input_id送入gen_head=>(parallel_size*2, vocab_size)
84
+ logit_cond = logits [0 ::2 , :]
85
+ logit_uncond = logits [1 ::2 , :]
86
+
87
+ logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond )
88
+ probs = ops .softmax (logits / temperature , dim = - 1 )
89
+
90
+ next_token = ops .multinomial (probs , num_samples = 1 ) # (parallel_size, num_samples=1)
91
+ generated_tokens [:, i ] = next_token .squeeze (axis = - 1 )
92
+
93
+ next_token = ops .cat ([next_token .unsqueeze (dim = 1 ), next_token .unsqueeze (dim = 1 )], dim = 1 ).view (- 1 ) # (parallel_size*2)
94
+ img_embeds = mmgpt .prepare_gen_img_embeds (next_token ) # (parallel_size*2, 2048)
95
+ # print("img_embeds.shape:", img_embeds.shape)
96
+ # print("img_embeds.dtype:", img_embeds.dtype)
97
+ inputs_embeds = img_embeds .unsqueeze (dim = 1 ) #(parallel_size*2, 2048)
98
+
99
+ if image_token_num_per_image == 576 :
100
+ dec = mmgpt .gen_vision_model .decode_code (generated_tokens .astype (ms .int32 ), shape = [parallel_size , 8 , img_size // patch_size , img_size // patch_size ])
101
+ else :
102
+ pad_last_token = generated_tokens [:,- 1 ].unsqueeze (dim = 1 ).tile ((1 , 576 - image_token_num_per_image ))
103
+ cat_generated_tokens = ops .cat ([generated_tokens , pad_last_token ], dim = 1 )
104
+ print ("cat_generated_tokens.shape:" ,cat_generated_tokens .shape ) #(1,576)
105
+ dec = mmgpt .gen_vision_model .decode_code (cat_generated_tokens .astype (ms .int32 ), shape = [parallel_size , 8 , img_size // patch_size , img_size // patch_size ])
106
+ dec = dec .astype (ms .float32 ).asnumpy ().transpose (0 , 2 , 3 , 1 )
107
+
108
+ dec = np .clip ((dec + 1 ) / 2 * 255 , 0 , 255 )
109
+
110
+ visual_img = np .zeros ((parallel_size , img_size , img_size , 3 ), dtype = np .uint8 )
111
+ visual_img [:, :, :] = dec
112
+
113
+ os .makedirs ('generated_samples' , exist_ok = True )
114
+ for i in range (parallel_size ):
115
+ save_path = os .path .join ('generated_samples' , "img_{}.jpg" .format (i ))
116
+ PIL .Image .fromarray (visual_img [i ]).save (save_path )
117
+ generate (
118
+ vl_gpt ,
119
+ vl_chat_processor ,
120
+ prompt ,
121
+ )
0 commit comments