Skip to content

Conversation

@hiddely
Copy link
Contributor

@hiddely hiddely commented Oct 14, 2025

This is an implementation of Huggingface's BERT transformer that I've been meaning to contribute back to MP-SPDZ. I implemented it as part of the evaluation of this work. The implementation includes all necessary BERT layers and a complete inference example. The example uses M-FAC/bert-tiny-finetuned-qnli, but can also be used with larger versions of BERT with more layers, e.g. gchhablani/bert-base-cased-finetuned-qnli.

I am happy to adapt this PR to better fit with the philosophy of the ML library. Currently the implementation does not use the new approach relying on torch.fx.trace, because the Huggingface implementation of BERT causes errors when used with torch.fx.trace (e.g. it has if statements in the model forward pass). However, if you have any feedback on how to improve the integration please let me know.

Highlights of Changes

  • In Compiler/ml.py:
    • BertLayer: Complete BERT encoder layer with multi-head attention and feed-forward network
    • MultiHeadAttention: Self-attention mechanism with multiple attention heads
    • Gelu: Implementation of Gelu layer, with version using polynomial approximation from Puma
    • Tanh: Activation using the existing sigmoid function
    • FlexDense: Version of Dense layer that supports a flexible number of dimensions.
      I kept it as a separate class to not introduce any bugs with existing models potentially relying on the existing implementation of Dense
    • FlexDropout: Version of Dropout supporting flexible number of dimensions.
    • LayerNorm: Implementation of layer normalization for transformer architectures
    • Extend ConvBase with optional bias instead of mandatory
    • Add backward pass support to Add layer
  • There is an inference example in Programs/Source/bert_inference.mpc using Huggingface's BERT model,
    which computes 25 samples and compares the output at each layer of a single sample with the pytorch output.

I took quite some care to make sure the implementation is correct; the forward pass layer-by-layer is pretty close
to Huggingface's implementation (modulo precision and approximations).
The by-layer error increases for the BertIntermediate layer, but the error seems to be pretty evenly distributed across the layer,
so I assume this is a result of the approximations. In layers afterwards the error seems to reduce again (in absolute terms).
The backward pass is also implemented and working without errors, but it is harder to check it for correctness.

I hope this can be a useful addition to MP-SPDZ's ML library. Let me know if it is, then I will clean it up further to make it ready to merge.

Example output of `bert_inference.mpc`
Trying to run 64-bit computation
Using SGD

=== Starting MPC Inference ===
Samples: 25
Batch size: 1
Running MPC inference...

=== Per-Sample Comparison ===
Sample | True Label | PyTorch Pred | MPC Pred | PT Correct | MPC Correct | Match
--------------------------------------------------------------------------------
0 | 0 | 0 | 0 | 1 | 1 | 1
1 | 1 | 1 | 1 | 1 | 1 | 1
2 | 1 | 0 | 1 | 0 | 1 | 0
3 | 0 | 1 | 1 | 0 | 0 | 1
4 | 1 | 0 | 0 | 0 | 0 | 1
5 | 1 | 1 | 1 | 1 | 1 | 1
6 | 1 | 0 | 0 | 0 | 0 | 1
7 | 1 | 1 | 1 | 1 | 1 | 1
8 | 1 | 0 | 0 | 0 | 0 | 1
9 | 0 | 0 | 0 | 1 | 1 | 1
10 | 1 | 1 | 1 | 1 | 1 | 1
11 | 0 | 0 | 0 | 1 | 1 | 1
12 | 0 | 0 | 0 | 1 | 1 | 1
13 | 1 | 1 | 1 | 1 | 1 | 1
14 | 1 | 1 | 1 | 1 | 1 | 1
15 | 0 | 1 | 1 | 0 | 0 | 1
16 | 0 | 0 | 0 | 1 | 1 | 1
17 | 0 | 1 | 1 | 0 | 0 | 1
18 | 0 | 0 | 0 | 1 | 1 | 1
19 | 1 | 1 | 1 | 1 | 1 | 1
20 | 0 | 0 | 0 | 1 | 1 | 1
21 | 1 | 1 | 1 | 1 | 1 | 1
22 | 1 | 0 | 0 | 0 | 0 | 1
23 | 0 | 0 | 0 | 1 | 1 | 1
24 | 0 | 0 | 0 | 1 | 1 | 1

=== Results Summary ===
PyTorch Accuracy: 0.68
MP-SPDZ Correct: 18/25
MP-SPDZ Accuracy: 0.72
MPC-PyTorch Match: 24/25 = 0.96

=== Layer-by-Layer Comparison ===
Running MPC forward pass for layer comparison...

Layer-by-layer comparison (Sample 0 only):
====================================================================================================

0.BertAttention
  Shape: [1, 64, 128], Elements: 8192
  Total Abs Diff: 1.53526
  PT Total Magnitude: 8674.72
  First 8 PT:  [-1.21268, -0.460876, -12.1223, -1.91841, 0.127014, -0.024826, -0.137589, 0.477814]
  First 8 MPC: [-1.21248, -0.4608, -12.1216, -1.91823, 0.126953, -0.0248108, -0.137558, 0.477692]

1.BertIntermediate
  Shape: [1, 64, 512], Elements: 32768
  Total Abs Diff: 145.205
  PT Total Magnitude: 9812.32
  First 8 PT:  [0.997009, -0.0773926, -0.104431, 0.0157928, -0.098938, 0.701431, 0.235107, 0.77919]
  First 8 MPC: [0.988831, -0.0701447, -0.0984344, 0.0243073, -0.0926971, 0.694016, 0.238785, 0.770889]

2.BertOutput
  Shape: [1, 64, 128], Elements: 8192
  Total Abs Diff: 29.1119
  PT Total Magnitude: 5764.84
  First 8 PT:  [-0.28949, -0.164566, -2.69733, -0.339325, 0.208282, 0.211044, 0.242966, -0.150711]
  First 8 MPC: [-0.287323, -0.163422, -2.69806, -0.337036, 0.208786, 0.209091, 0.242615, -0.15065]

3.BertLayer
  Shape: [1, 64, 128], Elements: 8192
  Total Abs Diff: 29.1119
  PT Total Magnitude: 5764.84
  First 8 PT:  [-0.28949, -0.164566, -2.69733, -0.339325, 0.208282, 0.211044, 0.242966, -0.150711]
  First 8 MPC: [-0.287323, -0.163422, -2.69806, -0.337036, 0.208786, 0.209091, 0.242615, -0.15065]

4.BertAttention
  Shape: [1, 64, 128], Elements: 8192
  Total Abs Diff: 32.178
  PT Total Magnitude: 7418.11
  First 8 PT:  [-0.945374, -0.995819, -1.84308, -2.34529, -0.588837, -0.417725, 0.179123, -1.99127]
  First 8 MPC: [-0.943817, -0.991333, -1.842, -2.3483, -0.597382, -0.428589, 0.174408, -1.99342]

5.BertIntermediate
  Shape: [1, 64, 512], Elements: 32768
  Total Abs Diff: 161.508
  PT Total Magnitude: 7922.42
  First 8 PT:  [0.233093, -0.111954, 1.01945, -0.057663, -0.165878, -0.0859375, -0.165787, 0.163406]
  First 8 MPC: [0.237289, -0.114639, 1.01451, -0.0514221, -0.167816, -0.0839844, -0.168732, 0.167877]

6.BertOutput
  Shape: [1, 64, 128], Elements: 8192
  Total Abs Diff: 56.9968
  PT Total Magnitude: 6668.43
  First 8 PT:  [-0.859436, -0.14003, -1.71599, -2.30341, -0.678909, -0.406189, 0.51741, -1.19707]
  First 8 MPC: [-0.865967, -0.150848, -1.7186, -2.30569, -0.675507, -0.414032, 0.513931, -1.19035]

7.BertLayer
  Shape: [1, 64, 128], Elements: 8192
  Total Abs Diff: 208.727
  PT Total Magnitude: 6515.35
  First 8 PT:  [-0.859436, -0.14003, -1.71599, -2.30341, -0.678909, -0.406189, 0.51741, -1.19707]
  First 8 MPC: [-0.865967, -0.150848, -1.7186, -2.30569, -0.675507, -0.414032, 0.513931, -1.19035]

8.BertPooler
  Shape: [1, 128], Elements: 128
  Total Abs Diff: 0.448059
  PT Total Magnitude: 66.9722
  First 8 PT:  [-0.991974, 0.0929565, -0.432709, -0.0628052, -0.976685, -0.0447235, -0.74736, 0.661972]
  First 8 MPC: [-0.992035, 0.0930481, -0.431732, -0.0752563, -0.976196, -0.0568848, -0.744385, 0.653748]
9.Dropout | Skipped (dropout)

10.Linear
  Shape: [1, 2], Elements: 2
  Total Abs Diff: 0.00212097
  PT Total Magnitude: 0.23735
  First 8 PT:  [0.182358, -0.0549927]
  First 8 MPC: [0.184189, -0.0552826]

=== Inference Complete ===
The following benchmarks are including preprocessing (offline phase).
Time = 196.373 seconds 
Data sent = 3353.93 MB in ~2252613 rounds (party 0 only; rounds counted double due to multi-threading; use '-v' for more details)
Global data sent = 10050.1 MB (all parties)
This program might benefit from some protocol options.
Consider adding the following at the beginning of your code:
        program.use_trunc_pr = True

@mkskeller
Copy link
Member

This looks great, thank you for your efforts! Just a few questions:

  • Can you revert the change to __repr__? The logic here is that accessing _Y doesn't allocate memory while accessing Y does (via the property in Layer). You should be able to avoid any issues by simply assigning _Y instead.
  • What is the reason for changing Y in MultiOutputLayer? It expects a one-hot vector there, so sint is the natural type.
  • What makes you worried about introducing bugs in the dense layer? Could you check your version against torch_mnist_dense?

@hiddely
Copy link
Contributor Author

hiddely commented Oct 21, 2025

Hi, thanks for the feedback! The __repr__ change was an accidental edit I made while debugging, and the MultiOutputBase change was unrelated work that shouldn’t have been included in this branch. I have reverted both now.

I’ve integrated the FlexDense changes into the Dense layer (it already had a d parameter, so the only update is that it now supports d > 1). I also fixed a few issues in the backward loop when batch values were non-contiguous, as in SGD. The example script torch_mnist_dense now produces the same outputs and operations as the original implementation.

Let me know if you have any other comments!

Output of torch_mnist_dense.mpc with changes
Hash: eed3fd1e3f0232c4ccafd2e1e8f6c1ba6aea4db73532531ea94abaa20a3b580b
Program requires at most:
      318283 integer opens
    55580000 integer inputs from player 0
    20505472 integer 3-way splits
      116616 integer 3-way split rounds
  2374153344 bit triples
    26384512 integer bit2As
      342024 integer bit2A rounds
      236032 integer bits
 16466262784 integer triples
    55624192 integer simple multiplications
           8 matrix multiplications (1x784 * 784x128)
    89346208 integer dot products
           4 matrix multiplications (2x784 * 784x128)
           8 matrix multiplications (1x128 * 128x128)
           4 matrix multiplications (2x128 * 128x128)
           8 matrix multiplications (1x128 * 128x10)
           4 matrix multiplications (2x128 * 128x10)
        2188 matrix multiplications (10x784 * 784x128)
        4376 matrix multiplications (11x784 * 784x128)
        5940 matrix multiplications (10x128 * 128x128)
       11880 matrix multiplications (11x128 * 128x128)
        4064 matrix multiplications (10x128 * 128x10)
        8128 matrix multiplications (11x128 * 128x10)
        1876 matrix multiplications (10x10 * 10x128)
        3752 matrix multiplications (11x10 * 10x128)
        3752 matrix multiplications (65x128 * 128x128)
        1876 matrix multiplications (66x128 * 128x128)
     1813113 virtual machine rounds
Compilation finished, running program...
Running /Users/hidde/PhD/auditing/MP-SPDZ-SYNC-FORK/Scripts/../replicated-ring-party.x 0 torch_mnist_dense-1-12 -pn 12372 -h localhost
Running /Users/hidde/PhD/auditing/MP-SPDZ-SYNC-FORK/Scripts/../replicated-ring-party.x 1 torch_mnist_dense-1-12 -pn 12372 -h localhost
Running /Users/hidde/PhD/auditing/MP-SPDZ-SYNC-FORK/Scripts/../replicated-ring-party.x 2 torch_mnist_dense-1-12 -pn 12372 -h localhost
Trying to run 64-bit computation
Using SGD
test loss: 0.39827
acc: 0.9262 (9262/10000)

Update mnist_full examples using Dropout
Clean up code style
@hiddely
Copy link
Contributor Author

hiddely commented Oct 21, 2025

I've also now merged FlexDropout with Dropout, whose __init__ now takes a shape as the argument, which is a flexible list of dimensions, where the first dimension is the number of samples.

Note that this change is not backwards compatible with older programs that use the ml.Dropout class directly, since it changes the __init__ signature of Dropout slightly. However, it should be easy to adapt them as it is enough to wrap the individual dimension parameters in a list. I looked through the existing usages of Dropout in the example programs and updated their use of Dropout. Alternatively, i could add a custom constructor that takes shape and leave the old __init__, but this would mean the class' __init__ is not the preferred way to construct it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants