diff --git a/official/colab/bert.ipynb b/official/colab/bert.ipynb new file mode 100644 index 00000000000..88c92ddb16d --- /dev/null +++ b/official/colab/bert.ipynb @@ -0,0 +1,383 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "How-to Guide: Using a PIP package for fine-tuning a BERT model", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "5T_-iFRIqliG", + "colab_type": "text" + }, + "source": [ + "## How-to Guide: Using a PIP package for fine-tuning a BERT model\n", + "\n", + "Author: [Chen Chen](https://github.com/chenGitHuber)\n", + "\n", + "In this example, we will work through fine-tuning a BERT model using the tensorflow-models PIP package." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mY1vX5VAq4SS", + "colab_type": "text" + }, + "source": [ + "## License\n", + "\n", + "Copyright 2020 The TensorFlow Authors. All Rights Reserved.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " http://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XV3k63bt0ihl", + "colab_type": "text" + }, + "source": [ + "## Learning objectives\n", + "\n", + "In this Colab notebook, you will learn how to fine-tune a BERT model using the TensorFlow Model Garden PIP package." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MHA-RWherfG4", + "colab_type": "text" + }, + "source": [ + "## Enable the GPU acceleration\n", + "Please enable GPU for better performance.\n", + "* Navigate to Edit 🡒 Notebook settings\n", + "* Select GPU from the \"Hardware Accelerator\" drop-down list\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l2B4N5Djrs2l", + "colab_type": "text" + }, + "source": [ + "## Install the Model Garden PIP package\n", + "\n", + "Install the Model Garden PIP package (tf-models-nightly) and other necessary PIP packages." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "VMYZDly6rx97", + "colab_type": "code", + "outputId": "146956ab-4568-4de6-c78e-cf25f115a5a8", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + } + }, + "source": [ + "pip install tf-models-nightly" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Collecting tf-models-nightly\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/bd/7c/1390d4e05d4d370e91d32dd9700d3a462dbc560c7f4e95a6477592b17def/tf_models_nightly-2.2.0.dev20200326-py2.py3-none-any.whl (710kB)\n", + "\u001b[K |████████████████████████████████| 716kB 2.8MB/s \n", + "\u001b[?25hRequirement already satisfied: scipy>=0.19.1 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.4.1)\n", + "Collecting opencv-python-headless\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/0b/23/5f10b30a48b218a4884bc84188c14381ac71288b210f6f8079a54f7a05e8/opencv_python_headless-4.2.0.32-cp36-cp36m-manylinux1_x86_64.whl (21.6MB)\n", + "\u001b[K |████████████████████████████████| 21.6MB 1.3MB/s \n", + "\u001b[?25hRequirement already satisfied: tensorflow-hub>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.7.0)\n", + "Requirement already satisfied: typing in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (3.6.6)\n", + "Collecting tensorflow-model-optimization>=0.2.1\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/8f/c4/4c3d011e432bd9c19f0323f7da7d3f783402615e4c3b5a98416c7da9cb05/tensorflow_model_optimization-0.2.1-py2.py3-none-any.whl (93kB)\n", + "\u001b[K |████████████████████████████████| 102kB 10.2MB/s \n", + "\u001b[?25hRequirement already satisfied: numpy>=1.15.4 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.18.2)\n", + "Requirement already satisfied: pandas>=0.22.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.25.3)\n", + "Requirement already satisfied: gin-config in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.3.0)\n", + "Requirement already satisfied: google-cloud-bigquery>=0.31.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.21.0)\n", + "Collecting mlperf-compliance==0.0.10\n", + " Downloading https://files.pythonhosted.org/packages/f4/08/f2febd8cbd5c9371f7dab311e90400d83238447ba7609b3bf0145b4cb2a2/mlperf_compliance-0.0.10-py3-none-any.whl\n", + "Requirement already satisfied: kaggle>=1.3.9 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.5.6)\n", + "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.12.0)\n", + "Requirement already satisfied: google-api-python-client>=1.6.7 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.7.12)\n", + "Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.8.3)\n", + "Requirement already satisfied: psutil>=5.4.3 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (5.4.8)\n", + "Requirement already satisfied: tensorflow-datasets in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (2.1.0)\n", + "Collecting sentencepiece\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)\n", + "\u001b[K |████████████████████████████████| 1.0MB 58.3MB/s \n", + "\u001b[?25hCollecting tf-nightly\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/39/1c/4408b4c4b0d8008a7de62162e35089d59d19cc7543cfd1b23a70121f3086/tf_nightly-2.2.0.dev20200325-cp36-cp36m-manylinux2010_x86_64.whl (516.1MB)\n", + "\u001b[K |████████████████████████████████| 516.1MB 21kB/s \n", + "\u001b[?25hCollecting py-cpuinfo>=3.3.0\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/42/60/63f28a5401da733043abe7053e7d9591491b4784c4f87c339bf51215aa0a/py-cpuinfo-5.0.0.tar.gz (82kB)\n", + "\u001b[K |████████████████████████████████| 92kB 13.7MB/s \n", + "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (7.0.0)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (3.13)\n", + "Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (3.2.1)\n", + "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.7)\n", + "Requirement already satisfied: Cython in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.29.15)\n", + "Requirement already satisfied: oauth2client>=4.1.2 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (4.1.3)\n", + "Requirement already satisfied: protobuf>=3.4.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub>=0.6.0->tf-models-nightly) (3.10.0)\n", + "Collecting enum34~=1.1\n", + " Downloading https://files.pythonhosted.org/packages/63/f6/ccb1c83687756aeabbf3ca0f213508fcfb03883ff200d201b3a4c60cedcc/enum34-1.1.10-py3-none-any.whl\n", + "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly) (2.8.1)\n", + "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly) (2018.9)\n", + "Requirement already satisfied: google-resumable-media!=0.4.0,<0.5.0dev,>=0.3.1 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly) (0.4.1)\n", + "Requirement already satisfied: google-cloud-core<2.0dev,>=1.0.3 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly) (1.0.3)\n", + "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (1.24.3)\n", + "Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (2019.11.28)\n", + "Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (4.0.0)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (4.38.0)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (2.21.0)\n", + "Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly) (0.0.3)\n", + "Requirement already satisfied: httplib2<1dev,>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly) (0.17.0)\n", + "Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly) (3.0.1)\n", + "Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly) (1.7.2)\n", + "Requirement already satisfied: typeguard in /usr/local/lib/python3.6/dist-packages (from tensorflow-addons->tf-models-nightly) (2.7.1)\n", + "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (0.16.0)\n", + "Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (0.9.0)\n", + "Requirement already satisfied: wrapt in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (1.12.1)\n", + "Requirement already satisfied: promise in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (2.3)\n", + "Requirement already satisfied: attrs>=18.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (19.3.0)\n", + "Requirement already satisfied: termcolor in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (1.1.0)\n", + "Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (0.21.1)\n", + "Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (0.3.1.1)\n", + "Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (0.3.3)\n", + "Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (0.34.2)\n", + "Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (1.27.2)\n", + "Collecting tb-nightly<2.3.0a0,>=2.2.0a0\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/52/b6/aa559ea9edbc6129a64ff752dbf6567bcd62fba34566defca00fdff4345e/tb_nightly-2.2.0a20200324-py3-none-any.whl (2.8MB)\n", + "\u001b[K |████████████████████████████████| 2.8MB 51.9MB/s \n", + "\u001b[?25hRequirement already satisfied: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (2.10.0)\n", + "Collecting tf-estimator-nightly\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/bf/eb/0d6c06d1181cd9b52d7c82535073887d68d224f3dbeeb00adefd04762a9c/tf_estimator_nightly-2.3.0.dev2020032501-py2.py3-none-any.whl (455kB)\n", + "\u001b[K |████████████████████████████████| 460kB 53.5MB/s \n", + "\u001b[?25hRequirement already satisfied: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (0.2.0)\n", + "Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (1.6.3)\n", + "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (3.2.0)\n", + "Requirement already satisfied: keras-preprocessing>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (1.1.0)\n", + "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly) (2.4.6)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly) (1.1.0)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly) (0.10.0)\n", + "Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client>=4.1.2->tf-models-nightly) (0.2.8)\n", + "Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client>=4.1.2->tf-models-nightly) (4.0)\n", + "Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client>=4.1.2->tf-models-nightly) (0.4.8)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.4.0->tensorflow-hub>=0.6.0->tf-models-nightly) (46.0.0)\n", + "Requirement already satisfied: google-api-core<2.0.0dev,>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from google-cloud-core<2.0dev,>=1.0.3->google-cloud-bigquery>=0.31.0->tf-models-nightly) (1.16.0)\n", + "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify->kaggle>=1.3.9->tf-models-nightly) (1.3)\n", + "Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->kaggle>=1.3.9->tf-models-nightly) (2.8)\n", + "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->kaggle>=1.3.9->tf-models-nightly) (3.0.4)\n", + "Requirement already satisfied: cachetools<3.2,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth>=1.4.1->google-api-python-client>=1.6.7->tf-models-nightly) (3.1.1)\n", + "Requirement already satisfied: googleapis-common-protos in /usr/local/lib/python3.6/dist-packages (from tensorflow-metadata->tensorflow-datasets->tf-models-nightly) (1.51.0)\n", + "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (3.2.1)\n", + "Collecting tensorboard-plugin-wit>=1.6.0\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/41/ec/3da49289b93963bd8b32d29ed108f1809436ff3d9cd4e29c90bac4a7292f/tensorboard_plugin_wit-1.6.0.post2-py3-none-any.whl (775kB)\n", + "\u001b[K |████████████████████████████████| 778kB 43.8MB/s \n", + "\u001b[?25hRequirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (0.4.1)\n", + "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (1.0.0)\n", + "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (1.3.0)\n", + "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (3.1.0)\n", + "Building wheels for collected packages: py-cpuinfo\n", + " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for py-cpuinfo: filename=py_cpuinfo-5.0.0-cp36-none-any.whl size=18684 sha256=093853bc49757be8f9facd8d38eab95b3cc2de13514dafa221ad7898725491ff\n", + " Stored in directory: /root/.cache/pip/wheels/01/7e/a9/b982d0fea22b7e4ae5619de949570cde5ad55420cec16e86a5\n", + "Successfully built py-cpuinfo\n", + "Installing collected packages: opencv-python-headless, enum34, tensorflow-model-optimization, mlperf-compliance, sentencepiece, tensorboard-plugin-wit, tb-nightly, tf-estimator-nightly, tf-nightly, py-cpuinfo, tf-models-nightly\n", + "Successfully installed enum34-1.1.10 mlperf-compliance-0.0.10 opencv-python-headless-4.2.0.32 py-cpuinfo-5.0.0 sentencepiece-0.1.85 tb-nightly-2.2.0a20200324 tensorboard-plugin-wit-1.6.0.post2 tensorflow-model-optimization-0.2.1 tf-estimator-nightly-2.3.0.dev2020032501 tf-models-nightly-2.2.0.dev20200326 tf-nightly-2.2.0.dev20200325\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "enum" + ] + } + } + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "w4oyRCTji-aa", + "colab_type": "text" + }, + "source": [ + "## BERT Fine-tuning\n", + "\n", + "The following code import necessary modules for fine-tuning a BERT model on a classification task.\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "rV8MIX7g078-", + "colab_type": "code", + "colab": {} + }, + "source": [ + "%tensorflow_version 2.x\n", + "import tensorflow as tf\n", + "\n", + "import json\n", + "import math\n", + "\n", + "from official.utils.misc import distribution_utils\n", + "from official.nlp import optimization\n", + "from official.nlp.bert import bert_models\n", + "from official.nlp.bert import configs as bert_configs\n", + "from official.nlp.bert import run_classifier\n", + "from official.modeling import activations\n", + "from official.nlp.modeling import networks\n", + "from official.nlp.modeling.models import bert_classifier" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DIuS8nYD08n3", + "colab_type": "text" + }, + "source": [ + "This section of code performs the following tasks:\n", + "* Load data for fine-tuning\n", + "* Fine-tune a BERT model\n", + "* Save the fine-tuned model to a TensorFlow SavedModel file\n", + "\n", + "Please check [create_finetuning_data.py](https://github.com/tensorflow/models/blob/master/official/nlp/data/create_finetuning_data.py) if you want to know how the train/eval data are created." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "PAby1RTCi_1e", + "colab_type": "code", + "outputId": "e663e830-cc9b-4b5d-99db-4504fd66d5f3", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 258 + } + }, + "source": [ + "\n", + "train_data_path = \"gs://cloud-tpu-checkpoints/bert/classification/mrpc_train.tf_record\"\n", + "eval_data_path = \"gs://cloud-tpu-checkpoints/bert/classification/mrpc_eval.tf_record\"\n", + "input_meta_path = \"gs://cloud-tpu-checkpoints/bert/classification/mrpc_meta_data\"\n", + "\n", + "bert_config_file = \"gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/bert_config.json\"\n", + "ckpt_path = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/bert_model.ckpt'\n", + "\n", + "with tf.io.gfile.GFile(input_meta_path, 'rb') as reader:\n", + " input_meta_data = json.loads(reader.read().decode('utf-8'))\n", + "\n", + "max_seq_length = input_meta_data['max_seq_length']\n", + "num_classes = input_meta_data['num_labels']\n", + "batch_size = 32\n", + "eval_batch_size = 32\n", + "train_input_fn = run_classifier.get_dataset_fn(train_data_path, max_seq_length, batch_size, is_training=True)\n", + "eval_input_fn = run_classifier.get_dataset_fn(eval_data_path, max_seq_length, eval_batch_size, is_training=False)\n", + "\n", + "strategy = distribution_utils.get_distribution_strategy(\n", + " distribution_strategy='one_device', num_gpus=1)\n", + "\n", + "with strategy.scope():\n", + " training_dataset = train_input_fn()\n", + " evaluation_dataset = eval_input_fn()\n", + " bert_config = bert_configs.BertConfig.from_json_file(bert_config_file)\n", + " classifier_model, encoder = bert_models.classifier_model(\n", + " bert_config, num_classes, max_seq_length)\n", + "\n", + " checkpoint = tf.train.Checkpoint(model=encoder)\n", + " checkpoint.restore(ckpt_path).assert_consumed()\n", + "\n", + " epochs = 3\n", + " train_data_size = input_meta_data['train_data_size']\n", + " eval_data_size = input_meta_data['eval_data_size']\n", + " steps_per_epoch = int(train_data_size / batch_size)\n", + " warmup_steps = int(epochs * train_data_size * 0.1 / batch_size)\n", + " optimizer = optimization.create_optimizer(\n", + " 2e-5, num_train_steps=steps_per_epoch * epochs, num_warmup_steps=warmup_steps)\n", + "\n", + " def metric_fn():\n", + " return tf.keras.metrics.SparseCategoricalAccuracy(\n", + " 'test_accuracy', dtype=tf.float32)\n", + "\n", + " classifier_model.compile(optimizer=optimizer,\n", + " loss=run_classifier.get_loss_fn(num_classes=2),\n", + " metrics=[metric_fn()])\n", + " classifier_model.fit(\n", + " x=training_dataset,\n", + " validation_data=evaluation_dataset,\n", + " steps_per_epoch=steps_per_epoch,\n", + " epochs=epochs,\n", + " validation_steps=int(eval_data_size / eval_batch_size))\n", + "\n", + " classifier_model.save('/tmp/saved_model', include_optimizer=False, save_format='tf')" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "WARNING:tensorflow:BertClassifier inputs must come from `tf.keras.Input` (thus holding past layer metadata), they cannot be the output of a previous non-Input layer. Here, a tensor specified as input to \"bert_classifier\" was not an Input tensor, it was generated by layer input_mask.\n", + "Note that input tensors are instantiated via `tensor = tf.keras.Input(shape)`.\n", + "The tensor that caused the issue was: input_mask:0\n", + "WARNING:tensorflow:BertClassifier inputs must come from `tf.keras.Input` (thus holding past layer metadata), they cannot be the output of a previous non-Input layer. Here, a tensor specified as input to \"bert_classifier\" was not an Input tensor, it was generated by layer input_type_ids.\n", + "Note that input tensors are instantiated via `tensor = tf.keras.Input(shape)`.\n", + "The tensor that caused the issue was: input_type_ids:0\n", + "Epoch 1/3\n", + "114/114 [==============================] - 96s 840ms/step - loss: 0.5932 - test_accuracy: 0.6960 - val_loss: 0.5083 - val_test_accuracy: 0.7604\n", + "Epoch 2/3\n", + "114/114 [==============================] - 100s 878ms/step - loss: 0.4225 - test_accuracy: 0.8183 - val_loss: 0.4020 - val_test_accuracy: 0.8438\n", + "Epoch 3/3\n", + "114/114 [==============================] - 100s 880ms/step - loss: 0.2482 - test_accuracy: 0.9134 - val_loss: 0.4065 - val_test_accuracy: 0.8151\n", + "INFO:tensorflow:Assets written to: /tmp/saved_model/assets\n" + ], + "name": "stdout" + } + ] + } + ] +} \ No newline at end of file diff --git a/official/pip_package/setup.py b/official/pip_package/setup.py index bfd95a1d3cd..e708bc47bff 100644 --- a/official/pip_package/setup.py +++ b/official/pip_package/setup.py @@ -78,6 +78,7 @@ def _get_requirements(): 'official.r1*', 'official.pip_package*', 'official.benchmark*', + 'official.colab*', ]), exclude_package_data={ '': ['*_test.py',],