Skip to content

Commit fd860f0

Browse files
authored
cast to fp16 and change file path (#1955)
1 parent ab0c765 commit fd860f0

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

llm/inference/janus_pro/understanding.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)
2929
print('loaded processor and ckpt ')
3030
question = 'describe this image'
31-
image = "/home/HwHiAiUser/janus-pro-mindspore/inpain_model_cat.png"
31+
image = "./inpain_model_cat.png"
3232
conversation = [
3333
{
3434
"role": "<|User|>",

mindnlp/core/ops/array.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def argwhere(input):
2121
# cat
2222
has_cat = hasattr(mindspore.mint, 'cat')
2323
def cat(tensors, dim=0):
24-
if use_pyboost() and has_cat:
25-
return mindspore.mint.cat(tensors, dim)
24+
# if use_pyboost() and has_cat:
25+
# return mindspore.mint.cat(tensors, dim)
2626
return ops.cat(tensors, dim)
2727

2828
# concat

mindnlp/transformers/cache_utils.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,10 @@ def update(
364364
# Update the number of seen tokens
365365
if layer_idx == 0:
366366
self._seen_tokens += key_states.shape[-2]
367-
367+
if key_states.dtype!=mindspore.float16:
368+
key_states = key_states.astype(mindspore.float16)
369+
if key_states.dtype!=mindspore.float16:
370+
value_states = value_states.astype(mindspore.float16)
368371
# Update the cache
369372
if len(self.key_cache) <= layer_idx:
370373
self.key_cache.append(key_states)
@@ -375,7 +378,8 @@ def update(
375378
self.key_cache[layer_idx] = key_states
376379
self.value_cache[layer_idx] = value_states
377380
else:
378-
self.key_cache[layer_idx] = ops.cat([self.key_cache[layer_idx], key_states], dim=-2)
381+
self.key_cache[layer_idx] = ops.cat(
382+
[self.key_cache[layer_idx].astype(mindspore.float16), key_states.astype(mindspore.float16)], dim=-2)
379383
self.value_cache[layer_idx] = ops.cat([self.value_cache[layer_idx], value_states], dim=-2)
380384

381385
return self.key_cache[layer_idx], self.value_cache[layer_idx]

0 commit comments

Comments
 (0)