Skip to content

Commit 966672b

Browse files
authored
Example notebook for QET training (#742)
* Adding README.md for PyG TensorNet * added missing TrajectoryObserver * fix ruff * improve unit tests * convert float into Tensor * Corrected expected values * Corrected expected values again * update torch<=2.9.1 * Convert PyG TensorNet Embedding and Interaction blocks into pure Torch * Pure Pytorch TensorNet for MLIPs is added * Fix united tests * Refactor the PyG TensorNet components to ensure compatibility with pure PyTorch. * Cleanup MGLDataset * Improve the handling of stress unit in PESCalculator * cleanup test_ase_pyg.py * update PESCalculator unit tests * Improve logging in PESCalculator * fix linting tests * fixed linting and united tests * include_ref_charge keyword is added in MGLDataset * Update Relaxations and Simulations using the QET Universal Potential.ipynb Signed-off-by: Tsz Wai Ko <47970742+kenko911@users.noreply.github.com> * fix ruff * avoid crashing in MatCalc united tests by adding calc_charge and compute_charge in PyG Potential and PESCalculator * QET training module added * fix united tests for QET training * make sure predicted and target charges into the same dimension during the training * added backed the self.charge_weight * change torch.vstack into torch.hstack for reference charges * fixed the unit test * fix the predict_structure function in QET model * example notebook for QET potential training --------- Signed-off-by: Tsz Wai Ko <47970742+kenko911@users.noreply.github.com>
1 parent 7dfb3a9 commit 966672b

1 file changed

Lines changed: 399 additions & 0 deletions

File tree

Lines changed: 399 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,399 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "0",
6+
"metadata": {
7+
"id": "35c97a76"
8+
},
9+
"source": [
10+
"# Introduction\n",
11+
"\n",
12+
"This notebook demonstrates how to fit a QET potential using PyTorch Lightning with MatGL."
13+
]
14+
},
15+
{
16+
"cell_type": "code",
17+
"execution_count": null,
18+
"id": "1",
19+
"metadata": {
20+
"id": "6355190a",
21+
"tags": []
22+
},
23+
"outputs": [],
24+
"source": [
25+
"from __future__ import annotations\n",
26+
"\n",
27+
"import os\n",
28+
"import shutil\n",
29+
"import warnings\n",
30+
"from functools import partial\n",
31+
"\n",
32+
"import lightning as L\n",
33+
"import numpy as np\n",
34+
"from dgl.data.utils import split_dataset\n",
35+
"from lightning.pytorch.loggers import CSVLogger\n",
36+
"from mp_api.client import MPRester\n",
37+
"\n",
38+
"import matgl\n",
39+
"matgl.config.BACKEND = \"DGL\":\n",
40+
"from matgl.config import DEFAULT_ELEMENTS\n",
41+
"from matgl.ext._pymatgen_dgl import Structure2Graph\n",
42+
"from matgl.graph._data_dgl import MGLDataLoader, MGLDataset, collate_fn_pes\n",
43+
"from matgl.models._qet_dgl import QET\n",
44+
"from matgl.utils.training import PotentialLightningModule\n",
45+
"\n",
46+
"# To suppress warnings for clearer output\n",
47+
"warnings.simplefilter(\"ignore\")"
48+
]
49+
},
50+
{
51+
"cell_type": "markdown",
52+
"id": "2",
53+
"metadata": {
54+
"id": "eaafc0bd"
55+
},
56+
"source": [
57+
"For the purposes of demonstration, we will download all Si-O compounds in the Materials Project via the MPRester. The forces and stresses are set to zero, though in a real context, these would be non-zero and obtained from DFT calculations."
58+
]
59+
},
60+
{
61+
"cell_type": "code",
62+
"execution_count": null,
63+
"id": "3",
64+
"metadata": {
65+
"colab": {
66+
"base_uri": "https://localhost:8080/",
67+
"height": 67,
68+
"referenced_widgets": [
69+
"aac9f06228444b8dbd7dd798e6c1a93f",
70+
"591a338fe6304870bebd845eb8d8e2a9",
71+
"19fe8f71e71048bf9a923291ad9b1bb4",
72+
"0ab337e54e0943cb8bc940922fb425f5",
73+
"ed3ba68de443454b8be182c1901f8cfa",
74+
"a93bf23e7f224a3092094a1b0961251a",
75+
"a5503afc1d2b427d9fd0e83bf733387d",
76+
"fbad168bdfc34eb1a439bc3334748369",
77+
"4deece8c90b249f384722bce145a6a08",
78+
"d2777074f8d148c591652654e68e6d9f",
79+
"2eda31d46a5d440281571f3ea1240228"
80+
]
81+
},
82+
"id": "bd0ce8a2-ec68-4160-9457-823fb9e6a35d",
83+
"outputId": "2252a59c-9a70-4673-926f-9ed8fc69ed0d",
84+
"tags": []
85+
},
86+
"outputs": [],
87+
"source": [
88+
"# Obtain your API key here: https://next-gen.materialsproject.org/api\n",
89+
"mpr = MPRester(api_key=\"YOUR_API_KEY\")\n",
90+
"entries = mpr.get_entries_in_chemsys([\"Si\", \"O\"])\n",
91+
"structures = [e.structure for e in entries]\n",
92+
"energies = [e.energy for e in entries]\n",
93+
"forces = [np.zeros((len(s), 3)).tolist() for s in structures]\n",
94+
"stresses = [np.zeros((3, 3)).tolist() for s in structures]\n",
95+
"charges = [np.zeros(len(s)).tolist() for s in structures]\n",
96+
"labels = {\n",
97+
" \"energies\": energies,\n",
98+
" \"forces\": forces,\n",
99+
" \"stresses\": stresses,\n",
100+
" \"charges\": charges,\n",
101+
"}\n",
102+
"\n",
103+
"print(f\"{len(structures)} downloaded from MP.\")"
104+
]
105+
},
106+
{
107+
"cell_type": "markdown",
108+
"id": "4",
109+
"metadata": {
110+
"id": "f666cb23"
111+
},
112+
"source": [
113+
"We will first setup the QET model and the LightningModule."
114+
]
115+
},
116+
{
117+
"cell_type": "code",
118+
"execution_count": null,
119+
"id": "5",
120+
"metadata": {
121+
"colab": {
122+
"base_uri": "https://localhost:8080/"
123+
},
124+
"id": "e9dc84cb",
125+
"outputId": "b9f93f24-0fd6-4737-a8e4-e87804cd3ad2",
126+
"tags": []
127+
},
128+
"outputs": [],
129+
"source": [
130+
"element_types = DEFAULT_ELEMENTS\n",
131+
"converter = Structure2Graph(element_types=element_types, cutoff=5.0)\n",
132+
"dataset = MGLDataset(\n",
133+
" structures=structures,\n",
134+
" converter=converter,\n",
135+
" labels=labels,\n",
136+
" include_ref_charge=True,\n",
137+
")\n",
138+
"train_data, val_data, test_data = split_dataset(\n",
139+
" dataset,\n",
140+
" frac_list=[0.8, 0.1, 0.1],\n",
141+
" shuffle=True,\n",
142+
" random_state=42,\n",
143+
")\n",
144+
"# if you are not intended to use stress for training, switch include_stress=False!\n",
145+
"my_collate_fn = partial(collate_fn_pes, include_charge=True, include_stress=True)\n",
146+
"train_loader, val_loader, test_loader = MGLDataLoader(\n",
147+
" train_data=train_data,\n",
148+
" val_data=val_data,\n",
149+
" test_data=test_data,\n",
150+
" collate_fn=my_collate_fn,\n",
151+
" batch_size=2,\n",
152+
" num_workers=0,\n",
153+
")\n",
154+
"model = QET(element_types=element_types, is_intensive=False, use_smooth=True, rbf_type=\"SphericalBessel\")\n",
155+
"# if you are not intended to use stress and charge for training, set stress_weight=0.0 and charge_weight=0.0!\n",
156+
"lit_module = PotentialLightningModule(model=model, stress_weight=0.01, charge_weight=0.001)"
157+
]
158+
},
159+
{
160+
"cell_type": "markdown",
161+
"id": "6",
162+
"metadata": {
163+
"id": "01be4689"
164+
},
165+
"source": [
166+
"Finally, we will initialize the Pytorch Lightning trainer and run the fitting. Here, the max_epochs is set to 2 just for demonstration purposes. In a real fitting, this would be a much larger number. Also, the `accelerator=\"cpu\"` was set just to ensure compatibility with M1 Macs. In a real world use case, please remove the kwarg or set it to cuda for GPU based training."
167+
]
168+
},
169+
{
170+
"cell_type": "code",
171+
"execution_count": null,
172+
"id": "7",
173+
"metadata": {
174+
"colab": {
175+
"base_uri": "https://localhost:8080/",
176+
"height": 423,
177+
"referenced_widgets": [
178+
"ca0304013c864637b614ca03d131f920",
179+
"e6d2aa7fa6d644f39fa9eefaabbbef48",
180+
"cbaf8681aaa1410d9150f70f755628af",
181+
"f4b8c28792544919b47bc832cd93d4a6",
182+
"9d73d246aa1a47a9af3d1e2caa961b5d",
183+
"c6dc5ee98e1346a08067491930872eff",
184+
"2a89da23f17c4b0b85ce81bcfd69cc37",
185+
"5161471448ae4ec1b8ba2d2152f42153",
186+
"641cc028432f453a8ef3eeab201a9218",
187+
"f562f5b7289d4fcf9f310426aa620521",
188+
"3b95d47d64994a9eb812488ddbac61ec",
189+
"4ce50de72a8a4b5f995e72487f34c560",
190+
"f4ec778fed4749e890cdac57ef0a16f5",
191+
"daaabf9a40b64db6a2240f7e2ae27d08",
192+
"bb84fc3ffb3a4add91ce6af226f9c4af",
193+
"b32866d1699a46fe814e42ef2e5b8477",
194+
"bd3cb443099643b194a13b5cd3ad037f",
195+
"f7b6a5d231bc45c98c7eec9e40c5943c",
196+
"79e719731a74439da8bff2cc57c9898e",
197+
"1b2cc76759ce4164b28d407911590d73",
198+
"ac65110f92bc494d843ef347c980031f",
199+
"0c466ca4cfb34eb58db798b750fecfaa",
200+
"67f78887930f4ba8a9ad25035ab48801",
201+
"ea74e30a47d043c991770a3d82aadbb5",
202+
"56949107dcd94055bc6115404a0de63f",
203+
"9f6d6018de5a42afbfca719fa7b2a740",
204+
"2b8242683dd54bf9aa82e3e186fde7c5",
205+
"b3359b9ef6384c0cae5849f0b80d6cb1",
206+
"169f8ac55aa14dc393b5381e56783dc5",
207+
"53be090dca604f3c98e944bcfb35e34d",
208+
"9de7ce8be6ea4f5298870cc326895412",
209+
"1ee31acfe8dd4002a3917561f83b91bd",
210+
"bf983f21b4b34030b9c14e912a8d5450"
211+
]
212+
},
213+
"id": "7472d071",
214+
"outputId": "9d10c152-752f-4afc-8759-c5174ea446b9",
215+
"tags": []
216+
},
217+
"outputs": [],
218+
"source": [
219+
"# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator=\"cpu\" kwarg.\n",
220+
"logger = CSVLogger(\"logs\", name=\"QET_training\")\n",
221+
"# Inference mode = False is required for calculating forces, stress in test mode and prediction mode\n",
222+
"trainer = L.Trainer(max_epochs=1, accelerator=\"cpu\", logger=logger, inference_mode=False)\n",
223+
"trainer.fit(model=lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader)"
224+
]
225+
},
226+
{
227+
"cell_type": "code",
228+
"execution_count": null,
229+
"id": "8",
230+
"metadata": {
231+
"tags": []
232+
},
233+
"outputs": [],
234+
"source": [
235+
"# test the model, remember to set inference_mode=False in trainer (see above)\n",
236+
"trainer.test(dataloaders=test_loader)"
237+
]
238+
},
239+
{
240+
"cell_type": "code",
241+
"execution_count": null,
242+
"id": "9",
243+
"metadata": {},
244+
"outputs": [],
245+
"source": [
246+
"# save trained model\n",
247+
"model_export_path = \"./trained_model/\"\n",
248+
"lit_module.model.save(model_export_path)\n",
249+
"\n",
250+
"# load trained model\n",
251+
"model = matgl.load_model(path=model_export_path)"
252+
]
253+
},
254+
{
255+
"cell_type": "markdown",
256+
"id": "10",
257+
"metadata": {},
258+
"source": [
259+
"## Finetuning a pre-trained QET\n",
260+
"In the previous cells, we demonstrated the process of training an QET from scratch. Next, let's see how to perform additional training on an QET that has already been trained using Materials Project data."
261+
]
262+
},
263+
{
264+
"cell_type": "code",
265+
"execution_count": null,
266+
"id": "11",
267+
"metadata": {
268+
"id": "85ef3a34-e6fb-452b-82cc-65012e80bce6"
269+
},
270+
"outputs": [],
271+
"source": [
272+
"# download a pre-trained M3GNet\n",
273+
"qet_nnp = matgl.load_model(\"QET-MatQ-PES\")\n",
274+
"model_pretrained = qet_nnp.model\n",
275+
"# obtain element energy offset\n",
276+
"property_offset = qet_nnp.element_refs.property_offset\n",
277+
"# you should test whether including the original property_offset helps improve training and validation accuracy\n",
278+
"lit_module_finetune = PotentialLightningModule(model=model, stress_weight=0.01, charge_weight=0.001)"
279+
]
280+
},
281+
{
282+
"cell_type": "code",
283+
"execution_count": null,
284+
"id": "12",
285+
"metadata": {
286+
"colab": {
287+
"base_uri": "https://localhost:8080/",
288+
"height": 423,
289+
"referenced_widgets": [
290+
"9d537bceb5994f0ca8908e6a8fa8dda0",
291+
"783e4a0f711e4e63b68bdfead3dd29dd",
292+
"b835f08350de4bfab6ce0e06b15c0e01",
293+
"ee3af67a2be244988ca27a46472f18c5",
294+
"30631c73610d4438a894a676ed31e49a",
295+
"ac1633b9d785471098cb40db339edb3f",
296+
"dfcaa58a05ab49a0b82108f04984e48b",
297+
"a8318cebd921434caee81d717a5f01fa",
298+
"479321dba45e48648bba9bad86ad4c4c",
299+
"8ac7f6294ff949b1a10da96532387ebb",
300+
"edafb607f21045d98a4b0c2059d27958",
301+
"d23643bb47234c969df66ad2ad223b4f",
302+
"fac4d3e8a6b64524bb0aa08036aa0588",
303+
"20ccc55a5b264bc3b193d324caeae685",
304+
"44f14d0ce3af41f5bb8f048a69da6a5d",
305+
"e8f06e2730cf4aef8427932a033252ab",
306+
"b662754e5b484ea58433fef79d401ce6",
307+
"4977216b2e9f4e94a41520127fde4a03",
308+
"c62f28b519c04081831d1a2507003e38",
309+
"e23184852e70488698cc75272c7057ae",
310+
"b7e24728fbb444c39a7115988c289883",
311+
"6921ddc707334f5d99b13e519d10d437",
312+
"db8cc5bcec1d4d0b999b48eb7a36f0de",
313+
"12511f8f4b134a7f80feada2b65240d0",
314+
"34a896f642fa47a78fe858ba8ab95c9e",
315+
"3a43080c70054e7793e73b5d49e48448",
316+
"88dddbd1837c4bfbbb1870beb73aa08a",
317+
"c59c41ae3d224b5aaa888a7526e1a130",
318+
"9cb7a81f00dc4d97b3f0639ec6dde1a4",
319+
"dba16d8574f8475f9d8f973fbf968b59",
320+
"d4df21b4ef1a4d3b938c5657337b7ec1",
321+
"b6a3965a166940d289ab4a50f2e78a60",
322+
"b5813fa692a74bf2aec3fbdf03088f3d"
323+
]
324+
},
325+
"id": "4133225a-5990-4b97-9d73-88195df87a1a",
326+
"outputId": "f149a68a-eef7-4726-b3a1-723525dc908f"
327+
},
328+
"outputs": [],
329+
"source": [
330+
"# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator=\"cpu\" kwarg.\n",
331+
"logger = CSVLogger(\"logs\", name=\"QET_finetuning\")\n",
332+
"trainer = L.Trainer(max_epochs=1, accelerator=\"cpu\", logger=logger, inference_mode=False)\n",
333+
"trainer.fit(model=lit_module_finetune, train_dataloaders=train_loader, val_dataloaders=val_loader)"
334+
]
335+
},
336+
{
337+
"cell_type": "code",
338+
"execution_count": null,
339+
"id": "13",
340+
"metadata": {
341+
"id": "252f6456-3ecf-47f0-84ca-c8e9dcc66ccc"
342+
},
343+
"outputs": [],
344+
"source": [
345+
"# save trained model\n",
346+
"model_save_path = \"./finetuned_model/\"\n",
347+
"lit_module_finetune.model.save(model_save_path)\n",
348+
"# load trained model\n",
349+
"trained_model = matgl.load_model(path=model_save_path)"
350+
]
351+
},
352+
{
353+
"cell_type": "code",
354+
"execution_count": null,
355+
"id": "14",
356+
"metadata": {
357+
"id": "cd11b92f"
358+
},
359+
"outputs": [],
360+
"source": [
361+
"# This code just performs cleanup for this notebook.\n",
362+
"\n",
363+
"for fn in (\"dgl_graph.bin\", \"lattice.pt\", \"dgl_line_graph.bin\", \"state_attr.pt\", \"labels.json\"):\n",
364+
" try:\n",
365+
" os.remove(fn)\n",
366+
" except FileNotFoundError:\n",
367+
" pass\n",
368+
"\n",
369+
"shutil.rmtree(\"logs\")\n",
370+
"shutil.rmtree(\"trained_model\")\n",
371+
"shutil.rmtree(\"finetuned_model\")"
372+
]
373+
}
374+
],
375+
"metadata": {
376+
"colab": {
377+
"provenance": []
378+
},
379+
"kernelspec": {
380+
"display_name": "Python 3 (ipykernel)",
381+
"language": "python",
382+
"name": "python3"
383+
},
384+
"language_info": {
385+
"codemirror_mode": {
386+
"name": "ipython",
387+
"version": 3
388+
},
389+
"file_extension": ".py",
390+
"mimetype": "text/x-python",
391+
"name": "python",
392+
"nbconvert_exporter": "python",
393+
"pygments_lexer": "ipython3",
394+
"version": "3.11.0"
395+
}
396+
},
397+
"nbformat": 4,
398+
"nbformat_minor": 5
399+
}

0 commit comments

Comments
 (0)