Skip to content

Commit 9a6d19f

Browse files
committed
It shows a user-friendly UI when generating token.
The update avoids producing too many multi-line outputs when the program generates 576 tokens, thus avoiding useful printed information being squeezed out of the buffer displayed on the screen. so, It shows a relatively friendly User Interface when generating token. Signed-off-by: manjucc <[email protected]>
1 parent 1294f5e commit 9a6d19f

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

llm/inference/janus_pro/generation.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import os
23
import PIL.Image
34
import mindspore
@@ -77,8 +78,8 @@ def generate(
7778

7879
generated_tokens = ops.zeros(parallel_size, image_token_num_per_image, dtype=ms.int32)
7980

81+
print("Generating tokens: ")
8082
for i in range(image_token_num_per_image):
81-
print(f"generating token {i}")
8283
outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
8384
hidden_states = outputs.last_hidden_state # (parallel_size*2, len(input_ids), 2048)
8485

@@ -97,7 +98,8 @@ def generate(
9798
# print("img_embeds.shape:", img_embeds.shape)
9899
# print("img_embeds.dtype:", img_embeds.dtype)
99100
inputs_embeds = img_embeds.unsqueeze(dim=1) #(parallel_size*2, 2048)
100-
print("generated one token")
101+
sys.stdout.write('.'); sys.stdout.flush()
102+
print(f"Generated {i+1} tokens.\n")
101103

102104
if image_token_num_per_image==576:
103105
dec = mmgpt.gen_vision_model.decode_code(generated_tokens.astype(ms.int32), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
@@ -121,4 +123,4 @@ def generate(
121123
vl_gpt,
122124
vl_chat_processor,
123125
prompt,
124-
)
126+
)

0 commit comments

Comments
 (0)