|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "id": "0", |
| 6 | + "metadata": { |
| 7 | + "id": "35c97a76" |
| 8 | + }, |
| 9 | + "source": [ |
| 10 | + "# Introduction\n", |
| 11 | + "\n", |
| 12 | + "This notebook demonstrates how to fit a QET potential using PyTorch Lightning with MatGL." |
| 13 | + ] |
| 14 | + }, |
| 15 | + { |
| 16 | + "cell_type": "code", |
| 17 | + "execution_count": null, |
| 18 | + "id": "1", |
| 19 | + "metadata": { |
| 20 | + "id": "6355190a", |
| 21 | + "tags": [] |
| 22 | + }, |
| 23 | + "outputs": [], |
| 24 | + "source": [ |
| 25 | + "from __future__ import annotations\n", |
| 26 | + "\n", |
| 27 | + "import os\n", |
| 28 | + "import shutil\n", |
| 29 | + "import warnings\n", |
| 30 | + "from functools import partial\n", |
| 31 | + "\n", |
| 32 | + "import lightning as L\n", |
| 33 | + "import numpy as np\n", |
| 34 | + "from dgl.data.utils import split_dataset\n", |
| 35 | + "from lightning.pytorch.loggers import CSVLogger\n", |
| 36 | + "from mp_api.client import MPRester\n", |
| 37 | + "\n", |
| 38 | + "import matgl\n", |
| 39 | + "matgl.config.BACKEND = \"DGL\":\n", |
| 40 | + "from matgl.config import DEFAULT_ELEMENTS\n", |
| 41 | + "from matgl.ext._pymatgen_dgl import Structure2Graph\n", |
| 42 | + "from matgl.graph._data_dgl import MGLDataLoader, MGLDataset, collate_fn_pes\n", |
| 43 | + "from matgl.models._qet_dgl import QET\n", |
| 44 | + "from matgl.utils.training import PotentialLightningModule\n", |
| 45 | + "\n", |
| 46 | + "# To suppress warnings for clearer output\n", |
| 47 | + "warnings.simplefilter(\"ignore\")" |
| 48 | + ] |
| 49 | + }, |
| 50 | + { |
| 51 | + "cell_type": "markdown", |
| 52 | + "id": "2", |
| 53 | + "metadata": { |
| 54 | + "id": "eaafc0bd" |
| 55 | + }, |
| 56 | + "source": [ |
| 57 | + "For the purposes of demonstration, we will download all Si-O compounds in the Materials Project via the MPRester. The forces and stresses are set to zero, though in a real context, these would be non-zero and obtained from DFT calculations." |
| 58 | + ] |
| 59 | + }, |
| 60 | + { |
| 61 | + "cell_type": "code", |
| 62 | + "execution_count": null, |
| 63 | + "id": "3", |
| 64 | + "metadata": { |
| 65 | + "colab": { |
| 66 | + "base_uri": "https://localhost:8080/", |
| 67 | + "height": 67, |
| 68 | + "referenced_widgets": [ |
| 69 | + "aac9f06228444b8dbd7dd798e6c1a93f", |
| 70 | + "591a338fe6304870bebd845eb8d8e2a9", |
| 71 | + "19fe8f71e71048bf9a923291ad9b1bb4", |
| 72 | + "0ab337e54e0943cb8bc940922fb425f5", |
| 73 | + "ed3ba68de443454b8be182c1901f8cfa", |
| 74 | + "a93bf23e7f224a3092094a1b0961251a", |
| 75 | + "a5503afc1d2b427d9fd0e83bf733387d", |
| 76 | + "fbad168bdfc34eb1a439bc3334748369", |
| 77 | + "4deece8c90b249f384722bce145a6a08", |
| 78 | + "d2777074f8d148c591652654e68e6d9f", |
| 79 | + "2eda31d46a5d440281571f3ea1240228" |
| 80 | + ] |
| 81 | + }, |
| 82 | + "id": "bd0ce8a2-ec68-4160-9457-823fb9e6a35d", |
| 83 | + "outputId": "2252a59c-9a70-4673-926f-9ed8fc69ed0d", |
| 84 | + "tags": [] |
| 85 | + }, |
| 86 | + "outputs": [], |
| 87 | + "source": [ |
| 88 | + "# Obtain your API key here: https://next-gen.materialsproject.org/api\n", |
| 89 | + "mpr = MPRester(api_key=\"YOUR_API_KEY\")\n", |
| 90 | + "entries = mpr.get_entries_in_chemsys([\"Si\", \"O\"])\n", |
| 91 | + "structures = [e.structure for e in entries]\n", |
| 92 | + "energies = [e.energy for e in entries]\n", |
| 93 | + "forces = [np.zeros((len(s), 3)).tolist() for s in structures]\n", |
| 94 | + "stresses = [np.zeros((3, 3)).tolist() for s in structures]\n", |
| 95 | + "charges = [np.zeros(len(s)).tolist() for s in structures]\n", |
| 96 | + "labels = {\n", |
| 97 | + " \"energies\": energies,\n", |
| 98 | + " \"forces\": forces,\n", |
| 99 | + " \"stresses\": stresses,\n", |
| 100 | + " \"charges\": charges,\n", |
| 101 | + "}\n", |
| 102 | + "\n", |
| 103 | + "print(f\"{len(structures)} downloaded from MP.\")" |
| 104 | + ] |
| 105 | + }, |
| 106 | + { |
| 107 | + "cell_type": "markdown", |
| 108 | + "id": "4", |
| 109 | + "metadata": { |
| 110 | + "id": "f666cb23" |
| 111 | + }, |
| 112 | + "source": [ |
| 113 | + "We will first setup the QET model and the LightningModule." |
| 114 | + ] |
| 115 | + }, |
| 116 | + { |
| 117 | + "cell_type": "code", |
| 118 | + "execution_count": null, |
| 119 | + "id": "5", |
| 120 | + "metadata": { |
| 121 | + "colab": { |
| 122 | + "base_uri": "https://localhost:8080/" |
| 123 | + }, |
| 124 | + "id": "e9dc84cb", |
| 125 | + "outputId": "b9f93f24-0fd6-4737-a8e4-e87804cd3ad2", |
| 126 | + "tags": [] |
| 127 | + }, |
| 128 | + "outputs": [], |
| 129 | + "source": [ |
| 130 | + "element_types = DEFAULT_ELEMENTS\n", |
| 131 | + "converter = Structure2Graph(element_types=element_types, cutoff=5.0)\n", |
| 132 | + "dataset = MGLDataset(\n", |
| 133 | + " structures=structures,\n", |
| 134 | + " converter=converter,\n", |
| 135 | + " labels=labels,\n", |
| 136 | + " include_ref_charge=True,\n", |
| 137 | + ")\n", |
| 138 | + "train_data, val_data, test_data = split_dataset(\n", |
| 139 | + " dataset,\n", |
| 140 | + " frac_list=[0.8, 0.1, 0.1],\n", |
| 141 | + " shuffle=True,\n", |
| 142 | + " random_state=42,\n", |
| 143 | + ")\n", |
| 144 | + "# if you are not intended to use stress for training, switch include_stress=False!\n", |
| 145 | + "my_collate_fn = partial(collate_fn_pes, include_charge=True, include_stress=True)\n", |
| 146 | + "train_loader, val_loader, test_loader = MGLDataLoader(\n", |
| 147 | + " train_data=train_data,\n", |
| 148 | + " val_data=val_data,\n", |
| 149 | + " test_data=test_data,\n", |
| 150 | + " collate_fn=my_collate_fn,\n", |
| 151 | + " batch_size=2,\n", |
| 152 | + " num_workers=0,\n", |
| 153 | + ")\n", |
| 154 | + "model = QET(element_types=element_types, is_intensive=False, use_smooth=True, rbf_type=\"SphericalBessel\")\n", |
| 155 | + "# if you are not intended to use stress and charge for training, set stress_weight=0.0 and charge_weight=0.0!\n", |
| 156 | + "lit_module = PotentialLightningModule(model=model, stress_weight=0.01, charge_weight=0.001)" |
| 157 | + ] |
| 158 | + }, |
| 159 | + { |
| 160 | + "cell_type": "markdown", |
| 161 | + "id": "6", |
| 162 | + "metadata": { |
| 163 | + "id": "01be4689" |
| 164 | + }, |
| 165 | + "source": [ |
| 166 | + "Finally, we will initialize the Pytorch Lightning trainer and run the fitting. Here, the max_epochs is set to 2 just for demonstration purposes. In a real fitting, this would be a much larger number. Also, the `accelerator=\"cpu\"` was set just to ensure compatibility with M1 Macs. In a real world use case, please remove the kwarg or set it to cuda for GPU based training." |
| 167 | + ] |
| 168 | + }, |
| 169 | + { |
| 170 | + "cell_type": "code", |
| 171 | + "execution_count": null, |
| 172 | + "id": "7", |
| 173 | + "metadata": { |
| 174 | + "colab": { |
| 175 | + "base_uri": "https://localhost:8080/", |
| 176 | + "height": 423, |
| 177 | + "referenced_widgets": [ |
| 178 | + "ca0304013c864637b614ca03d131f920", |
| 179 | + "e6d2aa7fa6d644f39fa9eefaabbbef48", |
| 180 | + "cbaf8681aaa1410d9150f70f755628af", |
| 181 | + "f4b8c28792544919b47bc832cd93d4a6", |
| 182 | + "9d73d246aa1a47a9af3d1e2caa961b5d", |
| 183 | + "c6dc5ee98e1346a08067491930872eff", |
| 184 | + "2a89da23f17c4b0b85ce81bcfd69cc37", |
| 185 | + "5161471448ae4ec1b8ba2d2152f42153", |
| 186 | + "641cc028432f453a8ef3eeab201a9218", |
| 187 | + "f562f5b7289d4fcf9f310426aa620521", |
| 188 | + "3b95d47d64994a9eb812488ddbac61ec", |
| 189 | + "4ce50de72a8a4b5f995e72487f34c560", |
| 190 | + "f4ec778fed4749e890cdac57ef0a16f5", |
| 191 | + "daaabf9a40b64db6a2240f7e2ae27d08", |
| 192 | + "bb84fc3ffb3a4add91ce6af226f9c4af", |
| 193 | + "b32866d1699a46fe814e42ef2e5b8477", |
| 194 | + "bd3cb443099643b194a13b5cd3ad037f", |
| 195 | + "f7b6a5d231bc45c98c7eec9e40c5943c", |
| 196 | + "79e719731a74439da8bff2cc57c9898e", |
| 197 | + "1b2cc76759ce4164b28d407911590d73", |
| 198 | + "ac65110f92bc494d843ef347c980031f", |
| 199 | + "0c466ca4cfb34eb58db798b750fecfaa", |
| 200 | + "67f78887930f4ba8a9ad25035ab48801", |
| 201 | + "ea74e30a47d043c991770a3d82aadbb5", |
| 202 | + "56949107dcd94055bc6115404a0de63f", |
| 203 | + "9f6d6018de5a42afbfca719fa7b2a740", |
| 204 | + "2b8242683dd54bf9aa82e3e186fde7c5", |
| 205 | + "b3359b9ef6384c0cae5849f0b80d6cb1", |
| 206 | + "169f8ac55aa14dc393b5381e56783dc5", |
| 207 | + "53be090dca604f3c98e944bcfb35e34d", |
| 208 | + "9de7ce8be6ea4f5298870cc326895412", |
| 209 | + "1ee31acfe8dd4002a3917561f83b91bd", |
| 210 | + "bf983f21b4b34030b9c14e912a8d5450" |
| 211 | + ] |
| 212 | + }, |
| 213 | + "id": "7472d071", |
| 214 | + "outputId": "9d10c152-752f-4afc-8759-c5174ea446b9", |
| 215 | + "tags": [] |
| 216 | + }, |
| 217 | + "outputs": [], |
| 218 | + "source": [ |
| 219 | + "# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator=\"cpu\" kwarg.\n", |
| 220 | + "logger = CSVLogger(\"logs\", name=\"QET_training\")\n", |
| 221 | + "# Inference mode = False is required for calculating forces, stress in test mode and prediction mode\n", |
| 222 | + "trainer = L.Trainer(max_epochs=1, accelerator=\"cpu\", logger=logger, inference_mode=False)\n", |
| 223 | + "trainer.fit(model=lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader)" |
| 224 | + ] |
| 225 | + }, |
| 226 | + { |
| 227 | + "cell_type": "code", |
| 228 | + "execution_count": null, |
| 229 | + "id": "8", |
| 230 | + "metadata": { |
| 231 | + "tags": [] |
| 232 | + }, |
| 233 | + "outputs": [], |
| 234 | + "source": [ |
| 235 | + "# test the model, remember to set inference_mode=False in trainer (see above)\n", |
| 236 | + "trainer.test(dataloaders=test_loader)" |
| 237 | + ] |
| 238 | + }, |
| 239 | + { |
| 240 | + "cell_type": "code", |
| 241 | + "execution_count": null, |
| 242 | + "id": "9", |
| 243 | + "metadata": {}, |
| 244 | + "outputs": [], |
| 245 | + "source": [ |
| 246 | + "# save trained model\n", |
| 247 | + "model_export_path = \"./trained_model/\"\n", |
| 248 | + "lit_module.model.save(model_export_path)\n", |
| 249 | + "\n", |
| 250 | + "# load trained model\n", |
| 251 | + "model = matgl.load_model(path=model_export_path)" |
| 252 | + ] |
| 253 | + }, |
| 254 | + { |
| 255 | + "cell_type": "markdown", |
| 256 | + "id": "10", |
| 257 | + "metadata": {}, |
| 258 | + "source": [ |
| 259 | + "## Finetuning a pre-trained QET\n", |
| 260 | + "In the previous cells, we demonstrated the process of training an QET from scratch. Next, let's see how to perform additional training on an QET that has already been trained using Materials Project data." |
| 261 | + ] |
| 262 | + }, |
| 263 | + { |
| 264 | + "cell_type": "code", |
| 265 | + "execution_count": null, |
| 266 | + "id": "11", |
| 267 | + "metadata": { |
| 268 | + "id": "85ef3a34-e6fb-452b-82cc-65012e80bce6" |
| 269 | + }, |
| 270 | + "outputs": [], |
| 271 | + "source": [ |
| 272 | + "# download a pre-trained M3GNet\n", |
| 273 | + "qet_nnp = matgl.load_model(\"QET-MatQ-PES\")\n", |
| 274 | + "model_pretrained = qet_nnp.model\n", |
| 275 | + "# obtain element energy offset\n", |
| 276 | + "property_offset = qet_nnp.element_refs.property_offset\n", |
| 277 | + "# you should test whether including the original property_offset helps improve training and validation accuracy\n", |
| 278 | + "lit_module_finetune = PotentialLightningModule(model=model, stress_weight=0.01, charge_weight=0.001)" |
| 279 | + ] |
| 280 | + }, |
| 281 | + { |
| 282 | + "cell_type": "code", |
| 283 | + "execution_count": null, |
| 284 | + "id": "12", |
| 285 | + "metadata": { |
| 286 | + "colab": { |
| 287 | + "base_uri": "https://localhost:8080/", |
| 288 | + "height": 423, |
| 289 | + "referenced_widgets": [ |
| 290 | + "9d537bceb5994f0ca8908e6a8fa8dda0", |
| 291 | + "783e4a0f711e4e63b68bdfead3dd29dd", |
| 292 | + "b835f08350de4bfab6ce0e06b15c0e01", |
| 293 | + "ee3af67a2be244988ca27a46472f18c5", |
| 294 | + "30631c73610d4438a894a676ed31e49a", |
| 295 | + "ac1633b9d785471098cb40db339edb3f", |
| 296 | + "dfcaa58a05ab49a0b82108f04984e48b", |
| 297 | + "a8318cebd921434caee81d717a5f01fa", |
| 298 | + "479321dba45e48648bba9bad86ad4c4c", |
| 299 | + "8ac7f6294ff949b1a10da96532387ebb", |
| 300 | + "edafb607f21045d98a4b0c2059d27958", |
| 301 | + "d23643bb47234c969df66ad2ad223b4f", |
| 302 | + "fac4d3e8a6b64524bb0aa08036aa0588", |
| 303 | + "20ccc55a5b264bc3b193d324caeae685", |
| 304 | + "44f14d0ce3af41f5bb8f048a69da6a5d", |
| 305 | + "e8f06e2730cf4aef8427932a033252ab", |
| 306 | + "b662754e5b484ea58433fef79d401ce6", |
| 307 | + "4977216b2e9f4e94a41520127fde4a03", |
| 308 | + "c62f28b519c04081831d1a2507003e38", |
| 309 | + "e23184852e70488698cc75272c7057ae", |
| 310 | + "b7e24728fbb444c39a7115988c289883", |
| 311 | + "6921ddc707334f5d99b13e519d10d437", |
| 312 | + "db8cc5bcec1d4d0b999b48eb7a36f0de", |
| 313 | + "12511f8f4b134a7f80feada2b65240d0", |
| 314 | + "34a896f642fa47a78fe858ba8ab95c9e", |
| 315 | + "3a43080c70054e7793e73b5d49e48448", |
| 316 | + "88dddbd1837c4bfbbb1870beb73aa08a", |
| 317 | + "c59c41ae3d224b5aaa888a7526e1a130", |
| 318 | + "9cb7a81f00dc4d97b3f0639ec6dde1a4", |
| 319 | + "dba16d8574f8475f9d8f973fbf968b59", |
| 320 | + "d4df21b4ef1a4d3b938c5657337b7ec1", |
| 321 | + "b6a3965a166940d289ab4a50f2e78a60", |
| 322 | + "b5813fa692a74bf2aec3fbdf03088f3d" |
| 323 | + ] |
| 324 | + }, |
| 325 | + "id": "4133225a-5990-4b97-9d73-88195df87a1a", |
| 326 | + "outputId": "f149a68a-eef7-4726-b3a1-723525dc908f" |
| 327 | + }, |
| 328 | + "outputs": [], |
| 329 | + "source": [ |
| 330 | + "# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator=\"cpu\" kwarg.\n", |
| 331 | + "logger = CSVLogger(\"logs\", name=\"QET_finetuning\")\n", |
| 332 | + "trainer = L.Trainer(max_epochs=1, accelerator=\"cpu\", logger=logger, inference_mode=False)\n", |
| 333 | + "trainer.fit(model=lit_module_finetune, train_dataloaders=train_loader, val_dataloaders=val_loader)" |
| 334 | + ] |
| 335 | + }, |
| 336 | + { |
| 337 | + "cell_type": "code", |
| 338 | + "execution_count": null, |
| 339 | + "id": "13", |
| 340 | + "metadata": { |
| 341 | + "id": "252f6456-3ecf-47f0-84ca-c8e9dcc66ccc" |
| 342 | + }, |
| 343 | + "outputs": [], |
| 344 | + "source": [ |
| 345 | + "# save trained model\n", |
| 346 | + "model_save_path = \"./finetuned_model/\"\n", |
| 347 | + "lit_module_finetune.model.save(model_save_path)\n", |
| 348 | + "# load trained model\n", |
| 349 | + "trained_model = matgl.load_model(path=model_save_path)" |
| 350 | + ] |
| 351 | + }, |
| 352 | + { |
| 353 | + "cell_type": "code", |
| 354 | + "execution_count": null, |
| 355 | + "id": "14", |
| 356 | + "metadata": { |
| 357 | + "id": "cd11b92f" |
| 358 | + }, |
| 359 | + "outputs": [], |
| 360 | + "source": [ |
| 361 | + "# This code just performs cleanup for this notebook.\n", |
| 362 | + "\n", |
| 363 | + "for fn in (\"dgl_graph.bin\", \"lattice.pt\", \"dgl_line_graph.bin\", \"state_attr.pt\", \"labels.json\"):\n", |
| 364 | + " try:\n", |
| 365 | + " os.remove(fn)\n", |
| 366 | + " except FileNotFoundError:\n", |
| 367 | + " pass\n", |
| 368 | + "\n", |
| 369 | + "shutil.rmtree(\"logs\")\n", |
| 370 | + "shutil.rmtree(\"trained_model\")\n", |
| 371 | + "shutil.rmtree(\"finetuned_model\")" |
| 372 | + ] |
| 373 | + } |
| 374 | + ], |
| 375 | + "metadata": { |
| 376 | + "colab": { |
| 377 | + "provenance": [] |
| 378 | + }, |
| 379 | + "kernelspec": { |
| 380 | + "display_name": "Python 3 (ipykernel)", |
| 381 | + "language": "python", |
| 382 | + "name": "python3" |
| 383 | + }, |
| 384 | + "language_info": { |
| 385 | + "codemirror_mode": { |
| 386 | + "name": "ipython", |
| 387 | + "version": 3 |
| 388 | + }, |
| 389 | + "file_extension": ".py", |
| 390 | + "mimetype": "text/x-python", |
| 391 | + "name": "python", |
| 392 | + "nbconvert_exporter": "python", |
| 393 | + "pygments_lexer": "ipython3", |
| 394 | + "version": "3.11.0" |
| 395 | + } |
| 396 | + }, |
| 397 | + "nbformat": 4, |
| 398 | + "nbformat_minor": 5 |
| 399 | +} |
0 commit comments