Skip to content

Commit ab0c765

Browse files
authored
Add Janus Pro Model (#1952)
1 parent 91ddff1 commit ab0c765

24 files changed

+4179
-89
lines changed

llm/inference/janus_pro/.gitignore

+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
share/python-wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
MANIFEST
28+
29+
# PyInstaller
30+
# Usually these files are written by a python script from a template
31+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
32+
*.manifest
33+
*.spec
34+
35+
# Installer logs
36+
pip-log.txt
37+
pip-delete-this-directory.txt
38+
39+
# Unit test / coverage reports
40+
htmlcov/
41+
.tox/
42+
.nox/
43+
.coverage
44+
.coverage.*
45+
.cache
46+
nosetests.xml
47+
coverage.xml
48+
*.cover
49+
*.py,cover
50+
.hypothesis/
51+
.pytest_cache/
52+
cover/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
.pybuilder/
76+
target/
77+
78+
# Jupyter Notebook
79+
.ipynb_checkpoints
80+
81+
# IPython
82+
profile_default/
83+
ipython_config.py
84+
85+
# pyenv
86+
# For a library or package, you might want to ignore these files since the code is
87+
# intended to run in multiple environments; otherwise, check them in:
88+
# .python-version
89+
90+
# pipenv
91+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
93+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
94+
# install all needed dependencies.
95+
#Pipfile.lock
96+
97+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
98+
__pypackages__/
99+
100+
# Celery stuff
101+
celerybeat-schedule
102+
celerybeat.pid
103+
104+
# SageMath parsed files
105+
*.sage.py
106+
107+
# Environments
108+
.env
109+
.venv
110+
env/
111+
venv/
112+
ENV/
113+
env.bak/
114+
venv.bak/
115+
116+
# Spyder project settings
117+
.spyderproject
118+
.spyproject
119+
120+
# Rope project settings
121+
.ropeproject
122+
123+
# mkdocs documentation
124+
/site
125+
*__pycache__*
126+
*kernel_meta*
127+
# mypy
128+
.mypy_cache/
129+
.dmypy.json
130+
dmypy.json
131+
132+
# Pyre type checker
133+
.pyre/
134+
135+
# pytype static type analyzer
136+
.pytype/
137+
138+
# Cython debug symbols
139+
cython_debug/
Loading

llm/inference/janus_pro/generation.py

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import os
2+
import PIL.Image
3+
import mindspore
4+
import mindspore as ms
5+
import numpy as np
6+
from mindnlp.core import ops
7+
from mindnlp.transformers import AutoModelForCausalLM
8+
from janus.models import MultiModalityCausalLM, VLChatProcessor
9+
import mindspore.context as context
10+
11+
from mindnlp.configs import use_pyboost, set_pyboost
12+
set_pyboost(False)
13+
print('use_pyboost:', use_pyboost())
14+
mindspore.set_context(
15+
mode=mindspore.PYNATIVE_MODE,
16+
# max_device_memory="15GB",
17+
pynative_synchronize=True,
18+
device_target="Ascend",
19+
# mode=mindspore.GRAPH_MODE,
20+
# jit_config={"jit_level":"O2"},
21+
ascend_config={"precision_mode":"allow_mix_precision"})
22+
print(mindspore.get_context("mode"))
23+
# specify the path to the model
24+
model_path = "/home/HwHiAiUser/Janus-Pro-1B"
25+
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
26+
tokenizer = vl_chat_processor.tokenizer
27+
28+
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
29+
model_path, trust_remote_code=True, ms_dtype=mindspore.float16
30+
)
31+
print('loaded processor and ckpt ')
32+
33+
34+
conversation = [
35+
{
36+
"role": "<|User|>",
37+
"content": "A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair",
38+
# "content": "sun under blue sky",
39+
},
40+
{"role": "<|Assistant|>", "content": ""},
41+
]
42+
43+
sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
44+
conversations=conversation,
45+
sft_format=vl_chat_processor.sft_format,
46+
system_prompt="",
47+
)
48+
prompt = sft_format + vl_chat_processor.image_start_tag
49+
from mindnlp.core import no_grad
50+
51+
# @torch.inference_mode()
52+
with no_grad():
53+
def generate(
54+
mmgpt: MultiModalityCausalLM,
55+
vl_chat_processor: VLChatProcessor,
56+
prompt: str,
57+
temperature: float = 1,
58+
parallel_size: int = 1, #16,
59+
cfg_weight: float = 5,
60+
# image_token_num_per_image: int = 8,#576,
61+
image_token_num_per_image: int = 576,#576,
62+
img_size: int = 384,
63+
patch_size: int = 16,
64+
):
65+
input_ids = vl_chat_processor.tokenizer.encode(prompt)
66+
input_ids = ms.Tensor(input_ids, dtype=ms.int64)
67+
68+
tokens = ops.zeros(parallel_size*2, len(input_ids), dtype=ms.int32)
69+
for i in range(parallel_size*2):
70+
tokens[i, :] = input_ids
71+
if i % 2 != 0:
72+
tokens[i, 1:-1] = vl_chat_processor.pad_id
73+
74+
inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens) #(parallel_size*2, len(input_ids) )
75+
76+
generated_tokens = ops.zeros(parallel_size, image_token_num_per_image, dtype=ms.int32)
77+
78+
for i in range(image_token_num_per_image):
79+
print(str(i)+'='*60)
80+
outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
81+
hidden_states = outputs.last_hidden_state # (parallel_size*2, len(input_ids), 2048)
82+
83+
logits = mmgpt.gen_head(hidden_states[:, -1, :]) #取最后一个input_id送入gen_head=>(parallel_size*2, vocab_size)
84+
logit_cond = logits[0::2, :]
85+
logit_uncond = logits[1::2, :]
86+
87+
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
88+
probs = ops.softmax(logits / temperature, dim=-1)
89+
90+
next_token = ops.multinomial(probs, num_samples=1) # (parallel_size, num_samples=1)
91+
generated_tokens[:, i] = next_token.squeeze(axis=-1)
92+
93+
next_token = ops.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) # (parallel_size*2)
94+
img_embeds = mmgpt.prepare_gen_img_embeds(next_token) # (parallel_size*2, 2048)
95+
# print("img_embeds.shape:", img_embeds.shape)
96+
# print("img_embeds.dtype:", img_embeds.dtype)
97+
inputs_embeds = img_embeds.unsqueeze(dim=1) #(parallel_size*2, 2048)
98+
99+
if image_token_num_per_image==576:
100+
dec = mmgpt.gen_vision_model.decode_code(generated_tokens.astype(ms.int32), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
101+
else:
102+
pad_last_token = generated_tokens[:,-1].unsqueeze(dim=1).tile((1, 576-image_token_num_per_image))
103+
cat_generated_tokens=ops.cat([generated_tokens, pad_last_token], dim=1)
104+
print("cat_generated_tokens.shape:",cat_generated_tokens.shape) #(1,576)
105+
dec = mmgpt.gen_vision_model.decode_code(cat_generated_tokens.astype(ms.int32), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
106+
dec = dec.astype(ms.float32).asnumpy().transpose(0, 2, 3, 1)
107+
108+
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
109+
110+
visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
111+
visual_img[:, :, :] = dec
112+
113+
os.makedirs('generated_samples', exist_ok=True)
114+
for i in range(parallel_size):
115+
save_path = os.path.join('generated_samples', "img_{}.jpg".format(i))
116+
PIL.Image.fromarray(visual_img[i]).save(save_path)
117+
generate(
118+
vl_gpt,
119+
vl_chat_processor,
120+
prompt,
121+
)
1.17 MB
Loading
+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) 2023-2024 DeepSeek.
2+
#
3+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
4+
# this software and associated documentation files (the "Software"), to deal in
5+
# the Software without restriction, including without limitation the rights to
6+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7+
# the Software, and to permit persons to whom the Software is furnished to do so,
8+
# subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all
11+
# copies or substantial portions of the Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19+
20+
21+
# check if python version is above 3.10
22+
import sys
23+
24+
if sys.version_info >= (3, 10):
25+
print("Python version is above 3.10, patching the collections module.")
26+
# Monkey patch collections
27+
import collections
28+
import collections.abc
29+
30+
for type_name in collections.abc.__all__:
31+
setattr(collections, type_name, getattr(collections.abc, type_name))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) 2023-2024 DeepSeek.
2+
#
3+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
4+
# this software and associated documentation files (the "Software"), to deal in
5+
# the Software without restriction, including without limitation the rights to
6+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7+
# the Software, and to permit persons to whom the Software is furnished to do so,
8+
# subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all
11+
# copies or substantial portions of the Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19+
20+
from .image_processing_vlm import VLMImageProcessor
21+
from .modeling_vlm import MultiModalityCausalLM
22+
from .processing_vlm import VLChatProcessor
23+
24+
__all__ = [
25+
"VLMImageProcessor",
26+
"VLChatProcessor",
27+
"MultiModalityCausalLM",
28+
]

0 commit comments

Comments
 (0)