Skip to content

Commit

Permalink
bias/var plots
Browse files Browse the repository at this point in the history
  • Loading branch information
fs446 committed Dec 17, 2024
1 parent b952e9f commit dc0c035
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 22 deletions.
32 changes: 21 additions & 11 deletions bias_variance_linear_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 64,
"id": "7d969206",
"metadata": {},
"outputs": [],
Expand All @@ -343,7 +343,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 65,
"id": "f18dc6d5",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -391,7 +391,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 66,
"id": "5e9a45d9",
"metadata": {},
"outputs": [],
Expand All @@ -418,15 +418,17 @@
"source": [
"# generate 'true' data with the design matrix of 'true' model\n",
"y = X @ beta\n",
"plt.figure(figsize=(5, 3))\n",
"plt.figure(figsize=(6, 3))\n",
"plt.plot(y, \"k-\")\n",
"plt.xlabel(\"independent features' input variable x\")\n",
"plt.ylabel((\"dependent variable y, true data\"))\n",
"plt.title(\"true model data as linear model (x -> 4 features + intercept)\")\n",
"plt.xlim(0, M)\n",
"plt.ylim(-2, 8)\n",
"plt.grid(True)\n",
"print(X.shape, y.shape)"
"print(X.shape, y.shape)\n",
"plt.tight_layout()\n",
"plt.savefig('bias_variance_plots/true_data.png', dpi=300)"
]
},
{
Expand Down Expand Up @@ -583,7 +585,9 @@
"X = np.copy(x)[:, None]\n",
"fig, axs = plt.subplots(2, 2, figsize=(10, 5))\n",
"bias_variance_of_model(X)\n",
"axs[0, 0].set_title(\"underfit, too low model complexity, high bias, low var\");"
"axs[0, 0].set_title(\"underfit, too low model complexity, high bias, low var\");\n",
"plt.tight_layout()\n",
"plt.savefig('bias_variance_plots/too_simple_model.png', dpi=300)"
]
},
{
Expand Down Expand Up @@ -618,7 +622,9 @@
"# note that intercept is only added in function bias_variance_of_model(X)\n",
"fig, axs = plt.subplots(2, 2, figsize=(10, 5))\n",
"bias_variance_of_model(X)\n",
"axs[0, 0].set_title(\"overfit, too high model complexity, low bias, high var\");"
"axs[0, 0].set_title(\"overfit, too high model complexity, low bias, high var\");\n",
"plt.tight_layout()\n",
"plt.savefig('bias_variance_plots/too_complex_model.png', dpi=300)"
]
},
{
Expand Down Expand Up @@ -652,7 +658,9 @@
"bias_variance_of_model(X) # lowest possible bias^2+variance, because we\n",
"# know the true model (again: which in practice likely never will occur)\n",
"# the remaining variance is from the added noise\n",
"axs[0, 0].set_title(\"true model features, lowest bias, lowest var\");"
"axs[0, 0].set_title(\"true model features, lowest bias, lowest var\");\n",
"plt.tight_layout()\n",
"plt.savefig('bias_variance_plots/true_model.png', dpi=300)"
]
},
{
Expand Down Expand Up @@ -684,7 +692,9 @@
"# note that intercept is only added in function bias_variance_of_model(X)\n",
"fig, axs = plt.subplots(2, 2, figsize=(10, 5))\n",
"bias_variance_of_model(X)\n",
"axs[0, 0].set_title(\"reasonable bias/var trade-off if true model is unknown\");"
"axs[0, 0].set_title(\"reasonable bias/var trade-off if true model is unknown\");\n",
"plt.tight_layout()\n",
"plt.savefig('bias_variance_plots/robust_model.png', dpi=300)"
]
},
{
Expand Down Expand Up @@ -721,7 +731,7 @@
"kernelspec": {
"display_name": "myddasp",
"language": "python",
"name": "myddasp"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -733,7 +743,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added bias_variance_plots/robust_model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added bias_variance_plots/too_complex_model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added bias_variance_plots/too_simple_model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added bias_variance_plots/true_data.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added bias_variance_plots/true_model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 14 additions & 11 deletions bias_variance_ridge_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"id": "c86e15f6",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -174,7 +174,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"id": "6764d6da",
"metadata": {},
"outputs": [],
Expand All @@ -185,7 +185,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"id": "d8ef032a",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -263,7 +263,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"id": "e21f304f",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -295,7 +295,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 18,
"id": "9afb4067",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -337,7 +337,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 19,
"id": "9c3ccf25",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -371,7 +371,7 @@
"outputs": [],
"source": [
"fig, axs = plt.subplots(1, 1, figsize=(8, 4))\n",
"axs.plot(alpha2_vec, bias_squared, \"C0\", label=r\"bias$^2$\", lw=2)\n",
"axs.plot(alpha2_vec, bias_squared, \"C0\", label=r\"bias$^2$\", lw=3)\n",
"axs.plot(alpha2_vec, variance, \"C1\", label=r\"var\")\n",
"axs.plot(alpha2_vec, bias_squared + variance, \"C2\", label=r\"bias$^2$+var\")\n",
"\n",
Expand All @@ -381,12 +381,15 @@
"\n",
"axs.set_xscale(\"log\")\n",
"axs.set_yscale(\"log\")\n",
"axs.set_xlabel(r\"regularization value $\\alpha^2$\")\n",
"axs.set_xlabel(r\"underfit region regularization value $\\alpha^2$ overfit region\")\n",
"axs.set_title(r\"$\\alpha^2_\\mathrm{opt}$=\" + \"{:4.3f}\".format(alpha2_vec[idx]))\n",
"axs.legend()\n",
"axs.set_xlim(10**alpha2_min, 10**alpha2_max)\n",
"axs.set_ylim(1e-2, 1e1)\n",
"axs.grid(True)"
"axs.grid(True)\n",
"axs.xaxis.set_inverted(True)\n",
"plt.tight_layout()\n",
"plt.savefig('bias_variance_plots/bias_var_l2_regularisation.png', dpi=600)"
]
},
{
Expand Down Expand Up @@ -668,7 +671,7 @@
"kernelspec": {
"display_name": "myddasp",
"language": "python",
"name": "myddasp"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -680,7 +683,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down

0 comments on commit dc0c035

Please sign in to comment.