-
Notifications
You must be signed in to change notification settings - Fork 333
Add BERT to ML Library #1736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add BERT to ML Library #1736
Conversation
|
This looks great, thank you for your efforts! Just a few questions:
|
|
Hi, thanks for the feedback! The I’ve integrated the Let me know if you have any other comments! Output of torch_mnist_dense.mpc with changes |
Update mnist_full examples using Dropout Clean up code style
|
I've also now merged Note that this change is not backwards compatible with older programs that use the |
This is an implementation of Huggingface's BERT transformer that I've been meaning to contribute back to MP-SPDZ. I implemented it as part of the evaluation of this work. The implementation includes all necessary BERT layers and a complete inference example. The example uses
M-FAC/bert-tiny-finetuned-qnli, but can also be used with larger versions of BERT with more layers, e.g.gchhablani/bert-base-cased-finetuned-qnli.I am happy to adapt this PR to better fit with the philosophy of the ML library. Currently the implementation does not use the new approach relying on
torch.fx.trace, because the Huggingface implementation of BERT causes errors when used withtorch.fx.trace(e.g. it has if statements in the model forward pass). However, if you have any feedback on how to improve the integration please let me know.Highlights of Changes
Compiler/ml.py:BertLayer: Complete BERT encoder layer with multi-head attention and feed-forward networkMultiHeadAttention: Self-attention mechanism with multiple attention headsGelu: Implementation of Gelu layer, with version using polynomial approximation from PumaTanh: Activation using the existingsigmoidfunctionFlexDense: Version ofDenselayer that supports a flexible number of dimensions.I kept it as a separate class to not introduce any bugs with existing models potentially relying on the existing implementation of
DenseFlexDropout: Version ofDropoutsupporting flexible number of dimensions.LayerNorm: Implementation of layer normalization for transformer architecturesConvBasewith optional bias instead of mandatoryAddlayerPrograms/Source/bert_inference.mpcusing Huggingface's BERT model,which computes 25 samples and compares the output at each layer of a single sample with the pytorch output.
I took quite some care to make sure the implementation is correct; the forward pass layer-by-layer is pretty close
to Huggingface's implementation (modulo precision and approximations).
The by-layer error increases for the BertIntermediate layer, but the error seems to be pretty evenly distributed across the layer,
so I assume this is a result of the approximations. In layers afterwards the error seems to reduce again (in absolute terms).
The backward pass is also implemented and working without errors, but it is harder to check it for correctness.
I hope this can be a useful addition to MP-SPDZ's ML library. Let me know if it is, then I will clean it up further to make it ready to merge.
Example output of `bert_inference.mpc`