1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import contextvars
1516import inspect
1617import time
1718from functools import wraps
3132
3233T = TypeVar ("T" , covariant = True )
3334tracer = trace .get_tracer (__name__ )
35+ _current_span_context : contextvars .ContextVar = contextvars .ContextVar (
36+ "current_span_context"
37+ )
3438
3539
3640def get_remote_func (func ): # type: ignore
@@ -74,7 +78,12 @@ def task(
7478
7579 def task_wrapper (func ): # type: ignore
7680 async def async_exec (* args : Any , ** kwargs : Any ) -> Any :
77- with tracer .start_as_current_span (name = func .__qualname__ ) as span :
81+ parent_ctx = _current_span_context .get (None )
82+ with tracer .start_as_current_span (
83+ name = func .__qualname__ , context = parent_ctx
84+ ) as span :
85+ _current_span_context .set (trace .set_span_in_context (span ))
86+
7887 input = _update_kwargs (args , kwargs , func )
7988 try :
8089 result = await (get_remote_func (func ) if distributed else func )(
@@ -98,7 +107,12 @@ async def async_exec(*args: Any, **kwargs: Any) -> Any:
98107 raise e
99108
100109 def sync_exec (* args : Any , ** kwargs : Any ) -> Any :
101- with tracer .start_as_current_span (name = func .__qualname__ ) as span :
110+ parent_ctx = _current_span_context .get (None )
111+ with tracer .start_as_current_span (
112+ name = func .__qualname__ , context = parent_ctx
113+ ) as span :
114+ _current_span_context .set (trace .set_span_in_context (span ))
115+
102116 input = _update_kwargs (args , kwargs , func )
103117 try :
104118 result = func (* args , ** kwargs )
@@ -121,9 +135,13 @@ def sync_exec(*args: Any, **kwargs: Any) -> Any:
121135
122136 @wraps (func )
123137 async def async_iter_task (* args : Any , ** kwargs : Any ) -> AsyncGenerator [T , None ]:
138+ parent_ctx = _current_span_context .get (None )
124139 span = tracer .start_span (
125- name = func .__qualname__ + ".first_iter" , start_time = time .time_ns ()
140+ name = func .__qualname__ + ".first_iter" ,
141+ start_time = time .time_ns (),
142+ context = parent_ctx ,
126143 )
144+ _current_span_context .set (trace .set_span_in_context (span ))
127145 input = _update_kwargs (args , kwargs , func )
128146 try :
129147 async for i , resp in aenumerate (func (* args , ** kwargs )): # type: ignore
@@ -142,12 +160,17 @@ async def async_iter_task(*args: Any, **kwargs: Any) -> AsyncGenerator[T, None]:
142160 custom_attributes = custom_attributes ,
143161 )
144162 span .end (end_time = time .time_ns ())
163+ _current_span_context .set (parent_ctx )
145164 yield resp
146165
147166 if trace_all :
167+ parent_ctx = _current_span_context .get ()
148168 span = tracer .start_span (
149- name = func .__qualname__ , start_time = time .time_ns ()
169+ name = func .__qualname__ ,
170+ start_time = time .time_ns (),
171+ context = parent_ctx ,
150172 )
173+ _current_span_context .set (trace .set_span_in_context (span ))
151174 except Exception as e :
152175 if not trace_all :
153176 span = tracer .start_span (
@@ -160,7 +183,12 @@ async def async_iter_task(*args: Any, **kwargs: Any) -> AsyncGenerator[T, None]:
160183
161184 @wraps (func )
162185 def iter_task (* args : Any , ** kwargs : Any ) -> Iterable [T ]:
163- span = tracer .start_span (name = func .__qualname__ , start_time = time .time_ns ())
186+ parent_ctx = _current_span_context .get (None )
187+ span = tracer .start_span (
188+ name = func .__qualname__ , start_time = time .time_ns (), context = parent_ctx
189+ )
190+ _current_span_context .set (trace .set_span_in_context (span ))
191+
164192 input = _update_kwargs (args , kwargs , func )
165193 try :
166194 for i , resp in enumerate (func (* args , ** kwargs )):
@@ -179,11 +207,16 @@ def iter_task(*args: Any, **kwargs: Any) -> Iterable[T]:
179207 custom_attributes = custom_attributes ,
180208 )
181209 span .end (end_time = time .time_ns ())
210+ _current_span_context .set (parent_ctx )
182211 yield resp
183212 if trace_all :
213+ parent_ctx = _current_span_context .get ()
184214 span = tracer .start_span (
185- name = func .__qualname__ , start_time = time .time_ns ()
215+ name = func .__qualname__ ,
216+ start_time = time .time_ns (),
217+ context = parent_ctx ,
186218 )
219+ _current_span_context .set (trace .set_span_in_context (span ))
187220 except Exception as e :
188221 if not trace_all :
189222 span = tracer .start_span (
0 commit comments