@@ -393,6 +393,8 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
393393 (applicable to 2D sharding only)
394394 if set and DMP collection is enabled for 2D sharding,
395395 sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
396+ apply_jit_context (Optional[ContextManager]): a context manager that
397+ will surround the application of the JIT
396398 """
397399
398400 # The PipelinedForward class that is used in _rewrite_model
@@ -413,13 +415,15 @@ def __init__(
413415 ] = None ,
414416 dmp_collection_sync_interval_batches : Optional [int ] = 1 ,
415417 enqueue_batch_after_forward : bool = False ,
418+ apply_jit_context : Optional [ContextManager [None ]] = None ,
416419 ) -> None :
417420 self ._model = model
418421 self ._optimizer = optimizer
419422 self ._device = device
420423 self ._execute_all_batches = execute_all_batches
421424 self ._apply_jit = apply_jit
422425 self ._enqueue_batch_after_forward = enqueue_batch_after_forward
426+ self ._apply_jit_context = apply_jit_context
423427
424428 if device .type == "cuda" :
425429 # use two data streams to support two concurrent batches
@@ -716,6 +720,7 @@ def _pipeline_model(
716720 apply_jit = self ._apply_jit ,
717721 pipelined_forward = pipelined_forward ,
718722 pipeline_postproc = self ._pipeline_postproc ,
723+ apply_jit_context = self ._apply_jit_context ,
719724 )
720725 # initializes input dist, so we can override input dist forwards
721726 self .start_sparse_data_dist (batch , context )
@@ -904,6 +909,8 @@ class TrainPipelineFusedSparseDist(TrainPipelineSparseDist[In, Out]):
904909 TODO: pipeline_postproc, custom_model_fwd, strict
905910 use_emb_lookuo_stream (bool): if true invoke the compute_and_output_dist
906911 (for batch i+1) using a new stream, else re-using the data_dist stream
912+ apply_jit_context (ContextManager): a context manager that will surround the
913+ application of the JIT
907914 """
908915
909916 # The PipelinedForward class that is used in _rewrite_model
@@ -922,6 +929,7 @@ def __init__(
922929 ] = None ,
923930 strict : bool = False ,
924931 emb_lookup_stream : str = "data_dist" , # new, current, data_dist (default)
932+ apply_jit_context : Optional [ContextManager [None ]] = None ,
925933 ) -> None :
926934 super ().__init__ (
927935 model = model ,
@@ -932,6 +940,7 @@ def __init__(
932940 context_type = EmbeddingTrainPipelineContext ,
933941 pipeline_postproc = pipeline_postproc ,
934942 custom_model_fwd = custom_model_fwd ,
943+ apply_jit_context = apply_jit_context ,
935944 )
936945 if emb_lookup_stream == "new" :
937946 self ._emb_lookup_stream : Optional [torch .Stream ] = (
@@ -1066,6 +1075,8 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
10661075 (applicable to 2D sharding only)
10671076 if set and DMP collection is enabled for 2D sharding,
10681077 sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
1078+ apply_jit_context (ContextManager): a context manager that will surround the
1079+ application of the JIT
10691080 """
10701081
10711082 # The PipelinedForward class that is used in _rewrite_model
@@ -1086,6 +1097,7 @@ def __init__(
10861097 ] = None ,
10871098 strict : bool = False ,
10881099 dmp_collection_sync_interval_batches : Optional [int ] = 1 ,
1100+ apply_jit_context : Optional [ContextManager [None ]] = None ,
10891101 ) -> None :
10901102 super ().__init__ (
10911103 model = model ,
@@ -1097,6 +1109,7 @@ def __init__(
10971109 pipeline_postproc = pipeline_postproc ,
10981110 custom_model_fwd = custom_model_fwd ,
10991111 dmp_collection_sync_interval_batches = dmp_collection_sync_interval_batches ,
1112+ apply_jit_context = apply_jit_context ,
11001113 )
11011114 self ._start_batch = start_batch
11021115 self ._stash_gradients = stash_gradients
@@ -1378,6 +1391,8 @@ class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
13781391 execute_all_batches (bool): executes remaining batches in pipeline after
13791392 exhausting dataloader iterator.
13801393 apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1394+ apply_jit_context (ContextManager): a context manager that will surround the
1395+ application of the JIT
13811396 """
13821397
13831398 # The PipelinedForward class that is used in _rewrite_model
@@ -1394,6 +1409,7 @@ def __init__(
13941409 custom_model_fwd : Optional [
13951410 Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
13961411 ] = None ,
1412+ apply_jit_context : Optional [ContextManager [None ]] = None ,
13971413 ) -> None :
13981414 super ().__init__ (
13991415 model = model ,
@@ -1404,6 +1420,7 @@ def __init__(
14041420 context_type = PrefetchTrainPipelineContext ,
14051421 pipeline_postproc = pipeline_postproc ,
14061422 custom_model_fwd = custom_model_fwd ,
1423+ apply_jit_context = apply_jit_context ,
14071424 )
14081425 self ._context = PrefetchTrainPipelineContext (version = 0 )
14091426 self ._prefetch_stream : Optional [torch .Stream ] = (
@@ -1535,6 +1552,8 @@ class EvalPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
15351552 device (torch.device): device where device transfer, sparse data dist, and
15361553 forward/backward pass will happen.
15371554 apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1555+ apply_jit_context (Optional[ContextManager]): a context manager that
1556+ will surround the application of the JIT
15381557 """
15391558
15401559 # The PipelinedForward class that is used in _rewrite_model
@@ -1546,8 +1565,16 @@ def __init__(
15461565 optimizer : torch .optim .Optimizer ,
15471566 device : torch .device ,
15481567 apply_jit : bool = False ,
1568+ apply_jit_context : Optional [ContextManager [None ]] = None ,
15491569 ) -> None :
1550- super ().__init__ (model , optimizer , device , True , apply_jit )
1570+ super ().__init__ (
1571+ model ,
1572+ optimizer ,
1573+ device ,
1574+ True ,
1575+ apply_jit ,
1576+ apply_jit_context = apply_jit_context ,
1577+ )
15511578 self ._batch_loader : Optional [DataLoadingThread [In ]] = None
15521579
15531580 def __del__ (self ) -> None :
@@ -1909,6 +1936,7 @@ def __init__(
19091936 custom_model_fwd : Optional [
19101937 Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
19111938 ] = None ,
1939+ apply_jit_context : Optional [ContextManager [None ]] = None ,
19121940 ) -> None :
19131941 super ().__init__ (
19141942 model ,
@@ -1919,6 +1947,7 @@ def __init__(
19191947 context_type ,
19201948 pipeline_postproc ,
19211949 custom_model_fwd ,
1950+ apply_jit_context = apply_jit_context ,
19221951 )
19231952
19241953 torch ._logging .set_logs (compiled_autograd_verbose = True )
0 commit comments