Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operation Question] How to separate truncation and matmul operations #672

Open
wilyub opened this issue Apr 30, 2024 · 9 comments
Open
Assignees

Comments

@wilyub
Copy link

wilyub commented Apr 30, 2024

Issue Type

Support

Modules Involved

MPC protocol, SPU runtime

Have you reproduced the bug with SPU HEAD?

Yes

Have you searched existing issues?

Yes

SPU Version

spu 0.9.0b1

OS Platform and Distribution

Linux Ubuntu 22.04

Python Version

3.11.5

Compiler Version

No response

Current Behavior?

I'd like to do some latency testing for multiplication and truncation operations. Here is a short overview of how it would look:

Input x and weight y are represented using 16 bits fixed point each. Therefore, when you do a multiplication, you expect them to have 32 bits representation in fixed point. However, suppose you have a ring size of 64 bits. You could then do another 32 bits * 32 bits multiplication before overflowing your ring size of 64 bits.

How can I do this in secretflow? In particular, I want to be able to control the fixed point representation size of each input to the matrix multiplication, and only truncate after specific multiplications (not every multiplication).

Thank you for your help!

Standalone code to reproduce the issue

N/A

Relevant log output

N/A
@Chrisdehe Chrisdehe assigned tpppppub and 6fj and unassigned tpppppub May 6, 2024
@6fj
Copy link
Member

6fj commented May 6, 2024

Hi @wilyub

I don't think you could seperate matmul and truncate in Python bindings of SPU. You may have to dive into kernels of SPU runtime.

@wilyub
Copy link
Author

wilyub commented May 6, 2024

Thanks for the advice. Any idea on where I could start looking? I'm not too familiar with the inner workings of SPU as I've only used it from python. Thanks!

@6fj
Copy link
Member

6fj commented May 7, 2024

I would suggest having a look at https://github.com/secretflow/spu/blob/main/REPO_LAYOUT.md.

@fionser
Copy link
Collaborator

fionser commented May 7, 2024

@wilyub To bench the matmul only without the truncation, we can use "integer" matrix instead of the floating point matrix.

@wilyub
Copy link
Author

wilyub commented May 7, 2024

Thanks for the update, I'll take a look at integer containers. Another thing that has been confusing me is benchmarking the time for softmax/silu/division, etc... From my reading of the literature (Puma, CipherGPT, etc..) these operations should introduce a high latency cost because they require a lot of communication. However, in my tests it seems like these operations do not add significant time to the end to end runtime of one layer in a transformer model? It seems like the largest amount of latency is attributed to the matrix multiplications instead.

I've added a screenshot of my log when I benchmarked one layer of Llama7B. The nonlinearities are implemented in jax (jax.nn.softmax and jax.nn.silu). However, as you see in the log, the vast majority of the time spent is on the mmul_aa operation (I assume this is matrix multiplication). Any idea if I've messed something up here?

Note: I'm using SEMI2K with FM64. I also did some 3PC testing with ABY3 (also FM64) and saw similar results.
mpc_weird

@anakinxc
Copy link
Collaborator

anakinxc commented May 8, 2024

Thanks for the update, I'll take a look at integer containers. Another thing that has been confusing me is benchmarking the time for softmax/silu/division, etc... From my reading of the literature (Puma, CipherGPT, etc..) these operations should introduce a high latency cost because they require a lot of communication. However, in my tests it seems like these operations do not add significant time to the end to end runtime of one layer in a transformer model? It seems like the largest amount of latency is attributed to the matrix multiplications instead.

I've added a screenshot of my log when I benchmarked one layer of Llama7B. The nonlinearities are implemented in jax (jax.nn.softmax and jax.nn.silu). However, as you see in the log, the vast majority of the time spent is on the mmul_aa operation (I assume this is matrix multiplication). Any idea if I've messed something up here?

Note: I'm using SEMI2K with FM64. I also did some 3PC testing with ABY3 (also FM64) and saw similar results. mpc_weird

@Ye-D

@wilyub
Copy link
Author

wilyub commented May 8, 2024

Please let me know if you find anything out about the above nonlinearities issue. I can also provide my code if it helps with reproducability (although like I said earlier, I just used the out of the box jnn.softmax() and jnn.silu() alongside some matrix multiplications).

One other thing that has been weird for me is the matmul benchmark. I tried two tests. One of them I did a matrix multiplication for the value weight matrix and the hidden states. (Only one matmul call). In the other test I did that matmul in addition to a matmul for hidden states and key weight matrix, and hidden states and query weight matrix. (Three matmul calls). However, the logs show the latency for the matmul to be nearly identitcal. Even weirder, the number of bytes sent/received are exactly the same. This sounds weird to me because I would expect matmul time to be tripled if we call three matmuls? I have sent a snapshot of my log and for the code.
attn_code
matmul_1
matmul_2

@anakinxc
Copy link
Collaborator

anakinxc commented May 8, 2024

Please let me know if you find anything out about the above nonlinearities issue. I can also provide my code if it helps with reproducability (although like I said earlier, I just used the out of the box jnn.softmax() and jnn.silu() alongside some matrix multiplications).

One other thing that has been weird for me is the matmul benchmark. I tried two tests. One of them I did a matrix multiplication for the value weight matrix and the hidden states. (Only one matmul call). In the other test I did that matmul in addition to a matmul for hidden states and key weight matrix, and hidden states and query weight matrix. (Three matmul calls). However, the logs show the latency for the matmul to be nearly identitcal. Even weirder, the number of bytes sent/received are exactly the same. This sounds weird to me because I would expect matmul time to be tripled if we call three matmuls? I have sent a snapshot of my log and for the code. attn_code matmul_1 matmul_2

both query_states and key_states are unused values, so matmuls defining them are dead code and should be killed during optimizations.

@wilyub
Copy link
Author

wilyub commented May 10, 2024

I tested out the integer matmul and it works (no truncation shows up in the log). I was wondering what function should I call to get trunc_a to show up in the log by itself? I tried jnp.trunc and didn't get that result. Thanks! Also any update on why the nonlinearities seem so cheap compared to matmuls?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants