Skip to content

Commit

Permalink
update notebooks. extend hdnnp2nd
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Feb 20, 2024
1 parent d7b9f3f commit 160b3df
Show file tree
Hide file tree
Showing 30 changed files with 3,946 additions and 3,744 deletions.
3 changes: 3 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ causing clashes with built-in functions. We catch defaults to be at least as bac
* Added ``kgcnn.__safe_scatter_max_min_to_zero__`` for tensorflow and jax backend scattering with default to True.
* Added simple ragged support for loss and metrics.
* Added simple ragged support for ``train_force.py``
* Implemented random equivariant initialize for PAiNN
* Implemented charge and dipole output for HDNNP2nd


v4.0.0

Expand Down
9 changes: 4 additions & 5 deletions docs/source/data.ipynb

Large diffs are not rendered by default.

136 changes: 85 additions & 51 deletions docs/source/forces.ipynb

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions docs/source/layers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,10 @@
"text": [
"tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
" 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [-0.2240, -0.1971, 0.4667, -0.2024, -0.1236, 0.4036, -0.0160, 0.2721,\n",
" -0.0063, -0.1154, 0.6441, 0.4041, -0.2673, 0.4717, -0.2080, 0.1283],\n",
" [-0.7468, -0.1389, 0.2592, -0.3825, 0.2881, 0.7621, 0.9968, 0.7264,\n",
" 0.4894, 0.4421, -0.4755, 0.6927, -0.3123, -0.3772, 0.4574, 1.0335]],\n",
" [ 0.1194, -0.0455, 0.2886, -0.2412, 0.5782, -0.0314, 0.0691, -0.0883,\n",
" -0.0427, 0.4910, 0.5271, 0.1340, 0.1813, -0.1867, -0.0742, 0.1329],\n",
" [ 0.3357, 0.2820, 0.3638, -0.1271, -0.1052, -0.3881, 0.3368, -0.5367,\n",
" -0.0495, 0.7888, 0.6429, -0.6687, 0.0486, -0.1593, -0.5402, 0.4580]],\n",
" device='cuda:0', grad_fn=<ScatterReduceBackward0>)\n"
]
}
Expand Down Expand Up @@ -363,10 +363,10 @@
{
"data": {
"text/plain": [
"tensor([[-0.2240, -0.1971, 0.4667, -0.2024, -0.1236, 0.4036, -0.0160, 0.2721,\n",
" -0.0063, -0.1154, 0.6441, 0.4041, -0.2673, 0.4717, -0.2080, 0.1283],\n",
" [-0.7468, -0.1389, 0.2592, -0.3825, 0.2881, 0.7621, 0.9968, 0.7264,\n",
" 0.4894, 0.4421, -0.4755, 0.6927, -0.3123, -0.3772, 0.4574, 1.0335]],\n",
"tensor([[ 0.1194, -0.0455, 0.2886, -0.2412, 0.5782, -0.0314, 0.0691, -0.0883,\n",
" -0.0427, 0.4910, 0.5271, 0.1340, 0.1813, -0.1867, -0.0742, 0.1329],\n",
" [ 0.3357, 0.2820, 0.3638, -0.1271, -0.1052, -0.3881, 0.3368, -0.5367,\n",
" -0.0495, 0.7888, 0.6429, -0.6687, 0.0486, -0.1593, -0.5402, 0.4580]],\n",
" device='cuda:0', grad_fn=<SliceBackward0>)"
]
},
Expand Down
100 changes: 71 additions & 29 deletions docs/source/literature.ipynb

Large diffs are not rendered by default.

73 changes: 36 additions & 37 deletions docs/source/models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 818ms/step\n"
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 399ms/step\n"
]
},
{
"data": {
"text/plain": [
"array([[1.340108 ],\n",
" [0.26912418],\n",
" [0.53824836],\n",
" [1.0764967 ]], dtype=float32)"
"array([[-1.5849516 ],\n",
" [-0.20575221],\n",
" [-0.41150442],\n",
" [-0.82300884]], dtype=float32)"
]
},
"execution_count": 6,
Expand All @@ -271,25 +271,25 @@
"output_type": "stream",
"text": [
"Epoch 1/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 8ms/step - loss: 0.2488 \n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 5ms/step - loss: 1.4705 \n",
"Epoch 2/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 3ms/step - loss: 0.1036 \n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - loss: 1.1647 \n",
"Epoch 3/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step - loss: 0.0680 \n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - loss: 0.8158 \n",
"Epoch 4/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 3ms/step - loss: 0.1017 \n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - loss: 0.7461 \n",
"Epoch 5/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 3ms/step - loss: 0.0918 \n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - loss: 0.4058 \n",
"Epoch 6/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step - loss: 0.0945 \n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - loss: 0.3058 \n",
"Epoch 7/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step - loss: 0.0903 \n"
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - loss: 0.2068 \n"
]
},
{
"data": {
"text/plain": [
"<keras.src.callbacks.history.History at 0x2601c652dd0>"
"<keras.src.callbacks.history.History at 0x1877228c400>"
]
},
"execution_count": 7,
Expand Down Expand Up @@ -350,16 +350,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 352ms/step\n"
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 207ms/step\n"
]
},
{
"data": {
"text/plain": [
"array([[-2.2719753],\n",
" [-0.5563201],\n",
" [-1.1126401],\n",
" [-2.2252803]], dtype=float32)"
"array([[0.40027413],\n",
" [0.05910223],\n",
" [0.11820446],\n",
" [0.23640892]], dtype=float32)"
]
},
"execution_count": 9,
Expand All @@ -382,25 +382,25 @@
"output_type": "stream",
"text": [
"Epoch 1/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step - loss: 1.6347 \n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step - loss: 0.2512 \n",
"Epoch 2/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 3ms/step - loss: 1.6999 \n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - loss: 0.0791 \n",
"Epoch 3/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 3ms/step - loss: 1.6529 \n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 3ms/step - loss: 0.0788 \n",
"Epoch 4/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step - loss: 1.4860 \n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - loss: 0.0627 \n",
"Epoch 5/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step - loss: 1.1121 \n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - loss: 0.0493 \n",
"Epoch 6/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step - loss: 1.2790 \n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 3ms/step - loss: 0.0367 \n",
"Epoch 7/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 3ms/step - loss: 1.1712 \n"
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.0497 \n"
]
},
{
"data": {
"text/plain": [
"<keras.src.callbacks.history.History at 0x2601f90ece0>"
"<keras.src.callbacks.history.History at 0x1877355f5b0>"
]
},
"execution_count": 10,
Expand Down Expand Up @@ -493,15 +493,18 @@
"output_type": "stream",
"text": [
"Epoch 1/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 905ms/step - loss: 0.3999\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 626ms/step - loss: 0.2121\n",
"Epoch 2/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 26ms/step - loss: 0.2608\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 16ms/step - loss: 0.1763\n",
"Epoch 3/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 25ms/step - loss: 0.2581\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 16ms/step - loss: 0.1677\n",
"Epoch 4/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 26ms/step - loss: 0.2573\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 15ms/step - loss: 0.1813\n",
"Epoch 5/7\n",
"\u001b[1m1/2\u001b[0m \u001b[32m━━━━━━━━━━\u001b[0m\u001b[37m━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 30ms/step - loss: 0.4163"
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 15ms/step - loss: 0.1444\n",
"Epoch 6/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 15ms/step - loss: 0.1416\n",
"Epoch 7/7\n"
]
},
{
Expand All @@ -516,17 +519,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 26ms/step - loss: 0.2573\n",
"Epoch 6/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 26ms/step - loss: 0.2686\n",
"Epoch 7/7\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 24ms/step - loss: 0.2421\n"
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 16ms/step - loss: 0.1393\n"
]
},
{
"data": {
"text/plain": [
"<keras.src.callbacks.history.History at 0x2601fb2fe50>"
"<keras.src.callbacks.history.History at 0x187756ec2e0>"
]
},
"execution_count": 13,
Expand Down
Loading

0 comments on commit 160b3df

Please sign in to comment.