diff --git a/notebooks/tutorials/7-decoding-strategies.ipynb b/notebooks/tutorials/7-decoding-strategies.ipynb
index 47c9f305..95e4937f 100644
--- a/notebooks/tutorials/7-decoding-strategies.ipynb
+++ b/notebooks/tutorials/7-decoding-strategies.ipynb
@@ -7,7 +7,7 @@
"source": [
"# RL4CO Decoding Strategies Notebook\n",
"\n",
- "This notebook demonstrates how to utilize the different decoding strategies available in rl4co/models/nn/dec_strategies.py during the different phases of model development. We will also demonstrate how to evaluate the model for different decoding strategies on the test dataset. \n",
+ "This notebook demonstrates how to utilize the different decoding strategies available in [rl4co/models/nn/dec_strategies.py](../../rl4co/models/nn/dec_strategies.py) during the different phases of model development. We will also demonstrate how to evaluate the model for different decoding strategies on the test dataset. \n",
"\n",
"\n"
]
@@ -22,7 +22,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"id": "7538da3e-67df-4c72-9acb-345a3bc9fba1",
"metadata": {},
"outputs": [],
@@ -36,7 +36,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 2,
"id": "4380f62f-bde8-4fc5-aa1a-072d5be58a32",
"metadata": {},
"outputs": [],
@@ -58,14 +58,14 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 3,
"id": "40c9a2ac-a2cc-4a90-a810-75a092fa4890",
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"# RL4CO env based on TorchRL\n",
- "env = TSPEnv(num_loc=10) \n",
+ "env = TSPEnv(num_loc=20) \n",
"\n",
"# Policy: neural network, in this case with encoder-decoder architecture\n",
"policy = AttentionModelPolicy(env.name, \n",
@@ -80,7 +80,7 @@
" batch_size = 128,\n",
" val_batch_size = 512,\n",
" test_batch_size = 512,\n",
- " train_data_size=1_00,\n",
+ " train_data_size=20_000, # fast training for demo\n",
" val_data_size=1_000,\n",
" test_data_size=1_000,\n",
" optimizer_kwargs={\"lr\": 1e-4},\n",
@@ -102,14 +102,126 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 4,
"id": "38e7840f-c3b7-4f47-b694-f00db7f25896",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using 16bit Automatic Mixed Precision (AMP)\n",
+ "GPU available: True (cuda), used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "IPU available: False, using: 0 IPUs\n",
+ "HPU available: False, using: 0 HPUs\n",
+ "/datasets/home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n",
+ "val_file not set. Generating dataset instead\n",
+ "test_file not set. Generating dataset instead\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "--------------------------------------------------\n",
+ "0 | env | TSPEnv | 0 \n",
+ "1 | policy | AttentionModelPolicy | 710 K \n",
+ "2 | baseline | WarmupBaseline | 710 K \n",
+ "--------------------------------------------------\n",
+ "1.4 M Trainable params\n",
+ "0 Non-trainable params\n",
+ "1.4 M Total params\n",
+ "5.681 Total estimated model params size (MB)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "16097f1a18e046ee993da874c54263e6",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Sanity Checking: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/datasets/home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.\n",
+ "/datasets/home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "db2c392d44f840d5a4c1b6987315a8b3",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Training: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "5996e8f68f7645a889561cdc197256c8",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b4f95d1630324b2d90ce1fad6b67f226",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2ebe27fbc2c2445098425b46e77a3516",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "`Trainer.fit` stopped: `max_epochs=3` reached.\n"
+ ]
+ }
+ ],
"source": [
"trainer = RL4COTrainer(\n",
- " max_epochs=2,\n",
- " logger=None,\n",
+ " max_epochs=3,\n",
+ " devices=1,\n",
")\n",
"\n",
"trainer.fit(model)"
@@ -125,12 +237,68 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 5,
"id": "ea794e13-5af6-41cc-b3cc-1a6b42520086",
"metadata": {
"scrolled": true
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "val_file not set. Generating dataset instead\n",
+ "test_file not set. Generating dataset instead\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
+ "/datasets/home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "6e3647dbbbb840f7988c5813cd9221b9",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Testing: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃ Test metric ┃ DataLoader 0 ┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ test/reward │ -4.0203351974487305 │\n",
+ "└───────────────────────────┴───────────────────────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test/reward \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m -4.0203351974487305 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "└───────────────────────────┴───────────────────────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[{'test/reward': -4.0203351974487305}]"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"# here we evaluate the model on the test set using the beam search decoding strategy as declared in the model constructor\n",
"trainer.test(model=model)"
@@ -138,10 +306,65 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 6,
"id": "36ec98df-17d3-4250-9a9f-c1d510175934",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "val_file not set. Generating dataset instead\n",
+ "test_file not set. Generating dataset instead\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2bc29132ff0b4f838994391e4dcacdcf",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Testing: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃ Test metric ┃ DataLoader 0 ┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ test/reward │ -4.104068756103516 │\n",
+ "└───────────────────────────┴───────────────────────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test/reward \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m -4.104068756103516 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "└───────────────────────────┴───────────────────────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[{'test/reward': -4.104068756103516}]"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"# we can simply change the decoding type of the current model instance\n",
"model.policy.test_decode_type = \"greedy\"\n",
@@ -153,34 +376,66 @@
"id": "08f2744e-83ea-402a-83ec-94cbf12a0870",
"metadata": {},
"source": [
- "### Manual Test Loop\n",
+ "## Test Loop\n",
"\n",
"Let's compare beam search with a greedy decoding strategy by manually looping over our test dataset:"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "338b12d4",
+ "metadata": {},
+ "source": [
+ "### Greedy decoding"
+ ]
+ },
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 13,
"id": "f2b89503-f416-4b73-a5dc-1f1dfd7369e3",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Average reward is tensor(-4.1084)\n"
+ ]
+ }
+ ],
"source": [
"bs_rewards = []\n",
"for batch in model.test_dataloader():\n",
" td = env.reset(batch)\n",
" with torch.no_grad():\n",
" # in a manual loop we can dynamically specify the decode type\n",
- " out = model(td, decode_type=\"beam_search\", beam_width=10)\n",
+ " out = model(td, decode_type=\"beam_search\", beam_width=20)\n",
" bs_rewards.append(out[\"reward\"])\n",
"print(\"Average reward is %s\" % torch.cat(bs_rewards).mean())"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "a5bab7f4",
+ "metadata": {},
+ "source": [
+ "### Beam search decoding\n"
+ ]
+ },
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 14,
"id": "2103e579-d35f-4496-b401-47e9d3be7caa",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Average reward is tensor(-4.2222)\n"
+ ]
+ }
+ ],
"source": [
"bs_rewards = []\n",
"for batch in model.test_dataloader():\n",
@@ -191,19 +446,37 @@
"print(\"Average reward is %s\" % torch.cat(bs_rewards).mean())"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "e1f1361c",
+ "metadata": {},
+ "source": [
+ "### Greedy multistart decoding\n",
+ "\n",
+ "Start from different nodes as done in POMO"
+ ]
+ },
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 15,
"id": "41b902d8-446a-4d3e-b14e-673560ca7af1",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Average reward is tensor(-4.0608)\n"
+ ]
+ }
+ ],
"source": [
"bs_rewards = []\n",
"for batch in model.test_dataloader():\n",
" td = env.reset(batch)\n",
" bs = batch.batch_size[0]\n",
" with torch.no_grad():\n",
- " out = model(td, decode_type=\"multistart_greedy\", num_starts=10, return_actions=True)\n",
+ " out = model(td, decode_type=\"multistart_greedy\", num_starts=20, return_actions=True)\n",
" rewards = torch.stack(out[\"reward\"].split(bs), 1).max(1).values\n",
" bs_rewards.append(rewards)\n",
"print(\"Average reward is %s\" % torch.cat(bs_rewards).mean())"
@@ -242,37 +515,20 @@
},
{
"cell_type": "code",
- "execution_count": null,
- "id": "9d476c7a-aa23-45bb-b3e9-441992dcdf81",
- "metadata": {},
- "outputs": [],
- "source": [
- "td = env.reset(batch)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b9e59a46-8aea-4739-a16d-913e3d3b8d0f",
- "metadata": {},
- "outputs": [],
- "source": [
- "bs = batch.batch_size[0]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
+ "execution_count": 10,
"id": "01727816-25ff-4a21-b092-b3a72be19343",
"metadata": {},
"outputs": [],
"source": [
+ "td = env.reset(batch)\n",
+ "bs = batch.batch_size[0]\n",
+ "\n",
"out = model(td, decode_type=\"beam_search\", beam_width=5, select_best=False, return_actions=True)"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 11,
"id": "5366d787-2ac7-4d0c-bffc-926d64c82b63",
"metadata": {},
"outputs": [],
@@ -284,10 +540,61 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 12,
"id": "7ee02330-965b-48f3-8cf8-ed20ed7e1af6",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "