diff --git a/mldaikon/instrumentor/tracer.py b/mldaikon/instrumentor/tracer.py index 96355e1e..08278eb6 100644 --- a/mldaikon/instrumentor/tracer.py +++ b/mldaikon/instrumentor/tracer.py @@ -29,7 +29,7 @@ funcs_to_be_replaced, is_funcs_to_be_unproxied, ) -from mldaikon.proxy_wrapper.proxy_basics import is_proxied, unproxy_func +from mldaikon.proxy_wrapper.proxy_basics import is_proxied from mldaikon.proxy_wrapper.proxy_config import enable_C_level_observer from mldaikon.proxy_wrapper.proxy_registry import get_global_registry from mldaikon.utils import get_timestamp_ns, get_unique_id, typename @@ -261,15 +261,16 @@ def find_proxy_in_args(args): add_observer_to_func, # import here to avoid circular import ) - original_function = add_observer_to_func(original_function, unproxy=True) + original_function = add_observer_to_func(original_function, unproxy=False) elif is_funcs_to_be_unproxied(original_function): - original_function = unproxy_func( - original_function, inspect_torch_module=True - ) + # original_function = unproxy_func( + # original_function, inspect_torch_module=True + # ) + pass elif is_builtin: # proxy objects being passed to backend will cause seg fault: TODO: replace with unproxy func - original_function = unproxy_func(original_function) - + # original_function = unproxy_func(original_function) + pass try: if COLLECT_OVERHEAD_METRICS: ORIG_ENTER_PERF_TIME = time.perf_counter() @@ -407,8 +408,8 @@ def core_wrapper(original_function, is_builtin, handle_proxy, *args, **kwargs): if DISABLE_WRAPPER: return original_function(*args, **kwargs) - if handle_proxy and is_builtin: - original_function = unproxy_func(original_function) + # if handle_proxy and is_builtin: + # original_function = unproxy_func(original_function) return original_function(*args, **kwargs)