Skip to content

Commit b2e7df2

Browse files
rootwanfengcxz
authored andcommitted
change graph
1 parent fec0eb8 commit b2e7df2

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def capture(self, **kwargs):
8080
context = self.ctx_mgr.current_context()
8181
self.update_Camb_context(self.meta, context)
8282
current_stream = torch.cuda.current_stream()
83+
8384
# warmup
8485
self.model(**padded_kwargs)
8586

@@ -292,13 +293,13 @@ def get_graph_key(self, input_ids: torch.Tensor,
292293
def __call__(self, **kwargs):
293294
"""call."""
294295
enable_graph = self.enable_graph(**kwargs)
295-
296-
if not enable_graph:
297-
return self.model(**kwargs)
298-
299296
graph_key = self.get_graph_key(**kwargs)
300297
max_tokens = graph_key[0]
301298
is_decoding = graph_key[1]
299+
300+
if (not enable_graph) or (not is_decoding):
301+
return self.model(**kwargs)
302+
302303
if graph_key not in self._runner_map:
303304
max_batches = max_tokens if is_decoding else self.max_batches
304305
runner = CAMBSingleGraphRunner(self.model,
@@ -312,6 +313,7 @@ def __call__(self, **kwargs):
312313
self._runner_map[graph_key] = runner
313314
else:
314315
runner = self._runner_map[graph_key]
316+
315317
output = runner.forward(**kwargs)
316318
return output
317319

0 commit comments

Comments
 (0)