Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speed issue after conversion #54

Open
thegodone opened this issue Jul 16, 2024 · 7 comments
Open

speed issue after conversion #54

thegodone opened this issue Jul 16, 2024 · 7 comments

Comments

@thegodone
Copy link

After conversion I have issue on speed:

PyTorch Inference Time: 0.7236480712890625
         558606 function calls (471806 primitive calls) in 0.723 seconds

   Ordered by: internal time
   List reduced from 77 to 25 due to restriction <25>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     9700    0.268    0.000    0.268    0.000 {built-in method torch._C._nn.linear}
124600/111900    0.056    0.000    0.552    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:285(decorator)
     3200    0.040    0.000    0.040    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3aca41da0}
     3600    0.030    0.000    0.030    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3ac9dfba0}
        1    0.028    0.028    0.723    0.723 /Users/tgg/Github/atr_igor/testkeras5.py:73(profile_pytorch_inference)
     1800    0.027    0.000    0.027    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3aca413a0}
     9600    0.020    0.000    0.020    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3ac9ddee0}
25200/500    0.020    0.000    0.680    0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1534(_call_impl)
    50400    0.019    0.000    0.019    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1696(__getattr__)
     1800    0.016    0.000    0.387    0.000 /Users/tgg/Github/atr_igor/transformer.py:66(forward)
     3200    0.015    0.000    0.015    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3aca1fa60}
25200/500    0.014    0.000    0.679    0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:241(forward)
     3200    0.012    0.000    0.112    0.000 /Users/tgg/Github/atr_igor/transformer.py:47(forward)
     9700    0.012    0.000    0.289    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/linear.py:115(forward)
25200/500    0.011    0.000    0.680    0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1528(_wrapped_call_impl)
   149800    0.010    0.000    0.010    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:212(is_tracing_enabled)
     9000    0.009    0.000    0.009    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3aca43740}
     1800    0.009    0.000    0.104    0.000 /Users/tgg/Github/atr_igor/transformer.py:77(attention)
     4200    0.008    0.000    0.014    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/functional.py:1279(dropout)
     7200    0.008    0.000    0.008    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3aca58900}
     5000    0.008    0.000    0.008    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3ac9f1260}
     3400    0.008    0.000    0.008    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3ac9dfce0}
     3000    0.007    0.000    0.641    0.000 /Users/tgg/Github/atr_igor/transformer.py:26(forward)
     1800    0.006    0.000    0.006    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3aca1f420}
     1800    0.005    0.000    0.005    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3ac9de7a0}


****************************************************************************************************
Keras Inference Time: 2.867401123046875
         3564959 function calls (3444008 primitive calls) in 2.861 seconds

   Ordered by: internal time
   List reduced from 1691 to 25 due to restriction <25>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      100    1.592    0.016    1.592    0.016 {built-in method tensorflow.python._pywrap_tfe.TFE_Py_Execute}
      400    0.144    0.000    0.145    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/constant_op.py:70(convert_to_eager_tensor)
     3083    0.040    0.000    0.092    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/util/tf_decorator.py:272(unwrap)
460826/460807    0.031    0.000    0.096    0.000 {built-in method builtins.isinstance}
   223876    0.024    0.000    0.029    0.000 {built-in method builtins.hasattr}
     5218    0.020    0.000    0.028    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/typing.py:1911(_get_protocol_attrs)
   236394    0.018    0.000    0.019    0.000 {built-in method builtins.getattr}
    63091    0.017    0.000    0.029    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/util/tf_decorator.py:187(_has_tf_decorator_attr)
     2450    0.016    0.000    0.063    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/ops.py:959(_create_c_op)
77423/13928    0.014    0.000    0.031    0.000 {built-in method builtins.hash}
     2705    0.013    0.000    0.118    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/util/tf_decorator.py:179(_get_bound_instance)
   1174/5    0.013    0.000    0.858    0.172 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/keras/src/engine/base_layer.py:1005(__call__)
    36466    0.013    0.000    0.029    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/dtypes.py:793(as_dtype)
     2450    0.012    0.000    0.012    0.000 {built-in method tensorflow.python.client._pywrap_tf_session.TF_FinishOperation}
     1374    0.012    0.000    0.380    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/op_def_library.py:752(_apply_op_helper)
     2439    0.012    0.000    0.013    0.000 {built-in method tensorflow.python.client._pywrap_tf_session.TF_OperationGetAttrValueProto}
   180059    0.011    0.000    0.011    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/util/tf_decorator.py:343(decorated_target)
     1374    0.010    0.000    0.120    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/op_def_library.py:411(_ExtractInputsAndAttrs)
5822/5520    0.010    0.000    0.040    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/inspect.py:2428(_signature_from_callable)
    23844    0.009    0.000    0.025    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/op_def_library.py:55(<genexpr>)
     2705    0.008    0.000    0.173    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/util/tf_decorator.py:115(make_decorator)
     2450    0.008    0.000    0.178    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/ops.py:2593(_create_op_internal)
     3172    0.008    0.000    0.020    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/inspect.py:2333(_signature_from_function)
    60699    0.008    0.000    0.011    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/inspect.py:300(ismethod)
     2450    0.008    0.000    0.108    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/ops.py:1051(from_node_def)

What is also strange is the number of operations between torch and keras.

@AlexanderLutsenko
Copy link
Owner

When Tensorflow performance sucks, these are the usual culprits:

  1. In Pytorch, transformers typically call scaled_dot_product_attention which may leverage highly optimized kernels (e.g. FlashAttention). Sadly, there is no such thing in Tensorflow, so Nobuco computes attention with a naive algorithm.
  2. Advanced tensor slicing, Tensorflow lacks good implementation for it. You can run this example and see how bulky the output graph is.

@thegodone
Copy link
Author

thegodone commented Jul 18, 2024

I use as much as I can tf.keras instead of keras import

I compare the two code speeds:

def converter_scaled_dot_product_attention1(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
    def func(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
        D = tf.shape(query)[-1]


        if scale is None:
            scale = tf.cast(D, query.dtype) ** -0.5

        # Corby's numerically more stable attention
        # See: https://github.com/bigscience-workshop/Megatron-DeepSpeed/pull/118
        s_scale = tf.cast(tf.sqrt(scale), query.dtype)
        query = query * s_scale
        key = key * s_scale

        sim = tf.matmul(query, key, transpose_b=True)

        if attn_mask is not None:
            sim += attn_mask * -1e9
        elif is_causal:
            L = tf.shape(query)[-2]
            S = tf.shape(key)[-2]
            causal_mask = tf.linalg.band_part(tf.ones((L, S)), -1, 0)
            sim = sim * causal_mask + (1.0 - causal_mask) * -1e9

        attn = tf.nn.softmax(sim, axis=-1)
        if dropout_p>0:
            attn = Dropout(dropout_p)(attn)

        return tf.matmul(attn, value)
    
    return func

and your code (little modified):


def tril(h, w):
    y = tf.range(0, h)[:, None]
    x = tf.range(0, w)[None, :]
    return y >= x


def converter_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
    def func(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
        D = tf.shape(query)[-1]

        if scale is None:
            scale = tf.cast(D, query.dtype) ** -0.5

        # Corby's numerically more stable attention
        # See: https://github.com/bigscience-workshop/Megatron-DeepSpeed/pull/118
        s_scale = tf.cast(tf.sqrt(scale), query.dtype)
        query = query * s_scale
        key = key * s_scale

        sim = query @ tf.experimental.numpy.swapaxes(key, -2, -1)

        if attn_mask is not None:
            sim = tf.where(attn_mask, sim, float("-inf"))
        elif is_causal:
            L = tf.shape(query)[-2]
            S = tf.shape(key)[-2]
            causal_mask = tril(L, S)
            sim = tf.where(causal_mask, sim, float("-inf"))

        attn = tf.nn.softmax(sim, axis=-1)
        attn = tf.keras.layers.Dropout(dropout_p)(attn)
        return attn @ value
    return func

I got almost same speed with a very little improvement in version "1"

@AlexanderLutsenko
Copy link
Owner

AlexanderLutsenko commented Jul 18, 2024

@thegodone One important thing I almost forgot about: Tensorflow really hates dynamic tensor shapes. To infer language models with varying context length, you should do input padding (see this example and the accompanying issue).

@AlexanderLutsenko
Copy link
Owner

Turns out, the inference is much faster if the Keras model is exported as SavedModel artifact:

keras_model.export(model_path)

saved_model = tf.saved_model.load(model_path)
saved_model.serve(inputs)

@thegodone
Copy link
Author

nice catch, look like there are optimization during the savedmodel. I will try that thanks a lot

@thegodone
Copy link
Author

thegodone commented Jul 19, 2024

indeed almost all is faster now, except the "first line" built-in method tensorflow.python._pywrap_tfe.TFE_Py_Execute this is strange:

(mlxgraphenv-py311) tgg@macbook-pro atr_igor % python testkeras2_savedmodel.py
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.container.ModuleList' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'transformer.MultiHeadAttention' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.sparse.Embedding' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
src transform: [[9, 20, 28, 20]]
PyTorch Inference Time: 0.6295859813690186
Generator: tensor([[-29.3162,  -8.5767,  -8.5767,  -8.1641,  -8.7406,  -8.5813,  -8.7283,
          -8.8958,  -7.8562,  -8.5062,  -8.2777,  -8.4021,  -8.7413,  -8.6585,
          -8.9101,  -8.7277,  -8.8235,  -7.7042,  -8.4626,  -7.2766,  -0.8226,
          -6.3867,  -8.3603,  -7.5078,  -8.5743,  -8.7478,  -4.1617,  -1.3617,
          -7.9011,  -7.7502,  -5.9109,  -8.5652,  -1.3151,  -8.6553,  -8.7373,
          -8.6551,  -8.6658,  -8.5221,  -4.9030,  -7.4738,  -8.4073,  -8.0457]],
       grad_fn=<SelectBackward0>)
****************************************************************************************************
Keras Inference Time: 2.041451930999756
Generator: tf.Tensor(
[[-29.316212   -8.576733   -8.576746   -8.164137   -8.740633   -8.581325
   -8.728261   -8.89583    -7.85616    -8.506172   -8.277732   -8.402089
   -8.741335   -8.658456   -8.910144   -8.727736   -8.823455   -7.704239
   -8.462603   -7.2765865  -0.8225823  -6.386689   -8.360339   -7.5077925
   -8.574331   -8.747827   -4.161747   -1.3616791  -7.901073   -7.7501645
   -5.910877   -8.565204   -1.3150749  -8.655296   -8.737296   -8.655108
   -8.665803   -8.522052   -4.9029703  -7.4737854  -8.407324   -8.04574  ]], shape=(1, 42), dtype=float32)
****************************************************************************************************
Exception ignored in: <function AtomicFunction.__del__ at 0x16ef13880>
Traceback (most recent call last):
  File "/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/atomic_function.py", line 291, in __del__
TypeError: 'NoneType' object is not subscriptable
(mlxgraphenv-py311) tgg@macbook-pro atr_igor % python testkeras2_savedmodel.py
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.container.ModuleList' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'transformer.MultiHeadAttention' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.sparse.Embedding' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
src transform: [[9, 20, 28, 20]]
Traceback (most recent call last):
  File "/Users/tgg/Github/atr_igor/testkeras2_savedmodel.py", line 113, in <module>
    profiler = cProfile.Profile()
               ^^^^^^^^
NameError: name 'cProfile' is not defined
(mlxgraphenv-py311) tgg@macbook-pro atr_igor % python testkeras2_savedmodel.py
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.container.ModuleList' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'transformer.MultiHeadAttention' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/serialization.py:1113: SourceChangeWarning: source code of class 'torch.nn.modules.sparse.Embedding' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
  warnings.warn(msg, SourceChangeWarning)
src transform: [[9, 20, 28, 20]]
PyTorch Inference Time: 0.7369470596313477
         558604 function calls (471804 primitive calls) in 0.737 seconds

   Ordered by: internal time
   List reduced from 76 to 25 due to restriction <25>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     9700    0.276    0.000    0.276    0.000 {built-in method torch._C._nn.linear}
124600/111900    0.053    0.000    0.560    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:285(decorator)
     1800    0.046    0.000    0.046    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f40f40}
     3200    0.039    0.000    0.039    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f41940}
        1    0.030    0.030    0.737    0.737 /Users/tgg/Github/atr_igor/testkeras2_savedmodel.py:80(profile_pytorch_inference)
     3600    0.029    0.000    0.029    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4edf740}
25200/500    0.022    0.000    0.699    0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1534(_call_impl)
    50400    0.020    0.000    0.020    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1696(__getattr__)
     9600    0.020    0.000    0.020    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4edda80}
     1800    0.017    0.000    0.411    0.000 /Users/tgg/Github/atr_igor/transformer.py:66(forward)
25200/500    0.014    0.000    0.698    0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:241(forward)
     3200    0.013    0.000    0.013    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f23600}
     3200    0.012    0.000    0.109    0.000 /Users/tgg/Github/atr_igor/transformer.py:47(forward)
     9700    0.011    0.000    0.297    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/linear.py:115(forward)
25200/500    0.011    0.000    0.699    0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1528(_wrapped_call_impl)
   149800    0.010    0.000    0.010    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:212(is_tracing_enabled)
     1800    0.009    0.000    0.121    0.000 /Users/tgg/Github/atr_igor/transformer.py:77(attention)
     9000    0.009    0.000    0.009    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f432e0}
     4200    0.008    0.000    0.014    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/functional.py:1279(dropout)
     7200    0.008    0.000    0.008    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f584a0}
     5000    0.008    0.000    0.008    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4ef0e00}
     3000    0.007    0.000    0.664    0.000 /Users/tgg/Github/atr_igor/transformer.py:26(forward)
     3400    0.006    0.000    0.006    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4edf880}
     1800    0.005    0.000    0.005    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4f22fc0}
     1800    0.005    0.000    0.005    0.000 {function Tracer.op_tracing_decorator.<locals>.decorator at 0x3b4ede340}


****************************************************************************************************
Keras Inference Time: 2.2881689071655273
         1447802 function calls (1311645 primitive calls) in 2.286 seconds

   Ordered by: internal time
   List reduced from 1591 to 25 due to restriction <25>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      100    1.671    0.017    1.671    0.017 {built-in method tensorflow.python._pywrap_tfe.TFE_Py_Execute}
      400    0.128    0.000    0.129    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/constant_op.py:70(convert_to_eager_tensor)
      667    0.030    0.000    0.030    0.000 {built-in method tensorflow.python.client._pywrap_tf_session.TF_FinishOperation}
      168    0.029    0.000    0.029    0.000 {method '_numpy_internal' of 'tensorflow.python.framework.ops.EagerTensor' objects}
10721/155    0.016    0.000    0.096    0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/ast.py:488(generic_visit)
103263/30224    0.015    0.000    0.033    0.000 {built-in method builtins.hash}
178603/178541    0.012    0.000    0.032    0.000 {built-in method builtins.isinstance}
   113900    0.010    0.000    0.011    0.000 {built-in method builtins.getattr}
 5098/199    0.009    0.000    0.021    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/autograph/pyct/ast_util.py:33(copy)
    34870    0.008    0.000    0.024    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/tensor.py:894(__hash__)
  6316/44    0.008    0.000    0.090    0.002 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/autograph/pyct/transformer.py:417(visit)
 16072/82    0.008    0.000    0.123    0.001 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/ast.py:414(visit)
    59573    0.007    0.000    0.007    0.000 {built-in method builtins.hasattr}
    26280    0.007    0.000    0.012    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/autograph/pyct/anno.py:130(hasanno)
    46432    0.007    0.000    0.009    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/ast.py:255(iter_fields)
     1098    0.004    0.000    0.006    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/typing.py:1911(_get_protocol_attrs)
    24987    0.004    0.000    0.005    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/enum.py:1230(__hash__)
    34870    0.004    0.000    0.005    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/tensor_shape.py:1508(__hash__)
      667    0.004    0.000    0.043    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/ops.py:959(_create_c_op)
      231    0.004    0.000    0.004    0.000 {built-in method tensorflow.python.client._pywrap_tf_session.TF_GraphCopyFunction}
     3502    0.004    0.000    0.009    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/nobuco/trace/trace.py:285(decorator)
      797    0.003    0.000    0.004    0.000 {built-in method tensorflow.python.client._pywrap_tf_session.TF_OperationGetAttrValueProto}
  2030/98    0.003    0.000    0.013    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/gast/astn.py:17(generic_visit)
      200    0.003    0.000    0.006    0.000 /Users/tgg/Github/atr_igor/data.py:237(pad_pack)
     7339    0.003    0.000    0.003    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/framework/dtypes.py:264(__eq__)


****************************************************************************************************
Exception ignored in: <function AtomicFunction.__del__ at 0x16c817880>
Traceback (most recent call last):
  File "/Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/atomic_function.py", line 291, in __del__
TypeError: 'NoneType' object is not subscriptable

@johndpope
Copy link

probably deserves mentioning on readme that performances suffers until model is saved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants