Skip to content

Commit

Permalink
Sync dag tutorial markdown with notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
merajhashemi committed Feb 10, 2025
1 parent d9c8139 commit a72fe77
Show file tree
Hide file tree
Showing 2 changed files with 303 additions and 90 deletions.
106 changes: 16 additions & 90 deletions docs/source/notebooks/plot_dag_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cooper-org/cooper/blob/master/docs/source/notebooks/plot_dag_learning.ipynb)\n",
"\n",
"\n",
"Consider the problem of learning a Directed Acyclic Graph (DAG) on data. This is a common problem in causal inference, where we are interested in learning the causal relationships between variables. In this notebook, we will demonstrate how to learn a DAG on data using a {py:class}`~cooper.formulations.QuadraticPenalty` formulation in **Cooper**.\n"
"Consider the problem of learning a Directed Acyclic Graph (DAG) on data. This is a common problem in causal inference, where we are interested in learning the causal relationships between variables. In this notebook, we will demonstrate how to learn a DAG on data using a {py:class}`~cooper.formulations.QuadraticPenalty` formulation in **Cooper**."
]
},
{
"cell_type": "code",
"execution_count": 92,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -25,7 +25,7 @@
},
{
"cell_type": "code",
"execution_count": 93,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -52,7 +52,6 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"Consider a $d$-dimensional random vector ${X_1, X_2, ..., X_d}$. Given $n$ observations of the random vector $X \\in \\mathbb{R}^{n \\times d}$, we are interested in learning a DAG $G = (V, E)$ whose edges represent the dependencies between the variables. We model the DAG via an adjacency matrix $A \\in \\{0, 1\\}^{d \\times d}$, where $A_{ij} = 1$ if there is an edge from $X_i$ to $X_j$ and $A_{ij} = 0$ otherwise.\n",
"\n",
"This problem can be formulated as the following optimization problem:\n",
Expand All @@ -75,7 +74,7 @@
},
{
"cell_type": "code",
"execution_count": 94,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -121,7 +120,7 @@
},
{
"cell_type": "code",
"execution_count": 103,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -200,26 +199,9 @@
},
{
"cell_type": "code",
"execution_count": 102,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],\n",
" [1., 0., 1., 1., 1., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 1., 1., 0., 0., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 0., 1., 0.]])\n"
]
}
],
"outputs": [],
"source": [
"D = 10\n",
"N = 1000\n",
Expand All @@ -240,7 +222,7 @@
},
{
"cell_type": "code",
"execution_count": 97,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -280,27 +262,9 @@
},
{
"cell_type": "code",
"execution_count": 99,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step 0, loss=2201.7742, violation=0.5068, penalty_coefficient=1.0100\n",
"Step 100, loss=1128.9160, violation=0.0110, penalty_coefficient=2.7319\n",
"Step 200, loss=693.1588, violation=0.5280, penalty_coefficient=7.3892\n",
"Step 300, loss=505.0096, violation=0.2386, penalty_coefficient=19.9863\n",
"Step 400, loss=400.4205, violation=0.0640, penalty_coefficient=54.0591\n",
"Step 500, loss=333.4678, violation=0.0208, penalty_coefficient=146.2198\n",
"Step 600, loss=286.8203, violation=0.0070, penalty_coefficient=395.4971\n",
"Step 700, loss=252.3343, violation=0.0024, penalty_coefficient=1069.7448\n",
"Step 800, loss=225.7109, violation=0.0008, penalty_coefficient=2893.4583\n",
"Step 900, loss=204.4717, violation=0.0003, penalty_coefficient=7826.2573\n",
"Final loss: 787.2999, violation: 0.5077, penalty_coefficient: 20342.5723\n"
]
}
],
"outputs": [],
"source": [
"A = torch.nn.Parameter(torch.randn(D, D, device=DEVICE) / math.sqrt(D))\n",
"\n",
Expand Down Expand Up @@ -356,7 +320,7 @@
},
{
"cell_type": "code",
"execution_count": 109,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -389,45 +353,9 @@
},
{
"cell_type": "code",
"execution_count": 110,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 640x480 with 0 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAC8CAYAAADl2K3eAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAFVFJREFUeJzt3X1w1NW9x/HPsgmbbMgDDwmCYoDwGGrBQsHh8iTaYSSCIBZBmZtAEFG8CIMo2lbAwSqGjE8VMUVIqWOdWlArA2q1WHSGVi1SZRRqMKEgXhLIA4EAIdlz/+Bm63YTE04SN3jer5nMuL/8vr/v2eX8lo+/3R/HY4wxAgAAzmoX6QEAAIDIIgwAAOA4wgAAAI4jDAAA4DjCAAAAjiMMAADgOMIAAACOIwwAAOA4wgAAAI4jDETIuHHjNG7cuEgPA4gozgO0NI/HoxUrVkR6GA3q2bOnrr/++kgPI0xEw0B+fr48Ho8++uijSA7je+f111/XpEmT1LVrV7Vv316dOnXSmDFjlJubqxMnTqioqEgej6dJP0VFRZF+Ok7gXMB3qbCwUHfddZf69esnv98vv9+v9PR0LViwQJ988kmkh9fqjhw5ohUrVmjPnj2tcvzPPvtMK1asuKjeP6MiPQC0nEAgoOzsbOXn5+uKK67QnXfeqR49eqiyslK7du3Sz3/+c23btk1//OMf9dvf/jakNjc3V4cPH9bjjz8esj05Ofm7fAoAWtnWrVt18803KyoqSrfeeqsGDx6sdu3aad++fdqyZYueffZZFRYWKjU1NdJDbTVHjhzRypUr1bNnTw0ZMqTFj//ZZ59p5cqVGjdunHr27Nnix28NhAFJNTU1CgQCat++faSH0iyPPfaY8vPztXjxYuXm5srj8QR/d/fdd+vrr7/Wpk2bFBcXp1mzZoXUvvTSSyorKwvbDnd8X84DNOzAgQOaMWOGUlNT9c4776hbt24hv1+9erXWrl2rdu2+/aLxqVOnFBcX15pDbVOqqqrk9/sjPYxWdVF8Z+Crr77SnDlz1LVrV/l8Pg0aNEgbNmwI2ae6uloPPvighg4dqsTERMXFxWn06NHasWNHyH51l8jXrFmjJ554QmlpafL5fMHLOh6PRwUFBcrKylJSUpISExM1e/ZsVVVVhY3rhRde0NChQxUbG6tOnTppxowZOnToUNh+eXl5SktLU2xsrIYPH6733nuvZV8gnZ+sq1ev1qBBg5STkxMSBOp069ZN9913X4v3xneD8wDN9dhjj+nUqVPauHFjWBCQpKioKC1cuFA9evQIbsvKylKHDh104MABTZw4UfHx8br11lslnQ8FS5YsUY8ePeTz+dS/f3+tWbNG31wMt26u5efnh/X7z8/3L2TunT17VosXL1ZycrLi4+M1efJkHT58uNHX4N1339WPf/xjSdLs2bODH4nWjW/cuHH6wQ9+oL///e8aM2aM/H6/HnjggXrHW6dnz57KysqSdP4jv5/+9KeSpKuvvjp4/HfffTek5v3339fw4cMVExOj3r17a9OmTY2OvTW1+SsDR48e1VVXXSWPx6O77rpLycnJ2r59u7Kzs3XixAktWrRIknTixAmtX79eM2fO1G233abKyko9//zzmjBhgj744IOwS0EbN27UmTNnNG/ePPl8PnXq1Cn4u+nTp6tXr1565JFHtHv3bq1fv14pKSlavXp1cJ+HH35Yv/jFLzR9+nTNnTtXJSUlevrppzVmzBh9/PHHSkpKkiQ9//zzuv322zVy5EgtWrRIX375pSZPnqxOnTqFnHDN9f7776u8vFz33HOPvF5vix0XbQPnAVrC1q1b1adPH40YMeKC6mpqajRhwgSNGjVKa9askd/vlzFGkydP1o4dO5Sdna0hQ4bozTff1NKlS/XVV1+FfeR4IZoy9+bOnasXXnhBt9xyi0aOHKk///nPysjIaPTYAwcO1EMPPaQHH3xQ8+bN0+jRoyVJI0eODO5z/PhxXXfddZoxY4ZmzZqlrl27NnnsY8aM0cKFC/XUU0/pgQce0MCBA4N96xQUFOimm25Sdna2MjMztWHDBmVlZWno0KEaNGhQk3u1KBNBGzduNJLMhx9+2OA+2dnZplu3bubYsWMh22fMmGESExNNVVWVMcaYmpoac/bs2ZB9ysrKTNeuXc2cOXOC2woLC40kk5CQYIqLi0P2X758uZEUsr8xxkydOtV07tw5+LioqMh4vV7z8MMPh+z36aefmqioqOD26upqk5KSYoYMGRIytry8PCPJjB07tsHnfaGefPJJI8m8+uqrIdtrampMSUlJyE8gEAirz8jIMKmpqS02HlyYxs4FzgM0V0VFhZFkpkyZEva7srKykPeIuvlkjDGZmZlGklm2bFlIzauvvmokmVWrVoVsv+mmm4zH4zEFBQXGmH/PtY0bN4b1lWSWL18efNzUubdnzx4jydx5550h+91yyy1hx6zPhx9+2OCYxo4daySZdevWNTreOqmpqSYzMzP4+OWXXzaSzI4dO+rdV5LZuXNncFtxcbHx+XxmyZIl3zru1tSmPyYwxmjz5s2aNGmSjDE6duxY8GfChAmqqKjQ7t27JUlerzf4WWcgEFBpaalqamo0bNiw4D7fNG3atAa/HDd//vyQx6NHj9bx48d14sQJSdKWLVsUCAQ0ffr0kDFdcskl6tu3b/CS7EcffaTi4mLNnz8/5HPYrKwsJSYmNv8F+oa6sXXo0CFk+6effqrk5OSQn+PHj7dob7QuzgO0hIbeI6Tzl8a/+R7xzDPPhO1zxx13hDzetm2bvF6vFi5cGLJ9yZIlMsZo+/bt1mNtbO5t27ZNksJ6110hay6fz6fZs2e3yLHqk56eHrwiIZ3/onb//v315ZdftlrPxrTpjwlKSkpUXl6uvLw85eXl1btPcXFx8L9/85vfKDc3V/v27dO5c+eC23v16hVWV9+2OpdffnnI444dO0qSysrKlJCQoC+++ELGGPXt27fe+ujoaEnSwYMHJSlsv+joaPXu3bvB/nVKS0tVXV0dfBwbG9vgm2d8fLwk6eTJkyHb+/Tpoz/96U+SpE2bNoXdRYC2z/XzAC2jofcISXruuedUWVmpo0eP1vsl4qioKF122WUh2w4ePKju3bsHj1un7nJ43Z+7jcbm3sGDB9WuXTulpaWF7Ne/f3/rnt906aWXtuoXaf/z+Unnn2NZWVmr9WxMmw4DgUBAkjRr1ixlZmbWu88Pf/hDSee/xJSVlaUpU6Zo6dKlSklJkdfr1SOPPKIDBw6E1cXGxjbYt6HP3M3/fykmEAjI4/Fo+/bt9e5bX/K2ceONN+ovf/lL8HFmZma9X8KRpAEDBkiS9u7dqxtuuCFkLNdee62k898rwMXH9fMALSMxMVHdunXT3r17w35X9x2Chu6L9/l8jd5h0JD6vswsSbW1tQ3WNDb3Wtu3nRf1+bbnUp9IP7/6tOkwUPct0dra2uBfaA35wx/+oN69e2vLli0hk2/58uUtPq60tDQZY9SrVy/169evwf3q7tP94osvNH78+OD2c+fOqbCwUIMHD/7WPrm5uSFJsXv37g3uO3r0aCUmJuqll17S/fffb33iou1x/TxAy8nIyND69ev1wQcfaPjw4c06Vmpqqt5++21VVlaGXB3Yt29f8PfSv/+vvry8PKS+OVcOUlNTFQgEdODAgZCrAfv3729SfUMBpTEdO3YMex7V1dX6+uuvW+T4kdSm/8bwer2aNm2aNm/eXG+aLSkpCdlXCk1Wf/vb37Rr164WH9eNN94or9erlStXhiU5Y0zwM/lhw4YpOTlZ69atC7ncn5+fHzah6jN06FBde+21wZ/09PQG9/X7/br33nu1d+9eLVu2rN6EGcnUCXuunwdoOffee6/8fr/mzJmjo0ePhv3+Qt4jJk6cqNraWv3qV78K2f7444/L4/HouuuukyQlJCSoS5cu2rlzZ8h+a9eutXgG59Ud+6mnngrZ/sQTTzSpvu7fSLjQ+ZeWlhb2PPLy8sKuDNgeP5LaxJWBDRs26I033gjbfvfdd+vRRx/Vjh07NGLECN12221KT09XaWmpdu/erbffflulpaWSpOuvv15btmzR1KlTlZGRocLCQq1bt07p6en1fkbWHGlpaVq1apXuv/9+FRUVacqUKYqPj1dhYaFeeeUVzZs3T/fcc4+io6O1atUq3X777Ro/frxuvvlmFRYWauPGja3yWemyZcv0+eefKycnR2+99ZamTZumyy67TGVlZdq9e7defvllpaSkKCYmpsV7o2U0dC6sWLGC8wDN1rdvX7344ouaOXOm+vfvH/wXCI0xKiws1Isvvqh27dqFfT+gPpMmTdLVV1+tn/3sZyoqKtLgwYP11ltv6bXXXtOiRYtCPs+fO3euHn30Uc2dO1fDhg3Tzp079c9//tP6eQwZMkQzZ87U2rVrVVFRoZEjR+qdd95RQUFBk+rT0tKUlJSkdevWKT4+XnFxcRoxYsS3foem7nnMnz9f06ZN009+8hP94x//0JtvvqkuXbqEjc/r9Wr16tWqqKiQz+fT+PHjlZKSYv2cW913eOdCmLrbqRr6OXTokDHGmKNHj5oFCxaYHj16mOjoaHPJJZeYa665xuTl5QWPFQgEzC9/+UuTmppqfD6fufLKK83WrVtNZmZmyC1zdbe55OTkhI2n7raWkpKSesdZWFgYsn3z5s1m1KhRJi4uzsTFxZkBAwaYBQsWmP3794fst3btWtOrVy/j8/nMsGHDzM6dO83YsWNb7ZaqV155xUycONEkJyebqKgok5SUZEaNGmVycnJMeXl5vTXcWhhZTTkXOA/QUgoKCswdd9xh+vTpY2JiYkxsbKwZMGCAmT9/vtmzZ0/IvpmZmSYuLq7e41RWVprFixeb7t27m+joaNO3b1+Tk5MTdvtyVVWVyc7ONomJiSY+Pt5Mnz7dFBcXN3hrYVPm3unTp83ChQtN586dTVxcnJk0aZI5dOhQk24tNMaY1157zaSnp5uoqKiQ2wzHjh1rBg0aVG9NbW2tue+++0yXLl2M3+83EyZMMAUFBWG3FhpjzK9//WvTu3dv4/V6Q24zTE1NNRkZGWHHjvS54DGGa8cAALisTX9nAAAAtD7CAAAAjiMMAADgOMIAAACOIwwAAOA4wgAAAI4jDAAA4Lgm/wuEfy0ot27ii65/UYbGHKk8bd0z2e+zqqs4e67xnRpw8lyNdW31BS50Uad9AwteNMXBCrvX98CxM9Y9n5460LrWFnO3cczdxkVi7v7PK59/5z3x/dOUucuVAQAAHEcYAADAcYQBAAAcRxgAAMBxhAEAABxHGAAAwHGEAQAAHEcYAADAcYQBAAAcRxgAAMBxhAEAABxHGAAAwHGEAQAAHEcYAADAcU1ewvhfJ6usm1zewW9V5/V4rHseO33Wqq6i2n4Z2PjoJr+cYYyxy2W1xlj37BRrN97qpPbWPSOBuds45i7gNq4MAADgOMIAAACOIwwAAOA4wgAAAI4jDAAA4DjCAAAAjiMMAADgOMIAAACOIwwAAOA4wgAAAI4jDAAA4DjCAAAAjiMMAADguCYv/dWvU7x1kyR/tF1huXVLVdXUWtV1jLdf1exEM1aNS2hv9xq1b2ef5yrP1VjVfXjopHXPSGDuNo65C7iNKwMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOOavIRxYfkp6yapxm9Vt+B3H1v3fGbmlVZ1KYkx1j3jzjT55QzzvydPW9VVVtst5SpJZWerrWsvJpGYux1i7OeCztiVMXfREnImDbSuXfr65y04EnyXuDIAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjmrxU2fyn3rNu4vF4rOpeXnqNdc+S02eta211iLVf+a0mYKzqknzR1j1d4fPaZ97ik3bzKC7afi6UR2BFPuZu29ScFQRtsfKgm7gyAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4rsnrlv7unvHWTWyXcz1Yecq6Z+eY9lZ105/cad3z93ePsa4tPWO3bO2ZmlrrnpcnxFnXXkxivF7r2otp7lYHAtY9T56usa5l7rYeV5YTbs5Sza68Rq2NKwMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOOavD7r1Ie2WTd546FJVnWPbd1v3fPMGbslWb3NWO62Q4zdcreS1EcdrOpO1dgvPXuy2q72koRo656RcOzMWeva2Ci7+ZCenGDds7rGbini09X2SwIzd9umSCzt25yeuHhxZQAAAMcRBgAAcBxhAAAAxxEGAABwHGEAAADHEQYAAHAcYQAAAMcRBgAAcBxhAAAAxxEGAABwHGEAAADHEQYAAHAcYQAAAMc1eamy6Pb2q33993N/taqrKD1h3TOhY7xV3etLx1r3LCw9ZV3b3muXyxKb8edy+NRpq7rOfvsV7iLBZ/naSpK3nceq7mBZlXXPLrE+q7pEv/1cYO5+/1xMqw/arrCIlsOVAQAAHEcYAADAcYQBAAAcRxgAAMBxhAEAABxHGAAAwHGEAQAAHEcYAADAcYQBAAAcRxgAAMBxhAEAABxHGAAAwHGEAQAAHEcYAADAcU1ez3PX6husm3Tu0N6q7kcPbLfueemlCVZ1Xo/dkrWS5I/2Wtc+u+tfVnWBgLHuee/YNKu6Mb27WPeMhPSUROta27lbfSRg3bNLvGXPGvuezZm7V1xq9/o2Y+pKxXZlF9vcjcTSvhfT0sfNFYnn2laXa+bKAAAAjiMMAADgOMIAAACOIwwAAOA4wgAAAI4jDAAA4DjCAAAAjiMMAADgOMIAAACOIwwAAOA4wgAAAI4jDAAA4DjCAAAAjvMYY5q0dljGcx9YN+kQE21VZ79+oPTJ/hKrun69O1n3TE6Isa49UlZlVZd/64+se/7XQ29b1W2ad5V1z6v6JFnX2jpcVm1d64u2y8vNmbsHj9nNhUs7xlr37BDT5AVMw5RXnbOqs31tJamo+JRVXW0zlkqMxNw9U/Odt0QTNGflQduVEpvT8+mpjffkygAAAI4jDAAA4DjCAAAAjiMMAADgOMIAAACOIwwAAOA4wgAAAI4jDAAA4DjCAAAAjiMMAADgOMIAAACOIwwAAOA4wgAAAI4jDAAA4LgmL2EMAAC+n7gyAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDj/g8IPF4S4bcT/AAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 640x480 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],\n",
" [1., 0., 1., 1., 1., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 1., 1., 0., 0., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 0., 1., 0.]])\n"
]
}
],
"outputs": [],
"source": [
"# Plot\n",
"\n",
Expand All @@ -441,15 +369,13 @@
"source": [
"## References\n",
"- NOTEARS\n",
"- Seb's paper?\n"
"- Seb's paper?"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "cooper",
"language": "python",
"name": "python3"
"jupytext": {
"formats": "ipynb,md:myst"
},
"language_info": {
"codemirror_mode": {
Expand Down
Loading

0 comments on commit a72fe77

Please sign in to comment.