From a91a5f338714bfd0cd3a5fad4d836fe27b4f83e3 Mon Sep 17 00:00:00 2001 From: Samir Kamal <1954121+skamril@users.noreply.github.com> Date: Thu, 30 Nov 2023 17:24:31 +0100 Subject: [PATCH 01/43] feat(ui-storages): use percentage values instead of ratio values (#1846) For `efficiency` and `initialLevel` fields --- .../Modelization/Areas/Storages/Fields.tsx | 8 +++--- .../Modelization/Areas/Storages/Form.tsx | 26 ++++++++++++++++--- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Fields.tsx b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Fields.tsx index 0e04b378a0..8485fd29e6 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Fields.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Fields.tsx @@ -101,8 +101,8 @@ function Fields() { message: t("form.field.minValue", { 0: 0 }), }, max: { - value: 1, - message: t("form.field.maxValue", { 0: 1 }), + value: 100, + message: t("form.field.maxValue", { 0: 100 }), }, }} /> @@ -116,8 +116,8 @@ function Fields() { message: t("form.field.minValue", { 0: 0 }), }, max: { - value: 1, - message: t("form.field.maxValue", { 0: 1 }), + value: 100, + message: t("form.field.maxValue", { 0: 100 }), }, }} /> diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Form.tsx b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Form.tsx index c3f2f717dd..0314fc0df1 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Form.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Form.tsx @@ -3,6 +3,7 @@ import { Box, Button } from "@mui/material"; import { useParams, useOutletContext, useNavigate } from "react-router-dom"; import ArrowBackIcon from "@mui/icons-material/ArrowBack"; import { useTranslation } from "react-i18next"; +import * as RA from "ramda-adjunct"; import { StudyMetadata } from "../../../../../../../common/types"; import Form from "../../../../../../common/Form"; import Fields from "./Fields"; @@ -27,17 +28,34 @@ function StorageForm() { }); // prevent re-fetch while useNavigateOnCondition event occurs - const defaultValues = useCallback(() => { - return getStorage(study.id, areaId, storageId); + const defaultValues = useCallback( + async () => { + const storage = await getStorage(study.id, areaId, storageId); + return { + ...storage, + // Convert to percentage ([0-1] -> [0-100]) + efficiency: storage.efficiency * 100, + initialLevel: storage.initialLevel * 100, + }; + }, // eslint-disable-next-line react-hooks/exhaustive-deps - }, []); + [], + ); //////////////////////////////////////////////////////////////// // Event handlers //////////////////////////////////////////////////////////////// const handleSubmit = ({ dirtyValues }: SubmitHandlerPlus) => { - return updateStorage(study.id, areaId, storageId, dirtyValues); + const newValues = { ...dirtyValues }; + // Convert to ratio ([0-100] -> [0-1]) + if (RA.isNumber(newValues.efficiency)) { + newValues.efficiency /= 100; + } + if (RA.isNumber(newValues.initialLevel)) { + newValues.initialLevel /= 100; + } + return updateStorage(study.id, areaId, storageId, newValues); }; //////////////////////////////////////////////////////////////// From 637a77dca0151aaddca6bed3359189547493e8d3 Mon Sep 17 00:00:00 2001 From: Hatim Dinia Date: Tue, 5 Dec 2023 14:54:56 +0100 Subject: [PATCH 02/43] chore(deps): upgrade material-react-table (#1851) fixes #1822 --- webapp/package-lock.json | 199 +++++++++++++----- webapp/package.json | 3 +- .../common/GroupedDataTable/index.tsx | 3 +- 3 files changed, 153 insertions(+), 52 deletions(-) diff --git a/webapp/package-lock.json b/webapp/package-lock.json index 19ba7c4221..ed11beab16 100644 --- a/webapp/package-lock.json +++ b/webapp/package-lock.json @@ -14,6 +14,7 @@ "@mui/icons-material": "5.14.11", "@mui/lab": "5.0.0-alpha.146", "@mui/material": "5.14.11", + "@mui/x-date-pickers": "6.18.3", "@reduxjs/toolkit": "1.9.6", "@types/d3": "5.16.0", "@types/draft-convert": "2.1.5", @@ -44,7 +45,7 @@ "js-cookie": "3.0.5", "jwt-decode": "3.1.2", "lodash": "4.17.21", - "material-react-table": "1.15.0", + "material-react-table": "2.0.5", "moment": "2.29.4", "notistack": "3.0.1", "os": "0.1.2", @@ -2191,9 +2192,9 @@ "integrity": "sha512-x/rqGMdzj+fWZvCOYForTghzbtqPDZ5gPwaoNGHdgDfF2QA/XZbCBp4Moo5scrkAMPhB7z26XM/AaHuIJdgauA==" }, "node_modules/@babel/runtime": { - "version": "7.23.1", - "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.23.1.tgz", - "integrity": "sha512-hC2v6p8ZSI/W0HUzh3V8C5g+NwSKzKPtJwSpTjwl0o297GP9+ZLQSkdvHz46CM3LqyoXxq+5G9komY+eSqSO0g==", + "version": "7.23.5", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.23.5.tgz", + "integrity": "sha512-NdUTHcPe4C99WxPub+K9l9tK5/lV4UXIoaHSYgzco9BCyjKAAwzdBI+wWtYqHt7LJdbo74ZjRPJgzVweq1sz0w==", "dependencies": { "regenerator-runtime": "^0.14.0" }, @@ -2769,9 +2770,9 @@ } }, "node_modules/@floating-ui/react-dom": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/@floating-ui/react-dom/-/react-dom-2.0.2.tgz", - "integrity": "sha512-5qhlDvjaLmAst/rKb3VdlCinwTF4EYMiVxuuc/HVUjs46W0zgtbMmAZ1UTsDrRTxRmUEzl92mOtWbeeXL26lSQ==", + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/@floating-ui/react-dom/-/react-dom-2.0.4.tgz", + "integrity": "sha512-CF8k2rgKeh/49UrnIBs4BdxPUV6vize/Db1d/YbCLyp9GiVZ0BEwf5AiDSxJRCr6yOkGqTFHtmrULxkEfYZ7dQ==", "dependencies": { "@floating-ui/dom": "^1.5.1" }, @@ -3570,11 +3571,11 @@ } }, "node_modules/@mui/types": { - "version": "7.2.4", - "resolved": "https://registry.npmjs.org/@mui/types/-/types-7.2.4.tgz", - "integrity": "sha512-LBcwa8rN84bKF+f5sDyku42w1NTxaPgPyYKODsh01U1fVstTClbUoSA96oyRBnSNyEiAVjKm6Gwx9vjR+xyqHA==", + "version": "7.2.10", + "resolved": "https://registry.npmjs.org/@mui/types/-/types-7.2.10.tgz", + "integrity": "sha512-wX1vbDC+lzF7FlhT6A3ffRZgEoKWPF8VqRoTu4lZwouFX2t90KyCMsgepMw5DxLak1BSp/KP86CmtZttikb/gQ==", "peerDependencies": { - "@types/react": "*" + "@types/react": "^17.0.0 || ^18.0.0" }, "peerDependenciesMeta": { "@types/react": { @@ -3583,12 +3584,12 @@ } }, "node_modules/@mui/utils": { - "version": "5.14.11", - "resolved": "https://registry.npmjs.org/@mui/utils/-/utils-5.14.11.tgz", - "integrity": "sha512-fmkIiCPKyDssYrJ5qk+dime1nlO3dmWfCtaPY/uVBqCRMBZ11JhddB9m8sjI2mgqQQwRJG5bq3biaosNdU/s4Q==", + "version": "5.14.20", + "resolved": "https://registry.npmjs.org/@mui/utils/-/utils-5.14.20.tgz", + "integrity": "sha512-Y6yL5MoFmtQml20DZnaaK1znrCEwG6/vRSzW8PKOTrzhyqKIql0FazZRUR7sA5EPASgiyKZfq0FPwISRXm5NdA==", "dependencies": { - "@babel/runtime": "^7.22.15", - "@types/prop-types": "^15.7.5", + "@babel/runtime": "^7.23.4", + "@types/prop-types": "^15.7.11", "prop-types": "^15.8.1", "react-is": "^18.2.0" }, @@ -3597,7 +3598,7 @@ }, "funding": { "type": "opencollective", - "url": "https://opencollective.com/mui" + "url": "https://opencollective.com/mui-org" }, "peerDependencies": { "@types/react": "^17.0.0 || ^18.0.0", @@ -3609,6 +3610,102 @@ } } }, + "node_modules/@mui/x-date-pickers": { + "version": "6.18.3", + "resolved": "https://registry.npmjs.org/@mui/x-date-pickers/-/x-date-pickers-6.18.3.tgz", + "integrity": "sha512-DmJrAAr6EfhuWA9yubANAdeQayAbUppCezdhxkYKwn38G8+HJPZBol0V5fKji+B4jMxruO78lkQYsGUxVxaR7A==", + "dependencies": { + "@babel/runtime": "^7.23.2", + "@mui/base": "^5.0.0-beta.22", + "@mui/utils": "^5.14.16", + "@types/react-transition-group": "^4.4.8", + "clsx": "^2.0.0", + "prop-types": "^15.8.1", + "react-transition-group": "^4.4.5" + }, + "engines": { + "node": ">=14.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/mui" + }, + "peerDependencies": { + "@emotion/react": "^11.9.0", + "@emotion/styled": "^11.8.1", + "@mui/material": "^5.8.6", + "@mui/system": "^5.8.0", + "date-fns": "^2.25.0", + "date-fns-jalali": "^2.13.0-0", + "dayjs": "^1.10.7", + "luxon": "^3.0.2", + "moment": "^2.29.4", + "moment-hijri": "^2.1.2", + "moment-jalaali": "^0.7.4 || ^0.8.0 || ^0.9.0 || ^0.10.0", + "react": "^17.0.0 || ^18.0.0", + "react-dom": "^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@emotion/react": { + "optional": true + }, + "@emotion/styled": { + "optional": true + }, + "date-fns": { + "optional": true + }, + "date-fns-jalali": { + "optional": true + }, + "dayjs": { + "optional": true + }, + "luxon": { + "optional": true + }, + "moment": { + "optional": true + }, + "moment-hijri": { + "optional": true + }, + "moment-jalaali": { + "optional": true + } + } + }, + "node_modules/@mui/x-date-pickers/node_modules/@mui/base": { + "version": "5.0.0-beta.26", + "resolved": "https://registry.npmjs.org/@mui/base/-/base-5.0.0-beta.26.tgz", + "integrity": "sha512-gPMRKC84VRw+tjqYoyBzyrBUqHQucMXdlBpYazHa5rCXrb91fYEQk5SqQ2U5kjxx9QxZxTBvWAmZ6DblIgaGhQ==", + "dependencies": { + "@babel/runtime": "^7.23.4", + "@floating-ui/react-dom": "^2.0.4", + "@mui/types": "^7.2.10", + "@mui/utils": "^5.14.20", + "@popperjs/core": "^2.11.8", + "clsx": "^2.0.0", + "prop-types": "^15.8.1" + }, + "engines": { + "node": ">=12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/mui-org" + }, + "peerDependencies": { + "@types/react": "^17.0.0 || ^18.0.0", + "react": "^17.0.0 || ^18.0.0", + "react-dom": "^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/@mui/x-tree-view": { "version": "6.0.0-alpha.1", "resolved": "https://registry.npmjs.org/@mui/x-tree-view/-/x-tree-view-6.0.0-alpha.1.tgz", @@ -4596,11 +4693,11 @@ } }, "node_modules/@tanstack/react-table": { - "version": "8.10.3", - "resolved": "https://registry.npmjs.org/@tanstack/react-table/-/react-table-8.10.3.tgz", - "integrity": "sha512-Qya1cJ+91arAlW7IRDWksRDnYw28O446jJ/ljkRSc663EaftJoBCAU10M+VV1K6MpCBLrXq1BD5IQc1zj/ZEjA==", + "version": "8.10.7", + "resolved": "https://registry.npmjs.org/@tanstack/react-table/-/react-table-8.10.7.tgz", + "integrity": "sha512-bXhjA7xsTcsW8JPTTYlUg/FuBpn8MNjiEPhkNhIGCUR6iRQM2+WEco4OBpvDeVcR9SE+bmWLzdfiY7bCbCSVuA==", "dependencies": { - "@tanstack/table-core": "8.10.3" + "@tanstack/table-core": "8.10.7" }, "engines": { "node": ">=12" @@ -4615,24 +4712,25 @@ } }, "node_modules/@tanstack/react-virtual": { - "version": "3.0.0-beta.60", - "resolved": "https://registry.npmjs.org/@tanstack/react-virtual/-/react-virtual-3.0.0-beta.60.tgz", - "integrity": "sha512-F0wL9+byp7lf/tH6U5LW0ZjBqs+hrMXJrj5xcIGcklI0pggvjzMNW9DdIBcyltPNr6hmHQ0wt8FDGe1n1ZAThA==", + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/@tanstack/react-virtual/-/react-virtual-3.0.1.tgz", + "integrity": "sha512-IFOFuRUTaiM/yibty9qQ9BfycQnYXIDHGP2+cU+0LrFFGNhVxCXSQnaY6wkX8uJVteFEBjUondX0Hmpp7TNcag==", "dependencies": { - "@tanstack/virtual-core": "3.0.0-beta.60" + "@tanstack/virtual-core": "3.0.0" }, "funding": { "type": "github", "url": "https://github.com/sponsors/tannerlinsley" }, "peerDependencies": { - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + "react": "^16.8.0 || ^17.0.0 || ^18.0.0", + "react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0" } }, "node_modules/@tanstack/table-core": { - "version": "8.10.3", - "resolved": "https://registry.npmjs.org/@tanstack/table-core/-/table-core-8.10.3.tgz", - "integrity": "sha512-hJ55YfJlWbfzRROfcyA/kC1aZr/shsLA8XNAwN8jXylhYWGLnPmiJJISrUfj4dMMWRiFi0xBlnlC7MLH+zSrcw==", + "version": "8.10.7", + "resolved": "https://registry.npmjs.org/@tanstack/table-core/-/table-core-8.10.7.tgz", + "integrity": "sha512-KQk5OMg5OH6rmbHZxuNROvdI+hKDIUxANaHlV+dPlNN7ED3qYQ/WkpY2qlXww1SIdeMlkIhpN/2L00rof0fXFw==", "engines": { "node": ">=12" }, @@ -4642,9 +4740,9 @@ } }, "node_modules/@tanstack/virtual-core": { - "version": "3.0.0-beta.60", - "resolved": "https://registry.npmjs.org/@tanstack/virtual-core/-/virtual-core-3.0.0-beta.60.tgz", - "integrity": "sha512-QlCdhsV1+JIf0c0U6ge6SQmpwsyAT0oQaOSZk50AtEeAyQl9tQrd6qCHAslxQpgphrfe945abvKG8uYvw3hIGA==", + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/@tanstack/virtual-core/-/virtual-core-3.0.0.tgz", + "integrity": "sha512-SYXOBTjJb05rXa2vl55TTwO40A6wKu0R5i1qQwhJYNDIqaIGF7D0HsLw+pJAyi2OvntlEIVusx3xtbbgSUi6zg==", "funding": { "type": "github", "url": "https://github.com/sponsors/tannerlinsley" @@ -5238,9 +5336,9 @@ "integrity": "sha512-+68kP9yzs4LMp7VNh8gdzMSPZFL44MLGqiHWvttYJe+6qnuVr4Ek9wSBQoveqY/r+LwjCcU29kNVkidwim+kYA==" }, "node_modules/@types/prop-types": { - "version": "15.7.8", - "resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.8.tgz", - "integrity": "sha512-kMpQpfZKSCBqltAJwskgePRaYRFukDkm1oItcAbC3gNELR20XIBcN9VRgg4+m8DKsTfkWeA4m4Imp4DDuWy7FQ==" + "version": "15.7.11", + "resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.11.tgz", + "integrity": "sha512-ga8y9v9uyeiLdpKddhxYQkxNDrfvuPrlFb0N1qnZZByvcElJaXthF1UhvCh9TLWJBEHeNtdnbysW7Y6Uq8CVng==" }, "node_modules/@types/q": { "version": "1.5.6", @@ -5332,9 +5430,9 @@ } }, "node_modules/@types/react-transition-group": { - "version": "4.4.7", - "resolved": "https://registry.npmjs.org/@types/react-transition-group/-/react-transition-group-4.4.7.tgz", - "integrity": "sha512-ICCyBl5mvyqYp8Qeq9B5G/fyBSRC0zx3XM3sCC6KkcMsNeAHqXBKkmat4GqdJET5jtYUpZXrxI5flve5qhi2Eg==", + "version": "4.4.9", + "resolved": "https://registry.npmjs.org/@types/react-transition-group/-/react-transition-group-4.4.9.tgz", + "integrity": "sha512-ZVNmWumUIh5NhH8aMD9CR2hdW0fNuYInlocZHaZ+dgk/1K49j1w/HoAuK1ki+pgscQrOFRTlXeoURtuzEkV3dg==", "dependencies": { "@types/react": "*" } @@ -15102,29 +15200,30 @@ "integrity": "sha512-6qE4B9deFBIa9YSpOc9O0Sgc43zTeVYbgDT5veRKSlB2+ZuHNoVVxA1L/ckMUayV9Ay9y7Z/SZCLcGteW9i7bg==" }, "node_modules/material-react-table": { - "version": "1.15.0", - "resolved": "https://registry.npmjs.org/material-react-table/-/material-react-table-1.15.0.tgz", - "integrity": "sha512-f59XPZ+jFErRAs3ym3cHsK6kBLCrYJGX6GoF473V1/gCpsNbkWEEdmCVMpB8ycOUNDEXtnRDMZzk3LjTMd6wpg==", + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/material-react-table/-/material-react-table-2.0.5.tgz", + "integrity": "sha512-axRrqa/2QQ+AO3SiJbOtSyemlHX0S03X+IXW72z344d3LT+u/jsKiAmdWMLTN8ARScYMAN5NgrArujiLEmftSQ==", "dependencies": { "@tanstack/match-sorter-utils": "8.8.4", - "@tanstack/react-table": "8.10.3", - "@tanstack/react-virtual": "3.0.0-beta.60", + "@tanstack/react-table": "8.10.7", + "@tanstack/react-virtual": "3.0.1", "highlight-words": "1.2.2" }, "engines": { - "node": ">=14" + "node": ">=16" }, "funding": { "type": "github", "url": "https://github.com/sponsors/kevinvandy" }, "peerDependencies": { - "@emotion/react": ">=11", - "@emotion/styled": ">=11", - "@mui/icons-material": ">=5", - "@mui/material": ">=5", - "react": ">=17.0", - "react-dom": ">=17.0" + "@emotion/react": ">=11.11", + "@emotion/styled": ">=11.11", + "@mui/icons-material": ">=5.11", + "@mui/material": ">=5.13", + "@mui/x-date-pickers": ">=6.15.0", + "react": ">=18.0", + "react-dom": ">=18.0" } }, "node_modules/math-log2": { diff --git a/webapp/package.json b/webapp/package.json index 1bc82aae3e..e113491672 100644 --- a/webapp/package.json +++ b/webapp/package.json @@ -12,6 +12,7 @@ "@mui/icons-material": "5.14.11", "@mui/lab": "5.0.0-alpha.146", "@mui/material": "5.14.11", + "@mui/x-date-pickers": "6.18.3", "@reduxjs/toolkit": "1.9.6", "@types/d3": "5.16.0", "@types/draft-convert": "2.1.5", @@ -42,7 +43,7 @@ "js-cookie": "3.0.5", "jwt-decode": "3.1.2", "lodash": "4.17.21", - "material-react-table": "1.15.0", + "material-react-table": "2.0.5", "moment": "2.29.4", "notistack": "3.0.1", "os": "0.1.2", diff --git a/webapp/src/components/common/GroupedDataTable/index.tsx b/webapp/src/components/common/GroupedDataTable/index.tsx index 45c23ab55d..17f1df92a1 100644 --- a/webapp/src/components/common/GroupedDataTable/index.tsx +++ b/webapp/src/components/common/GroupedDataTable/index.tsx @@ -5,7 +5,8 @@ import AddIcon from "@mui/icons-material/Add"; import { Button } from "@mui/material"; import DeleteIcon from "@mui/icons-material/Delete"; import ContentCopyIcon from "@mui/icons-material/ContentCopy"; -import MaterialReactTable, { +import { + MaterialReactTable, MRT_RowSelectionState, MRT_ToggleFiltersButton, MRT_ToggleGlobalFilterButton, From 41a29fc2c53563bfbe24d942ebe5cdc6af053f22 Mon Sep 17 00:00:00 2001 From: Hatim Dinia Date: Tue, 5 Dec 2023 16:53:39 +0100 Subject: [PATCH 03/43] fix(ui-common): prevent matrices float values to be converted (#1850) --- .../common/EditableMatrix/index.tsx | 31 +++++++++++-------- .../components/common/EditableMatrix/utils.ts | 15 +++++---- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/webapp/src/components/common/EditableMatrix/index.tsx b/webapp/src/components/common/EditableMatrix/index.tsx index 91bfeb7d81..8050c9bd39 100644 --- a/webapp/src/components/common/EditableMatrix/index.tsx +++ b/webapp/src/components/common/EditableMatrix/index.tsx @@ -13,7 +13,11 @@ import "handsontable/dist/handsontable.min.css"; import MatrixGraphView from "./MatrixGraphView"; import { Root } from "./style"; import "./style.css"; -import { computeStats, createDateFromIndex, slice } from "./utils"; +import { + computeStats, + createDateFromIndex, + cellChangesToMatrixEdits, +} from "./utils"; import Handsontable from "../Handsontable"; const logError = debug("antares:editablematrix:error"); @@ -68,18 +72,19 @@ function EditableMatrix(props: PropTypes) { // Event Handlers //////////////////////////////////////////////////////////////// - const handleSlice = (change: CellChange[], source: string) => { - const isChanged = change.map((item) => { - if (parseFloat(item[2]) === parseFloat(item[3])) { - return; - } - return item; - }); - if (onUpdate) { - const edit = slice( - isChanged.filter((e) => e !== undefined) as CellChange[], - ); - onUpdate(edit, source); + const handleSlice = (changes: CellChange[], source: string) => { + if (!onUpdate) { + return; + } + + const filteredChanges = changes.filter( + ([, , oldValue, newValue]) => + parseFloat(oldValue) !== parseFloat(newValue), + ); + + if (filteredChanges.length > 0) { + const edits = cellChangesToMatrixEdits(filteredChanges); + onUpdate(edits, source); } }; diff --git a/webapp/src/components/common/EditableMatrix/utils.ts b/webapp/src/components/common/EditableMatrix/utils.ts index 020455cd5e..13fa17e6db 100644 --- a/webapp/src/components/common/EditableMatrix/utils.ts +++ b/webapp/src/components/common/EditableMatrix/utils.ts @@ -82,14 +82,13 @@ export const createDateFromIndex = ( return date; }; -export const slice = (tab: CellChange[]): MatrixEditDTO[] => { - return tab.map((cell) => { - return { - coordinates: [[cell[0] as number, (cell[1] as number) - 1]], - operation: { operation: Operator.EQ, value: parseInt(cell[3], 10) }, - }; - }); -}; +export const cellChangesToMatrixEdits = ( + cellChanges: CellChange[], +): MatrixEditDTO[] => + cellChanges.map(([row, column, , value]) => ({ + coordinates: [[row, (column as number) - 1]], + operation: { operation: Operator.EQ, value: parseFloat(value) }, + })); export const computeStats = ( statsType: string, From 243e0164012556dfc3bc16fe2a5ec29588819df2 Mon Sep 17 00:00:00 2001 From: Hatim Dinia Date: Wed, 6 Dec 2023 09:35:39 +0100 Subject: [PATCH 04/43] feat(ui-modelization): add dynamic area selection on Areas tab click (#1835) ## Dynamic Handling of Area Selection in `Modelization` Component ### Description: This PR introduces a significant change in the way areas are handled within the `Modelization`. The current system has a default area selection mechanism when navigating to the map view. However, this approach has led to a couple of issues: - Inability to Access the Areas View without an Active Selection: Previously, if a user navigated to the map view and deselected the currently active area, they found themselves unable to navigate back to the areas view unless they selected an area again. This behavior was unintuitive and potentially frustrating for users who might want to browse different areas or had no specific area to focus on initially. - Errors When Creating New Layers with a Preselected Area: The existing system also posed a problem when creating new layers while an area was selected. If the selected area was not present in the newly created layer, the application attempted to retrieve UI information for the non-existent area in the context of that layer, leading to errors. ### Solution: To address these issues, we have implemented a dynamic area handling mechanism. Now, the area is determined dynamically when the "Areas" tab is clicked. The key changes include: - No Default Area on Initial Render: No area is preselected. - Dynamic Area Selection on Tab Click: The `"Areas"` tab click now triggers a dynamic process to determine and set the current area. If no area is currently selected (i.e., `areaId` is empty), the first area from the list of available areas is automatically selected and dispatched to the Redux store. This selection is then reflected in the navigation, taking the user to the view for the selected area. --- .../Modelization/Areas/Hydro/index.tsx | 2 +- .../explore/Modelization/Areas/Hydro/style.ts | 1 + .../explore/Modelization/index.tsx | 60 ++++++++++++------- .../App/Singlestudy/explore/TabWrapper.tsx | 42 ++++++++----- webapp/src/redux/ducks/studySyntheses.ts | 9 +-- 5 files changed, 68 insertions(+), 46 deletions(-) diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Hydro/index.tsx b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Hydro/index.tsx index 42a84ca56c..8b65d57f44 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Hydro/index.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Hydro/index.tsx @@ -65,7 +65,7 @@ function Hydro() { return ( - + ); } diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Hydro/style.ts b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Hydro/style.ts index 3103e60f79..0bd3dab7d4 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Hydro/style.ts +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Hydro/style.ts @@ -4,6 +4,7 @@ export const Root = styled(Box)(({ theme }) => ({ width: "100%", height: "100%", padding: theme.spacing(2), + paddingTop: 0, display: "flex", overflow: "auto", })); diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/index.tsx b/webapp/src/components/App/Singlestudy/explore/Modelization/index.tsx index 5d81c34155..d298c94a31 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/index.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/index.tsx @@ -1,57 +1,73 @@ import { useMemo } from "react"; -import { useOutletContext } from "react-router-dom"; +import { useNavigate, useOutletContext } from "react-router-dom"; import { Box } from "@mui/material"; import { useTranslation } from "react-i18next"; import { StudyMetadata } from "../../../../../common/types"; import TabWrapper from "../TabWrapper"; import useAppSelector from "../../../../../redux/hooks/useAppSelector"; -import { getCurrentAreaId } from "../../../../../redux/selectors"; +import { getAreas, getCurrentAreaId } from "../../../../../redux/selectors"; +import useAppDispatch from "../../../../../redux/hooks/useAppDispatch"; +import { setCurrentArea } from "../../../../../redux/ducks/studySyntheses"; function Modelization() { const { study } = useOutletContext<{ study: StudyMetadata }>(); - const areaId = useAppSelector(getCurrentAreaId); const [t] = useTranslation(); + const dispatch = useAppDispatch(); + const navigate = useNavigate(); + const areas = useAppSelector((state) => getAreas(state, study.id)); + const areaId = useAppSelector(getCurrentAreaId); + + const tabList = useMemo(() => { + const basePath = `/studies/${study.id}/explore/modelization`; - const tabList = useMemo( - () => [ + const handleAreasClick = () => { + if (areaId.length === 0 && areas.length > 0) { + const firstAreaId = areas[0].id ?? null; + + if (firstAreaId) { + dispatch(setCurrentArea(firstAreaId)); + navigate(`${basePath}/area/${firstAreaId}`, { replace: true }); + } + } + }; + + return [ { label: t("study.modelization.map"), - path: `/studies/${study?.id}/explore/modelization/map`, + path: `${basePath}/map`, }, { label: t("study.areas"), - path: `/studies/${study?.id}/explore/modelization/area/${areaId}`, + path: `${basePath}/area/${areaId}`, + onClick: handleAreasClick, }, { label: t("study.links"), - path: `/studies/${study?.id}/explore/modelization/links`, + path: `${basePath}/links`, }, { label: t("study.bindingconstraints"), - path: `/studies/${study?.id}/explore/modelization/bindingcontraint`, + path: `${basePath}/bindingcontraint`, }, { label: t("study.debug"), - path: `/studies/${study?.id}/explore/modelization/debug`, + path: `${basePath}/debug`, }, { label: t("study.modelization.tableMode"), - path: `/studies/${study?.id}/explore/modelization/tablemode`, + path: `${basePath}/tablemode`, }, - ], - [areaId, study?.id, t], - ); + ]; + }, [areaId, areas, dispatch, navigate, study?.id, t]); return ( diff --git a/webapp/src/components/App/Singlestudy/explore/TabWrapper.tsx b/webapp/src/components/App/Singlestudy/explore/TabWrapper.tsx index 7a61f90fb6..482801c482 100644 --- a/webapp/src/components/App/Singlestudy/explore/TabWrapper.tsx +++ b/webapp/src/components/App/Singlestudy/explore/TabWrapper.tsx @@ -1,5 +1,5 @@ /* eslint-disable react/jsx-props-no-spreading */ -import { useEffect } from "react"; +import { useEffect, useState } from "react"; import * as React from "react"; import { styled, SxProps, Theme } from "@mui/material"; import Tabs from "@mui/material/Tabs"; @@ -28,19 +28,32 @@ export const StyledTab = styled(Tabs, { }), ); +interface TabItem { + label: string; + path: string; + onClick?: () => void; +} + interface Props { study: StudyMetadata | undefined; - tabList: Array<{ label: string; path: string }>; + tabList: TabItem[]; border?: boolean; tabStyle?: "normal" | "withoutBorder"; sx?: SxProps; + isScrollable?: boolean; } -function TabWrapper(props: Props) { - const { study, tabList, border, tabStyle, sx } = props; +function TabWrapper({ + study, + tabList, + border, + tabStyle, + sx, + isScrollable = false, +}: Props) { const location = useLocation(); const navigate = useNavigate(); - const [selectedTab, setSelectedTab] = React.useState(0); + const [selectedTab, setSelectedTab] = useState(0); useEffect(() => { const getTabIndex = (): number => { @@ -66,6 +79,11 @@ function TabWrapper(props: Props) { const handleChange = (event: React.SyntheticEvent, newValue: number) => { setSelectedTab(newValue); navigate(tabList[newValue].path); + + const onTabClick = tabList[newValue].onClick; + if (onTabClick) { + onTabClick(); + } }; //////////////////////////////////////////////////////////////// @@ -87,16 +105,15 @@ function TabWrapper(props: Props) { )} > {tabList.map((tab) => ( @@ -108,9 +125,4 @@ function TabWrapper(props: Props) { ); } -TabWrapper.defaultProps = { - border: undefined, - tabStyle: "normal", -}; - export default TabWrapper; diff --git a/webapp/src/redux/ducks/studySyntheses.ts b/webapp/src/redux/ducks/studySyntheses.ts index 9aad642d31..a7ae6cdc52 100644 --- a/webapp/src/redux/ducks/studySyntheses.ts +++ b/webapp/src/redux/ducks/studySyntheses.ts @@ -87,14 +87,7 @@ const initDefaultAreaLinkSelection = ( studyData?: FileStudyTreeConfigDTO, ): void => { if (studyData) { - // Set current area - const areas = Object.keys(studyData.areas); - if (areas.length > 0) { - dispatch(setCurrentArea(areas[0])); - } else { - dispatch(setCurrentArea("")); - } - + dispatch(setCurrentArea("")); dispatch(setCurrentLink("")); } else { dispatch(setCurrentArea("")); From bf9bcb3eb995d31dd8afcdce284c22f4edd34f98 Mon Sep 17 00:00:00 2001 From: hatim dinia Date: Tue, 28 Nov 2023 14:17:31 +0100 Subject: [PATCH 05/43] fix(ui-studies): resolve explore button visibility issue in StudyCard --- .../src/components/App/Studies/StudyCard.tsx | 66 ++++++++----------- 1 file changed, 28 insertions(+), 38 deletions(-) diff --git a/webapp/src/components/App/Studies/StudyCard.tsx b/webapp/src/components/App/Studies/StudyCard.tsx index d6f6e91dd6..ce59e316ad 100644 --- a/webapp/src/components/App/Studies/StudyCard.tsx +++ b/webapp/src/components/App/Studies/StudyCard.tsx @@ -16,6 +16,7 @@ import { ListItemText, Tooltip, Chip, + Divider, } from "@mui/material"; import { styled } from "@mui/material/styles"; import { indigo } from "@mui/material/colors"; @@ -218,7 +219,13 @@ const StudyCard = memo((props: Props) => { )} { flexFlow: "nowrap", px: 0.5, paddingBottom: 0.5, + width: "90%", + whiteSpace: "nowrap", + textOverflow: "ellipsis", + overflow: "hidden", }} > {study.folder} @@ -308,8 +319,6 @@ const StudyCard = memo((props: Props) => { sx={{ display: "flex", maxWidth: "65%", - flexDirection: "row", - justifyContent: "flex-start", alignItems: "center", }} > @@ -321,38 +330,19 @@ const StudyCard = memo((props: Props) => { - + {buildModificationDate(study.modificationDate, t, i18n.language)} - - - - - + + {study.owner.name} + + {`v${displayVersionName(study.version)}`} - {`v${displayVersionName(study.version)}`} { icon={} label={t("studies.variant").toLowerCase()} color="primary" + size="small" /> )} - {study.tags && - study.tags.map((elm) => ( - - ))} + {study.tags?.map((tag) => ( + + ))} From 7131b4f09d385ae65373c89ca76fbf306d81c2da Mon Sep 17 00:00:00 2001 From: hatim dinia Date: Tue, 28 Nov 2023 14:18:06 +0100 Subject: [PATCH 06/43] feat(ui-studies): make study title clickable --- .../src/components/App/Studies/StudyCard.tsx | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/webapp/src/components/App/Studies/StudyCard.tsx b/webapp/src/components/App/Studies/StudyCard.tsx index ce59e316ad..976ee7acd8 100644 --- a/webapp/src/components/App/Studies/StudyCard.tsx +++ b/webapp/src/components/App/Studies/StudyCard.tsx @@ -1,5 +1,5 @@ import { memo, useState } from "react"; -import { NavLink } from "react-router-dom"; +import { NavLink, useNavigate } from "react-router-dom"; import { AxiosError } from "axios"; import { useSnackbar } from "notistack"; import { useTranslation } from "react-i18next"; @@ -94,6 +94,7 @@ const StudyCard = memo((props: Props) => { const study = useAppSelector((state) => getStudy(state, id)); const isFavorite = useAppSelector((state) => isStudyFavorite(state, id)); const dispatch = useAppDispatch(); + const navigate = useNavigate(); //////////////////////////////////////////////////////////////// // Event Handlers @@ -243,6 +244,7 @@ const StudyCard = memo((props: Props) => { noWrap variant="h6" component="div" + onClick={() => navigate(`/studies/${study.id}`)} sx={{ color: "white", boxSizing: "border-box", @@ -251,6 +253,11 @@ const StudyCard = memo((props: Props) => { whiteSpace: "nowrap", textOverflow: "ellipsis", overflow: "hidden", + cursor: "pointer", + "&:hover": { + color: "primary.main", + textDecoration: "underline", + }, }} > {study.name} @@ -338,12 +345,20 @@ const StudyCard = memo((props: Props) => { {buildModificationDate(study.modificationDate, t, i18n.language)} - - {study.owner.name} - {`v${displayVersionName(study.version)}`} + + + {study.owner.name} + { flexWrap: "wrap", justifyContent: "flex-start", alignItems: "center", - overflowX: "hidden", - overflowY: "auto", + gap: 0.5, ".MuiChip-root": { color: "black", From 39803ffd9f296ff3915b8c03d5c1e9c5c4d9e624 Mon Sep 17 00:00:00 2001 From: hatim dinia Date: Tue, 5 Dec 2023 16:20:58 +0100 Subject: [PATCH 07/43] feat(ui): add manual submit on clusters form --- .../Modelization/Areas/Renewables/Form.tsx | 27 +++++++++---------- .../Modelization/Areas/Storages/Form.tsx | 27 +++++++++---------- .../Modelization/Areas/Thermal/Form.tsx | 27 +++++++++---------- 3 files changed, 36 insertions(+), 45 deletions(-) diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Renewables/Form.tsx b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Renewables/Form.tsx index 0f98902d24..30dc1bc7b7 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Renewables/Form.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Renewables/Form.tsx @@ -65,24 +65,21 @@ function RenewablesForm() { key={study.id + areaId} config={{ defaultValues }} onSubmit={handleSubmit} - autoSubmit + enableUndoRedo > - - - + + + ); } diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Form.tsx b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Form.tsx index 0314fc0df1..40aa166664 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Form.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Form.tsx @@ -79,24 +79,21 @@ function StorageForm() { defaultValues, }} onSubmit={handleSubmit} - autoSubmit + enableUndoRedo > - - - + + + ); } diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Thermal/Form.tsx b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Thermal/Form.tsx index ee16dbbf6d..810338b581 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Thermal/Form.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Thermal/Form.tsx @@ -63,24 +63,21 @@ function ThermalForm() { key={study.id + areaId} config={{ defaultValues }} onSubmit={handleSubmit} - autoSubmit + enableUndoRedo > - - - + + + ); } From 6edef65124a781436b8b5644e8d7c20502749b70 Mon Sep 17 00:00:00 2001 From: hatim dinia Date: Wed, 6 Dec 2023 10:26:41 +0100 Subject: [PATCH 08/43] feat(ui-thermal): add steps on volatility fields --- .../Singlestudy/explore/Modelization/Areas/Thermal/Fields.tsx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Thermal/Fields.tsx b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Thermal/Fields.tsx index 3253167d5a..cf5cb2fc66 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Thermal/Fields.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Thermal/Fields.tsx @@ -233,6 +233,7 @@ function Fields() { message: t("form.field.maxValue", { 0: 1 }), }, }} + inputProps={{ step: 0.1 }} /> Date: Thu, 7 Dec 2023 13:38:53 +0100 Subject: [PATCH 09/43] doc(config): enhance application configuration documentation (#1710) Co-authored-by: Mohamed Abdel Wedoud --- docs/install/1-CONFIG.md | 695 +++++++++++++++++++++++++++++- resources/application.yaml | 72 +--- resources/deploy/config.prod.yaml | 87 ---- resources/deploy/config.yaml | 69 +-- scripts/package_antares_web.sh | 4 +- tests/core/test_config.py | 253 ----------- 6 files changed, 714 insertions(+), 466 deletions(-) diff --git a/docs/install/1-CONFIG.md b/docs/install/1-CONFIG.md index 7c82b044dd..8af3e22615 100644 --- a/docs/install/1-CONFIG.md +++ b/docs/install/1-CONFIG.md @@ -1,11 +1,690 @@ -# Application Configuration +# Application Configuration Documentation -Almost all the configuration of the application can be found in the -[application.yaml](https://github.com/AntaresSimulatorTeam/AntaREST/blob/master/resources/application.yaml) file. -If the path to this configuration file is not explicitly provided (through the `-c` option), -the application will try to look for files in the following location (in order): +In the following, we will be exploring how to edit your application configuration file.
+As explained in the main documentation readme file, you can use the following command line +to start the API: - 1. `./config.yaml` - 2. `../config.yaml` - 3. `$HOME/.antares/config.yaml` +```shell +python3 antarest/main.py -c resources/application.yaml --auto-upgrade-db --no-front +``` +The `-c` option here describes the path towards the configuration `.yaml` file. If this option is +not fed to the program, it will to look for files in the following locations (in order): + +1. `./config.yaml` +2. `../config.yaml` +3. `$HOME/.antares/config.yaml` + +
+In this documentation, you will have a global overview of the configuration +file structure and details for each of the `.yaml` fields with specifications regarding +type of data and the default values, and descriptions of those fields. + + +# File Structure + +- [Security](#security) +- [Database](#db) +- [Storage](#storage) +- [Launcher](#launcher) +- [Logging](#logging) +- [Root Path](#root_path) +- [Optional sections](#debug) + +# security + +This section defines the settings for application security, authentication, and groups. + +## **disabled** + +- **Type:** Boolean +- **Default value:** false +- **Description:** If set to `false`, user identification will be required when launching the app. + +## **jwt** + +### **key** + +- **Type:** String (usually a Base64-encoded one) +- **Default value:** "" +- **Description:** JWT (Json Web Token) secret key for authentication. + +## **login** + +### **admin** + +#### **pwd** + +- **Type:** String +- **Default value:** "" +- **Description:** Admin user's password. + +## **external_auth** + +This subsection is about setting up an external authentication service that lets you connect to an LDAP using a web +service. The group names and their IDs are obtained from the LDAP directory. + +### **url** + +- **Type:** String +- **Default value:** "" +- **Description:** External authentication URL. If you want to enable local authentication, you should write "". + +### **default_group_role** + +- **Type:** Integer +- **Default value:** 10 +- **Description:** Default user role for external authentication + - `ADMIN = 40` + - `WRITER = 30` + - `RUNNER = 20` + - `READER = 10` + +### **add_ext_groups** + +- **Type:** Boolean +- **Default value:** false +- **Description:** Whether to add external groups to user roles. + +### **group_mapping** + +- **Type:** Dictionary +- **Default value:** {} +- **Description:** Groups of the application: Keys = Ids, Values = Names. Example: + - 00000001: espace_commun + - 00001188: drd + - 00001574: cnes + +```yaml +# example for security settings +security: + disabled: false + jwt: + key: best-key + login: + admin: + pwd: root + external_auth: + url: "" + default_group_role: 10 + group_mapping: + id_ext: id_int + add_ext_groups: false +``` + +# db + +This section presents the configuration of application's database connection. + +## **url** + +- **Type:** String +- **Default value:** "" +- **Description:** The Database URL. For example, `sqlite:///database.db` for a local SQLite DB + or `postgresql://postgres_user:postgres_password@postgres_host:postgres_port/postgres_db` for a PostgreSQL DB. + +## **admin_url** + +- **Type:** String +- **Default value:** None +- **Description:** The URL you can use to directly access your database. + +## **pool_use_null** + +- **Type:** Boolean +- **Default value:** false +- **Description:** If set to `true`, connections are not pooled. This parameter should be kept at `false` to avoid + issues. + +## **db_connect_timeout** + +- **Type:** Integer +- **Default value:** 10 +- **Description:** The timeout (in seconds) for database connection creation. + +## **pool_recycle** + +- **Type:** Integer +- **Default value:** None +- **Description:** Prevents the pool from using a particular connection that has passed a certain time in seconds. An + often-used value is 3600, which corresponds to an hour. *Not used for SQLite DB.* + +## **pool_size** + +- **Type:** Integer +- **Default value:** 5 +- **Description:** The maximum number of permanent connections to keep. *Not used for SQLite DB.* + +## **pool_use_lifo** + +- **Type:** Boolean +- **Default value:** false +- **Description:** Specifies whether the Database should use the Last-in-First-out method. It is commonly used in cases + where the most recent data entry is the most important and applies to the application context. Therefore, it's better + to set this parameter to `true`. *Not used for SQLite DB.* + +## **pool_pre_ping** + +- **Type:** Boolean +- **Default value:** false +- **Description:** Connections that are closed from the server side are gracefully handled by the connection pool and + replaced with a new connection. *Not used for SQLite DB.* + +## **pool_max_overflow** + +- **Type:** Integer +- **Default value:** 10 +- **Description:** Temporarily exceeds the set pool_size if no connections are available. *Not used for SQLite DB.* + +```yaml +# example for db settings +db: + url: "postgresql://postgres:My:s3Cr3t/@127.0.0.1:30432/antares" + admin_url: "postgresql://{{postgres_owner}}:{{postgres_owner_password}}@{{postgres_host}}:{{postgres_port}}/{{postgres_db}}" + pool_recycle: 3600 + pool_max_overflow: 10 + pool_size: 5 + pool_use_lifo: true + pool_use_null: false +``` + +# storage + +The following section configuration parameters define the application paths and services options. + +## **tmp_dir** + +- **Type:** Path +- **Default value:** `tempfile.gettempdir()` ( + documentation [here](https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir)) +- **Description:** The temporary directory for storing temporary files. An often-used value is `./tmp`. + +## **matrixstore** + +- **Type:** Path +- **Default value:** `./matrixstore` +- **Description:** Antares Web extracts matrices data and shares them between managed studies to save space. These + matrices are stored here. + +## **archive_dir** + +- **Type:** Path +- **Default value:** `./archives` +- **Description:** The directory for archived (zipped) studies. + +## **workspaces** + +- **Type:** Dictionary +- **Default value:** {} +- **Description:** Different workspaces where the application expects to find studies. Keys = Folder names, Values = + WorkspaceConfig object. Such an object has 4 fields: + - `groups`: List of groups corresponding to the workspace (default []) + - `path`: Path of the workspace (default `Path()`) + - `filter_in`: List of regex. If a folder does not contain a file whose name matches one of the regex, it's not + scanned (default [".*"]) + - `filter_out`: List of regex. If a folder contains any file whose name matches one of the regex, it's not scanned ( + default []) + +> NOTE: If a directory is to be ignored by the watcher, place a file named `AW_NO_SCAN` inside. + +Examples: + +```yaml +default: + path: /home/john/Projects/antarest_data/internal_studies/ +studies: + path: /home/john/Projects/antarest_data/studies/ +staging_studies: + path: /home/john/Projects/antarest_data/staging_studies/ +``` + +```yaml +default: + path: /studies/internal +"public": + path: /mounts/public + filter_in: + - .* + filter_out: + - ^R$ + - System Volume Information + - .*RECYCLE.BIN + - .Rproj.* + - ^.git$ + - ^areas$ +"aws_share_2": + path: /mounts/aws_share_2 + groups: + - test +"sedre_archive": + path: /mounts/sedre_archive + groups: + - sedre +``` + +## **allow_deletion** + +- **Type:** Boolean +- **Default value:** false +- **Description:** Indicates if studies found in non-default workspace can be deleted by the application. + +## **matrix_gc_sleeping_time** + +- **Type:** Integer +- **Default value:** 3600 (corresponds to 1 hour) +- **Description:** Time in seconds to sleep between two garbage collections (which means matrix suppression). + +## **matrix_gc_dry_run** + +- **Type:** Boolean +- **Default value:** false +- **Description:** If `true`, matrices will never be removed. Else, the ones that are unused will. + +## **auto_archive_sleeping_time** + +- **Type:** Integer +- **Default value:** 3600 (corresponds to 1 hour) +- **Description:** Time in seconds to sleep between two auto_archiver tasks (which means zipping unused studies). + +## **auto_archive_dry_run** + +- **Type:** Boolean +- **Default value:** false +- **Description:** If `true`, studies will never be archived. Else, the ones that no one has accessed for a while will. + +## **auto_archive_threshold_days** + +- **Type:** Integer +- **Default value:** 60 +- **Description:** Number of days after the last study access date before it should be archived. + +## **auto_archive_max_parallel** + +- **Type:** Integer +- **Default value:** 5 +- **Description:** Max auto archival tasks in parallel. + +## **watcher_lock** + +- **Type:** Boolean +- **Default value:** true +- **Description:** If false, it will scan without any delay. Else, its delay will be the value of the + field `watcher_lock_delay`. + +## **watcher_lock_delay** + +- **Type:** Integer +- **Default value:** 10 +- **Description:** Seconds delay between two scans. + +## **download_default_expiration_timeout_minutes** + +- **Type:** Integer +- **Default value:** 1440 (corresponds to 1 day) +- **Description:** Minutes before your study download will be cleared. The value could be less than the default one as a + user should download his study pretty soon after the download becomes available. + +```yaml +# example for storage settings +storage: + tmp_dir: /home/jon/Projects/antarest_data/tmp + matrixstore: /home/jon/Projects/antarest_data/matrices + archive_dir: /home/jon/Projects/antarest_data/archives + allow_deletion: false + matrix_gc_sleeping_time: 3600 + matrix_gc_dry_run: False + workspaces: + default: + path: /home/jon/Projects/antarest_data/internal_studies/ + studies: + path: /home/jon/Projects/antarest_data/studies/ + staging_studies: + path: /home/jon/Projects/antarest_data/staging_studies/ +``` + +# launcher + +This section provides the launcher with specified options and defines the settings for solver binaries. + +## **default** + +- **Type:** String, possible values: `local` or `slurm` +- **Default value:** `local` +- **Description:** Default launcher configuration, if set to `local` then the launcher is defined locally. Otherwise +it is instantiated on shared servers using `slurm`. + +## **local** + +### **enable_nb_cores_detection** + +- **Type:** Boolean +- **Default value:** false +- **Description:** Enables detection of available CPUs for the solver. If so, the default value used will be `max(1, + multiprocessing.cpu_count() - 2)`. Else, it will be 22. To maximize the solver's performance, it is recommended to + activate this option. + +### **binaries** + +- **Type:** Dictionary +- **Default value:** {} +- **Description:** Binary paths for various versions of the launcher. Example: + +```yaml +700: /home/john/Antares/antares_web_data/antares-solver/antares-8.0-solver +800: /home/john/Antares/antares_web_data/antares-solver/antares-8.0-solver +810: /home/john/Antares/antares_web_data/antares-solver/antares-8.3-solver +820: /home/john/Antares/antares_web_data/antares-solver/antares-8.3-solver +830: /home/john/Antares/antares_web_data/antares-solver/antares-8.3-solver +840: /home/john/Antares/antares_web_data/antares-solver/antares-8.4-solver +850: /home/john/Antares/antares_web_data/antares-solver/antares-8.5-solver +860: /home/john/Antares/antares_web_data/antares-solver/antares-8.6-solver +``` + +> NOTE: As you can see, you can use newer solver for older study version thanks to the solver retro-compatibility + +## **slurm** + +SLURM (Simple Linux Utility for Resource Management) is used to interact with a remote environment (for Antares it's +computing server) as a workload manager. + +### **local_workspace** + +- **Type:** Path +- **Default value:** Path +- **Description:** Path to the local SLURM workspace + +### **username** + +- **Type:** String +- **Default value:** "" +- **Description:** Username for SLURM to connect itself with SSH protocol to computing server. + +### **hostname** + +- **Type:** String +- **Default value:** "" +- **Description:** IP address for SLURM to connect itself with SSH protocol to computing server. + +### **port** + +- **Type:** Integer +- **Default value:** 0 +- **Description:** SSH port for SLURM + +Examples: + +- Options to connect SLURM to computing server `prod-server-name` (production): + +``` +username: run-antares +hostname: XX.XXX.XXX.XXX +port: 22 +``` + +- Options to connect SLURM to computing server `dev-server-name` (recette and integration): + +``` +username: dev-antares +hostname: XX.XXX.XXX.XXX +port: 22 +``` + +### **private_key_file** + +- **Type:** Path +- **Default value:** Path() +- **Description:** SSH private key file. If you do not have one, you have to fill the `password` field. + +### **password** + +- **Type:** String +- **Default value:** "" +- **Description:** SSH password for the remote server. You need it or a private key file for SLURM to connect itself. + +### **key_password** + +- **Type:** String +- **Default value:** "" +- **Description:** An optional password to use to decrypt the key file, if it's encrypted + +### **default_wait_time** + +> NOTE: Deprecated as the app is launched with wait_mode=false* + +- **Type:** Integer +- **Default value:** 0 +- **Description:** Default delay (in seconds) of the SLURM loop checking the status of the tasks and recovering those + completed in the loop. Often used value: 900 (15 minutes) + +### **default_time_limit** + +- **Type:** Integer +- **Default value:** 0 +- **Description:** Time limit for SLURM jobs (in seconds). If a jobs exceed this time limit, SLURM kills the job and it + is considered failed. Often used value: 172800 (48 hours) + +### **enable_nb_cores_detection** + +- **Type:** Boolean +- **Default value:** false +- **Description:** Enables detection of available CPUs for the solver (Not implemented yet). + +### **nb_cores** + +#### **min** + +- **Type:** Integer +- **Default value:** 1 +- **Description:** Minimum amount of CPUs to use when launching a simulation. + +#### **default** + +- **Type:** Integer +- **Default value:** 22 +- **Description:** Default amount of CPUs to use when launching a simulation. The user can override this value in the + launch dialog box. + +#### **max** + +- **Type:** Integer +- **Default value:** 24 +- **Description:** Maximum amount of CPUs to use when launching a simulation. + +### **default_json_db_name** + +- **Type:** String +- **Default value:** "" +- **Description:** SLURM local DB name. Often used value : `launcher_db.json` + +### **slurm_script_path** + +- **Type:** String +- **Default value:** "" +- **Description:** Bash script path to execute on remote server. + - If SLURM is connected to `prod-server-name` (*production*), use this path: `/applis/antares/launchAntares.sh` + - If SLURM is connected to `dev-server-name` (*recette* and *integration*), use this + path: `/applis/antares/launchAntaresRec.sh` + +### **antares_versions_on_remote_server** + +- **Type:** List of String +- **Default value:** [] +- **Description:** List of Antares solver versions available on the remote server. Examples: + +```yaml +# example for launcher settings +launcher: + default: local + local: + binaries: + 860: /home/jon/opt/antares-solver_ubuntu20.04/antares-8.6-solver + slurm: + local_workspace: /home/jon/Projects/antarest_data/slurm_workspace + username: jon + hostname: localhost + port: 22 + private_key_file: /home/jon/.ssh/id_rsa + key_password: + default_wait_time: 900 + default_time_limit: 172800 + default_n_cpu: 20 + default_json_db_name: launcher_db.json + slurm_script_path: /applis/antares/launchAntares.sh + db_primary_key: name + antares_versions_on_remote_server: + - '610' + - '700' +``` + +# Logging + +This section sets the configuration for the application logs. + +## **level** + +- **Type:** String, possible values: "DEBUG", "INFO", "WARNING", "ERROR" +- **Default value:** `INFO` +- **Description:** The logging level of the application (INFO, DEBUG, etc.). + +## **logfile** + +- **Type:** Path +- **Default value:** None +- **Description:** The path to the application log file. An often-used value is `.tmp/antarest.log`. + +## **json** + +- **Type:** Boolean +- **Default value:** false +- **Description:** If `true`, the logging format will be `json`; otherwise, it is `console`. + - `console`: The default format used for console output, suitable for Desktop versions or development environments. + - `json`: A specific JSON format suitable for consumption by monitoring tools via a web service. + +```yaml +# example for logging settings +logging: + level: INFO + logfile: ./tmp/antarest.log + json: false +``` + +# root_path + +- **Type:** String +- **Default value:** "" +- **Description:** The root path for FastAPI. To use a remote server, use `/api`, and for a local environment: `api`. + +```yaml +# example for root_path settings +root_path: "/{root_path}" + +``` + +## `Extra optional configuration` + +# debug + +- **Type:** Boolean +- **Default value:** false +- **Description:** This flag determines whether the engine will log all the SQL statements it executes to the console. + If you turn this on by setting it to `true`, you'll see a detailed log of the database queries. + +```yaml +# example for debug settings +debug: false +``` + +# cache + +## **checker_delay** + +- **Type:** Float +- **Default value:** 0.2 +- **Description:** The time in seconds to sleep before checking what needs to be removed from the cache. + +```yaml +# example for cache settings +cache: + checker_delay: 0.2 +``` + +# tasks + +## **max_workers** + +- **Type:** Integer +- **Default value:** 5 +- **Description:** The number of threads for Tasks in the ThreadPoolExecutor. + +## **remote_workers** + +- **Type:** List +- **Default value:** [] +- **Description:** Example: + +```yaml +# example for tasks settings +tasks: + max_workers: 4 + remote_workers: + - name: aws_share_2 + queues: + - unarchive_aws_share_2 + - name: simulator_worker + queues: + - generate-timeseries + - generate-kirshoff-constraints +``` + +# server + +## **worker_threadpool_size** + +- **Type:** Integer +- **Default value:** 5 +- **Description:** The number of threads of the Server in the `ThreadPoolExecutor`. + +## **services** + +- **Type:** List of Strings +- **Default value:** [] +- **Description:** Services to enable when launching the application. Possible values: "watcher," "matrix_gc," " + archive_worker," "auto_archiver," "simulator_worker." + +```yaml +#example for server settings +server: + worker_threadpool_size: 5 + services: + - watcher + - matrix_gc +``` + +# redis + +This section is for the settings of Redis backend, which is used for managing the event bus and in-memory caching. + +## **host** + +- **Type:** String +- **Default value:** `localhost` +- **Description:** The Redis server hostname. + +## **port** + +- **Type:** Integer +- **Default value:** 6379 +- **Description:** The Redis server port. + +## **password** + +- **Type:** String +- **Default value:** None +- **Description:** The Redis password. + +```yaml +# example for redis settings +redis: + host: localhost + port: 9862 +``` \ No newline at end of file diff --git a/resources/application.yaml b/resources/application.yaml index a85357634f..6fbdb31f9f 100644 --- a/resources/application.yaml +++ b/resources/application.yaml @@ -1,45 +1,22 @@ +# Documentation about this file can be found in this file: `docs/install/1-CONFIG.md` + security: disabled: true jwt: key: super-secret - login: - admin: - pwd: admin - external_auth: - url: "" - default_group_role: 10 -# group_mapping: -# id_ext: id_int -# ... - add_ext_groups: false - - db: url: "sqlite:///database.db" - #pool_recycle: storage: tmp_dir: ./tmp matrixstore: ./matrices archive_dir: ./examples/archives - allow_deletion: false # indicate if studies found in non default workspace can be deleted by the application - #matrix_gc_sleeping_time: 3600 # time in seconds to sleep between two garbage collection - #matrix_gc_dry_run: False # Skip matrix effective deletion - #auto_archive_sleeping_time: 3600 # time in seconds to sleep between two auto archival checks - #auto_archive_dry_run: True # Skip auto archive effective archival - #auto_archive_threshold_days: 60 # number of days after last study access when the study should be archived - #auto_archive_max_parallel: 5 # max auto archival tasks in parallel workspaces: - default: # required, no filters applied, this folder is not watched + default: path: ./examples/internal_studies/ - # other workspaces can be added - # if a directory is to be ignored by the watcher, place a file named AW_NO_SCAN inside - tmp: + studies: path: ./examples/studies/ - # filter_in: ['.*'] # default to '.*' - # filter_out: [] # default to empty - # groups: [] # default empty launcher: default: local @@ -49,39 +26,8 @@ launcher: 700: path/to/700 enable_nb_cores_detection: true -# slurm: -# local_workspace: path/to/workspace -# username: username -# hostname: 0.0.0.0 -# port: 22 -# private_key_file: path/to/key -# key_password: key_password -# password: password_is_optional_but_necessary_if_key_is_absent -# default_wait_time: 900 -# default_time_limit: 172800 -# enable_nb_cores_detection: False -# nb_cores: -# min: 1 -# default: 22 -# max: 24 -# default_json_db_name: launcher_db.json -# slurm_script_path: /path/to/launchantares_v1.1.3.sh -# db_primary_key: name -# antares_versions_on_remote_server : -# - "610" -# - "700" -# - "710" -# - "720" -# - "800" - - -debug: true - root_path: "api" -#tasks: -# max_workers: 5 - server: worker_threadpool_size: 12 services: @@ -90,12 +36,4 @@ server: logging: level: INFO - logfile: ./tmp/antarest.log -# json: false - -# Uncomment these lines to use redis as a backend for the eventbus -# It is required to use redis when using this application on multiple workers in a preforked model like gunicorn for instance -#eventbus: -# redis: -# host: localhost -# port: 6379 + logfile: ./tmp/antarest.log \ No newline at end of file diff --git a/resources/deploy/config.prod.yaml b/resources/deploy/config.prod.yaml index 02fbb4b8bc..e69de29bb2 100644 --- a/resources/deploy/config.prod.yaml +++ b/resources/deploy/config.prod.yaml @@ -1,87 +0,0 @@ -security: - disabled: false - jwt: - key: secretkeytochange - login: - admin: - pwd: admin - external_auth: - url: "" - default_group_role: 10 - -db: - url: "postgresql://postgres:somepass@postgresql:5432/postgres" - admin_url: "postgresql://postgres:somepass@postgresql:5432/postgres" - pool_recycle: 3600 - -storage: - tmp_dir: /antarest_tmp_dir - archive_dir: /studies/archives - matrixstore: /matrixstore - matrix_gc_dry_run: true - workspaces: - default: # required, no filters applied, this folder is not watched - path: /workspaces/internal_studies/ - # other workspaces can be added - # if a directory is to be ignored by the watcher, place a file named AW_NO_SCAN inside - tmp: - path: /workspaces/studies/ - # filter_in: ['.*'] # default to '.*' - # filter_out: [] # default to empty - # groups: [] # default empty - -launcher: - default: local - - local: - binaries: - 800: /antares_simulator/antares-8.2-solver - enable_nb_cores_detection: true - -# slurm: -# local_workspace: path/to/workspace -# username: username -# hostname: 0.0.0.0 -# port: 22 -# private_key_file: path/to/key -# key_password: key_password -# password: password_is_optional_but_necessary_if_key_is_absent -# default_wait_time: 900 -# default_time_limit: 172800 -# enable_nb_cores_detection: False -# nb_cores: -# min: 1 -# default: 22 -# max: 24 -# default_json_db_name: launcher_db.json -# slurm_script_path: /path/to/launchantares_v1.1.3.sh -# db_primary_key: name -# antares_versions_on_remote_server : -# - "610" -# - "700" -# - "710" -# - "720" -# - "800" - - -debug: false - -root_path: "api" - -#tasks: -# max_workers: 5 -server: - worker_threadpool_size: 12 -# services: -# - watcher - -logging: - level: INFO -# logfile: /logs/antarest.log -# json: true - -# Uncomment these lines to use redis as a backend for the eventbus -# It is required to use redis when using this application on multiple workers in a preforked model like gunicorn for instance -redis: - host: redis - port: 6379 diff --git a/resources/deploy/config.yaml b/resources/deploy/config.yaml index 810e1f8d24..3eaaf891b6 100644 --- a/resources/deploy/config.yaml +++ b/resources/deploy/config.yaml @@ -1,13 +1,9 @@ +# Documentation about this file can be found in this file: `docs/install/1-CONFIG.md` + security: disabled: true jwt: key: super-secret - login: - admin: - pwd: admin - external_auth: - url: "" - default_group_role: 10 db: url: "sqlite:///database.db" @@ -17,43 +13,33 @@ storage: matrixstore: ./matrices archive_dir: ./examples/archives workspaces: - default: # required, no filters applied, this folder is not watched + default: path: ./examples/internal_studies/ - # other workspaces can be added - # if a directory is to be ignored by the watcher, place a file named AW_NO_SCAN inside - tmp: + studies: path: ./examples/studies/ - # filter_in: ['.*'] # default to '.*' - # filter_out: [] # default to empty - # groups: [] # default empty launcher: - default: local - local: binaries: - 700: path/to/700 - enable_nb_cores_detection: true + VER: ANTARES_SOLVER_PATH # slurm: -# local_workspace: path/to/workspace -# username: username -# hostname: 0.0.0.0 -# port: 22 -# private_key_file: path/to/key -# key_password: key_password -# password: password_is_optional_but_necessary_if_key_is_absent -# default_wait_time: 900 -# default_time_limit: 172800 -# enable_nb_cores_detection: False +# local_workspace: /path/to/slurm_workspace # Path to the local SLURM workspace +# username: run-antares # SLURM username +# hostname: 10.134.248.111 # SLURM server hostname +# port: 22 # SSH port for SLURM +# private_key_file: /path/to/ssh_private_key # SSH private key file +# default_wait_time: 900 # Default wait time for SLURM jobs +# default_time_limit: 172800 # Default time limit for SLURM jobs +# enable_nb_cores_detection: False # Enable detection of available CPU cores for SLURM # nb_cores: -# min: 1 -# default: 22 -# max: 24 -# default_json_db_name: launcher_db.json -# slurm_script_path: /path/to/launchantares_v1.1.3.sh -# db_primary_key: name -# antares_versions_on_remote_server : +# min: 1 # Minimum number of CPU cores +# default: 22 # Default number of CPU cores +# max: 24 # Maximum number of CPU cores +# default_json_db_name: launcher_db.json # Default JSON database name for SLURM +# slurm_script_path: /applis/antares/launchAntares.sh # Path to the SLURM script (on distant server) +# db_primary_key: name # Primary key for the SLURM database +# antares_versions_on_remote_server: #List of Antares versions available on the remote SLURM server # - "840" # - "850" @@ -62,20 +48,5 @@ debug: false root_path: "api" -#tasks: -# max_workers: 5 -server: - worker_threadpool_size: 12 - services: - - watcher - logging: - level: INFO logfile: ./tmp/antarest.log -# json: false - -# Uncomment these lines to use redis as a backend for the eventbus -# It is required to use redis when using this application on multiple workers in a preforked model like gunicorn for instance -#redis: -# host: localhost -# port: 6379 diff --git a/scripts/package_antares_web.sh b/scripts/package_antares_web.sh index 5dc2da6adb..31ae7ac0f1 100755 --- a/scripts/package_antares_web.sh +++ b/scripts/package_antares_web.sh @@ -73,9 +73,9 @@ echo "INFO: Copying basic configuration files..." rm -rf "${DIST_DIR}/examples" # in case of replay cp -r "${RESOURCES_DIR}"/deploy/* "${DIST_DIR}" if [[ "$OSTYPE" == "msys"* ]]; then - sed -i "s/700: path\/to\/700/$ANTARES_SOLVER_FULL_VERSION_INT: .\/AntaresWeb\/antares_solver\/antares-$ANTARES_SOLVER_VERSION-solver.exe/g" "${DIST_DIR}/config.yaml" + sed -i "s/VER: ANTARES_SOLVER_PATH/$ANTARES_SOLVER_FULL_VERSION_INT: .\/AntaresWeb\/antares_solver\/antares-$ANTARES_SOLVER_VERSION-solver.exe/g" "${DIST_DIR}/config.yaml" else - sed -i "s/700: path\/to\/700/$ANTARES_SOLVER_FULL_VERSION_INT: .\/AntaresWeb\/antares_solver\/antares-$ANTARES_SOLVER_VERSION-solver/g" "${DIST_DIR}/config.yaml" + sed -i "s/VER: ANTARES_SOLVER_PATH/$ANTARES_SOLVER_FULL_VERSION_INT: .\/AntaresWeb\/antares_solver\/antares-$ANTARES_SOLVER_VERSION-solver/g" "${DIST_DIR}/config.yaml" fi echo "INFO: Creating shortcuts..." diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 00c6f9458d..e69de29bb2 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -1,253 +0,0 @@ -from pathlib import Path -from unittest import mock - -import pytest - -from antarest.core.config import ( - Config, - InvalidConfigurationError, - LauncherConfig, - LocalConfig, - NbCoresConfig, - SlurmConfig, -) -from tests.core.assets import ASSETS_DIR - -LAUNCHER_CONFIG = { - "default": "slurm", - "local": { - "binaries": {"860": Path("/bin/solver-860.exe")}, - "enable_nb_cores_detection": False, - "nb_cores": {"min": 2, "default": 10, "max": 20}, - }, - "slurm": { - "local_workspace": Path("/home/john/antares/workspace"), - "username": "john", - "hostname": "slurm-001", - "port": 22, - "private_key_file": Path("/home/john/.ssh/id_rsa"), - "key_password": "password", - "password": "password", - "default_wait_time": 10, - "default_time_limit": 20, - "default_json_db_name": "antares.db", - "slurm_script_path": "/path/to/slurm/launcher.sh", - "max_cores": 32, - "antares_versions_on_remote_server": ["860"], - "enable_nb_cores_detection": False, - "nb_cores": {"min": 1, "default": 34, "max": 36}, - }, - "batch_size": 100, -} - - -class TestNbCoresConfig: - def test_init__default_values(self): - config = NbCoresConfig() - assert config.min == 1 - assert config.default == 22 - assert config.max == 24 - - def test_init__invalid_values(self): - with pytest.raises(ValueError): - # default < min - NbCoresConfig(min=2, default=1, max=24) - with pytest.raises(ValueError): - # default > max - NbCoresConfig(min=1, default=25, max=24) - with pytest.raises(ValueError): - # min < 0 - NbCoresConfig(min=0, default=22, max=23) - with pytest.raises(ValueError): - # min > max - NbCoresConfig(min=22, default=22, max=21) - - def test_to_json(self): - config = NbCoresConfig() - # ReactJs Material UI expects "min", "defaultValue" and "max" keys - assert config.to_json() == {"min": 1, "defaultValue": 22, "max": 24} - - -class TestLocalConfig: - def test_init__default_values(self): - config = LocalConfig() - assert config.binaries == {}, "binaries should be empty by default" - assert config.enable_nb_cores_detection, "nb cores auto-detection should be enabled by default" - assert config.nb_cores == NbCoresConfig() - - def test_from_dict(self): - config = LocalConfig.from_dict( - { - "binaries": {"860": Path("/bin/solver-860.exe")}, - "enable_nb_cores_detection": False, - "nb_cores": {"min": 2, "default": 10, "max": 20}, - } - ) - assert config.binaries == {"860": Path("/bin/solver-860.exe")} - assert not config.enable_nb_cores_detection - assert config.nb_cores == NbCoresConfig(min=2, default=10, max=20) - - def test_from_dict__auto_detect(self): - with mock.patch("multiprocessing.cpu_count", return_value=8): - config = LocalConfig.from_dict( - { - "binaries": {"860": Path("/bin/solver-860.exe")}, - "enable_nb_cores_detection": True, - } - ) - assert config.binaries == {"860": Path("/bin/solver-860.exe")} - assert config.enable_nb_cores_detection - assert config.nb_cores == NbCoresConfig(min=1, default=6, max=8) - - -class TestSlurmConfig: - def test_init__default_values(self): - config = SlurmConfig() - assert config.local_workspace == Path() - assert config.username == "" - assert config.hostname == "" - assert config.port == 0 - assert config.private_key_file == Path() - assert config.key_password == "" - assert config.password == "" - assert config.default_wait_time == 0 - assert config.default_time_limit == 0 - assert config.default_json_db_name == "" - assert config.slurm_script_path == "" - assert config.max_cores == 64 - assert config.antares_versions_on_remote_server == [], "solver versions should be empty by default" - assert not config.enable_nb_cores_detection, "nb cores auto-detection shouldn't be enabled by default" - assert config.nb_cores == NbCoresConfig() - - def test_from_dict(self): - config = SlurmConfig.from_dict( - { - "local_workspace": Path("/home/john/antares/workspace"), - "username": "john", - "hostname": "slurm-001", - "port": 22, - "private_key_file": Path("/home/john/.ssh/id_rsa"), - "key_password": "password", - "password": "password", - "default_wait_time": 10, - "default_time_limit": 20, - "default_json_db_name": "antares.db", - "slurm_script_path": "/path/to/slurm/launcher.sh", - "max_cores": 32, - "antares_versions_on_remote_server": ["860"], - "enable_nb_cores_detection": False, - "nb_cores": {"min": 2, "default": 10, "max": 20}, - } - ) - assert config.local_workspace == Path("/home/john/antares/workspace") - assert config.username == "john" - assert config.hostname == "slurm-001" - assert config.port == 22 - assert config.private_key_file == Path("/home/john/.ssh/id_rsa") - assert config.key_password == "password" - assert config.password == "password" - assert config.default_wait_time == 10 - assert config.default_time_limit == 20 - assert config.default_json_db_name == "antares.db" - assert config.slurm_script_path == "/path/to/slurm/launcher.sh" - assert config.max_cores == 32 - assert config.antares_versions_on_remote_server == ["860"] - assert not config.enable_nb_cores_detection - assert config.nb_cores == NbCoresConfig(min=2, default=10, max=20) - - def test_from_dict__default_n_cpu__backport(self): - config = SlurmConfig.from_dict( - { - "local_workspace": Path("/home/john/antares/workspace"), - "username": "john", - "hostname": "slurm-001", - "port": 22, - "private_key_file": Path("/home/john/.ssh/id_rsa"), - "key_password": "password", - "password": "password", - "default_wait_time": 10, - "default_time_limit": 20, - "default_json_db_name": "antares.db", - "slurm_script_path": "/path/to/slurm/launcher.sh", - "max_cores": 32, - "antares_versions_on_remote_server": ["860"], - "default_n_cpu": 15, - } - ) - assert config.nb_cores == NbCoresConfig(min=1, default=15, max=24) - - def test_from_dict__auto_detect(self): - with pytest.raises(NotImplementedError): - SlurmConfig.from_dict({"enable_nb_cores_detection": True}) - - -class TestLauncherConfig: - def test_init__default_values(self): - config = LauncherConfig() - assert config.default == "local", "default launcher should be local" - assert config.local is None - assert config.slurm is None - assert config.batch_size == 9999 - - def test_from_dict(self): - config = LauncherConfig.from_dict(LAUNCHER_CONFIG) - assert config.default == "slurm" - assert config.local == LocalConfig( - binaries={"860": Path("/bin/solver-860.exe")}, - enable_nb_cores_detection=False, - nb_cores=NbCoresConfig(min=2, default=10, max=20), - ) - assert config.slurm == SlurmConfig( - local_workspace=Path("/home/john/antares/workspace"), - username="john", - hostname="slurm-001", - port=22, - private_key_file=Path("/home/john/.ssh/id_rsa"), - key_password="password", - password="password", - default_wait_time=10, - default_time_limit=20, - default_json_db_name="antares.db", - slurm_script_path="/path/to/slurm/launcher.sh", - max_cores=32, - antares_versions_on_remote_server=["860"], - enable_nb_cores_detection=False, - nb_cores=NbCoresConfig(min=1, default=34, max=36), - ) - assert config.batch_size == 100 - - def test_init__invalid_launcher(self): - with pytest.raises(ValueError): - LauncherConfig(default="invalid_launcher") - - def test_get_nb_cores__default(self): - config = LauncherConfig.from_dict(LAUNCHER_CONFIG) - # default == "slurm" - assert config.get_nb_cores(launcher="default") == NbCoresConfig(min=1, default=34, max=36) - - def test_get_nb_cores__local(self): - config = LauncherConfig.from_dict(LAUNCHER_CONFIG) - assert config.get_nb_cores(launcher="local") == NbCoresConfig(min=2, default=10, max=20) - - def test_get_nb_cores__slurm(self): - config = LauncherConfig.from_dict(LAUNCHER_CONFIG) - assert config.get_nb_cores(launcher="slurm") == NbCoresConfig(min=1, default=34, max=36) - - def test_get_nb_cores__invalid_configuration(self): - config = LauncherConfig.from_dict(LAUNCHER_CONFIG) - with pytest.raises(InvalidConfigurationError): - config.get_nb_cores("invalid_launcher") - config = LauncherConfig.from_dict({}) - with pytest.raises(InvalidConfigurationError): - config.get_nb_cores("slurm") - - -class TestConfig: - @pytest.mark.parametrize("config_name", ["application-2.14.yaml", "application-2.15.yaml"]) - def test_from_yaml_file(self, config_name: str) -> None: - yaml_path = ASSETS_DIR.joinpath("config", config_name) - config = Config.from_yaml_file(yaml_path) - assert config.security.admin_pwd == "admin" - assert config.storage.workspaces["default"].path == Path("/home/john/antares_data/internal_studies") - assert not config.logging.json - assert config.logging.level == "INFO" From ca0cc4e603372f0b968c256ed4f9d2c0b2c1aa56 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Fri, 1 Dec 2023 14:07:35 +0100 Subject: [PATCH 10/43] fix(bc): correct the name of the binding constraint matrices We correct the matrix constants generator to avoid the following error: The name of a binding constraint matrix whose frequency is "weekly" must be "empty_2nd_member_weekly" and not "empty_2nd_member_daily" or "empty_2nd_member_hourly". --- .../storage/variantstudy/business/matrix_constants_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py index 8cb973785e..419c6779a4 100644 --- a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py +++ b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py @@ -37,7 +37,7 @@ # Binding constraint aliases BINDING_CONSTRAINT_HOURLY = "empty_2nd_member_hourly" BINDING_CONSTRAINT_DAILY = "empty_2nd_member_daily" -BINDING_CONSTRAINT_WEEKLY = "empty_2nd_member_daily" +BINDING_CONSTRAINT_WEEKLY = "empty_2nd_member_weekly" # Short-term storage aliases ST_STORAGE_PMAX_INJECTION = ONES_SCENARIO_MATRIX From a835a40acdc2b6218065b9efaff34a409046ad0e Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Thu, 7 Dec 2023 14:26:44 +0100 Subject: [PATCH 11/43] fix(bc): correct the shape of the binding constraint matrices --- .../bindingconstraints_ini.py | 19 ++++++++++++++++- .../bindingconstraints/bindingcontraints.py | 17 +++++++++------ .../model/filesystem/root/input/input.py | 5 +++++ .../binding_constraint/series.py | 15 +++++++------ .../business/matrix_constants_generator.py | 21 ++++++++++--------- .../command/create_binding_constraint.py | 11 +++++++--- .../test_matrix_constants_generator.py | 6 +++--- .../test_manage_binding_constraints.py | 21 +++++++++---------- 8 files changed, 75 insertions(+), 40 deletions(-) diff --git a/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingconstraints_ini.py b/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingconstraints_ini.py index 5e4059252a..51e426fda2 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingconstraints_ini.py +++ b/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingconstraints_ini.py @@ -3,6 +3,23 @@ from antarest.study.storage.rawstudy.model.filesystem.ini_file_node import IniFileNode +# noinspection SpellCheckingInspection class BindingConstraintsIni(IniFileNode): + """ + Handle the binding constraints configuration file: `/input/bindingconstraints/bindingconstraints.ini`. + + This files contains a list of sections numbered from 1 to n. + + Each section contains the following fields: + + - `name`: the name of the binding constraint. + - `id`: the id of the binding constraint (normalized name in lower case). + - `enabled`: whether the binding constraint is enabled or not. + - `type`: the frequency of the binding constraint ("hourly", "daily" or "weekly") + - `operator`: the operator of the binding constraint ("both", "equal", "greater", "less") + - `comment`: a comment + - and a list of coefficients (one per line) of the form `{area1}%{area2} = {coeff}`. + """ + def __init__(self, context: ContextServer, config: FileStudyTreeConfig): - IniFileNode.__init__(self, context, config, types={}) + super().__init__(context, config, types={}) diff --git a/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingcontraints.py b/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingcontraints.py index e86dedfe18..69fe669183 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingcontraints.py +++ b/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingcontraints.py @@ -7,18 +7,22 @@ BindingConstraintsIni, ) from antarest.study.storage.variantstudy.business.matrix_constants.binding_constraint.series import ( - default_binding_constraint_daily, - default_binding_constraint_hourly, - default_binding_constraint_weekly, + default_bc_hourly, + default_bc_weekly_daily, ) class BindingConstraints(FolderNode): + """ + Handle the binding constraints folder which contains the binding constraints + configuration and matrices. + """ + def build(self) -> TREE: default_matrices = { - BindingConstraintFrequency.HOURLY: default_binding_constraint_hourly, - BindingConstraintFrequency.DAILY: default_binding_constraint_daily, - BindingConstraintFrequency.WEEKLY: default_binding_constraint_weekly, + BindingConstraintFrequency.HOURLY: default_bc_hourly, + BindingConstraintFrequency.DAILY: default_bc_weekly_daily, + BindingConstraintFrequency.WEEKLY: default_bc_weekly_daily, } children: TREE = { binding.id: InputSeriesMatrix( @@ -31,6 +35,7 @@ def build(self) -> TREE: for binding in self.config.bindings } + # noinspection SpellCheckingInspection children["bindingconstraints"] = BindingConstraintsIni( self.context, self.config.next_file("bindingconstraints.ini") ) diff --git a/antarest/study/storage/rawstudy/model/filesystem/root/input/input.py b/antarest/study/storage/rawstudy/model/filesystem/root/input/input.py index 995dbf92f0..88b58c5369 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/root/input/input.py +++ b/antarest/study/storage/rawstudy/model/filesystem/root/input/input.py @@ -18,7 +18,12 @@ class Input(FolderNode): + """ + Handle the input folder which contains all the input data of the study. + """ + def build(self) -> TREE: + # noinspection SpellCheckingInspection children: TREE = { "areas": InputAreas(self.context, self.config.next_file("areas")), "bindingconstraints": BindingConstraints(self.context, self.config.next_file("bindingconstraints")), diff --git a/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/series.py b/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/series.py index e7b20a1137..f093c8e4a3 100644 --- a/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/series.py +++ b/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/series.py @@ -1,10 +1,13 @@ import numpy as np -default_binding_constraint_hourly = np.zeros((8760, 3), dtype=np.float64) -default_binding_constraint_hourly.flags.writeable = False +# Matrice shapes for binding constraints are different from usual shapes, +# because we need to take leap years into account, which contains 366 days and 8784 hours. +# Also, we use the same matrices for "weekly" and "daily" frequencies, +# because the solver calculates the weekly matrix from the daily matrix. +# See https://github.com/AntaresSimulatorTeam/AntaREST/issues/1843 -default_binding_constraint_daily = np.zeros((365, 3), dtype=np.float64) -default_binding_constraint_daily.flags.writeable = False +default_bc_hourly = np.zeros((8784, 3), dtype=np.float64) +default_bc_hourly.flags.writeable = False -default_binding_constraint_weekly = np.zeros((52, 3), dtype=np.float64) -default_binding_constraint_weekly.flags.writeable = False +default_bc_weekly_daily = np.zeros((366, 3), dtype=np.float64) +default_bc_weekly_daily.flags.writeable = False diff --git a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py index 419c6779a4..4048f03fda 100644 --- a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py +++ b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py @@ -36,8 +36,10 @@ # Binding constraint aliases BINDING_CONSTRAINT_HOURLY = "empty_2nd_member_hourly" -BINDING_CONSTRAINT_DAILY = "empty_2nd_member_daily" -BINDING_CONSTRAINT_WEEKLY = "empty_2nd_member_weekly" +"""2D-matrix of shape (8784, 3), filled-in with zeros for hourly binding constraints.""" + +BINDING_CONSTRAINT_WEEKLY_DAILY = "empty_2nd_member_weekly_daily" +"""2D-matrix of shape (366, 3), filled-in with zeros for weekly/daily binding constraints.""" # Short-term storage aliases ST_STORAGE_PMAX_INJECTION = ONES_SCENARIO_MATRIX @@ -90,9 +92,8 @@ def _init(self) -> None: # Binding constraint matrices series = matrix_constants.binding_constraint.series - self.hashes[BINDING_CONSTRAINT_HOURLY] = self.matrix_service.create(series.default_binding_constraint_hourly) - self.hashes[BINDING_CONSTRAINT_DAILY] = self.matrix_service.create(series.default_binding_constraint_daily) - self.hashes[BINDING_CONSTRAINT_WEEKLY] = self.matrix_service.create(series.default_binding_constraint_weekly) + self.hashes[BINDING_CONSTRAINT_HOURLY] = self.matrix_service.create(series.default_bc_hourly) + self.hashes[BINDING_CONSTRAINT_WEEKLY_DAILY] = self.matrix_service.create(series.default_bc_weekly_daily) # Some short-term storage matrices use np.ones((8760, 1)) self.hashes[ONES_SCENARIO_MATRIX] = self.matrix_service.create( @@ -152,16 +153,16 @@ def get_default_miscgen(self) -> str: return MATRIX_PROTOCOL_PREFIX + self.hashes[MISCGEN_TS] def get_binding_constraint_hourly(self) -> str: - """2D-matrix of shape (8760, 3), filled-in with zeros.""" + """2D-matrix of shape (8784, 3), filled-in with zeros.""" return MATRIX_PROTOCOL_PREFIX + self.hashes[BINDING_CONSTRAINT_HOURLY] def get_binding_constraint_daily(self) -> str: - """2D-matrix of shape (365, 3), filled-in with zeros.""" - return MATRIX_PROTOCOL_PREFIX + self.hashes[BINDING_CONSTRAINT_DAILY] + """2D-matrix of shape (366, 3), filled-in with zeros.""" + return MATRIX_PROTOCOL_PREFIX + self.hashes[BINDING_CONSTRAINT_WEEKLY_DAILY] def get_binding_constraint_weekly(self) -> str: - """2D-matrix of shape (52, 3), filled-in with zeros.""" - return MATRIX_PROTOCOL_PREFIX + self.hashes[BINDING_CONSTRAINT_WEEKLY] + """2D-matrix of shape (366, 3), filled-in with zeros, same as daily.""" + return MATRIX_PROTOCOL_PREFIX + self.hashes[BINDING_CONSTRAINT_WEEKLY_DAILY] def get_st_storage_pmax_injection(self) -> str: """2D-matrix of shape (8760, 1), filled-in with ones.""" diff --git a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py index 178c918a0c..ed3125f34b 100644 --- a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py +++ b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py @@ -40,10 +40,15 @@ def check_matrix_values(time_step: BindingConstraintFrequency, values: MatrixTyp If the matrix shape does not match the expected shape for the given time step. If the matrix values contain NaN (Not-a-Number). """ + # Matrice shapes for binding constraints are different from usual shapes, + # because we need to take leap years into account, which contains 366 days and 8784 hours. + # Also, we use the same matrices for "weekly" and "daily" frequencies, + # because the solver calculates the weekly matrix from the daily matrix. + # See https://github.com/AntaresSimulatorTeam/AntaREST/issues/1843 shapes = { - BindingConstraintFrequency.HOURLY: (8760, 3), - BindingConstraintFrequency.DAILY: (365, 3), - BindingConstraintFrequency.WEEKLY: (52, 3), + BindingConstraintFrequency.HOURLY: (8784, 3), + BindingConstraintFrequency.DAILY: (366, 3), + BindingConstraintFrequency.WEEKLY: (366, 3), } # Check the matrix values and create the corresponding matrix link array = np.array(values, dtype=np.float64) diff --git a/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py b/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py index 93a3262259..6b508425df 100644 --- a/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py +++ b/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py @@ -44,14 +44,14 @@ def test_get_binding_constraint(self, tmp_path): hourly = generator.get_binding_constraint_hourly() hourly_matrix_id = hourly.split(MATRIX_PROTOCOL_PREFIX)[1] hourly_matrix_dto = generator.matrix_service.get(hourly_matrix_id) - assert np.array(hourly_matrix_dto.data).all() == series.default_binding_constraint_hourly.all() + assert np.array(hourly_matrix_dto.data).all() == series.default_bc_hourly.all() daily = generator.get_binding_constraint_daily() daily_matrix_id = daily.split(MATRIX_PROTOCOL_PREFIX)[1] daily_matrix_dto = generator.matrix_service.get(daily_matrix_id) - assert np.array(daily_matrix_dto.data).all() == series.default_binding_constraint_daily.all() + assert np.array(daily_matrix_dto.data).all() == series.default_bc_weekly_daily.all() weekly = generator.get_binding_constraint_weekly() weekly_matrix_id = weekly.split(MATRIX_PROTOCOL_PREFIX)[1] weekly_matrix_dto = generator.matrix_service.get(weekly_matrix_id) - assert np.array(weekly_matrix_dto.data).all() == series.default_binding_constraint_weekly.all() + assert np.array(weekly_matrix_dto.data).all() == series.default_bc_weekly_daily.all() diff --git a/tests/variantstudy/model/command/test_manage_binding_constraints.py b/tests/variantstudy/model/command/test_manage_binding_constraints.py index a1309c2e47..3387db8e6d 100644 --- a/tests/variantstudy/model/command/test_manage_binding_constraints.py +++ b/tests/variantstudy/model/command/test_manage_binding_constraints.py @@ -8,9 +8,8 @@ from antarest.study.storage.variantstudy.business.command_extractor import CommandExtractor from antarest.study.storage.variantstudy.business.command_reverter import CommandReverter from antarest.study.storage.variantstudy.business.matrix_constants.binding_constraint.series import ( - default_binding_constraint_daily, - default_binding_constraint_hourly, - default_binding_constraint_weekly, + default_bc_hourly, + default_bc_weekly_daily, ) from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator from antarest.study.storage.variantstudy.model.command.create_area import CreateArea @@ -109,7 +108,7 @@ def test_manage_binding_constraint( "type": "daily", } - weekly_values = default_binding_constraint_weekly.tolist() + weekly_values = default_bc_weekly_daily.tolist() bind_update = UpdateBindingConstraint( id="bd 1", enabled=False, @@ -148,7 +147,7 @@ def test_manage_binding_constraint( def test_match(command_context: CommandContext): - values = default_binding_constraint_daily.tolist() + values = default_bc_weekly_daily.tolist() base = CreateBindingConstraint( name="foo", enabled=False, @@ -231,9 +230,9 @@ def test_match(command_context: CommandContext): def test_revert(command_context: CommandContext): - hourly_values = default_binding_constraint_hourly.tolist() - daily_values = default_binding_constraint_daily.tolist() - weekly_values = default_binding_constraint_weekly.tolist() + hourly_values = default_bc_hourly.tolist() + daily_values = default_bc_weekly_daily.tolist() + weekly_values = default_bc_weekly_daily.tolist() base = CreateBindingConstraint( name="foo", enabled=False, @@ -339,7 +338,7 @@ def test_revert(command_context: CommandContext): def test_create_diff(command_context: CommandContext): - values_a = np.random.rand(365, 3).tolist() + values_a = np.random.rand(366, 3).tolist() base = CreateBindingConstraint( name="foo", enabled=False, @@ -350,7 +349,7 @@ def test_create_diff(command_context: CommandContext): command_context=command_context, ) - values_b = np.random.rand(8760, 3).tolist() + values_b = np.random.rand(8784, 3).tolist() other_match = CreateBindingConstraint( name="foo", enabled=True, @@ -372,7 +371,7 @@ def test_create_diff(command_context: CommandContext): ) ] - values = default_binding_constraint_daily.tolist() + values = default_bc_weekly_daily.tolist() base = UpdateBindingConstraint( id="foo", enabled=False, From 521277f5c28d179042661479c6921f6824390bf0 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE <43534797+laurent-laporte-pro@users.noreply.github.com> Date: Mon, 11 Dec 2023 16:00:22 +0100 Subject: [PATCH 12/43] fix(ui-matrix): calculate the prepend index according to the existence of a time column (#1856) This pull request addresses the column number offset issue in tables without a time column on the left. Changes in `EditableMatrix`: - Modified the `prependIndex` variable to use an integer (0 or 1) instead of a boolean to indicate the offset between the column number being modified and the column index in the matrix. Change in `utils.ts`: - Added this parameter to the `cellChangesToMatrixEdits` function, which converts `CellChange` objects to `MatrixEditDTO` objects, to account for the column offset. This fix resolves the problem for the "Overall Monthly Hydro" and "Daily power" tables, which are not time series. No regressions have been observed in the other time series tables. --- .../src/components/common/EditableMatrix/index.tsx | 2 +- .../src/components/common/EditableMatrix/utils.ts | 14 ++++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/webapp/src/components/common/EditableMatrix/index.tsx b/webapp/src/components/common/EditableMatrix/index.tsx index 8050c9bd39..8176cc0f5c 100644 --- a/webapp/src/components/common/EditableMatrix/index.tsx +++ b/webapp/src/components/common/EditableMatrix/index.tsx @@ -83,7 +83,7 @@ function EditableMatrix(props: PropTypes) { ); if (filteredChanges.length > 0) { - const edits = cellChangesToMatrixEdits(filteredChanges); + const edits = cellChangesToMatrixEdits(filteredChanges, matrixTime); onUpdate(edits, source); } }; diff --git a/webapp/src/components/common/EditableMatrix/utils.ts b/webapp/src/components/common/EditableMatrix/utils.ts index 13fa17e6db..341fb8d2a3 100644 --- a/webapp/src/components/common/EditableMatrix/utils.ts +++ b/webapp/src/components/common/EditableMatrix/utils.ts @@ -84,11 +84,17 @@ export const createDateFromIndex = ( export const cellChangesToMatrixEdits = ( cellChanges: CellChange[], + matrixTime: boolean, ): MatrixEditDTO[] => - cellChanges.map(([row, column, , value]) => ({ - coordinates: [[row, (column as number) - 1]], - operation: { operation: Operator.EQ, value: parseFloat(value) }, - })); + cellChanges.map(([row, column, , value]) => { + const rowIndex = parseFloat(row.toString()); + const colIndex = parseFloat(column.toString()) - (matrixTime ? 1 : 0); + + return { + coordinates: [[rowIndex, colIndex]], + operation: { operation: Operator.EQ, value: parseFloat(value) }, + }; + }); export const computeStats = ( statsType: string, From ccd4475d773dffd3695ec9bd2ee3e39bd085972a Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE <43534797+laurent-laporte-pro@users.noreply.github.com> Date: Mon, 11 Dec 2023 17:07:44 +0100 Subject: [PATCH 13/43] fix(ui-output): add the missing "ST Storages" option in the Display selector in results view (#1855) Added the 'ST Storage' option to the dropdown list for selecting the view to display simulation results related to short-term storage. --- .../App/Singlestudy/explore/Results/ResultDetails/index.tsx | 1 + .../App/Singlestudy/explore/Results/ResultDetails/utils.ts | 1 + 2 files changed, 2 insertions(+) diff --git a/webapp/src/components/App/Singlestudy/explore/Results/ResultDetails/index.tsx b/webapp/src/components/App/Singlestudy/explore/Results/ResultDetails/index.tsx index 7875ff7246..1f6b4a9bdc 100644 --- a/webapp/src/components/App/Singlestudy/explore/Results/ResultDetails/index.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Results/ResultDetails/index.tsx @@ -296,6 +296,7 @@ function ResultDetails() { { value: DataType.Thermal, label: "Thermal plants" }, { value: DataType.Renewable, label: "Ren. clusters" }, { value: DataType.Record, label: "RecordYears" }, + { value: DataType.STStorage, label: "ST Storages" }, ]} size="small" variant="outlined" diff --git a/webapp/src/components/App/Singlestudy/explore/Results/ResultDetails/utils.ts b/webapp/src/components/App/Singlestudy/explore/Results/ResultDetails/utils.ts index f140205498..cba62478c8 100644 --- a/webapp/src/components/App/Singlestudy/explore/Results/ResultDetails/utils.ts +++ b/webapp/src/components/App/Singlestudy/explore/Results/ResultDetails/utils.ts @@ -11,6 +11,7 @@ export enum DataType { Thermal = "details", Renewable = "details-res", Record = "id", + STStorage = "details-STstorage", } export enum Timestep { From 1fd438e790fb394bd9de4998ebe5d6f2f3a41aee Mon Sep 17 00:00:00 2001 From: Mohamed Abdel Wedoud Date: Mon, 20 Nov 2023 18:41:52 +0100 Subject: [PATCH 14/43] feat(db-init): separate database initialization from global database session (#1805) --- antarest/core/tasks/model.py | 19 +- antarest/login/main.py | 2 +- antarest/login/repository.py | 269 ++++++++++++------ antarest/main.py | 22 +- antarest/matrixstore/repository.py | 62 ++-- antarest/matrixstore/service.py | 22 +- antarest/singleton_services.py | 130 ++++----- antarest/study/main.py | 1 + .../business/command_extractor.py | 3 + .../business/matrix_constants_generator.py | 89 +++--- .../variantstudy/variant_command_extractor.py | 1 + antarest/tools/lib.py | 34 ++- antarest/utils.py | 51 ++-- tests/conftest_services.py | 10 +- tests/login/test_repository.py | 85 ++---- tests/matrixstore/test_repository.py | 23 +- .../storage/business/test_arealink_manager.py | 10 +- tests/storage/integration/conftest.py | 6 +- .../test_matrix_constants_generator.py | 19 +- .../test_variant_study_service.py | 1 + tests/variantstudy/conftest.py | 8 +- 21 files changed, 497 insertions(+), 370 deletions(-) diff --git a/antarest/core/tasks/model.py b/antarest/core/tasks/model.py index af3a46b8f7..1206db9fc4 100644 --- a/antarest/core/tasks/model.py +++ b/antarest/core/tasks/model.py @@ -1,11 +1,12 @@ import uuid from datetime import datetime from enum import Enum -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, Sequence, String # type: ignore -from sqlalchemy.orm import relationship # type: ignore +from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.orm import Session, relationship, sessionmaker # type: ignore from antarest.core.persistence import Base @@ -171,3 +172,17 @@ def __repr__(self) -> str: f" result_msg={self.result_msg}," f" result_status={self.result_status}" ) + + +def cancel_orphan_tasks(engine: Engine, session_args: Dict[str, bool]) -> None: + updated_values = { + TaskJob.status: TaskStatus.FAILED.value, + TaskJob.result: False, + TaskJob.result_msg: "Task was interrupted due to server restart", + TaskJob.completion_date: datetime.utcnow(), + } + with sessionmaker(bind=engine, **session_args)() as session: + session.query(TaskJob).filter(TaskJob.status.in_([TaskStatus.RUNNING.value, TaskStatus.PENDING.value])).update( + updated_values, synchronize_session=False + ) + session.commit() diff --git a/antarest/login/main.py b/antarest/login/main.py index 9b487de5b7..d87a082abd 100644 --- a/antarest/login/main.py +++ b/antarest/login/main.py @@ -37,7 +37,7 @@ def build_login( """ if service is None: - user_repo = UserRepository(config) + user_repo = UserRepository() bot_repo = BotRepository() group_repo = GroupRepository() role_repo = RoleRepository() diff --git a/antarest/login/repository.py b/antarest/login/repository.py index 4f68e1924c..edac68d495 100644 --- a/antarest/login/repository.py +++ b/antarest/login/repository.py @@ -1,10 +1,10 @@ import logging -from typing import List, Optional +from typing import Dict, List, Optional from sqlalchemy import exists # type: ignore -from sqlalchemy.orm import joinedload # type: ignore +from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.orm import joinedload, Session, sessionmaker # type: ignore -from antarest.core.config import Config from antarest.core.jwt import ADMIN_ID from antarest.core.roles import RoleType from antarest.core.utils.fastapi_sqlalchemy import db @@ -12,43 +12,99 @@ logger = logging.getLogger(__name__) +DB_INIT_DEFAULT_GROUP_ID = "admin" +DB_INIT_DEFAULT_GROUP_NAME = "admin" + +DB_INIT_DEFAULT_USER_ID = ADMIN_ID +DB_INIT_DEFAULT_USER_NAME = "admin" + +DB_INIT_DEFAULT_ROLE_ID = ADMIN_ID +DB_INIT_DEFAULT_ROLE_GROUP_ID = "admin" + + +def init_admin_user(engine: Engine, session_args: Dict[str, bool], admin_password: str) -> None: + with sessionmaker(bind=engine, **session_args)() as session: + group = Group( + id=DB_INIT_DEFAULT_GROUP_ID, + name=DB_INIT_DEFAULT_GROUP_NAME, + ) + user = User( + id=DB_INIT_DEFAULT_USER_ID, + name=DB_INIT_DEFAULT_USER_NAME, + password=Password(admin_password), + ) + role = Role( + type=RoleType.ADMIN, + identity=User(id=DB_INIT_DEFAULT_USER_ID), + group=Group( + id=DB_INIT_DEFAULT_GROUP_ID, + ), + ) + + existing_group = session.query(Group).get(group.id) + if not existing_group: + session.add(group) + session.commit() + + existing_user = session.query(User).get(user.id) + if not existing_user: + session.add(user) + session.commit() + + existing_role = session.query(Role).get((DB_INIT_DEFAULT_USER_ID, DB_INIT_DEFAULT_GROUP_ID)) + if not existing_role: + role.group = session.merge(role.group) + role.identity = session.merge(role.identity) + session.add(role) + + session.commit() + class GroupRepository: """ Database connector to manage Group entity. """ - def __init__(self) -> None: - with db(): - self.save(Group(id="admin", name="admin")) + def __init__( + self, + session: Optional[Session] = None, + ) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session def save(self, group: Group) -> Group: - res = db.session.query(exists().where(Group.id == group.id)).scalar() + res = self.session.query(exists().where(Group.id == group.id)).scalar() if res: - db.session.merge(group) + self.session.merge(group) else: - db.session.add(group) - db.session.commit() + self.session.add(group) + self.session.commit() logger.debug(f"Group {group.id} saved") return group def get(self, id: str) -> Optional[Group]: - group: Group = db.session.query(Group).get(id) + group: Group = self.session.query(Group).get(id) return group def get_by_name(self, name: str) -> Group: - group: Group = db.session.query(Group).filter_by(name=name).first() + group: Group = self.session.query(Group).filter_by(name=name).first() return group def get_all(self) -> List[Group]: - groups: List[Group] = db.session.query(Group).all() + groups: List[Group] = self.session.query(Group).all() return groups def delete(self, id: str) -> None: - g = db.session.query(Group).get(id) - db.session.delete(g) - db.session.commit() + g = self.session.query(Group).get(id) + self.session.delete(g) + self.session.commit() logger.debug(f"Group {id} deleted") @@ -58,35 +114,32 @@ class UserRepository: Database connector to manage User entity. """ - def __init__(self, config: Config) -> None: - # init seed admin user from conf - with db(): - admin_user = self.get_by_name("admin") - if admin_user is None: - self.save( - User( - id=ADMIN_ID, - name="admin", - password=Password(config.security.admin_pwd), - ) - ) - elif not admin_user.password.check(config.security.admin_pwd): # type: ignore - admin_user.password = Password(config.security.admin_pwd) # type: ignore - self.save(admin_user) + def __init__( + self, + session: Optional[Session] = None, + ) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session def save(self, user: User) -> User: - res = db.session.query(exists().where(User.id == user.id)).scalar() + res = self.session.query(exists().where(User.id == user.id)).scalar() if res: - db.session.merge(user) + self.session.merge(user) else: - db.session.add(user) - db.session.commit() + self.session.add(user) + self.session.commit() logger.debug(f"User {user.id} saved") return user - def get(self, id: int) -> Optional[User]: - user: User = db.session.query(User).get(id) + def get(self, id_number: int) -> Optional[User]: + user: User = self.session.query(User).get(id_number) return user def get_by_name(self, name: str) -> Optional[User]: @@ -94,13 +147,13 @@ def get_by_name(self, name: str) -> Optional[User]: return user def get_all(self) -> List[User]: - users: List[User] = db.session.query(User).all() + users: List[User] = self.session.query(User).all() return users def delete(self, id: int) -> None: - u: User = db.session.query(User).get(id) - db.session.delete(u) - db.session.commit() + u: User = self.session.query(User).get(id) + self.session.delete(u) + self.session.commit() logger.debug(f"User {id} deleted") @@ -110,39 +163,54 @@ class UserLdapRepository: Database connector to manage UserLdap entity. """ + def __init__( + self, + session: Optional[Session] = None, + ) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session + def save(self, user_ldap: UserLdap) -> UserLdap: - res = db.session.query(exists().where(UserLdap.id == user_ldap.id)).scalar() + res = self.session.query(exists().where(UserLdap.id == user_ldap.id)).scalar() if res: - db.session.merge(user_ldap) + self.session.merge(user_ldap) else: - db.session.add(user_ldap) - db.session.commit() + self.session.add(user_ldap) + self.session.commit() logger.debug(f"User LDAP {user_ldap.id} saved") return user_ldap - def get(self, id: int) -> Optional[UserLdap]: - user_ldap: Optional[UserLdap] = db.session.query(UserLdap).get(id) + def get(self, id_number: int) -> Optional[UserLdap]: + user_ldap: Optional[UserLdap] = self.session.query(UserLdap).get(id_number) return user_ldap def get_by_name(self, name: str) -> Optional[UserLdap]: - user: UserLdap = db.session.query(UserLdap).filter_by(name=name).first() + user: UserLdap = self.session.query(UserLdap).filter_by(name=name).first() return user def get_by_external_id(self, external_id: str) -> Optional[UserLdap]: - user: UserLdap = db.session.query(UserLdap).filter_by(external_id=external_id).first() + user: UserLdap = self.session.query(UserLdap).filter_by(external_id=external_id).first() return user - def get_all(self) -> List[UserLdap]: - users_ldap: List[UserLdap] = db.session.query(UserLdap).all() + def get_all( + self, + ) -> List[UserLdap]: + users_ldap: List[UserLdap] = self.session.query(UserLdap).all() return users_ldap - def delete(self, id: int) -> None: - u: UserLdap = db.session.query(UserLdap).get(id) - db.session.delete(u) - db.session.commit() + def delete(self, id_number: int) -> None: + u: UserLdap = self.session.query(UserLdap).get(id_number) + self.session.delete(u) + self.session.commit() - logger.debug(f"User LDAP {id} deleted") + logger.debug(f"User LDAP {id_number} deleted") class BotRepository: @@ -150,42 +218,57 @@ class BotRepository: Database connector to manage Bot entity. """ + def __init__( + self, + session: Optional[Session] = None, + ) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session + def save(self, bot: Bot) -> Bot: - res = db.session.query(exists().where(Bot.id == bot.id)).scalar() + res = self.session.query(exists().where(Bot.id == bot.id)).scalar() if res: raise ValueError("Bot already exist") else: - db.session.add(bot) - db.session.commit() + self.session.add(bot) + self.session.commit() logger.debug(f"Bot {bot.id} saved") return bot - def get(self, id: int) -> Optional[Bot]: - bot: Bot = db.session.query(Bot).get(id) + def get(self, id_number: int) -> Optional[Bot]: + bot: Bot = self.session.query(Bot).get(id_number) return bot - def get_all(self) -> List[Bot]: - bots: List[Bot] = db.session.query(Bot).all() + def get_all( + self, + ) -> List[Bot]: + bots: List[Bot] = self.session.query(Bot).all() return bots - def delete(self, id: int) -> None: - u: Bot = db.session.query(Bot).get(id) - db.session.delete(u) - db.session.commit() + def delete(self, id_number: int) -> None: + u: Bot = self.session.query(Bot).get(id_number) + self.session.delete(u) + self.session.commit() - logger.debug(f"Bot {id} deleted") + logger.debug(f"Bot {id_number} deleted") def get_all_by_owner(self, owner: int) -> List[Bot]: - bots: List[Bot] = db.session.query(Bot).filter_by(owner=owner).all() + bots: List[Bot] = self.session.query(Bot).filter_by(owner=owner).all() return bots def get_by_name_and_owner(self, owner: int, name: str) -> Optional[Bot]: - bot: Bot = db.session.query(Bot).filter_by(owner=owner, name=name).first() + bot: Bot = self.session.query(Bot).filter_by(owner=owner, name=name).first() return bot - def exists(self, id: int) -> bool: - res: bool = db.session.query(exists().where(Bot.id == id)).scalar() + def exists(self, id_number: int) -> bool: + res: bool = self.session.query(exists().where(Bot.id == id_number)).scalar() return res @@ -194,29 +277,31 @@ class RoleRepository: Database connector to manage Role entity. """ - def __init__(self) -> None: - with db(): - if self.get(1, "admin") is None: - self.save( - Role( - type=RoleType.ADMIN, - identity=User(id=1), - group=Group(id="admin"), - ) - ) + def __init__( + self, + session: Optional[Session] = None, + ) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session def save(self, role: Role) -> Role: - role.group = db.session.merge(role.group) - role.identity = db.session.merge(role.identity) + role.group = self.session.merge(role.group) + role.identity = self.session.merge(role.identity) - db.session.add(role) - db.session.commit() + self.session.add(role) + self.session.commit() logger.debug(f"Role (user={role.identity}, group={role.group} saved") return role def get(self, user: int, group: str) -> Optional[Role]: - role: Role = db.session.query(Role).get((user, group)) + role: Role = self.session.query(Role).get((user, group)) return role def get_all_by_user(self, /, user_id: int) -> List[Role]: @@ -231,17 +316,17 @@ def get_all_by_user(self, /, user_id: int) -> List[Role]: """ # When we fetch the list of roles, we also need to fetch the associated groups. # We use a SQL query with joins to fetch all these data efficiently. - stm = db.session.query(Role).options(joinedload(Role.group)).filter_by(identity_id=user_id) + stm = self.session.query(Role).options(joinedload(Role.group)).filter_by(identity_id=user_id) roles: List[Role] = stm.all() return roles def get_all_by_group(self, group: str) -> List[Role]: - roles: List[Role] = db.session.query(Role).filter_by(group_id=group).all() + roles: List[Role] = self.session.query(Role).filter_by(group_id=group).all() return roles def delete(self, user: int, group: str) -> None: - r = db.session.query(Role).get((user, group)) - db.session.delete(r) - db.session.commit() + r = self.session.query(Role).get((user, group)) + self.session.delete(r) + self.session.commit() logger.debug(f"Role (user={user}, group={group} deleted") diff --git a/antarest/main.py b/antarest/main.py index 1e0c9183dd..5e1c1ec850 100644 --- a/antarest/main.py +++ b/antarest/main.py @@ -30,15 +30,18 @@ from antarest.core.logging.utils import LoggingMiddleware, configure_logger from antarest.core.requests import RATE_LIMIT_CONFIG from antarest.core.swagger import customize_openapi +from antarest.core.tasks.model import cancel_orphan_tasks +from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.core.utils.utils import get_local_path from antarest.core.utils.web import tags_metadata from antarest.login.auth import Auth, JwtSettings +from antarest.login.repository import init_admin_user from antarest.matrixstore.matrix_garbage_collector import MatrixGarbageCollector -from antarest.singleton_services import SingletonServices +from antarest.singleton_services import start_all_services from antarest.study.storage.auto_archive_service import AutoArchiveService from antarest.study.storage.rawstudy.watcher import Watcher from antarest.tools.admin_lib import clean_locks -from antarest.utils import Module, create_services, init_db +from antarest.utils import SESSION_ARGS, Module, create_services, init_db_engine logger = logging.getLogger(__name__) @@ -246,7 +249,12 @@ def fastapi_app( ) # Database - init_db(config_file, config, auto_upgrade_db, application) + engine = init_db_engine(config_file, config, auto_upgrade_db) + application.add_middleware( + DBSessionMiddleware, + custom_engine=engine, + session_args=dict(SESSION_ARGS), + ) application.add_middleware(LoggingMiddleware) @@ -401,6 +409,7 @@ def handle_all_exception(request: Request, exc: Exception) -> Any: config=RATE_LIMIT_CONFIG, ) + init_admin_user(engine=engine, session_args=dict(SESSION_ARGS), admin_password=config.security.admin_pwd) services = create_services(config, application) if mount_front: @@ -428,6 +437,10 @@ def handle_all_exception(request: Request, exc: Exception) -> Any: auto_archiver.start() customize_openapi(application) + cancel_orphan_tasks( + engine=engine, + session_args=dict(SESSION_ARGS), + ) return application, services @@ -455,8 +468,7 @@ def main() -> None: # noinspection PyTypeChecker uvicorn.run(app, host="0.0.0.0", port=8080, log_config=LOGGING_CONFIG) else: - services = SingletonServices(arguments.config_file, [arguments.module]) - services.start() + start_all_services(arguments.config_file, [arguments.module]) if __name__ == "__main__": diff --git a/antarest/matrixstore/repository.py b/antarest/matrixstore/repository.py index 6301e39c7f..9ab44a69ec 100644 --- a/antarest/matrixstore/repository.py +++ b/antarest/matrixstore/repository.py @@ -7,7 +7,7 @@ from filelock import FileLock from numpy import typing as npt from sqlalchemy import and_, exists # type: ignore -from sqlalchemy.orm import aliased # type: ignore +from sqlalchemy.orm import Session, aliased # type: ignore from antarest.core.utils.fastapi_sqlalchemy import db from antarest.matrixstore.model import Matrix, MatrixContent, MatrixData, MatrixDataSet @@ -20,23 +20,33 @@ class MatrixDataSetRepository: Database connector to manage Matrix metadata entity """ + def __init__(self, session: t.Optional[Session] = None) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session + def save(self, matrix_user_metadata: MatrixDataSet) -> MatrixDataSet: - res: bool = db.session.query(exists().where(MatrixDataSet.id == matrix_user_metadata.id)).scalar() + res: bool = self.session.query(exists().where(MatrixDataSet.id == matrix_user_metadata.id)).scalar() if res: - matrix_user_metadata = db.session.merge(matrix_user_metadata) + matrix_user_metadata = self.session.merge(matrix_user_metadata) else: - db.session.add(matrix_user_metadata) - db.session.commit() + self.session.add(matrix_user_metadata) + self.session.commit() logger.debug(f"Matrix dataset {matrix_user_metadata.id} for user {matrix_user_metadata.owner_id} saved") return matrix_user_metadata - def get(self, id: str) -> t.Optional[MatrixDataSet]: - matrix: MatrixDataSet = db.session.query(MatrixDataSet).get(id) + def get(self, id_number: str) -> t.Optional[MatrixDataSet]: + matrix: MatrixDataSet = self.session.query(MatrixDataSet).get(id_number) return matrix def get_all_datasets(self) -> t.List[MatrixDataSet]: - matrix_datasets: t.List[MatrixDataSet] = db.session.query(MatrixDataSet).all() + matrix_datasets: t.List[MatrixDataSet] = self.session.query(MatrixDataSet).all() return matrix_datasets def query( @@ -54,7 +64,7 @@ def query( Returns: the list of metadata per user, matching the query """ - query = db.session.query(MatrixDataSet) + query = self.session.query(MatrixDataSet) if name is not None: query = query.filter(MatrixDataSet.name.ilike(f"%{name}%")) # type: ignore if owner is not None: @@ -63,9 +73,9 @@ def query( return datasets def delete(self, dataset_id: str) -> None: - dataset = db.session.query(MatrixDataSet).get(dataset_id) - db.session.delete(dataset) - db.session.commit() + dataset = self.session.query(MatrixDataSet).get(dataset_id) + self.session.delete(dataset) + self.session.commit() class MatrixRepository: @@ -73,28 +83,38 @@ class MatrixRepository: Database connector to manage Matrix entity. """ + def __init__(self, session: t.Optional[Session] = None) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session + def save(self, matrix: Matrix) -> Matrix: - if db.session.query(exists().where(Matrix.id == matrix.id)).scalar(): - db.session.merge(matrix) + if self.session.query(exists().where(Matrix.id == matrix.id)).scalar(): + self.session.merge(matrix) else: - db.session.add(matrix) - db.session.commit() + self.session.add(matrix) + self.session.commit() logger.debug(f"Matrix {matrix.id} saved") return matrix def get(self, matrix_hash: str) -> t.Optional[Matrix]: - matrix: Matrix = db.session.query(Matrix).get(matrix_hash) + matrix: Matrix = self.session.query(Matrix).get(matrix_hash) return matrix def exists(self, matrix_hash: str) -> bool: - res: bool = db.session.query(exists().where(Matrix.id == matrix_hash)).scalar() + res: bool = self.session.query(exists().where(Matrix.id == matrix_hash)).scalar() return res def delete(self, matrix_hash: str) -> None: - if g := db.session.query(Matrix).get(matrix_hash): - db.session.delete(g) - db.session.commit() + if g := self.session.query(Matrix).get(matrix_hash): + self.session.delete(g) + self.session.commit() else: logger.warning(f"Trying to delete matrix {matrix_hash}, but was not found in database!") logger.debug(f"Matrix {matrix_hash} deleted") diff --git a/antarest/matrixstore/service.py b/antarest/matrixstore/service.py index 4869ed11fa..c7030160ad 100644 --- a/antarest/matrixstore/service.py +++ b/antarest/matrixstore/service.py @@ -54,6 +54,13 @@ class ISimpleMatrixService(ABC): + def __init__(self, matrix_content_repository: MatrixContentRepository) -> None: + self.matrix_content_repository = matrix_content_repository + + @property + def bucket_dir(self) -> Path: + return self.matrix_content_repository.bucket_dir + @abstractmethod def create(self, data: Union[List[List[MatrixData]], npt.NDArray[np.float64]]) -> str: raise NotImplementedError() @@ -72,15 +79,14 @@ def delete(self, matrix_id: str) -> None: class SimpleMatrixService(ISimpleMatrixService): - def __init__(self, bucket_dir: Path): - self.bucket_dir = bucket_dir - self.content_repo = MatrixContentRepository(bucket_dir) + def __init__(self, matrix_content_repository: MatrixContentRepository): + super().__init__(matrix_content_repository=matrix_content_repository) def create(self, data: Union[List[List[MatrixData]], npt.NDArray[np.float64]]) -> str: - return self.content_repo.save(data) + return self.matrix_content_repository.save(data) def get(self, matrix_id: str) -> MatrixDTO: - data = self.content_repo.get(matrix_id) + data = self.matrix_content_repository.get(matrix_id) return MatrixDTO.construct( id=matrix_id, width=len(data.columns), @@ -91,10 +97,10 @@ def get(self, matrix_id: str) -> MatrixDTO: ) def exists(self, matrix_id: str) -> bool: - return self.content_repo.exists(matrix_id) + return self.matrix_content_repository.exists(matrix_id) def delete(self, matrix_id: str) -> None: - self.content_repo.delete(matrix_id) + self.matrix_content_repository.delete(matrix_id) class MatrixService(ISimpleMatrixService): @@ -108,9 +114,9 @@ def __init__( config: Config, user_service: LoginService, ): + super().__init__(matrix_content_repository=matrix_content_repository) self.repo = repo self.repo_dataset = repo_dataset - self.matrix_content_repository = matrix_content_repository self.user_service = user_service self.file_transfer_manager = file_transfer_manager self.task_service = task_service diff --git a/antarest/singleton_services.py b/antarest/singleton_services.py index 9b702a346b..70a791002d 100644 --- a/antarest/singleton_services.py +++ b/antarest/singleton_services.py @@ -1,90 +1,76 @@ -import logging -import time from pathlib import Path from typing import Dict, List from antarest.core.config import Config from antarest.core.interfaces.service import IService from antarest.core.logging.utils import configure_logger +from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.core.utils.utils import get_local_path from antarest.study.storage.auto_archive_service import AutoArchiveService from antarest.utils import ( + SESSION_ARGS, Module, create_archive_worker, create_core_services, create_matrix_gc, create_simulator_worker, create_watcher, - init_db, + init_db_engine, ) -logger = logging.getLogger(__name__) - -class SingletonServices: - def __init__(self, config_file: Path, services_list: List[Module]) -> None: - self.services_list = self._init(config_file, services_list) - - @staticmethod - def _init(config_file: Path, services_list: List[Module]) -> Dict[Module, IService]: - res = get_local_path() / "resources" - config = Config.from_yaml_file(res=res, file=config_file) - init_db(config_file, config, False, None) - configure_logger(config) - - ( - cache, - event_bus, - task_service, - ft_manager, - login_service, - matrix_service, - study_service, - ) = create_core_services(None, config) - - services: Dict[Module, IService] = {} - - if Module.WATCHER in services_list: - watcher = create_watcher(config=config, application=None, study_service=study_service) - services[Module.WATCHER] = watcher - - if Module.MATRIX_GC in services_list: - matrix_gc = create_matrix_gc( - config=config, - application=None, - study_service=study_service, - matrix_service=matrix_service, - ) - services[Module.MATRIX_GC] = matrix_gc - - if Module.ARCHIVE_WORKER in services_list: - worker = create_archive_worker(config, "test", event_bus=event_bus) - services[Module.ARCHIVE_WORKER] = worker - - if Module.SIMULATOR_WORKER in services_list: - worker = create_simulator_worker(config, matrix_service=matrix_service, event_bus=event_bus) - services[Module.SIMULATOR_WORKER] = worker - - if Module.AUTO_ARCHIVER in services_list: - auto_archive_service = AutoArchiveService(study_service, config) - services[Module.AUTO_ARCHIVER] = auto_archive_service - - return services - - def start(self) -> None: - for service in self.services_list: - self.services_list[service].start(threaded=True) - - self._loop() - - def _loop(self) -> None: - while True: - try: - pass - except Exception as e: - logger.error( - "Unexpected error happened while processing service manager loop", - exc_info=e, - ) - finally: - time.sleep(2) +def _init(config_file: Path, services_list: List[Module]) -> Dict[Module, IService]: + res = get_local_path() / "resources" + config = Config.from_yaml_file(res=res, file=config_file) + engine = init_db_engine( + config_file, + config, + False, + ) + DBSessionMiddleware(None, custom_engine=engine, session_args=dict(SESSION_ARGS)) + configure_logger(config) + + ( + cache, + event_bus, + task_service, + ft_manager, + login_service, + matrix_service, + study_service, + ) = create_core_services(None, config) + + services: Dict[Module, IService] = {} + + if Module.WATCHER in services_list: + watcher = create_watcher(config=config, application=None, study_service=study_service) + services[Module.WATCHER] = watcher + + if Module.MATRIX_GC in services_list: + matrix_gc = create_matrix_gc( + config=config, + application=None, + study_service=study_service, + matrix_service=matrix_service, + ) + services[Module.MATRIX_GC] = matrix_gc + + if Module.ARCHIVE_WORKER in services_list: + worker = create_archive_worker(config, "test", event_bus=event_bus) + services[Module.ARCHIVE_WORKER] = worker + + if Module.SIMULATOR_WORKER in services_list: + worker = create_simulator_worker(config, matrix_service=matrix_service, event_bus=event_bus) + services[Module.SIMULATOR_WORKER] = worker + + if Module.AUTO_ARCHIVER in services_list: + auto_archive_service = AutoArchiveService(study_service, config) + services[Module.AUTO_ARCHIVER] = auto_archive_service + + return services + + +def start_all_services(config_file: Path, services_list: List[Module]) -> None: + services = _init(config_file, services_list) + for service in services: + services[service].start(threaded=True) diff --git a/antarest/study/main.py b/antarest/study/main.py index e4a981afd2..c3b48356af 100644 --- a/antarest/study/main.py +++ b/antarest/study/main.py @@ -81,6 +81,7 @@ def build_study_service( ) generator_matrix_constants = generator_matrix_constants or GeneratorMatrixConstants(matrix_service=matrix_service) + generator_matrix_constants.init_constant_matrices(bucket_dir=generator_matrix_constants.matrix_service.bucket_dir) command_factory = CommandFactory( generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, diff --git a/antarest/study/storage/variantstudy/business/command_extractor.py b/antarest/study/storage/variantstudy/business/command_extractor.py index e0fd1d1e3c..4d3c563799 100644 --- a/antarest/study/storage/variantstudy/business/command_extractor.py +++ b/antarest/study/storage/variantstudy/business/command_extractor.py @@ -48,6 +48,9 @@ class CommandExtractor(ICommandExtractor): def __init__(self, matrix_service: ISimpleMatrixService, patch_service: PatchService): self.matrix_service = matrix_service self.generator_matrix_constants = GeneratorMatrixConstants(self.matrix_service) + self.generator_matrix_constants.init_constant_matrices( + bucket_dir=self.generator_matrix_constants.matrix_service.bucket_dir + ) self.patch_service = patch_service self.command_context = CommandContext( generator_matrix_constants=self.generator_matrix_constants, diff --git a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py index 4048f03fda..cbd76bc9f6 100644 --- a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py +++ b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py @@ -49,6 +49,7 @@ ST_STORAGE_INFLOWS = EMPTY_SCENARIO_MATRIX MATRIX_PROTOCOL_PREFIX = "matrix://" +MATRIX_CONSTANT_INIT_LOCK_FILE_NAME = "matrix_constant_init.lock" # noinspection SpellCheckingInspection @@ -56,49 +57,51 @@ class GeneratorMatrixConstants: def __init__(self, matrix_service: ISimpleMatrixService) -> None: self.hashes: Dict[str, str] = {} self.matrix_service: ISimpleMatrixService = matrix_service - with FileLock(str(Path(tempfile.gettempdir()) / "matrix_constant_init.lock")): - self._init() - - def _init(self) -> None: - self.hashes[HYDRO_COMMON_CAPACITY_MAX_POWER_V7] = self.matrix_service.create( - matrix_constants.hydro.v7.max_power - ) - self.hashes[HYDRO_COMMON_CAPACITY_RESERVOIR_V7] = self.matrix_service.create( - matrix_constants.hydro.v7.reservoir - ) - self.hashes[HYDRO_COMMON_CAPACITY_RESERVOIR_V6] = self.matrix_service.create( - matrix_constants.hydro.v6.reservoir - ) - self.hashes[HYDRO_COMMON_CAPACITY_INFLOW_PATTERN] = self.matrix_service.create( - matrix_constants.hydro.v7.inflow_pattern - ) - self.hashes[HYDRO_COMMON_CAPACITY_CREDIT_MODULATION] = self.matrix_service.create( - matrix_constants.hydro.v7.credit_modulations - ) - self.hashes[PREPRO_CONVERSION] = self.matrix_service.create(matrix_constants.prepro.conversion) - self.hashes[PREPRO_DATA] = self.matrix_service.create(matrix_constants.prepro.data) - self.hashes[THERMAL_PREPRO_DATA] = self.matrix_service.create(matrix_constants.thermals.prepro.data) - - self.hashes[THERMAL_PREPRO_MODULATION] = self.matrix_service.create(matrix_constants.thermals.prepro.modulation) - self.hashes[LINK_V7] = self.matrix_service.create(matrix_constants.link.v7.link) - self.hashes[LINK_V8] = self.matrix_service.create(matrix_constants.link.v8.link) - self.hashes[LINK_DIRECT] = self.matrix_service.create(matrix_constants.link.v8.direct) - self.hashes[LINK_INDIRECT] = self.matrix_service.create(matrix_constants.link.v8.indirect) - - self.hashes[NULL_MATRIX_NAME] = self.matrix_service.create(NULL_MATRIX) - self.hashes[EMPTY_SCENARIO_MATRIX] = self.matrix_service.create(NULL_SCENARIO_MATRIX) - self.hashes[RESERVES_TS] = self.matrix_service.create(FIXED_4_COLUMNS) - self.hashes[MISCGEN_TS] = self.matrix_service.create(FIXED_8_COLUMNS) - - # Binding constraint matrices - series = matrix_constants.binding_constraint.series - self.hashes[BINDING_CONSTRAINT_HOURLY] = self.matrix_service.create(series.default_bc_hourly) - self.hashes[BINDING_CONSTRAINT_WEEKLY_DAILY] = self.matrix_service.create(series.default_bc_weekly_daily) - - # Some short-term storage matrices use np.ones((8760, 1)) - self.hashes[ONES_SCENARIO_MATRIX] = self.matrix_service.create( - matrix_constants.st_storage.series.pmax_injection - ) + + def init_constant_matrices(self, bucket_dir: Path) -> None: + bucket_dir.mkdir(parents=True, exist_ok=True) + with FileLock(bucket_dir / MATRIX_CONSTANT_INIT_LOCK_FILE_NAME): + self.hashes[HYDRO_COMMON_CAPACITY_MAX_POWER_V7] = self.matrix_service.create( + matrix_constants.hydro.v7.max_power + ) + self.hashes[HYDRO_COMMON_CAPACITY_RESERVOIR_V7] = self.matrix_service.create( + matrix_constants.hydro.v7.reservoir + ) + self.hashes[HYDRO_COMMON_CAPACITY_RESERVOIR_V6] = self.matrix_service.create( + matrix_constants.hydro.v6.reservoir + ) + self.hashes[HYDRO_COMMON_CAPACITY_INFLOW_PATTERN] = self.matrix_service.create( + matrix_constants.hydro.v7.inflow_pattern + ) + self.hashes[HYDRO_COMMON_CAPACITY_CREDIT_MODULATION] = self.matrix_service.create( + matrix_constants.hydro.v7.credit_modulations + ) + self.hashes[PREPRO_CONVERSION] = self.matrix_service.create(matrix_constants.prepro.conversion) + self.hashes[PREPRO_DATA] = self.matrix_service.create(matrix_constants.prepro.data) + self.hashes[THERMAL_PREPRO_DATA] = self.matrix_service.create(matrix_constants.thermals.prepro.data) + + self.hashes[THERMAL_PREPRO_MODULATION] = self.matrix_service.create( + matrix_constants.thermals.prepro.modulation + ) + self.hashes[LINK_V7] = self.matrix_service.create(matrix_constants.link.v7.link) + self.hashes[LINK_V8] = self.matrix_service.create(matrix_constants.link.v8.link) + self.hashes[LINK_DIRECT] = self.matrix_service.create(matrix_constants.link.v8.direct) + self.hashes[LINK_INDIRECT] = self.matrix_service.create(matrix_constants.link.v8.indirect) + + self.hashes[NULL_MATRIX_NAME] = self.matrix_service.create(NULL_MATRIX) + self.hashes[EMPTY_SCENARIO_MATRIX] = self.matrix_service.create(NULL_SCENARIO_MATRIX) + self.hashes[RESERVES_TS] = self.matrix_service.create(FIXED_4_COLUMNS) + self.hashes[MISCGEN_TS] = self.matrix_service.create(FIXED_8_COLUMNS) + + # Binding constraint matrices + series = matrix_constants.binding_constraint.series + self.hashes[BINDING_CONSTRAINT_HOURLY] = self.matrix_service.create(series.default_bc_hourly) + self.hashes[BINDING_CONSTRAINT_WEEKLY_DAILY] = self.matrix_service.create(series.default_bc_weekly_daily) + + # Some short-term storage matrices use np.ones((8760, 1)) + self.hashes[ONES_SCENARIO_MATRIX] = self.matrix_service.create( + matrix_constants.st_storage.series.pmax_injection + ) def get_hydro_max_power(self, version: int) -> str: if version > 650: diff --git a/antarest/study/storage/variantstudy/variant_command_extractor.py b/antarest/study/storage/variantstudy/variant_command_extractor.py index 5a88dde857..33ee3ff49f 100644 --- a/antarest/study/storage/variantstudy/variant_command_extractor.py +++ b/antarest/study/storage/variantstudy/variant_command_extractor.py @@ -20,6 +20,7 @@ class VariantCommandsExtractor: def __init__(self, matrix_service: ISimpleMatrixService, patch_service: PatchService): self.matrix_service = matrix_service self.generator_matrix_constants = GeneratorMatrixConstants(self.matrix_service) + self.generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir) self.command_extractor = CommandExtractor(self.matrix_service, patch_service=patch_service) def extract(self, study: FileStudy) -> List[CommandDTO]: diff --git a/antarest/tools/lib.py b/antarest/tools/lib.py index 5ade3d214b..058a3402fa 100644 --- a/antarest/tools/lib.py +++ b/antarest/tools/lib.py @@ -24,6 +24,7 @@ from antarest.core.config import CacheConfig from antarest.core.tasks.model import TaskDTO from antarest.core.utils.utils import StopWatch, get_local_path +from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.matrixstore.uri_resolver_service import UriResolverService from antarest.study.model import NEW_DEFAULT_STUDY_VERSION, STUDY_REFERENCE_TEMPLATES @@ -140,7 +141,12 @@ def render_template(self, study_version: str = NEW_DEFAULT_STUDY_VERSION) -> Non def apply_commands(self, commands: List[CommandDTO], matrices_dir: Path) -> GenerationResultInfoDTO: stopwatch = StopWatch() - matrix_service = SimpleMatrixService(matrices_dir) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrices_dir, + ) + matrix_service = SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) matrix_resolver = UriResolverService(matrix_service) local_cache = LocalCache(CacheConfig()) study_factory = StudyFactory( @@ -149,8 +155,10 @@ def apply_commands(self, commands: List[CommandDTO], matrices_dir: Path) -> Gene cache=local_cache, ) generator = VariantCommandGenerator(study_factory) + generator_matrix_constants = GeneratorMatrixConstants(matrix_service) + generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir) command_factory = CommandFactory( - generator_matrix_constants=GeneratorMatrixConstants(matrix_service), + generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, patch_service=PatchService(), ) @@ -176,8 +184,12 @@ def extract_commands(study_path: Path, commands_output_dir: Path) -> None: commands_output_dir.mkdir(parents=True) matrices_dir = commands_output_dir / MATRIX_STORE_DIR matrices_dir.mkdir() - - matrix_service = SimpleMatrixService(matrices_dir) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrices_dir, + ) + matrix_service = SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) matrix_resolver = UriResolverService(matrix_service) cache = LocalCache(CacheConfig()) study_factory = StudyFactory( @@ -187,7 +199,12 @@ def extract_commands(study_path: Path, commands_output_dir: Path) -> None: ) study = study_factory.create_from_fs(study_path, str(study_path), use_cache=False) - local_matrix_service = SimpleMatrixService(matrices_dir) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrices_dir, + ) + local_matrix_service = SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) extractor = VariantCommandsExtractor(local_matrix_service, patch_service=PatchService()) command_list = extractor.extract(study) @@ -233,7 +250,12 @@ def generate_diff( study_id = "empty_base" path_study = output_dir / study_id - local_matrix_service = SimpleMatrixService(matrices_dir) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrices_dir, + ) + local_matrix_service = SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) resolver = UriResolverService(matrix_service=local_matrix_service) cache = LocalCache() diff --git a/antarest/utils.py b/antarest/utils.py index d49951017f..39ea094168 100644 --- a/antarest/utils.py +++ b/antarest/utils.py @@ -1,7 +1,8 @@ +import datetime import logging from enum import Enum from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Mapping, Optional, Tuple import redis import sqlalchemy.ext.baked # type: ignore @@ -12,6 +13,7 @@ from ratelimit.backends.redis import RedisBackend # type: ignore from ratelimit.backends.simple import MemoryBackend # type: ignore from sqlalchemy import create_engine +from sqlalchemy.engine.base import Engine # type: ignore from sqlalchemy.pool import NullPool # type: ignore from antarest.core.cache.main import build_cache @@ -20,13 +22,11 @@ from antarest.core.filetransfer.service import FileTransferManager from antarest.core.interfaces.cache import ICache from antarest.core.interfaces.eventbus import IEventBus -from antarest.core.logging.utils import configure_logger from antarest.core.maintenance.main import build_maintenance_manager from antarest.core.persistence import upgrade_db from antarest.core.tasks.main import build_taskjob_manager from antarest.core.tasks.service import ITaskService -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware -from antarest.core.utils.utils import get_local_path, new_redis_instance +from antarest.core.utils.utils import new_redis_instance from antarest.eventbus.main import build_eventbus from antarest.launcher.main import build_launcher from antarest.login.main import build_login @@ -46,6 +46,19 @@ logger = logging.getLogger(__name__) +SESSION_ARGS: Mapping[str, bool] = { + "autocommit": False, + "expire_on_commit": False, + "autoflush": False, +} +""" +This mapping can be used to instantiate a new session, for example: + +>>> with sessionmaker(engine, **SESSION_ARGS)() as session: +... session.execute("SELECT 1") +""" + + class Module(str, Enum): APP = "app" WATCHER = "watcher" @@ -55,12 +68,11 @@ class Module(str, Enum): SIMULATOR_WORKER = "simulator_worker" -def init_db( +def init_db_engine( config_file: Path, config: Config, auto_upgrade_db: bool, - application: Optional[FastAPI], -) -> None: +) -> Engine: if auto_upgrade_db: upgrade_db(config_file) connect_args: Dict[str, Any] = {} @@ -86,19 +98,7 @@ def init_db( engine = create_engine(config.db.db_url, echo=config.debug, connect_args=connect_args, **extra) - session_args = { - "autocommit": False, - "expire_on_commit": False, - "autoflush": False, - } - if application: - application.add_middleware( - DBSessionMiddleware, - custom_engine=engine, - session_args=session_args, - ) - else: - DBSessionMiddleware(None, custom_engine=engine, session_args=session_args) + return engine def create_event_bus( @@ -264,14 +264,3 @@ def create_services(config: Config, application: Optional[FastAPI], create_all: services["cache"] = cache services["maintenance"] = maintenance_service return services - - -def create_env(config_file: Path) -> Dict[str, Any]: - """ - Create application services env for testing and scripting purpose - """ - res = get_local_path() / "resources" - config = Config.from_yaml_file(res=res, file=config_file) - configure_logger(config) - init_db(config_file, config, False, None) - return create_services(config, None) diff --git a/tests/conftest_services.py b/tests/conftest_services.py index ee2fea2057..7fa50f6f86 100644 --- a/tests/conftest_services.py +++ b/tests/conftest_services.py @@ -18,6 +18,7 @@ from antarest.core.tasks.model import CustomTaskEventMessages, TaskDTO, TaskListFilter, TaskResult, TaskStatus, TaskType from antarest.core.tasks.service import ITaskService, Task from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware +from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.matrixstore.uri_resolver_service import UriResolverService from antarest.study.storage.patch_service import PatchService @@ -128,7 +129,10 @@ def simple_matrix_service_fixture(bucket_dir: Path) -> SimpleMatrixService: Returns: An instance of the SimpleMatrixService class representing the matrix service. """ - return SimpleMatrixService(bucket_dir) + matrix_content_repository = MatrixContentRepository( + bucket_dir=bucket_dir, + ) + return SimpleMatrixService(matrix_content_repository=matrix_content_repository) @pytest.fixture(name="generator_matrix_constants", scope="session") @@ -144,7 +148,9 @@ def generator_matrix_constants_fixture( Returns: An instance of the GeneratorMatrixConstants class representing the matrix constants generator. """ - return GeneratorMatrixConstants(matrix_service=simple_matrix_service) + out_generator_matrix_constants = GeneratorMatrixConstants(simple_matrix_service) + out_generator_matrix_constants.init_constant_matrices(bucket_dir=simple_matrix_service.bucket_dir) + return out_generator_matrix_constants @pytest.fixture(name="uri_resolver_service", scope="session") diff --git a/tests/login/test_repository.py b/tests/login/test_repository.py index 6669747507..60bdbc0dbf 100644 --- a/tests/login/test_repository.py +++ b/tests/login/test_repository.py @@ -1,29 +1,14 @@ import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import scoped_session, sessionmaker # type: ignore +from sqlalchemy.orm import Session, scoped_session, sessionmaker # type: ignore -from antarest.core.config import Config, SecurityConfig -from antarest.core.persistence import Base -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db from antarest.login.model import Bot, Group, Password, Role, RoleType, User, UserLdap from antarest.login.repository import BotRepository, GroupRepository, RoleRepository, UserLdapRepository, UserRepository @pytest.mark.unit_test -def test_users(): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - with db(): - repo = UserRepository( - config=Config(security=SecurityConfig(admin_pwd="admin")), - ) +def test_users(db_session: Session): + with db_session: + repo = UserRepository(session=db_session) a = User( name="a", password=Password("a"), @@ -43,18 +28,9 @@ def test_users(): @pytest.mark.unit_test -def test_users_ldap(): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - with db(): - repo = UserLdapRepository() +def test_users_ldap(db_session: Session): + repo = UserLdapRepository(session=db_session) + with repo.session: a = UserLdap(name="a", external_id="b") a = repo.save(a) @@ -67,18 +43,9 @@ def test_users_ldap(): @pytest.mark.unit_test -def test_bots(): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - with db(): - repo = BotRepository() +def test_bots(db_session: Session): + repo = BotRepository(session=db_session) + with repo.session: a = Bot(name="a", owner=1) a = repo.save(a) assert a.id @@ -98,19 +65,9 @@ def test_bots(): @pytest.mark.unit_test -def test_groups(): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - with db(): - repo = GroupRepository() - +def test_groups(db_session: Session): + repo = GroupRepository(session=db_session) + with repo.session: a = Group(name="a") a = repo.save(a) @@ -125,19 +82,9 @@ def test_groups(): @pytest.mark.unit_test -def test_roles(): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - with db(): - repo = RoleRepository() - +def test_roles(db_session: Session): + repo = RoleRepository(session=db_session) + with repo.session: a = Role(type=RoleType.ADMIN, identity=User(id=0), group=Group(id="group")) a = repo.save(a) diff --git a/tests/matrixstore/test_repository.py b/tests/matrixstore/test_repository.py index 3973a18d39..3825924f85 100644 --- a/tests/matrixstore/test_repository.py +++ b/tests/matrixstore/test_repository.py @@ -5,6 +5,7 @@ import numpy as np import pytest from numpy import typing as npt +from sqlalchemy.orm import Session # ignore type from antarest.core.config import Config, SecurityConfig from antarest.core.utils.fastapi_sqlalchemy import db @@ -51,20 +52,20 @@ def test_bucket_lifecycle(self, tmp_path: Path) -> None: with pytest.raises(FileNotFoundError): repo.get(aid) - def test_dataset(self) -> None: - with db(): + def test_dataset(self, db_session: Session) -> None: + with db_session: # sourcery skip: extract-duplicate-method, extract-method - repo = MatrixRepository() + repo = MatrixRepository(session=db_session) - user_repo = UserRepository(Config(security=SecurityConfig())) + user_repo = UserRepository(session=db_session) # noinspection PyArgumentList user = user_repo.save(User(name="foo", password=Password("bar"))) - group_repo = GroupRepository() + group_repo = GroupRepository(session=db_session) # noinspection PyArgumentList group = group_repo.save(Group(name="group")) - dataset_repo = MatrixDataSetRepository() + dataset_repo = MatrixDataSetRepository(session=db_session) m1 = Matrix(id="hello", created_at=datetime.now()) repo.save(m1) @@ -105,22 +106,22 @@ def test_dataset(self) -> None: assert dataset_query_result.name == "some name change" assert dataset_query_result.owner_id == user.id - def test_datastore_query(self) -> None: + def test_datastore_query(self, db_session: Session) -> None: # sourcery skip: extract-duplicate-method with db(): - user_repo = UserRepository(Config(security=SecurityConfig())) + user_repo = UserRepository(session=db_session) # noinspection PyArgumentList user1 = user_repo.save(User(name="foo", password=Password("bar"))) # noinspection PyArgumentList user2 = user_repo.save(User(name="hello", password=Password("world"))) - repo = MatrixRepository() + repo = MatrixRepository(session=db_session) m1 = Matrix(id="hello", created_at=datetime.now()) repo.save(m1) m2 = Matrix(id="world", created_at=datetime.now()) repo.save(m2) - dataset_repo = MatrixDataSetRepository() + dataset_repo = MatrixDataSetRepository(session=db_session) dataset = MatrixDataSet( name="some name", @@ -165,7 +166,7 @@ def test_datastore_query(self) -> None: assert ( len( # fmt: off - db.session + db_session .query(MatrixDataSetRelation) .filter(MatrixDataSetRelation.dataset_id == dataset.id) .all() diff --git a/tests/storage/business/test_arealink_manager.py b/tests/storage/business/test_arealink_manager.py index 9f8e0be884..314e77e500 100644 --- a/tests/storage/business/test_arealink_manager.py +++ b/tests/storage/business/test_arealink_manager.py @@ -9,6 +9,7 @@ from antarest.core.jwt import DEFAULT_ADMIN_USER from antarest.core.requests import RequestParameters from antarest.core.utils.fastapi_sqlalchemy import db +from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.study.business.area_management import AreaCreationDTO, AreaManager, AreaType, AreaUI from antarest.study.business.link_management import LinkInfoDTO, LinkManager @@ -66,7 +67,10 @@ def matrix_service_fixture(tmp_path: Path) -> SimpleMatrixService: """ matrix_path = tmp_path.joinpath("matrix-store") matrix_path.mkdir() - return SimpleMatrixService(matrix_path) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrix_path, + ) + return SimpleMatrixService(matrix_content_repository=matrix_content_repository) @with_db_context @@ -94,8 +98,10 @@ def test_area_crud(empty_study: FileStudy, matrix_service: SimpleMatrixService): raw_study_service.get_raw.return_value = empty_study raw_study_service.cache = Mock() + generator_matrix_constants = GeneratorMatrixConstants(matrix_service) + generator_matrix_constants.init_constant_matrices(bucket_dir=generator_matrix_constants.matrix_service.bucket_dir) variant_study_service.command_factory = CommandFactory( - GeneratorMatrixConstants(matrix_service), + generator_matrix_constants, matrix_service, patch_service=Mock(spec=PatchService), ) diff --git a/tests/storage/integration/conftest.py b/tests/storage/integration/conftest.py index 4ff8fbf888..197be27144 100644 --- a/tests/storage/integration/conftest.py +++ b/tests/storage/integration/conftest.py @@ -12,6 +12,7 @@ from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.dbmodel import Base from antarest.login.model import User +from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.study.main import build_study_service from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, StudyAdditionalData @@ -87,7 +88,10 @@ def storage_service(tmp_path: Path, project_path: Path, sta_mini_zip_path: Path) matrix_path = tmp_path / "matrices" matrix_path.mkdir() - matrix_service = SimpleMatrixService(matrix_path) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrix_path, + ) + matrix_service = SimpleMatrixService(matrix_content_repository=matrix_content_repository) storage_service = build_study_service( application=Mock(), cache=LocalCache(config=config.cache), diff --git a/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py b/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py index 6b508425df..c3e2fd3e44 100644 --- a/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py +++ b/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py @@ -1,5 +1,6 @@ import numpy as np +from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.study.storage.variantstudy.business import matrix_constants from antarest.study.storage.variantstudy.business.matrix_constants_generator import ( @@ -10,7 +11,14 @@ class TestGeneratorMatrixConstants: def test_get_st_storage(self, tmp_path): - generator = GeneratorMatrixConstants(matrix_service=SimpleMatrixService(bucket_dir=tmp_path)) + matrix_content_repository = MatrixContentRepository( + bucket_dir=tmp_path, + ) + generator = GeneratorMatrixConstants( + matrix_service=SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) + ) ref1 = generator.get_st_storage_pmax_injection() matrix_id1 = ref1.split(MATRIX_PROTOCOL_PREFIX)[1] @@ -38,7 +46,14 @@ def test_get_st_storage(self, tmp_path): assert np.array(matrix_dto5.data).all() == matrix_constants.st_storage.series.inflows.all() def test_get_binding_constraint(self, tmp_path): - generator = GeneratorMatrixConstants(matrix_service=SimpleMatrixService(bucket_dir=tmp_path)) + matrix_content_repository = MatrixContentRepository( + bucket_dir=tmp_path, + ) + generator = GeneratorMatrixConstants( + matrix_service=SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) + ) series = matrix_constants.binding_constraint.series hourly = generator.get_binding_constraint_hourly() diff --git a/tests/study/storage/variantstudy/test_variant_study_service.py b/tests/study/storage/variantstudy/test_variant_study_service.py index 25317a9589..8766bfd308 100644 --- a/tests/study/storage/variantstudy/test_variant_study_service.py +++ b/tests/study/storage/variantstudy/test_variant_study_service.py @@ -11,6 +11,7 @@ from antarest.core.requests import RequestParameters from antarest.core.utils.fastapi_sqlalchemy import db from antarest.login.model import Group, User +from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import RawStudy, StudyAdditionalData diff --git a/tests/variantstudy/conftest.py b/tests/variantstudy/conftest.py index 9db21ab220..b069e029d8 100644 --- a/tests/variantstudy/conftest.py +++ b/tests/variantstudy/conftest.py @@ -91,8 +91,10 @@ def command_context_fixture(matrix_service: MatrixService) -> CommandContext: CommandContext: The CommandContext object. """ # sourcery skip: inline-immediately-returned-variable + generator_matrix_constants = GeneratorMatrixConstants(matrix_service) + generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir) command_context = CommandContext( - generator_matrix_constants=GeneratorMatrixConstants(matrix_service=matrix_service), + generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, patch_service=PatchService(repository=Mock(spec=StudyMetadataRepository)), ) @@ -110,8 +112,10 @@ def command_factory_fixture(matrix_service: MatrixService) -> CommandFactory: Returns: CommandFactory: The CommandFactory object. """ + generator_matrix_constants = GeneratorMatrixConstants(matrix_service) + generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir) return CommandFactory( - generator_matrix_constants=GeneratorMatrixConstants(matrix_service=matrix_service), + generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, patch_service=PatchService(), ) From 983c772f4e7832c9f417cf3ed6532fc08a369157 Mon Sep 17 00:00:00 2001 From: Mohamed Abdel Wedoud Date: Thu, 23 Nov 2023 23:57:35 +0100 Subject: [PATCH 15/43] feat(db-init): separate database initialization from global database session (#1805) --- antarest/core/jwt.py | 4 +- antarest/login/model.py | 51 +++++++++++++++- antarest/login/repository.py | 58 ++----------------- antarest/main.py | 2 +- antarest/study/main.py | 2 +- .../business/command_extractor.py | 4 +- .../business/matrix_constants_generator.py | 8 ++- .../variantstudy/variant_command_extractor.py | 2 +- antarest/tools/lib.py | 2 +- tests/conftest_services.py | 2 +- tests/login/test_model.py | 33 ++++++++++- .../storage/business/test_arealink_manager.py | 2 +- .../test_matrix_constants_generator.py | 2 + tests/variantstudy/conftest.py | 4 +- 14 files changed, 104 insertions(+), 72 deletions(-) diff --git a/antarest/core/jwt.py b/antarest/core/jwt.py index ff9ffd1187..4fb8b8fcb1 100644 --- a/antarest/core/jwt.py +++ b/antarest/core/jwt.py @@ -3,9 +3,9 @@ from pydantic import BaseModel from antarest.core.roles import RoleType -from antarest.login.model import Group, Identity +from antarest.login.model import USER_ID, Group, Identity -ADMIN_ID = 1 +ADMIN_ID = USER_ID class JWTGroup(BaseModel): diff --git a/antarest/login/model.py b/antarest/login/model.py index 52106685bc..50c62f8295 100644 --- a/antarest/login/model.py +++ b/antarest/login/model.py @@ -1,11 +1,14 @@ +import logging import typing as t import uuid import bcrypt from pydantic.main import BaseModel from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, Sequence, String # type: ignore +from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.exc import IntegrityError # type: ignore from sqlalchemy.ext.hybrid import hybrid_property # type: ignore -from sqlalchemy.orm import relationship # type: ignore +from sqlalchemy.orm import Session, relationship, sessionmaker # type: ignore from antarest.core.persistence import Base from antarest.core.roles import RoleType @@ -15,6 +18,16 @@ from antarest.launcher.model import JobResult +logger = logging.getLogger(__name__) + + +GROUP_ID = "admin" +GROUP_NAME = "admin" + +USER_ID = 1 +USER_NAME = "admin" + + class UserInfo(BaseModel): id: int name: str @@ -282,3 +295,39 @@ class CredentialsDTO(BaseModel): user: int access_token: str refresh_token: str + + +def init_admin_user(engine: Engine, session_args: t.Mapping[str, bool], admin_password: str) -> None: + with sessionmaker(bind=engine, **session_args)() as session: + group = Group(id=GROUP_ID, name=GROUP_NAME) + user = User(id=USER_ID, name=USER_NAME, password=Password(admin_password)) + role = Role(type=RoleType.ADMIN, identity=User(id=USER_ID), group=Group(id=GROUP_ID)) + + existing_group = session.query(Group).get(group.id) + if not existing_group: + session.add(group) + try: + session.commit() + except IntegrityError as e: + session.rollback() # Rollback any changes made before the error + logger.error(f"IntegrityError: {e}") + + existing_user = session.query(User).get(user.id) + if not existing_user: + session.add(user) + try: + session.commit() + except IntegrityError as e: + session.rollback() # Rollback any changes made before the error + logger.error(f"IntegrityError: {e}") + + existing_role = session.query(Role).get((USER_ID, GROUP_ID)) + if not existing_role: + role.group = session.merge(role.group) + role.identity = session.merge(role.identity) + session.add(role) + try: + session.commit() + except IntegrityError as e: + session.rollback() # Rollback any changes made before the error + logger.error(f"IntegrityError: {e}") diff --git a/antarest/login/repository.py b/antarest/login/repository.py index edac68d495..b2058952b1 100644 --- a/antarest/login/repository.py +++ b/antarest/login/repository.py @@ -1,64 +1,14 @@ import logging -from typing import Dict, List, Optional +from typing import List, Optional from sqlalchemy import exists # type: ignore -from sqlalchemy.engine.base import Engine # type: ignore -from sqlalchemy.orm import joinedload, Session, sessionmaker # type: ignore +from sqlalchemy.orm import joinedload, Session # type: ignore -from antarest.core.jwt import ADMIN_ID -from antarest.core.roles import RoleType from antarest.core.utils.fastapi_sqlalchemy import db -from antarest.login.model import Bot, Group, Password, Role, User, UserLdap +from antarest.login.model import Bot, Group, Role, User, UserLdap logger = logging.getLogger(__name__) -DB_INIT_DEFAULT_GROUP_ID = "admin" -DB_INIT_DEFAULT_GROUP_NAME = "admin" - -DB_INIT_DEFAULT_USER_ID = ADMIN_ID -DB_INIT_DEFAULT_USER_NAME = "admin" - -DB_INIT_DEFAULT_ROLE_ID = ADMIN_ID -DB_INIT_DEFAULT_ROLE_GROUP_ID = "admin" - - -def init_admin_user(engine: Engine, session_args: Dict[str, bool], admin_password: str) -> None: - with sessionmaker(bind=engine, **session_args)() as session: - group = Group( - id=DB_INIT_DEFAULT_GROUP_ID, - name=DB_INIT_DEFAULT_GROUP_NAME, - ) - user = User( - id=DB_INIT_DEFAULT_USER_ID, - name=DB_INIT_DEFAULT_USER_NAME, - password=Password(admin_password), - ) - role = Role( - type=RoleType.ADMIN, - identity=User(id=DB_INIT_DEFAULT_USER_ID), - group=Group( - id=DB_INIT_DEFAULT_GROUP_ID, - ), - ) - - existing_group = session.query(Group).get(group.id) - if not existing_group: - session.add(group) - session.commit() - - existing_user = session.query(User).get(user.id) - if not existing_user: - session.add(user) - session.commit() - - existing_role = session.query(Role).get((DB_INIT_DEFAULT_USER_ID, DB_INIT_DEFAULT_GROUP_ID)) - if not existing_role: - role.group = session.merge(role.group) - role.identity = session.merge(role.identity) - session.add(role) - - session.commit() - class GroupRepository: """ @@ -143,7 +93,7 @@ def get(self, id_number: int) -> Optional[User]: return user def get_by_name(self, name: str) -> Optional[User]: - user: User = db.session.query(User).filter_by(name=name).first() + user: User = self.session.query(User).filter_by(name=name).first() return user def get_all(self) -> List[User]: diff --git a/antarest/main.py b/antarest/main.py index 5e1c1ec850..bf233260f2 100644 --- a/antarest/main.py +++ b/antarest/main.py @@ -35,7 +35,7 @@ from antarest.core.utils.utils import get_local_path from antarest.core.utils.web import tags_metadata from antarest.login.auth import Auth, JwtSettings -from antarest.login.repository import init_admin_user +from antarest.login.model import init_admin_user from antarest.matrixstore.matrix_garbage_collector import MatrixGarbageCollector from antarest.singleton_services import start_all_services from antarest.study.storage.auto_archive_service import AutoArchiveService diff --git a/antarest/study/main.py b/antarest/study/main.py index c3b48356af..0758c6d070 100644 --- a/antarest/study/main.py +++ b/antarest/study/main.py @@ -81,7 +81,7 @@ def build_study_service( ) generator_matrix_constants = generator_matrix_constants or GeneratorMatrixConstants(matrix_service=matrix_service) - generator_matrix_constants.init_constant_matrices(bucket_dir=generator_matrix_constants.matrix_service.bucket_dir) + generator_matrix_constants.init_constant_matrices() command_factory = CommandFactory( generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, diff --git a/antarest/study/storage/variantstudy/business/command_extractor.py b/antarest/study/storage/variantstudy/business/command_extractor.py index 4d3c563799..9aa5a9b397 100644 --- a/antarest/study/storage/variantstudy/business/command_extractor.py +++ b/antarest/study/storage/variantstudy/business/command_extractor.py @@ -48,9 +48,7 @@ class CommandExtractor(ICommandExtractor): def __init__(self, matrix_service: ISimpleMatrixService, patch_service: PatchService): self.matrix_service = matrix_service self.generator_matrix_constants = GeneratorMatrixConstants(self.matrix_service) - self.generator_matrix_constants.init_constant_matrices( - bucket_dir=self.generator_matrix_constants.matrix_service.bucket_dir - ) + self.generator_matrix_constants.init_constant_matrices() self.patch_service = patch_service self.command_context = CommandContext( generator_matrix_constants=self.generator_matrix_constants, diff --git a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py index cbd76bc9f6..d297e0d2de 100644 --- a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py +++ b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py @@ -57,10 +57,12 @@ class GeneratorMatrixConstants: def __init__(self, matrix_service: ISimpleMatrixService) -> None: self.hashes: Dict[str, str] = {} self.matrix_service: ISimpleMatrixService = matrix_service + self._lock_dir = tempfile.gettempdir() - def init_constant_matrices(self, bucket_dir: Path) -> None: - bucket_dir.mkdir(parents=True, exist_ok=True) - with FileLock(bucket_dir / MATRIX_CONSTANT_INIT_LOCK_FILE_NAME): + def init_constant_matrices( + self, + ) -> None: + with FileLock(str(Path(self._lock_dir) / MATRIX_CONSTANT_INIT_LOCK_FILE_NAME)): self.hashes[HYDRO_COMMON_CAPACITY_MAX_POWER_V7] = self.matrix_service.create( matrix_constants.hydro.v7.max_power ) diff --git a/antarest/study/storage/variantstudy/variant_command_extractor.py b/antarest/study/storage/variantstudy/variant_command_extractor.py index 33ee3ff49f..bd052a6c0a 100644 --- a/antarest/study/storage/variantstudy/variant_command_extractor.py +++ b/antarest/study/storage/variantstudy/variant_command_extractor.py @@ -20,7 +20,7 @@ class VariantCommandsExtractor: def __init__(self, matrix_service: ISimpleMatrixService, patch_service: PatchService): self.matrix_service = matrix_service self.generator_matrix_constants = GeneratorMatrixConstants(self.matrix_service) - self.generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir) + self.generator_matrix_constants.init_constant_matrices() self.command_extractor = CommandExtractor(self.matrix_service, patch_service=patch_service) def extract(self, study: FileStudy) -> List[CommandDTO]: diff --git a/antarest/tools/lib.py b/antarest/tools/lib.py index 058a3402fa..c3c5db9dff 100644 --- a/antarest/tools/lib.py +++ b/antarest/tools/lib.py @@ -156,7 +156,7 @@ def apply_commands(self, commands: List[CommandDTO], matrices_dir: Path) -> Gene ) generator = VariantCommandGenerator(study_factory) generator_matrix_constants = GeneratorMatrixConstants(matrix_service) - generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir) + generator_matrix_constants.init_constant_matrices() command_factory = CommandFactory( generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, diff --git a/tests/conftest_services.py b/tests/conftest_services.py index 7fa50f6f86..5afb53460b 100644 --- a/tests/conftest_services.py +++ b/tests/conftest_services.py @@ -149,7 +149,7 @@ def generator_matrix_constants_fixture( An instance of the GeneratorMatrixConstants class representing the matrix constants generator. """ out_generator_matrix_constants = GeneratorMatrixConstants(simple_matrix_service) - out_generator_matrix_constants.init_constant_matrices(bucket_dir=simple_matrix_service.bucket_dir) + out_generator_matrix_constants.init_constant_matrices() return out_generator_matrix_constants diff --git a/tests/login/test_model.py b/tests/login/test_model.py index 787f4f2d6a..0b1da1c8f2 100644 --- a/tests/login/test_model.py +++ b/tests/login/test_model.py @@ -1,5 +1,36 @@ -from antarest.login.model import Password +from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.orm import sessionmaker # type: ignore + +from antarest.login.model import GROUP_ID, GROUP_NAME, USER_ID, USER_NAME, Group, Password, Role, User, init_admin_user +from antarest.utils import SESSION_ARGS + +TEST_ADMIN_PASS_WORD = "test" def test_password(): assert Password("pwd").check("pwd") + + +class TestInitAdminUser: + def test_nominal_init_admin_user(self, db_engine: Engine): + init_admin_user(db_engine, dict(SESSION_ARGS), admin_password=TEST_ADMIN_PASS_WORD) + make_session = sessionmaker(bind=db_engine) + with make_session() as session: + user = session.query(User).get(USER_ID) + assert user is not None + assert user.id == USER_ID + assert user.name == USER_NAME + assert user.password.check(TEST_ADMIN_PASS_WORD) + group = session.query(Group).get(GROUP_ID) + assert group is not None + assert group.id == GROUP_ID + assert group.name == GROUP_NAME + role = session.query(Role).get((USER_ID, GROUP_ID)) + assert role is not None + assert role.identity is not None + assert role.identity.id == USER_ID + assert role.identity.name == USER_NAME + assert role.identity.password.check(TEST_ADMIN_PASS_WORD) + assert role.group is not None + assert role.group.id == GROUP_ID + assert role.group.name == GROUP_NAME diff --git a/tests/storage/business/test_arealink_manager.py b/tests/storage/business/test_arealink_manager.py index 314e77e500..4caee7b7bd 100644 --- a/tests/storage/business/test_arealink_manager.py +++ b/tests/storage/business/test_arealink_manager.py @@ -99,7 +99,7 @@ def test_area_crud(empty_study: FileStudy, matrix_service: SimpleMatrixService): raw_study_service.get_raw.return_value = empty_study raw_study_service.cache = Mock() generator_matrix_constants = GeneratorMatrixConstants(matrix_service) - generator_matrix_constants.init_constant_matrices(bucket_dir=generator_matrix_constants.matrix_service.bucket_dir) + generator_matrix_constants.init_constant_matrices() variant_study_service.command_factory = CommandFactory( generator_matrix_constants, matrix_service, diff --git a/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py b/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py index c3e2fd3e44..a216571510 100644 --- a/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py +++ b/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py @@ -19,6 +19,7 @@ def test_get_st_storage(self, tmp_path): matrix_content_repository=matrix_content_repository, ) ) + generator.init_constant_matrices() ref1 = generator.get_st_storage_pmax_injection() matrix_id1 = ref1.split(MATRIX_PROTOCOL_PREFIX)[1] @@ -54,6 +55,7 @@ def test_get_binding_constraint(self, tmp_path): matrix_content_repository=matrix_content_repository, ) ) + generator.init_constant_matrices() series = matrix_constants.binding_constraint.series hourly = generator.get_binding_constraint_hourly() diff --git a/tests/variantstudy/conftest.py b/tests/variantstudy/conftest.py index b069e029d8..011a6bb68d 100644 --- a/tests/variantstudy/conftest.py +++ b/tests/variantstudy/conftest.py @@ -92,7 +92,7 @@ def command_context_fixture(matrix_service: MatrixService) -> CommandContext: """ # sourcery skip: inline-immediately-returned-variable generator_matrix_constants = GeneratorMatrixConstants(matrix_service) - generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir) + generator_matrix_constants.init_constant_matrices() command_context = CommandContext( generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, @@ -113,7 +113,7 @@ def command_factory_fixture(matrix_service: MatrixService) -> CommandFactory: CommandFactory: The CommandFactory object. """ generator_matrix_constants = GeneratorMatrixConstants(matrix_service) - generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir) + generator_matrix_constants.init_constant_matrices() return CommandFactory( generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, From 619d4a629506af5d9bb33e4b4428452be4d2017c Mon Sep 17 00:00:00 2001 From: Mohamed Abdel Wedoud Date: Fri, 24 Nov 2023 00:23:44 +0100 Subject: [PATCH 16/43] feat(db-init): separate database initialization from global database session (#1805) --- tests/variantstudy/model/test_variant_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/variantstudy/model/test_variant_model.py b/tests/variantstudy/model/test_variant_model.py index 7acbce4530..25efa4d7c2 100644 --- a/tests/variantstudy/model/test_variant_model.py +++ b/tests/variantstudy/model/test_variant_model.py @@ -76,7 +76,7 @@ def test_commands_service( ) -> None: # Initialize the default matrix constants # noinspection PyProtectedMember - generator_matrix_constants._init() + generator_matrix_constants.init_constant_matrices() params = RequestParameters(user=jwt_user) From c98fb4daa7f21eb46b7d2df2cc7ee6372c868064 Mon Sep 17 00:00:00 2001 From: Mohamed Abdel Wedoud Date: Fri, 24 Nov 2023 11:26:46 +0100 Subject: [PATCH 17/43] feat(db-init): separate database initialization from global database session (#1805) --- antarest/core/tasks/model.py | 2 +- tests/core/test_tasks.py | 70 +++++++++++++++++++++++++++++++++++- 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/antarest/core/tasks/model.py b/antarest/core/tasks/model.py index 1206db9fc4..6eacab3f10 100644 --- a/antarest/core/tasks/model.py +++ b/antarest/core/tasks/model.py @@ -177,7 +177,7 @@ def __repr__(self) -> str: def cancel_orphan_tasks(engine: Engine, session_args: Dict[str, bool]) -> None: updated_values = { TaskJob.status: TaskStatus.FAILED.value, - TaskJob.result: False, + TaskJob.result_status: False, TaskJob.result_msg: "Task was interrupted due to server restart", TaskJob.completion_date: datetime.utcnow(), } diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index dfad126555..73782745ce 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -6,6 +6,8 @@ import pytest from sqlalchemy import create_engine +from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.orm import sessionmaker from antarest.core.config import Config, RemoteWorkerConfig, TaskConfig from antarest.core.interfaces.eventbus import Event, EventType, IEventBus @@ -13,12 +15,22 @@ from antarest.core.model import PermissionInfo, PublicMode from antarest.core.persistence import Base from antarest.core.requests import RequestParameters, UserHasNotPermissionError -from antarest.core.tasks.model import TaskDTO, TaskJob, TaskJobLog, TaskListFilter, TaskResult, TaskStatus, TaskType +from antarest.core.tasks.model import ( + TaskDTO, + TaskJob, + TaskJobLog, + TaskListFilter, + TaskResult, + TaskStatus, + TaskType, + cancel_orphan_tasks, +) from antarest.core.tasks.repository import TaskJobRepository from antarest.core.tasks.service import TaskJobService from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db from antarest.eventbus.business.local_eventbus import LocalEventBus from antarest.eventbus.service import EventBusService +from antarest.utils import SESSION_ARGS from antarest.worker.worker import AbstractWorker, WorkerTaskCommand from tests.helpers import with_db_context @@ -453,3 +465,59 @@ def test_cancel(): service.cancel_task("a", RequestParameters(user=DEFAULT_ADMIN_USER), dispatch=True) task.status = TaskStatus.CANCELLED.value repo_mock.save.assert_called_with(task) + + +@pytest.mark.parametrize( + ("status", "result_status", "result_msg"), + [ + (TaskStatus.RUNNING.value, False, "task ongoing"), + (TaskStatus.PENDING.value, True, "task pending"), + (TaskStatus.FAILED.value, False, "task failed"), + (TaskStatus.COMPLETED.value, True, "task finished"), + (TaskStatus.TIMEOUT.value, False, "task timed out"), + (TaskStatus.CANCELLED.value, True, "task canceled"), + ], +) +def test_cancel_orphan_tasks( + db_engine: Engine, + status: int, + result_status: bool, + result_msg: str, + max_diff_seconds: int = 6, + test_id: str = "test_cancel_orphan_tasks_id", +): + completion_date: datetime.datetime = datetime.datetime.utcnow() + task_job = TaskJob( + id=test_id, + status=status, + result_status=result_status, + result_msg=result_msg, + completion_date=completion_date, + ) + with sessionmaker(bind=db_engine, **dict(SESSION_ARGS))() as session: + if session.query(TaskJob).get(test_id) is not None: + session.merge(task_job) + session.commit() + else: + session.add(task_job) + session.commit() + cancel_orphan_tasks(engine=db_engine, session_args=dict(SESSION_ARGS)) + with sessionmaker(bind=db_engine, **dict(SESSION_ARGS))() as session: + if status in [TaskStatus.RUNNING.value, TaskStatus.PENDING.value]: + updated_task_job = ( + session.query(TaskJob) + .filter(TaskJob.status.in_([TaskStatus.RUNNING.value, TaskStatus.PENDING.value])) + .all() + ) + assert not updated_task_job + updated_task_job = session.query(TaskJob).get(test_id) + assert updated_task_job.status == TaskStatus.FAILED.value + assert not updated_task_job.result_status + assert updated_task_job.result_msg == "Task was interrupted due to server restart" + assert (datetime.datetime.utcnow() - updated_task_job.completion_date).seconds <= max_diff_seconds + else: + updated_task_job = session.query(TaskJob).get(test_id) + assert updated_task_job.status == status + assert updated_task_job.result_status == result_status + assert updated_task_job.result_msg == result_msg + assert (datetime.datetime.utcnow() - updated_task_job.completion_date).seconds <= max_diff_seconds From 9b029e5e55b921d624d07bea1488b538a25cf4a6 Mon Sep 17 00:00:00 2001 From: Mohamed Abdel Wedoud Date: Sun, 26 Nov 2023 23:12:08 +0100 Subject: [PATCH 18/43] feat(db-init): separate database initialization from global database session (#1805) --- antarest/core/tasks/model.py | 4 +-- antarest/login/model.py | 42 +++++++++------------------ antarest/main.py | 6 ++-- antarest/singleton_services.py | 4 +-- tests/core/test_tasks.py | 20 ++++++------- tests/login/test_model.py | 53 +++++++++++++++++++++++++++------- 6 files changed, 73 insertions(+), 56 deletions(-) diff --git a/antarest/core/tasks/model.py b/antarest/core/tasks/model.py index 6eacab3f10..8f5be488f5 100644 --- a/antarest/core/tasks/model.py +++ b/antarest/core/tasks/model.py @@ -1,7 +1,7 @@ import uuid from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, List, Mapping, Optional from pydantic import BaseModel, Extra from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, Sequence, String # type: ignore @@ -174,7 +174,7 @@ def __repr__(self) -> str: ) -def cancel_orphan_tasks(engine: Engine, session_args: Dict[str, bool]) -> None: +def cancel_orphan_tasks(engine: Engine, session_args: Mapping[str, bool]) -> None: updated_values = { TaskJob.status: TaskStatus.FAILED.value, TaskJob.result_status: False, diff --git a/antarest/login/model.py b/antarest/login/model.py index 50c62f8295..dbccb08e34 100644 --- a/antarest/login/model.py +++ b/antarest/login/model.py @@ -1,3 +1,4 @@ +import contextlib import logging import typing as t import uuid @@ -298,36 +299,21 @@ class CredentialsDTO(BaseModel): def init_admin_user(engine: Engine, session_args: t.Mapping[str, bool], admin_password: str) -> None: - with sessionmaker(bind=engine, **session_args)() as session: + make_session = sessionmaker(bind=engine, **session_args) + with make_session() as session: group = Group(id=GROUP_ID, name=GROUP_NAME) - user = User(id=USER_ID, name=USER_NAME, password=Password(admin_password)) - role = Role(type=RoleType.ADMIN, identity=User(id=USER_ID), group=Group(id=GROUP_ID)) - - existing_group = session.query(Group).get(group.id) - if not existing_group: + with contextlib.suppress(IntegrityError): session.add(group) - try: - session.commit() - except IntegrityError as e: - session.rollback() # Rollback any changes made before the error - logger.error(f"IntegrityError: {e}") - - existing_user = session.query(User).get(user.id) - if not existing_user: + session.commit() + + with make_session() as session: + user = User(id=USER_ID, name=USER_NAME, password=Password(admin_password)) + with contextlib.suppress(IntegrityError): session.add(user) - try: - session.commit() - except IntegrityError as e: - session.rollback() # Rollback any changes made before the error - logger.error(f"IntegrityError: {e}") - - existing_role = session.query(Role).get((USER_ID, GROUP_ID)) - if not existing_role: - role.group = session.merge(role.group) - role.identity = session.merge(role.identity) + session.commit() + + with make_session() as session: + role = Role(type=RoleType.ADMIN, identity_id=USER_ID, group_id=GROUP_ID) + with contextlib.suppress(IntegrityError): session.add(role) - try: session.commit() - except IntegrityError as e: - session.rollback() # Rollback any changes made before the error - logger.error(f"IntegrityError: {e}") diff --git a/antarest/main.py b/antarest/main.py index bf233260f2..a661e51bff 100644 --- a/antarest/main.py +++ b/antarest/main.py @@ -253,7 +253,7 @@ def fastapi_app( application.add_middleware( DBSessionMiddleware, custom_engine=engine, - session_args=dict(SESSION_ARGS), + session_args=SESSION_ARGS, ) application.add_middleware(LoggingMiddleware) @@ -409,7 +409,7 @@ def handle_all_exception(request: Request, exc: Exception) -> Any: config=RATE_LIMIT_CONFIG, ) - init_admin_user(engine=engine, session_args=dict(SESSION_ARGS), admin_password=config.security.admin_pwd) + init_admin_user(engine=engine, session_args=SESSION_ARGS, admin_password=config.security.admin_pwd) services = create_services(config, application) if mount_front: @@ -439,7 +439,7 @@ def handle_all_exception(request: Request, exc: Exception) -> Any: customize_openapi(application) cancel_orphan_tasks( engine=engine, - session_args=dict(SESSION_ARGS), + session_args=SESSION_ARGS, ) return application, services diff --git a/antarest/singleton_services.py b/antarest/singleton_services.py index 70a791002d..f106099523 100644 --- a/antarest/singleton_services.py +++ b/antarest/singleton_services.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, List +from typing import Dict, List, cast from antarest.core.config import Config from antarest.core.interfaces.service import IService @@ -27,7 +27,7 @@ def _init(config_file: Path, services_list: List[Module]) -> Dict[Module, IServi config, False, ) - DBSessionMiddleware(None, custom_engine=engine, session_args=dict(SESSION_ARGS)) + DBSessionMiddleware(None, custom_engine=engine, session_args=cast(Dict[str, bool], SESSION_ARGS)) configure_logger(config) ( diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index 73782745ce..1d05cb4887 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -483,9 +483,10 @@ def test_cancel_orphan_tasks( status: int, result_status: bool, result_msg: str, - max_diff_seconds: int = 6, - test_id: str = "test_cancel_orphan_tasks_id", ): + max_diff_seconds: int = 1 + test_id: str = "test_cancel_orphan_tasks_id" + completion_date: datetime.datetime = datetime.datetime.utcnow() task_job = TaskJob( id=test_id, @@ -494,15 +495,12 @@ def test_cancel_orphan_tasks( result_msg=result_msg, completion_date=completion_date, ) - with sessionmaker(bind=db_engine, **dict(SESSION_ARGS))() as session: - if session.query(TaskJob).get(test_id) is not None: - session.merge(task_job) - session.commit() - else: - session.add(task_job) - session.commit() - cancel_orphan_tasks(engine=db_engine, session_args=dict(SESSION_ARGS)) - with sessionmaker(bind=db_engine, **dict(SESSION_ARGS))() as session: + make_session = sessionmaker(bind=db_engine, **SESSION_ARGS) + with make_session() as session: + session.add(task_job) + session.commit() + cancel_orphan_tasks(engine=db_engine, session_args=SESSION_ARGS) + with make_session() as session: if status in [TaskStatus.RUNNING.value, TaskStatus.PENDING.value]: updated_task_job = ( session.query(TaskJob) diff --git a/tests/login/test_model.py b/tests/login/test_model.py index 0b1da1c8f2..72fced9478 100644 --- a/tests/login/test_model.py +++ b/tests/login/test_model.py @@ -1,7 +1,21 @@ +import contextlib + from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.exc import IntegrityError # type: ignore from sqlalchemy.orm import sessionmaker # type: ignore -from antarest.login.model import GROUP_ID, GROUP_NAME, USER_ID, USER_NAME, Group, Password, Role, User, init_admin_user +from antarest.login.model import ( + GROUP_ID, + GROUP_NAME, + USER_ID, + USER_NAME, + Group, + Password, + Role, + RoleType, + User, + init_admin_user, +) from antarest.utils import SESSION_ARGS TEST_ADMIN_PASS_WORD = "test" @@ -12,8 +26,8 @@ def test_password(): class TestInitAdminUser: - def test_nominal_init_admin_user(self, db_engine: Engine): - init_admin_user(db_engine, dict(SESSION_ARGS), admin_password=TEST_ADMIN_PASS_WORD) + def test_init_admin_user_nominal(self, db_engine: Engine): + init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD) make_session = sessionmaker(bind=db_engine) with make_session() as session: user = session.query(User).get(USER_ID) @@ -27,10 +41,29 @@ def test_nominal_init_admin_user(self, db_engine: Engine): assert group.name == GROUP_NAME role = session.query(Role).get((USER_ID, GROUP_ID)) assert role is not None - assert role.identity is not None - assert role.identity.id == USER_ID - assert role.identity.name == USER_NAME - assert role.identity.password.check(TEST_ADMIN_PASS_WORD) - assert role.group is not None - assert role.group.id == GROUP_ID - assert role.group.name == GROUP_NAME + assert role.identity is user + assert role.group is group + + def test_init_admin_user_redundancy_check(self, db_engine: Engine): + # run first time + init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD) + # run second time + init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD) + + def test_init_admin_user_existing_group(self, db_engine: Engine): + make_session = sessionmaker(bind=db_engine) + with make_session() as session: + group = Group(id=GROUP_ID, name=GROUP_NAME) + with contextlib.suppress(IntegrityError): + session.add(group) + session.commit() + init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD) + + def test_init_admin_user_existing_user(self, db_engine: Engine): + make_session = sessionmaker(bind=db_engine) + with make_session() as session: + user = User(id=USER_ID, name=USER_NAME, password=Password(TEST_ADMIN_PASS_WORD)) + with contextlib.suppress(IntegrityError): + session.add(user) + session.commit() + init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD) From 276f769c0fec06c1d21416bbc3a4c0a2a291b6ee Mon Sep 17 00:00:00 2001 From: Mohamed Abdel Wedoud Date: Sun, 26 Nov 2023 23:14:23 +0100 Subject: [PATCH 19/43] feat(db-init): separate database initialization from global database session (#1805) --- tests/login/test_model.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/tests/login/test_model.py b/tests/login/test_model.py index 72fced9478..0f4035a4fc 100644 --- a/tests/login/test_model.py +++ b/tests/login/test_model.py @@ -4,18 +4,7 @@ from sqlalchemy.exc import IntegrityError # type: ignore from sqlalchemy.orm import sessionmaker # type: ignore -from antarest.login.model import ( - GROUP_ID, - GROUP_NAME, - USER_ID, - USER_NAME, - Group, - Password, - Role, - RoleType, - User, - init_admin_user, -) +from antarest.login.model import GROUP_ID, GROUP_NAME, USER_ID, USER_NAME, Group, Password, Role, User, init_admin_user from antarest.utils import SESSION_ARGS TEST_ADMIN_PASS_WORD = "test" From 0a097dffd06d6774c4773bbba557431ad782249e Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Tue, 28 Nov 2023 14:20:40 +0100 Subject: [PATCH 20/43] test(db-init): update the `user_repo` pytest fixtures which don't require configuration anymore --- tests/login/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/login/conftest.py b/tests/login/conftest.py index 7b7935d6d3..2e70b8168a 100644 --- a/tests/login/conftest.py +++ b/tests/login/conftest.py @@ -20,12 +20,12 @@ def group_repo_fixture(db_middleware: DBSessionMiddleware) -> GroupRepository: # noinspection PyUnusedLocal @pytest.fixture(name="user_repo") -def user_repo_fixture(core_config: Config, db_middleware: DBSessionMiddleware) -> UserRepository: +def user_repo_fixture(db_middleware: DBSessionMiddleware) -> UserRepository: """Fixture that creates a UserRepository instance.""" # note: `DBSessionMiddleware` is required to instantiate a thread-local db session. # important: the `UserRepository` insert an admin user in the database if it does not exist. # >>> User(id=1, name="admin", password=Password(config.security.admin_pwd)) - return UserRepository(config=core_config) + return UserRepository() # noinspection PyUnusedLocal From b7ef06657ccdf4392a0ef32e180acccdab76d2da Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Tue, 28 Nov 2023 14:23:05 +0100 Subject: [PATCH 21/43] style(db-init): organise imports in `repository.py` --- antarest/login/repository.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/antarest/login/repository.py b/antarest/login/repository.py index b2058952b1..d70fa57e13 100644 --- a/antarest/login/repository.py +++ b/antarest/login/repository.py @@ -2,7 +2,7 @@ from typing import List, Optional from sqlalchemy import exists # type: ignore -from sqlalchemy.orm import joinedload, Session # type: ignore +from sqlalchemy.orm import Session, joinedload # type: ignore from antarest.core.utils.fastapi_sqlalchemy import db from antarest.login.model import Bot, Group, Role, User, UserLdap From cf5014985a339a2d3db4cfd63d334df5fe79b323 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Tue, 28 Nov 2023 14:26:44 +0100 Subject: [PATCH 22/43] test(db-init): change in `TestMatrixRepository` to use the `db_session` fixture The global object `db` from `fastapi_sqlalchemy` module is no longer used in this unit test --- tests/matrixstore/test_repository.py | 44 ++++++++++++---------------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/tests/matrixstore/test_repository.py b/tests/matrixstore/test_repository.py index 3825924f85..5b787085c5 100644 --- a/tests/matrixstore/test_repository.py +++ b/tests/matrixstore/test_repository.py @@ -1,14 +1,12 @@ +import datetime import typing as t -from datetime import datetime from pathlib import Path import numpy as np import pytest from numpy import typing as npt -from sqlalchemy.orm import Session # ignore type +from sqlalchemy.orm import Session # type: ignore -from antarest.core.config import Config, SecurityConfig -from antarest.core.utils.fastapi_sqlalchemy import db from antarest.login.model import Group, Password, User from antarest.login.repository import GroupRepository, UserRepository from antarest.matrixstore.model import Matrix, MatrixContent, MatrixDataSet, MatrixDataSetRelation @@ -18,11 +16,10 @@ class TestMatrixRepository: - def test_db_lifecycle(self) -> None: - with db(): - # sourcery skip: extract-method - repo = MatrixRepository() - m = Matrix(id="hello", created_at=datetime.now()) + def test_db_lifecycle(self, db_session: Session) -> None: + with db_session: + repo = MatrixRepository(db_session) + m = Matrix(id="hello", created_at=datetime.datetime.now()) repo.save(m) assert m.id assert m == repo.get(m.id) @@ -54,22 +51,19 @@ def test_bucket_lifecycle(self, tmp_path: Path) -> None: def test_dataset(self, db_session: Session) -> None: with db_session: - # sourcery skip: extract-duplicate-method, extract-method repo = MatrixRepository(session=db_session) user_repo = UserRepository(session=db_session) - # noinspection PyArgumentList user = user_repo.save(User(name="foo", password=Password("bar"))) group_repo = GroupRepository(session=db_session) - # noinspection PyArgumentList group = group_repo.save(Group(name="group")) dataset_repo = MatrixDataSetRepository(session=db_session) - m1 = Matrix(id="hello", created_at=datetime.now()) + m1 = Matrix(id="hello", created_at=datetime.datetime.now()) repo.save(m1) - m2 = Matrix(id="world", created_at=datetime.now()) + m2 = Matrix(id="world", created_at=datetime.datetime.now()) repo.save(m2) dataset = MatrixDataSet( @@ -77,8 +71,8 @@ def test_dataset(self, db_session: Session) -> None: public=True, owner_id=user.id, groups=[group], - created_at=datetime.now(), - updated_at=datetime.now(), + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), ) matrix_relation = MatrixDataSetRelation(name="m1") @@ -98,7 +92,7 @@ def test_dataset(self, db_session: Session) -> None: id=dataset.id, name="some name change", public=False, - updated_at=datetime.now(), + updated_at=datetime.datetime.now(), ) dataset_repo.save(dataset_update) dataset_query_result = dataset_repo.get(dataset.id) @@ -108,17 +102,15 @@ def test_dataset(self, db_session: Session) -> None: def test_datastore_query(self, db_session: Session) -> None: # sourcery skip: extract-duplicate-method - with db(): + with db_session: user_repo = UserRepository(session=db_session) - # noinspection PyArgumentList user1 = user_repo.save(User(name="foo", password=Password("bar"))) - # noinspection PyArgumentList user2 = user_repo.save(User(name="hello", password=Password("world"))) repo = MatrixRepository(session=db_session) - m1 = Matrix(id="hello", created_at=datetime.now()) + m1 = Matrix(id="hello", created_at=datetime.datetime.now()) repo.save(m1) - m2 = Matrix(id="world", created_at=datetime.now()) + m2 = Matrix(id="world", created_at=datetime.datetime.now()) repo.save(m2) dataset_repo = MatrixDataSetRepository(session=db_session) @@ -127,8 +119,8 @@ def test_datastore_query(self, db_session: Session) -> None: name="some name", public=True, owner_id=user1.id, - created_at=datetime.now(), - updated_at=datetime.now(), + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), ) matrix_relation = MatrixDataSetRelation(name="m1") matrix_relation.matrix_id = "hello" @@ -142,8 +134,8 @@ def test_datastore_query(self, db_session: Session) -> None: name="some name 2", public=False, owner_id=user2.id, - created_at=datetime.now(), - updated_at=datetime.now(), + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), ) matrix_relation = MatrixDataSetRelation(name="m1") matrix_relation.matrix_id = "hello" From 80c16538607f32f130318e6c47353bec15d109f1 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Tue, 28 Nov 2023 14:33:02 +0100 Subject: [PATCH 23/43] refactor(db-init): use `ADMIN_ID` and `ADMIN_NAME` for the administrator ID and name. --- antarest/core/jwt.py | 4 +--- antarest/login/model.py | 20 +++++++++++--------- tests/login/test_model.py | 12 ++++++------ 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/antarest/core/jwt.py b/antarest/core/jwt.py index 4fb8b8fcb1..16849fa9f0 100644 --- a/antarest/core/jwt.py +++ b/antarest/core/jwt.py @@ -3,9 +3,7 @@ from pydantic import BaseModel from antarest.core.roles import RoleType -from antarest.login.model import USER_ID, Group, Identity - -ADMIN_ID = USER_ID +from antarest.login.model import ADMIN_ID, Group, Identity class JWTGroup(BaseModel): diff --git a/antarest/login/model.py b/antarest/login/model.py index dbccb08e34..78fe322547 100644 --- a/antarest/login/model.py +++ b/antarest/login/model.py @@ -1,5 +1,4 @@ import contextlib -import logging import typing as t import uuid @@ -9,7 +8,7 @@ from sqlalchemy.engine.base import Engine # type: ignore from sqlalchemy.exc import IntegrityError # type: ignore from sqlalchemy.ext.hybrid import hybrid_property # type: ignore -from sqlalchemy.orm import Session, relationship, sessionmaker # type: ignore +from sqlalchemy.orm import relationship, sessionmaker # type: ignore from antarest.core.persistence import Base from antarest.core.roles import RoleType @@ -19,14 +18,17 @@ from antarest.launcher.model import JobResult -logger = logging.getLogger(__name__) - - GROUP_ID = "admin" +"""Unique ID of the administrator group.""" + GROUP_NAME = "admin" +"""Name of the administrator group.""" + +ADMIN_ID = 1 +"""Unique ID of the site administrator.""" -USER_ID = 1 -USER_NAME = "admin" +ADMIN_NAME = "admin" +"""Name of the site administrator.""" class UserInfo(BaseModel): @@ -307,13 +309,13 @@ def init_admin_user(engine: Engine, session_args: t.Mapping[str, bool], admin_pa session.commit() with make_session() as session: - user = User(id=USER_ID, name=USER_NAME, password=Password(admin_password)) + user = User(id=ADMIN_ID, name=ADMIN_NAME, password=Password(admin_password)) with contextlib.suppress(IntegrityError): session.add(user) session.commit() with make_session() as session: - role = Role(type=RoleType.ADMIN, identity_id=USER_ID, group_id=GROUP_ID) + role = Role(type=RoleType.ADMIN, identity_id=ADMIN_ID, group_id=GROUP_ID) with contextlib.suppress(IntegrityError): session.add(role) session.commit() diff --git a/tests/login/test_model.py b/tests/login/test_model.py index 0f4035a4fc..6ab2b35739 100644 --- a/tests/login/test_model.py +++ b/tests/login/test_model.py @@ -4,7 +4,7 @@ from sqlalchemy.exc import IntegrityError # type: ignore from sqlalchemy.orm import sessionmaker # type: ignore -from antarest.login.model import GROUP_ID, GROUP_NAME, USER_ID, USER_NAME, Group, Password, Role, User, init_admin_user +from antarest.login.model import GROUP_ID, GROUP_NAME, ADMIN_ID, ADMIN_NAME, Group, Password, Role, User, init_admin_user from antarest.utils import SESSION_ARGS TEST_ADMIN_PASS_WORD = "test" @@ -19,16 +19,16 @@ def test_init_admin_user_nominal(self, db_engine: Engine): init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD) make_session = sessionmaker(bind=db_engine) with make_session() as session: - user = session.query(User).get(USER_ID) + user = session.query(User).get(ADMIN_ID) assert user is not None - assert user.id == USER_ID - assert user.name == USER_NAME + assert user.id == ADMIN_ID + assert user.name == ADMIN_NAME assert user.password.check(TEST_ADMIN_PASS_WORD) group = session.query(Group).get(GROUP_ID) assert group is not None assert group.id == GROUP_ID assert group.name == GROUP_NAME - role = session.query(Role).get((USER_ID, GROUP_ID)) + role = session.query(Role).get((ADMIN_ID, GROUP_ID)) assert role is not None assert role.identity is user assert role.group is group @@ -51,7 +51,7 @@ def test_init_admin_user_existing_group(self, db_engine: Engine): def test_init_admin_user_existing_user(self, db_engine: Engine): make_session = sessionmaker(bind=db_engine) with make_session() as session: - user = User(id=USER_ID, name=USER_NAME, password=Password(TEST_ADMIN_PASS_WORD)) + user = User(id=ADMIN_ID, name=ADMIN_NAME, password=Password(TEST_ADMIN_PASS_WORD)) with contextlib.suppress(IntegrityError): session.add(user) session.commit() From fdbb743c414eb5dd33dc3af82cc68d2c1fc26938 Mon Sep 17 00:00:00 2001 From: Mohamed Abdel Wedoud Date: Thu, 30 Nov 2023 09:40:57 +0100 Subject: [PATCH 24/43] test(formatting): black changes --- tests/login/test_model.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/login/test_model.py b/tests/login/test_model.py index 6ab2b35739..5b11c5d543 100644 --- a/tests/login/test_model.py +++ b/tests/login/test_model.py @@ -4,7 +4,17 @@ from sqlalchemy.exc import IntegrityError # type: ignore from sqlalchemy.orm import sessionmaker # type: ignore -from antarest.login.model import GROUP_ID, GROUP_NAME, ADMIN_ID, ADMIN_NAME, Group, Password, Role, User, init_admin_user +from antarest.login.model import ( + ADMIN_ID, + ADMIN_NAME, + GROUP_ID, + GROUP_NAME, + Group, + Password, + Role, + User, + init_admin_user, +) from antarest.utils import SESSION_ARGS TEST_ADMIN_PASS_WORD = "test" From a4aacb68906f0b6f4b017b76e3465eb492eb1d0e Mon Sep 17 00:00:00 2001 From: Mohamed Abdel Wedoud Date: Sun, 3 Dec 2023 23:33:53 +0100 Subject: [PATCH 25/43] feat(pr-review): add doc-strings, formatting, optimize code --- antarest/core/tasks/model.py | 6 ++++++ antarest/login/model.py | 4 ++++ antarest/main.py | 11 ++--------- antarest/matrixstore/service.py | 4 ---- antarest/study/main.py | 5 +++-- .../business/matrix_constants_generator.py | 4 ++-- tests/core/test_tasks.py | 8 ++++---- tests/login/test_login_service.py | 8 +++++++- tests/login/test_model.py | 10 ++++------ tests/matrixstore/test_repository.py | 14 ++++++-------- .../variantstudy/test_variant_study_service.py | 1 - tests/variantstudy/model/test_variant_model.py | 1 - 12 files changed, 38 insertions(+), 38 deletions(-) diff --git a/antarest/core/tasks/model.py b/antarest/core/tasks/model.py index 8f5be488f5..cb92032445 100644 --- a/antarest/core/tasks/model.py +++ b/antarest/core/tasks/model.py @@ -175,6 +175,12 @@ def __repr__(self) -> str: def cancel_orphan_tasks(engine: Engine, session_args: Mapping[str, bool]) -> None: + """ + When the web application restarts, such as after a new deployment, any pending or running tasks may be lost. + To mitigate this, it is preferable to set these tasks to a "FAILED" status. + This ensures that users can easily identify the tasks that were affected by the restart and take appropriate + actions, such as restarting the tasks manually. + """ updated_values = { TaskJob.status: TaskStatus.FAILED.value, TaskJob.result_status: False, diff --git a/antarest/login/model.py b/antarest/login/model.py index 78fe322547..f56230464a 100644 --- a/antarest/login/model.py +++ b/antarest/login/model.py @@ -301,6 +301,10 @@ class CredentialsDTO(BaseModel): def init_admin_user(engine: Engine, session_args: t.Mapping[str, bool], admin_password: str) -> None: + """ + When starting the app, the 'admin' group and 'admin' user are automatically created if they + do not already exist in the database. + """ make_session = sessionmaker(bind=engine, **session_args) with make_session() as session: group = Group(id=GROUP_ID, name=GROUP_NAME) diff --git a/antarest/main.py b/antarest/main.py index a661e51bff..700947a4dd 100644 --- a/antarest/main.py +++ b/antarest/main.py @@ -250,11 +250,7 @@ def fastapi_app( # Database engine = init_db_engine(config_file, config, auto_upgrade_db) - application.add_middleware( - DBSessionMiddleware, - custom_engine=engine, - session_args=SESSION_ARGS, - ) + application.add_middleware(DBSessionMiddleware, custom_engine=engine, session_args=SESSION_ARGS) application.add_middleware(LoggingMiddleware) @@ -437,10 +433,7 @@ def handle_all_exception(request: Request, exc: Exception) -> Any: auto_archiver.start() customize_openapi(application) - cancel_orphan_tasks( - engine=engine, - session_args=SESSION_ARGS, - ) + cancel_orphan_tasks(engine=engine, session_args=SESSION_ARGS) return application, services diff --git a/antarest/matrixstore/service.py b/antarest/matrixstore/service.py index c7030160ad..c0a9d91788 100644 --- a/antarest/matrixstore/service.py +++ b/antarest/matrixstore/service.py @@ -57,10 +57,6 @@ class ISimpleMatrixService(ABC): def __init__(self, matrix_content_repository: MatrixContentRepository) -> None: self.matrix_content_repository = matrix_content_repository - @property - def bucket_dir(self) -> Path: - return self.matrix_content_repository.bucket_dir - @abstractmethod def create(self, data: Union[List[List[MatrixData]], npt.NDArray[np.float64]]) -> str: raise NotImplementedError() diff --git a/antarest/study/main.py b/antarest/study/main.py index 0758c6d070..83ad90dca3 100644 --- a/antarest/study/main.py +++ b/antarest/study/main.py @@ -80,8 +80,9 @@ def build_study_service( cache=cache, ) - generator_matrix_constants = generator_matrix_constants or GeneratorMatrixConstants(matrix_service=matrix_service) - generator_matrix_constants.init_constant_matrices() + if not generator_matrix_constants: + generator_matrix_constants = GeneratorMatrixConstants(matrix_service=matrix_service) + generator_matrix_constants.init_constant_matrices() command_factory = CommandFactory( generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, diff --git a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py index d297e0d2de..6a4dc233d4 100644 --- a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py +++ b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py @@ -49,7 +49,7 @@ ST_STORAGE_INFLOWS = EMPTY_SCENARIO_MATRIX MATRIX_PROTOCOL_PREFIX = "matrix://" -MATRIX_CONSTANT_INIT_LOCK_FILE_NAME = "matrix_constant_init.lock" +_LOCK_FILE_NAME = "matrix_constant_init.lock" # noinspection SpellCheckingInspection @@ -62,7 +62,7 @@ def __init__(self, matrix_service: ISimpleMatrixService) -> None: def init_constant_matrices( self, ) -> None: - with FileLock(str(Path(self._lock_dir) / MATRIX_CONSTANT_INIT_LOCK_FILE_NAME)): + with FileLock(str(Path(self._lock_dir) / _LOCK_FILE_NAME)): self.hashes[HYDRO_COMMON_CAPACITY_MAX_POWER_V7] = self.matrix_service.create( matrix_constants.hydro.v7.max_power ) diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index 1d05cb4887..671a78f267 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -485,7 +485,7 @@ def test_cancel_orphan_tasks( result_msg: str, ): max_diff_seconds: int = 1 - test_id: str = "test_cancel_orphan_tasks_id" + test_id: str = "2ea94758-9ea5-4015-a45f-b245a6ffc147" completion_date: datetime.datetime = datetime.datetime.utcnow() task_job = TaskJob( @@ -502,12 +502,12 @@ def test_cancel_orphan_tasks( cancel_orphan_tasks(engine=db_engine, session_args=SESSION_ARGS) with make_session() as session: if status in [TaskStatus.RUNNING.value, TaskStatus.PENDING.value]: - updated_task_job = ( + update_tasks_count = ( session.query(TaskJob) .filter(TaskJob.status.in_([TaskStatus.RUNNING.value, TaskStatus.PENDING.value])) - .all() + .count() ) - assert not updated_task_job + assert not update_tasks_count updated_task_job = session.query(TaskJob).get(test_id) assert updated_task_job.status == TaskStatus.FAILED.value assert not updated_task_job.result_status diff --git a/tests/login/test_login_service.py b/tests/login/test_login_service.py index a56f256003..b446a74fe7 100644 --- a/tests/login/test_login_service.py +++ b/tests/login/test_login_service.py @@ -112,7 +112,13 @@ def test_create_user(self, login_service: LoginService, param: RequestParameters @with_db_context @pytest.mark.parametrize( - "param, can_save", [(SITE_ADMIN, True), (GROUP_ADMIN, False), (USER3, False), (BAD_PARAM, False)] + "param, can_save", + [ + # (SITE_ADMIN, True), + (GROUP_ADMIN, False), + # (USER3, False), + # (BAD_PARAM, False), + ], ) def test_save_user(self, login_service: LoginService, param: RequestParameters, can_save: bool) -> None: create = UserCreateDTO(name="Laurent", password="S3cr3t") diff --git a/tests/login/test_model.py b/tests/login/test_model.py index 5b11c5d543..2dee1d994e 100644 --- a/tests/login/test_model.py +++ b/tests/login/test_model.py @@ -53,16 +53,14 @@ def test_init_admin_user_existing_group(self, db_engine: Engine): make_session = sessionmaker(bind=db_engine) with make_session() as session: group = Group(id=GROUP_ID, name=GROUP_NAME) - with contextlib.suppress(IntegrityError): - session.add(group) - session.commit() + session.add(group) + session.commit() init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD) def test_init_admin_user_existing_user(self, db_engine: Engine): make_session = sessionmaker(bind=db_engine) with make_session() as session: user = User(id=ADMIN_ID, name=ADMIN_NAME, password=Password(TEST_ADMIN_PASS_WORD)) - with contextlib.suppress(IntegrityError): - session.add(user) - session.commit() + session.add(user) + session.commit() init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD) diff --git a/tests/matrixstore/test_repository.py b/tests/matrixstore/test_repository.py index 5b787085c5..ba76b17fac 100644 --- a/tests/matrixstore/test_repository.py +++ b/tests/matrixstore/test_repository.py @@ -156,14 +156,12 @@ def test_datastore_query(self, db_session: Session) -> None: assert len(dataset_repo.query("name 2")) == 0 assert repo.get(m1.id) is not None assert ( - len( - # fmt: off - db_session - .query(MatrixDataSetRelation) - .filter(MatrixDataSetRelation.dataset_id == dataset.id) - .all() - # fmt: on - ) + # fmt: off + db_session + .query(MatrixDataSetRelation) + .filter(MatrixDataSetRelation.dataset_id == dataset.id) + .count() + # fmt: on == 0 ) diff --git a/tests/study/storage/variantstudy/test_variant_study_service.py b/tests/study/storage/variantstudy/test_variant_study_service.py index 8766bfd308..25317a9589 100644 --- a/tests/study/storage/variantstudy/test_variant_study_service.py +++ b/tests/study/storage/variantstudy/test_variant_study_service.py @@ -11,7 +11,6 @@ from antarest.core.requests import RequestParameters from antarest.core.utils.fastapi_sqlalchemy import db from antarest.login.model import Group, User -from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import RawStudy, StudyAdditionalData diff --git a/tests/variantstudy/model/test_variant_model.py b/tests/variantstudy/model/test_variant_model.py index 25efa4d7c2..63ac7293b8 100644 --- a/tests/variantstudy/model/test_variant_model.py +++ b/tests/variantstudy/model/test_variant_model.py @@ -75,7 +75,6 @@ def test_commands_service( variant_study_service: VariantStudyService, ) -> None: # Initialize the default matrix constants - # noinspection PyProtectedMember generator_matrix_constants.init_constant_matrices() params = RequestParameters(user=jwt_user) From c8adcd381b7ed09f8033e1db2689d97b64641282 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Wed, 29 Nov 2023 10:32:40 +0100 Subject: [PATCH 26/43] chore(db-init): remove comments in unit tests --- tests/login/conftest.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/login/conftest.py b/tests/login/conftest.py index 2e70b8168a..61bebee728 100644 --- a/tests/login/conftest.py +++ b/tests/login/conftest.py @@ -13,8 +13,6 @@ def group_repo_fixture(db_middleware: DBSessionMiddleware) -> GroupRepository: """Fixture that creates a GroupRepository instance.""" # note: `DBSessionMiddleware` is required to instantiate a thread-local db session. - # important: the `GroupRepository` insert an admin group in the database if it does not exist: - # >>> Group(id="admin", name="admin") return GroupRepository() @@ -23,8 +21,6 @@ def group_repo_fixture(db_middleware: DBSessionMiddleware) -> GroupRepository: def user_repo_fixture(db_middleware: DBSessionMiddleware) -> UserRepository: """Fixture that creates a UserRepository instance.""" # note: `DBSessionMiddleware` is required to instantiate a thread-local db session. - # important: the `UserRepository` insert an admin user in the database if it does not exist. - # >>> User(id=1, name="admin", password=Password(config.security.admin_pwd)) return UserRepository() @@ -49,8 +45,6 @@ def bot_repo_fixture(db_middleware: DBSessionMiddleware) -> BotRepository: def role_repo_fixture(db_middleware: DBSessionMiddleware) -> RoleRepository: """Fixture that creates a RoleRepository instance.""" # note: `DBSessionMiddleware` is required to instantiate a thread-local db session. - # important: the `RoleRepository` insert an admin role in the database if it does not exist. - # >>> Role(type=RoleType.ADMIN, identity=User(id=1), group=Group(id="admin")) return RoleRepository() From 5353896c19036c2cfa1bdc4f64e9d87d04d281dc Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Wed, 29 Nov 2023 10:53:52 +0100 Subject: [PATCH 27/43] test(loging-service): correct and improve the unit tests according the new init-db strategy --- tests/login/test_login_service.py | 1770 +++++++++++++++-------------- 1 file changed, 891 insertions(+), 879 deletions(-) diff --git a/tests/login/test_login_service.py b/tests/login/test_login_service.py index b446a74fe7..e48a54a918 100644 --- a/tests/login/test_login_service.py +++ b/tests/login/test_login_service.py @@ -1,13 +1,13 @@ import typing as t -from unittest.mock import Mock +from unittest.mock import patch import pytest -from fastapi import HTTPException from antarest.core.jwt import JWTGroup, JWTUser -from antarest.core.requests import RequestParameters, UserHasNotPermissionError +from antarest.core.requests import RequestParameters from antarest.core.roles import RoleType from antarest.login.model import ( + ADMIN_ID, Bot, BotCreateDTO, BotRoleCreateDTO, @@ -22,579 +22,654 @@ from antarest.login.service import LoginService from tests.helpers import with_db_context -SITE_ADMIN = RequestParameters( - user=JWTUser( - id=0, - impersonator=0, +# For the unit tests, we will define several fictitious users, groups and roles. + +GroupObj = t.TypedDict("GroupObj", {"id": str, "name": str}) +UserObj = t.TypedDict("UserObj", {"id": int, "name": str}) +RoleObj = t.TypedDict("RoleObj", {"type": RoleType, "group_id": str, "identity_id": int}) + + +_GROUPS: t.List[GroupObj] = [ + {"id": "admin", "name": "X-Men"}, + {"id": "superman", "name": "Superman"}, + {"id": "metropolis", "name": "Metropolis"}, +] + + +_USERS: t.List[UserObj] = [ + # main characters + {"id": ADMIN_ID, "name": "Professor Xavier"}, # site admin + {"id": 2, "name": "Clark Kent"}, # admin of "Superman" group + {"id": 3, "name": "Lois Lane"}, # reader in "Superman" group + {"id": 4, "name": "Joh Fredersen"}, # "Metropolis" leader + {"id": 5, "name": "Freder Fredersen"}, # reader in "Metropolis" group + # secondary characters + {"id": 50, "name": "Storm"}, # evil man in "X-Men" group + {"id": 60, "name": "Livewire"}, # evil woman in "Superman" group + {"id": 70, "name": "Maria"}, # robot in "Metropolis" group + {"id": 80, "name": "Jane DOE"}, # external user +] + +_ROLES: t.List[RoleObj] = [ + {"type": RoleType.ADMIN, "group_id": "admin", "identity_id": ADMIN_ID}, + {"type": RoleType.ADMIN, "group_id": "superman", "identity_id": 2}, + {"type": RoleType.READER, "group_id": "superman", "identity_id": 3}, + {"type": RoleType.ADMIN, "group_id": "metropolis", "identity_id": 4}, + {"type": RoleType.READER, "group_id": "metropolis", "identity_id": 5}, +] + + +def get_jwt_user(user: User, roles: t.Iterable[Role], owner_id: int = 0) -> JWTUser: + jwt_user = JWTUser( + id=user.id, + impersonator=owner_id or user.id, type="users", - groups=[JWTGroup(id="admin", name="admin", role=RoleType.ADMIN)], + groups=[JWTGroup(id=role.group.id, name=role.group.name, role=role.type) for role in roles], ) -) + return jwt_user -GROUP_ADMIN = RequestParameters( - user=JWTUser( - id=1, - impersonator=1, - type="users", - groups=[JWTGroup(id="group", name="group", role=RoleType.ADMIN)], - ) -) - -USER3 = RequestParameters( - user=JWTUser( - id=3, - impersonator=3, - type="users", - groups=[JWTGroup(id="group", name="group", role=RoleType.READER)], - ) -) -BAD_PARAM = RequestParameters(user=None) +def get_request_param( + user: t.Union[User, UserLdap, Bot], + role: t.Optional[Role], + owner_id: int = 0, +) -> RequestParameters: + if user is None: + return RequestParameters(user=None) + roles = (role,) if role else () + jwt_user = get_jwt_user(user, roles, owner_id=owner_id) + return RequestParameters(user=jwt_user) -class TestLoginService: - """ - Test login service. +def get_user_param(login_service: LoginService, user_id: int, group_id: str = "(unknown)") -> RequestParameters: + user = login_service.users.get(user_id) or login_service.ldap.get(user_id) + assert user is not None + role = login_service.roles.get(user_id, group_id) + return get_request_param(user, role) - important: - - the `GroupRepository` insert an admin group in the database if it does not exist: - `Group(id="admin", name="admin")` +def get_bot_param(login_service: LoginService, bot_id: int, group_id: str = "(unknown)") -> RequestParameters: + bot = login_service.bots.get(bot_id) + assert bot is not None + role = login_service.roles.get(bot_id, group_id) + return get_request_param(bot, role, owner_id=bot.owner) - - the `UserRepository` insert an admin user in the database if it does not exist. - `User(id=1, name="admin", password=Password(config.security.admin_pwd))` - - the `RoleRepository` insert an admin role in the database if it does not exist. - `Role(type=RoleType.ADMIN, identity=User(id=1), group=Group(id="admin"))` +class TestLoginService: + """ + Test login service. """ + @pytest.fixture(name="populate_db", autouse=True) @with_db_context - @pytest.mark.parametrize( - "param, can_create", [(SITE_ADMIN, True), (GROUP_ADMIN, True), (USER3, False), (BAD_PARAM, False)] - ) - def test_save_group(self, login_service: LoginService, param: RequestParameters, can_create: bool) -> None: - group = Group(id="group", name="group") - - # Only site admin and group admin can update a group - if can_create: - actual = login_service.save_group(group, param) - assert actual == group - else: - with pytest.raises(UserHasNotPermissionError): - login_service.save_group(group, param) - actual = login_service.groups.get(group.id) - assert actual is None - - # Users can't create a duplicate group - with pytest.raises(HTTPException): - login_service.save_group(group, param) + def populate_db_fixture(self, login_service: LoginService) -> None: + for group in _GROUPS: + login_service.groups.save(Group(**group)) + main_characters = (u for u in _USERS if u["id"] < 10) + for user in main_characters: + login_service.users.save(User(**user)) + for role in _ROLES: + group = t.cast(Group, login_service.groups.get(role["group_id"])) + user = t.cast(User, login_service.users.get(role["identity_id"])) + role = Role(**role, group=group, identity=user) + login_service.roles.save(role) @with_db_context - @pytest.mark.parametrize( - "param, can_create", [(SITE_ADMIN, True), (GROUP_ADMIN, False), (USER3, False), (BAD_PARAM, False)] - ) - def test_create_user(self, login_service: LoginService, param: RequestParameters, can_create: bool) -> None: - create = UserCreateDTO(name="hello", password="world") - - # Only site admin can create a user - if can_create: - actual = login_service.create_user(create, param) - assert actual.name == create.name - else: - with pytest.raises(UserHasNotPermissionError): - login_service.create_user(create, param) - actual = login_service.users.get_by_name(create.name) - assert actual is None - - # Users can't create a duplicate user - with pytest.raises(HTTPException): - login_service.create_user(create, param) + def test_save_group(self, login_service: LoginService) -> None: + # site admin can update any group + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + login_service.save_group(Group(id="superman", name="Poor Men"), _param) + actual = login_service.groups.get("superman") + assert actual is not None + assert actual.name == "Poor Men" + + # Group admin can update his own group + _param = get_user_param(login_service, user_id=2, group_id="superman") + login_service.save_group(Group(id="superman", name="Man of Steel"), _param) + actual = login_service.groups.get("superman") + assert actual is not None + assert actual.name == "Man of Steel" + + # Another user of the same group cannot update the group + _param = get_user_param(login_service, user_id=3, group_id="superman") + with pytest.raises(Exception): + login_service.save_group(Group(id="superman", name="Woman of Steel"), _param) + actual = login_service.groups.get("superman") + assert actual is not None + assert actual.name == "Man of Steel" # not updated + + # Group admin cannot update another group + _param = get_user_param(login_service, user_id=2, group_id="superman") + with pytest.raises(Exception): + login_service.save_group(Group(id="metropolis", name="Man of Steel"), _param) + actual = login_service.groups.get("metropolis") + assert actual is not None + assert actual.name == "Metropolis" # not updated @with_db_context - @pytest.mark.parametrize( - "param, can_save", - [ - # (SITE_ADMIN, True), - (GROUP_ADMIN, False), - # (USER3, False), - # (BAD_PARAM, False), - ], - ) - def test_save_user(self, login_service: LoginService, param: RequestParameters, can_save: bool) -> None: - create = UserCreateDTO(name="Laurent", password="S3cr3t") - user = login_service.create_user(create, SITE_ADMIN) - user.name = "Roland" + def test_create_user(self, login_service: LoginService) -> None: + # Site admin can create a user + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + login_service.create_user(UserCreateDTO(name="Laurent", password="S3cr3t"), _param) + actual = login_service.users.get_by_name("Laurent") + assert actual is not None + assert actual.name == "Laurent" + + # Group admin cannot create a user + _param = get_user_param(login_service, user_id=2, group_id="superman") + with pytest.raises(Exception): + login_service.create_user(UserCreateDTO(name="Alexandre", password="S3cr3t"), _param) + actual = login_service.users.get_by_name("Alexandre") + assert actual is None + + @with_db_context + def test_save_user(self, login_service: LoginService) -> None: + # Prepare a new user + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + user = login_service.create_user(UserCreateDTO(name="Laurentius", password="S3cr3t"), _param) # Only site admin can update a user - if can_save: - login_service.save_user(user, param) - actual = login_service.users.get_by_name(user.name) - assert actual == user - else: - with pytest.raises(UserHasNotPermissionError): - login_service.save_user(user, param) - actual = login_service.users.get_by_name(user.name) - assert actual != user + login_service.save_user(User(id=user.id, name="Lawrence"), _param) + actual = login_service.users.get(user.id) + assert actual is not None + assert actual.name == "Lawrence" + + # Group admin cannot update a user + _param = get_user_param(login_service, user_id=2, group_id="superman") + with pytest.raises(Exception): + login_service.save_user(User(id=user.id, name="Loran"), _param) + actual = login_service.users.get(user.id) + assert actual is not None + assert actual.name == "Lawrence" - @with_db_context - def test_save_user__themselves(self, login_service: LoginService) -> None: - user_create = UserCreateDTO(name="Laurent", password="S3cr3t") - user = login_service.create_user(user_create, SITE_ADMIN) - - # users can update themselves - param = RequestParameters( - user=JWTUser( - id=user.id, - impersonator=user.id, - type="users", - groups=[JWTGroup(id="group", name="group", role=RoleType.READER)], - ) - ) - user.name = "Roland" - actual = login_service.save_user(user, param) - assert actual == user + # A user can update himself + _param = get_user_param(login_service, user_id=user.id) + login_service.save_user(User(id=user.id, name="Loran"), _param) + actual = login_service.users.get(user.id) + assert actual is not None + assert actual.name == "Loran" @with_db_context def test_save_bot(self, login_service: LoginService) -> None: - # Prepare the user3 in the db - assert USER3.user is not None - user3 = User(id=USER3.user.id, name="Scoobydoo") - login_service.users.save(user3) - - # Prepare the user group and role - for jwt_group in USER3.user.groups: - group = Group(id=jwt_group.id, name=jwt_group.name) - login_service.groups.save(group) - role = Role(type=jwt_group.role, identity=user3, group=group) - login_service.roles.save(role) - - # Request parameters must reference a user - with pytest.raises(HTTPException): - login_service.save_bot(BotCreateDTO(name="bot", roles=[]), BAD_PARAM) - - # The user USER3 is a reader in the group "group" and can crate a bot with the same role - assert all(jwt_group.role == RoleType.READER for jwt_group in USER3.user.groups) - bot_create = BotCreateDTO(name="bot", roles=[BotRoleCreateDTO(group="group", role=RoleType.READER.value)]) - bot = login_service.save_bot(bot_create, USER3) - - assert bot.name == bot_create.name - assert bot.owner == USER3.user.id - assert bot.is_author is True - - # The user can't create a bot with an empty name - bot_create = BotCreateDTO(name="", roles=[BotRoleCreateDTO(group="group", role=RoleType.READER.value)]) - with pytest.raises(HTTPException): - login_service.save_bot(bot_create, USER3) - - # The user can't create a bot with a higher role than his own - for role_type in set(RoleType) - {RoleType.READER}: - bot_create = BotCreateDTO(name="bot", roles=[BotRoleCreateDTO(group="group", role=role_type.value)]) - with pytest.raises(UserHasNotPermissionError): - login_service.save_bot(bot_create, USER3) - - # The user can't create a bot that already exists - bot_create = BotCreateDTO(name="bot", roles=[BotRoleCreateDTO(group="group", role=RoleType.READER.value)]) - with pytest.raises(HTTPException): - login_service.save_bot(bot_create, USER3) + # Joh Fredersen can create Maria because he is the leader of Metropolis + _param = get_user_param(login_service, user_id=4, group_id="metropolis") + login_service.save_bot(BotCreateDTO(name="Maria I", roles=[]), _param) + actual: t.Sequence[Role] = login_service.bots.get_all_by_owner(4) + assert len(actual) == 1 + assert actual[0].name == "Maria I" + + # Freder Fredersen can create Maria with the reader role + _param = get_user_param(login_service, user_id=5, group_id="metropolis") + login_service.save_bot( + BotCreateDTO( + name="Maria II", + roles=[BotRoleCreateDTO(group="metropolis", role=RoleType.READER.value)], + ), + _param, + ) + actual = login_service.bots.get_all_by_owner(5) + assert len(actual) == 1 + assert actual[0].name == "Maria II" + + # Freder Fredersen cannot create Maria with the admin role + _param = get_user_param(login_service, user_id=5, group_id="metropolis") + with pytest.raises(Exception): + login_service.save_bot( + BotCreateDTO( + name="Maria III", + roles=[BotRoleCreateDTO(group="metropolis", role=RoleType.ADMIN.value)], + ), + _param, + ) + actual = login_service.bots.get_all_by_owner(5) + assert len(actual) == 1 + assert actual[0].name == "Maria II" + + # Freder Fredersen cannot create a bot with an empty name + _param = get_user_param(login_service, user_id=5, group_id="metropolis") + with pytest.raises(Exception): + login_service.save_bot( + BotCreateDTO( + name="", + roles=[BotRoleCreateDTO(group="metropolis", role=RoleType.ADMIN.value)], + ), + _param, + ) + actual = login_service.bots.get_all_by_owner(5) + assert len(actual) == 1 + assert actual[0].name == "Maria II" + + # Freder Fredersen cannot create a bot that already exists + _param = get_user_param(login_service, user_id=5, group_id="metropolis") + with pytest.raises(Exception): + login_service.save_bot( + BotCreateDTO( + name="Maria II", + roles=[BotRoleCreateDTO(group="metropolis", role=RoleType.ADMIN.value)], + ), + _param, + ) + actual = login_service.bots.get_all_by_owner(5) + assert len(actual) == 1 + assert actual[0].name == "Maria II" + + # Freder Fredersen cannot create a bot with an invalid group + _param = get_user_param(login_service, user_id=5, group_id="metropolis") + with pytest.raises(Exception): + login_service.save_bot( + BotCreateDTO( + name="Maria III", + roles=[BotRoleCreateDTO(group="metropolis2", role=RoleType.ADMIN.value)], + ), + _param, + ) + actual = login_service.bots.get_all_by_owner(5) + assert len(actual) == 1 + assert actual[0].name == "Maria II" + + # Bot's name cannot be empty + _param = get_user_param(login_service, user_id=4, group_id="metropolis") + with pytest.raises(Exception): + login_service.save_bot( + BotCreateDTO( + name="", + roles=[BotRoleCreateDTO(group="metropolis", role=RoleType.ADMIN.value)], + ), + _param, + ) - # The user can't create a bot with a group that does not exist - bot_create = BotCreateDTO(name="bot", roles=[BotRoleCreateDTO(group="unknown", role=RoleType.READER.value)]) - with pytest.raises(HTTPException): - login_service.save_bot(bot_create, USER3) + # Avoid duplicate bots + _param = get_user_param(login_service, user_id=4, group_id="metropolis") + with pytest.raises(Exception): + login_service.save_bot( + BotCreateDTO( + name="Maria I", + roles=[BotRoleCreateDTO(group="metropolis", role=RoleType.ADMIN.value)], + ), + _param, + ) @with_db_context - @pytest.mark.parametrize( - "param, can_save", [(SITE_ADMIN, True), (GROUP_ADMIN, True), (USER3, False), (BAD_PARAM, False)] - ) - def test_save_role(self, login_service: LoginService, param: RequestParameters, can_save: bool) -> None: - # Prepare the site admin in the db - assert SITE_ADMIN.user is not None - admin = User(id=SITE_ADMIN.user.id, name="Superman") - login_service.users.save(admin) - - # Prepare the group "group" in the db - # noinspection SpellCheckingInspection - group = Group(id="group", name="Kryptonians") - login_service.groups.save(group) - - # Only site admin and group admin can update a role - role = RoleCreationDTO(type=RoleType.ADMIN, identity_id=0, group_id="group") - if can_save: - actual = login_service.save_role(role, param) - assert actual.type == RoleType.ADMIN - assert actual.identity == admin - else: - with pytest.raises(UserHasNotPermissionError): - login_service.save_role(role, param) - actual = login_service.roles.get_all_by_group(group.id) - assert len(actual) == 0 + def test_save_role(self, login_service: LoginService) -> None: + # Prepare a new group and a new user + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + login_service.groups.save(Group(id="web", name="Spider Web")) + login_service.users.save(User(id=20, name="Spider-man")) + login_service.users.save(User(id=21, name="Spider-woman")) + + # The site admin can create a role + login_service.save_role( + RoleCreationDTO(type=RoleType.ADMIN, group_id="web", identity_id=20), + _param, + ) + actual = login_service.roles.get(20, "web") + assert actual is not None + assert actual.type == RoleType.ADMIN + + # The group admin can create a role + _param = get_user_param(login_service, user_id=20, group_id="web") + login_service.save_role( + RoleCreationDTO(type=RoleType.WRITER, group_id="web", identity_id=21), + _param, + ) + actual = login_service.roles.get(21, "web") + assert actual is not None + assert actual.type == RoleType.WRITER + + # The group admin cannot create a role with an invalid group + _param = get_user_param(login_service, user_id=20, group_id="web") + with pytest.raises(Exception): + login_service.save_role( + RoleCreationDTO(type=RoleType.WRITER, group_id="web2", identity_id=21), + _param, + ) + actual = login_service.roles.get(21, "web") + assert actual is not None + assert actual.type == RoleType.WRITER + + # The user cannot create a role + _param = get_user_param(login_service, user_id=21, group_id="web") + with pytest.raises(Exception): + login_service.save_role( + RoleCreationDTO(type=RoleType.READER, group_id="web", identity_id=20), + _param, + ) + actual = login_service.roles.get(20, "web") + assert actual is not None + assert actual.type == RoleType.ADMIN @with_db_context - @pytest.mark.parametrize( - "param, can_get", [(SITE_ADMIN, True), (GROUP_ADMIN, True), (USER3, True), (BAD_PARAM, False)] - ) - def test_get_group(self, login_service: LoginService, param: RequestParameters, can_get: bool) -> None: - # Prepare the group "group" in the db - # noinspection SpellCheckingInspection - group = Group(id="group", name="Vulcans") - login_service.groups.save(group) - - # Anybody except anonymous can get a group - if can_get: - actual = login_service.get_group("group", param) - assert actual == group - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_group(group.id, param) - - # noinspection SpellCheckingInspection + def test_get_group(self, login_service: LoginService) -> None: + # Site admin can get any group + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_group("superman", _param) + assert actual is not None + assert actual.name == "Superman" + + # Group admin can get his own group + _param = get_user_param(login_service, user_id=2, group_id="superman") + actual = login_service.get_group("superman", _param) + assert actual is not None + assert actual.name == "Superman" + + # Group admin cannot get another group + _param = get_user_param(login_service, user_id=2, group_id="superman") + with pytest.raises(Exception): + login_service.get_group("metropolis", _param) + + # Lois Lane can get its own group + _param = get_user_param(login_service, user_id=3, group_id="superman") + actual = login_service.get_group("superman", _param) + assert actual is not None + assert actual.id == "superman" + @with_db_context - @pytest.mark.parametrize( - "param, expected", - [ - ( - SITE_ADMIN, - { - "id": "group", - "name": "Vulcans", - "users": [ - {"id": 3, "name": "Spock", "role": RoleType.RUNNER}, - {"id": 4, "name": "Saavik", "role": RoleType.RUNNER}, - ], - }, - ), - ( - GROUP_ADMIN, - { - "id": "group", - "name": "Vulcans", - "users": [ - {"id": 3, "name": "Spock", "role": RoleType.RUNNER}, - {"id": 4, "name": "Saavik", "role": RoleType.RUNNER}, - ], - }, - ), - (USER3, {}), - (BAD_PARAM, {}), - ], - ) - def test_get_group_info( - self, - login_service: LoginService, - param: RequestParameters, - expected: t.Mapping[str, t.Any], - ) -> None: - # Prepare the group "group" in the db - # noinspection SpellCheckingInspection - group = Group(id="group", name="Vulcans") - login_service.groups.save(group) - - # Prepare the user3 in the db - assert USER3.user is not None - user3 = User(id=USER3.user.id, name="Spock") - login_service.users.save(user3) - - # Prepare an LDAP user named "Jane" with id=4 - user4 = UserLdap(id=4, name="Saavik") - login_service.users.save(user4) - - # Spock and Saavik are vulcans and can run simulations - role = Role(type=RoleType.RUNNER, identity=user3, group=group) - login_service.roles.save(role) - role = Role(type=RoleType.RUNNER, identity=user4, group=group) - login_service.roles.save(role) - - # Only site admin and group admin can get a group info - if expected: - actual = login_service.get_group_info("group", param) - assert actual.dict() == expected - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_group_info(group.id, param) + def test_get_group_info(self, login_service: LoginService) -> None: + # Site admin can get any group + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_group_info("superman", _param) + assert actual is not None + assert actual.name == "Superman" + assert [obj.dict() for obj in actual.users] == [ + {"id": 2, "name": "Clark Kent", "role": RoleType.ADMIN}, + {"id": 3, "name": "Lois Lane", "role": RoleType.READER}, + ] + + # Group admin can get his own group + _param = get_user_param(login_service, user_id=2, group_id="superman") + actual = login_service.get_group_info("superman", _param) + assert actual is not None + assert actual.name == "Superman" + + # Group admin cannot get another group + _param = get_user_param(login_service, user_id=2, group_id="superman") + with pytest.raises(Exception): + login_service.get_group_info("metropolis", _param) + + # Lois Lane cannot get its own group + _param = get_user_param(login_service, user_id=3, group_id="superman") + with pytest.raises(Exception): + login_service.get_group_info("superman", _param) @with_db_context - @pytest.mark.parametrize( - "param, can_get", [(SITE_ADMIN, True), (GROUP_ADMIN, True), (USER3, True), (BAD_PARAM, False)] - ) - def test_get_user(self, login_service: LoginService, param: RequestParameters, can_get: bool) -> None: - # Prepare a group of readers - group = Group(id="group", name="readers") - login_service.groups.save(group) - - # The user3 is a reader in the group "group" - user3 = User(id=USER3.user.id, name="Batman") - login_service.users.save(user3) - role = Role(type=RoleType.READER, identity=user3, group=group) - login_service.roles.save(role) - - # Anybody except anonymous can get the user3 - if can_get: - actual = login_service.get_user(user3.id, param) - assert actual == user3 - else: - # This function doesn't raise an exception if the user does not exist - actual = login_service.get_user(user3.id, param) - assert actual is None + def test_get_user(self, login_service: LoginService) -> None: + # Site admin can get any user + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_user(2, _param) + assert actual is not None + assert actual.name == "Clark Kent" + + # Group admin can get a user of his own group + _param = get_user_param(login_service, user_id=2, group_id="superman") + actual = login_service.get_user(3, _param) + assert actual is not None + assert actual.name == "Lois Lane" + + # Group admin cannot get a user of another group + _param = get_user_param(login_service, user_id=2, group_id="superman") + actual = login_service.get_user(5, _param) + assert actual is None + + # Lois Lane can get its own user + _param = get_user_param(login_service, user_id=3, group_id="superman") + actual = login_service.get_user(3, _param) + assert actual is not None + assert actual.name == "Lois Lane" + + # Create a bot for Lois Lane + _param = get_user_param(login_service, user_id=3, group_id="superman") + bot = login_service.save_bot(BotCreateDTO(name="Lois bot", roles=[]), _param) + + # The bot can get its owner + _param = get_bot_param(login_service, bot_id=bot.id) + actual = login_service.get_user(3, _param) + assert actual is not None + assert actual.name == "Lois Lane" @with_db_context def test_get_identity(self, login_service: LoginService) -> None: - # important: id=1 is the admin user - user = login_service.users.save(User(id=2, name="John")) - user_ldap = login_service.users.save(UserLdap(id=3, name="Jane")) - bot = login_service.users.save(Bot(id=4, name="my-app", owner=3, is_author=False)) + # Create the admin user "Storm" + storm = login_service.users.save(User(id=50, name="Storm")) + # Create the LDAP user "Jane DOE" + jane = login_service.users.save(UserLdap(id=60, name="Jane DOE")) + # Create the bot "Maria" + maria = login_service.users.save(Bot(id=70, name="Maria", owner=50, is_author=False)) - assert login_service.get_identity(2, include_token=False) == user - assert login_service.get_identity(3, include_token=False) == user_ldap - assert login_service.get_identity(4, include_token=False) is None + assert login_service.get_identity(50, include_token=False) == storm + assert login_service.get_identity(60, include_token=False) == jane + assert login_service.get_identity(70, include_token=False) is None - assert login_service.get_identity(2, include_token=True) == user - assert login_service.get_identity(3, include_token=True) == user_ldap - assert login_service.get_identity(4, include_token=True) == bot + assert login_service.get_identity(50, include_token=True) == storm + assert login_service.get_identity(60, include_token=True) == jane + assert login_service.get_identity(70, include_token=True) == maria @with_db_context - @pytest.mark.parametrize( - "param, expected", - [ - ( - SITE_ADMIN, + def test_get_user_info(self, login_service: LoginService) -> None: + # Site admin can get any user + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + clark_id = 2 + actual = login_service.get_user_info(clark_id, _param) + assert actual is not None + assert actual.dict() == { + "id": clark_id, + "name": "Clark Kent", + "roles": [ { - "id": 3, - "name": "Batman", - "roles": [ - { - "group_id": "group", - "group_name": "readers", - "identity_id": 3, - "type": RoleType.READER, - } - ], - }, - ), - ( - GROUP_ADMIN, + "group_id": "superman", + "group_name": "Superman", + "identity_id": clark_id, + "type": RoleType.ADMIN, + } + ], + } + + # Group admin can get a user of his own group + _param = get_user_param(login_service, user_id=clark_id, group_id="superman") + lois_id = 3 + actual = login_service.get_user_info(lois_id, _param) + assert actual is not None + assert actual.dict() == { + "id": lois_id, + "name": "Lois Lane", + "roles": [ { - "id": 3, - "name": "Batman", - "roles": [ - { - "group_id": "group", - "group_name": "readers", - "identity_id": 3, - "type": RoleType.READER, - } - ], - }, - ), - ( - USER3, + "group_id": "superman", + "group_name": "Superman", + "identity_id": lois_id, + "type": RoleType.READER, + } + ], + } + + # Group admin cannot get a user of another group + _param = get_user_param(login_service, user_id=clark_id, group_id="superman") + freder_id = 5 + actual = login_service.get_user_info(freder_id, _param) + assert actual is None + + # Lois Lane can get its own user info + _param = get_user_param(login_service, user_id=lois_id, group_id="superman") + actual = login_service.get_user_info(lois_id, _param) + assert actual is not None + assert actual.dict() == { + "id": lois_id, + "name": "Lois Lane", + "roles": [ { - "id": 3, - "name": "Batman", - "roles": [ - { - "group_id": "group", - "group_name": "readers", - "identity_id": 3, - "type": RoleType.READER, - } - ], - }, - ), - (BAD_PARAM, {}), - ], - ) - def test_get_user_info( - self, - login_service: LoginService, - param: RequestParameters, - expected: t.Mapping[str, t.Any], - ) -> None: - # Prepare a group of readers - group = Group(id="group", name="readers") - login_service.groups.save(group) - - # The user3 is a reader in the group "group" - user3 = User(id=USER3.user.id, name="Batman") - login_service.users.save(user3) - role = Role(type=RoleType.READER, identity=user3, group=group) - login_service.roles.save(role) - - # Anybody except anonymous can get the user3 - if expected: - actual = login_service.get_user_info(user3.id, param) - assert actual.dict() == expected - else: - # This function doesn't raise an exception if the user does not exist - actual = login_service.get_user_info(user3.id, param) - assert actual is None + "group_id": "superman", + "group_name": "Superman", + "identity_id": lois_id, + "type": RoleType.READER, + } + ], + } + + # Create a bot for Lois Lane + _param = get_user_param(login_service, user_id=lois_id, group_id="superman") + bot = login_service.save_bot(BotCreateDTO(name="Lois bot", roles=[]), _param) + + # The bot can get its owner + _param = get_bot_param(login_service, bot_id=bot.id) + actual = login_service.get_user_info(lois_id, _param) + assert actual is not None + assert actual.dict() == { + "id": lois_id, + "name": "Lois Lane", + "roles": [ + { + "group_id": "superman", + "group_name": "Superman", + "identity_id": lois_id, + "type": RoleType.READER, + } + ], + } @with_db_context - @pytest.mark.parametrize( - "param, can_get", [(SITE_ADMIN, True), (GROUP_ADMIN, False), (USER3, True), (BAD_PARAM, False)] - ) - def test_get_bot(self, login_service: LoginService, param: RequestParameters, can_get: bool) -> None: - # Prepare a user in the db, with id=4 - clark = User(id=3, name="Clark") - login_service.users.save(clark) - - # Prepare a bot in the db (using the ID or user3) - bot = Bot(id=4, name="Maria", owner=clark.id, is_author=True) - login_service.users.save(bot) - - # Only site admin and the owner can get a bot - if can_get: - actual = login_service.get_bot(bot.id, param) - assert actual == bot - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_bot(bot.id, param) + def test_get_bot(self, login_service: LoginService) -> None: + # Create a bot for Joh Fredersen + joh_id = 4 + _param = get_user_param(login_service, user_id=joh_id, group_id="metropolis") + joh_bot = login_service.save_bot(BotCreateDTO(name="Maria", roles=[]), _param) + + # The site admin can get any bot + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_bot(joh_bot.id, _param) + assert actual is not None + assert actual.name == "Maria" + + # Joh Fredersen can get its own bot + _param = get_user_param(login_service, user_id=joh_id, group_id="superman") + actual = login_service.get_bot(joh_bot.id, _param) + assert actual is not None + assert actual.name == "Maria" + + # The bot cannot get itself + _param = get_bot_param(login_service, bot_id=joh_bot.id) + with pytest.raises(Exception): + login_service.get_bot(joh_bot.id, _param) + + # Freder Fredersen cannot get the bot + freder_id = 5 + _param = get_user_param(login_service, user_id=freder_id, group_id="superman") + with pytest.raises(Exception): + login_service.get_bot(joh_bot.id, _param) @with_db_context - @pytest.mark.parametrize( - "param, expected", - [ - ( - SITE_ADMIN, - { - "id": 4, - "isAuthor": True, - "name": "Maria", - "roles": [ - { - "group_id": "Metropolis", - "group_name": "watchers", - "identity_id": 4, - "type": RoleType.READER, - } - ], - }, - ), - (GROUP_ADMIN, {}), - ( - USER3, - { - "id": 4, - "isAuthor": True, - "name": "Maria", - "roles": [ - { - "group_id": "Metropolis", - "group_name": "watchers", - "identity_id": 4, - "type": RoleType.READER, - } - ], - }, - ), - (BAD_PARAM, {}), - ], - ) - def test_get_bot_info( - self, - login_service: LoginService, - param: RequestParameters, - expected: t.Mapping[str, t.Any], - ) -> None: - # Prepare a user in the db, with id=4 - clark = User(id=3, name="Clark") - login_service.users.save(clark) - - # Prepare a bot in the db (using the ID or user3) - bot = Bot(id=4, name="Maria", owner=clark.id, is_author=True) - login_service.users.save(bot) - - # Prepare a group of readers - group = Group(id="Metropolis", name="watchers") - login_service.groups.save(group) - - # The user3 is a reader in the group "group" - role = Role(type=RoleType.READER, identity=bot, group=group) - login_service.roles.save(role) - - # Only site admin and the owner can get a bot - if expected: - actual = login_service.get_bot_info(bot.id, param) - assert actual is not None - assert actual.dict() == expected - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_bot_info(bot.id, param) + def test_get_bot_info(self, login_service: LoginService) -> None: + # Create a bot for Joh Fredersen + joh_id = 4 + _param = get_user_param(login_service, user_id=joh_id, group_id="superman") + joh_bot = login_service.save_bot(BotCreateDTO(name="Maria", roles=[]), _param) + + # The site admin can get any bot + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_bot_info(joh_bot.id, _param) + assert actual is not None + assert actual.dict() == {"id": 6, "isAuthor": True, "name": "Maria", "roles": []} + + # Joh Fredersen can get its own bot + _param = get_user_param(login_service, user_id=joh_id, group_id="superman") + actual = login_service.get_bot_info(joh_bot.id, _param) + assert actual is not None + assert actual.dict() == {"id": 6, "isAuthor": True, "name": "Maria", "roles": []} + + # The bot cannot get itself + _param = get_bot_param(login_service, bot_id=joh_bot.id) + with pytest.raises(Exception): + login_service.get_bot_info(joh_bot.id, _param) + + # Freder Fredersen cannot get the bot + freder_id = 5 + _param = get_user_param(login_service, user_id=freder_id, group_id="superman") + with pytest.raises(Exception): + login_service.get_bot_info(joh_bot.id, _param) + + # Freder Fredersen cannot get the bot + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + with pytest.raises(Exception): + login_service.get_bot_info(999, _param) @with_db_context - @pytest.mark.parametrize("param, expected", [(SITE_ADMIN, [5]), (GROUP_ADMIN, []), (USER3, [5]), (BAD_PARAM, [])]) - def test_get_all_bots_by_owner( - self, - login_service: LoginService, - param: RequestParameters, - expected: t.Mapping[str, t.Any], - ) -> None: - # add a user, an LDAP user and a bot in the db - user = User(id=3, name="John") - login_service.users.save(user) - user_ldap = UserLdap(id=4, name="Jane") - login_service.users.save(user_ldap) - bot = Bot(id=5, name="my-app", owner=3, is_author=False) - login_service.users.save(bot) - - if expected: - actual = login_service.get_all_bots_by_owner(3, param) - assert [b.id for b in actual] == expected - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_all_bots_by_owner(3, param) + def test_get_all_bots_by_owner(self, login_service: LoginService) -> None: + # Create a bot for Joh Fredersen + joh_id = 4 + _param = get_user_param(login_service, user_id=joh_id, group_id="superman") + joh_bot = login_service.save_bot(BotCreateDTO(name="Maria", roles=[]), _param) + + # The site admin can get any bot + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_all_bots_by_owner(joh_id, _param) + expected = [{"id": joh_bot.id, "is_author": True, "name": "Maria", "owner": joh_id}] + assert [obj.to_dto().dict() for obj in actual] == expected + + # Freder Fredersen can get its own bot + _param = get_user_param(login_service, user_id=joh_id, group_id="superman") + actual = login_service.get_all_bots_by_owner(joh_id, _param) + expected = [{"id": joh_bot.id, "is_author": True, "name": "Maria", "owner": joh_id}] + assert [obj.to_dto().dict() for obj in actual] == expected + + # The bot cannot get itself + _param = get_bot_param(login_service, bot_id=joh_bot.id) + with pytest.raises(Exception): + login_service.get_all_bots_by_owner(joh_id, _param) + + # Freder Fredersen cannot get the bot + freder_id = 5 + _param = get_user_param(login_service, user_id=freder_id, group_id="superman") + with pytest.raises(Exception): + login_service.get_all_bots_by_owner(joh_id, _param) @with_db_context def test_exists_bot(self, login_service: LoginService) -> None: - # Prepare the user3 in the db - assert USER3.user is not None - user3 = User(id=USER3.user.id, name="Clark") - login_service.users.save(user3) - - # Prepare a bot in the db (using the ID or user3) - bot = Bot(id=4, name="Maria", owner=user3.id, is_author=True) - login_service.users.save(bot) + # Create a bot for Joh Fredersen + joh_id = 4 + _param = get_user_param(login_service, user_id=joh_id, group_id="superman") + joh_bot = login_service.save_bot(BotCreateDTO(name="Maria", roles=[]), _param) # Everybody can check the existence of a bot - assert login_service.exists_bot(4) - assert not login_service.exists_bot(5) # unknown ID - assert not login_service.exists_bot(3) # user ID, not bot ID + assert login_service.exists_bot(joh_id) is False, "not a bot" + assert login_service.exists_bot(joh_bot.id) is True + assert login_service.exists_bot(999) is False @with_db_context - def test_authenticate__unknown_user(self, login_service: LoginService) -> None: - # An unknown user cannot log in - user = login_service.authenticate(name="unknown", pwd="S3cr3t") - assert user is None - - @with_db_context - def test_authenticate__known_user(self, login_service: LoginService) -> None: - # Create a user named "Tarzan" in the group "Adventure" - group = Group(id="adventure", name="Adventure") - login_service.groups.save(group) - user = User(id=3, name="Tarzan", password=Password("S3cr3t")) - login_service.users.save(user) - role = Role(type=RoleType.READER, identity=user, group=group) - login_service.roles.save(role) + def test_authenticate(self, login_service: LoginService) -> None: + # Update the password of "Lois Lane" + lois_id = 3 + login_service.users.save(User(id=lois_id, name="Lois Lane", password=Password("S3cr3t"))) # A known user can log in - jwt_user = login_service.authenticate(name="Tarzan", pwd="S3cr3t") + jwt_user = login_service.authenticate(name="Lois Lane", pwd="S3cr3t") assert jwt_user is not None - assert jwt_user.id == user.id - assert jwt_user.impersonator == user.id + assert jwt_user.id == lois_id + assert jwt_user.impersonator == lois_id assert jwt_user.type == "users" - assert jwt_user.groups == [JWTGroup(id="adventure", name="Adventure", role=RoleType.READER)] + assert jwt_user.groups == [ + JWTGroup(id="superman", name="Superman", role=RoleType.READER), + ] - @with_db_context - def test_authenticate__external_user(self, login_service: LoginService) -> None: - # Create a user named "Tarzan" - user_ldap = UserLdap(id=3, name="Tarzan", external_id="tarzan", firstname="Tarzan", lastname="Jungle") + # An unknown user cannot log in + user = login_service.authenticate(name="unknown", pwd="S3cr3t") + assert user is None + + # Update the user "Jane DOE" which is an LDAP user + jane_id = 60 + user_ldap = UserLdap( + id=jane_id, + name="Jane DOE", + external_id="j.doe", + firstname="Jane", + lastname="DOE", + ) login_service.users.save(user_ldap) # Mock the LDAP service - login_service.ldap.login = Mock(return_value=user_ldap) # type: ignore - login_service.ldap.get = Mock(return_value=user_ldap) # type: ignore + with patch("antarest.login.ldap.LdapService.login") as mock_login: + mock_login.return_value = user_ldap + with patch("antarest.login.ldap.LdapService.login") as mock_get: + mock_get.return_value = user_ldap + jwt_user = login_service.authenticate(name="Jane DOE", pwd="S3cr3t") - # A known user can log in - jwt_user = login_service.authenticate(name="Tarzan", pwd="S3cr3t") assert jwt_user is not None assert jwt_user.id == user_ldap.id assert jwt_user.impersonator == user_ldap.id @@ -603,32 +678,30 @@ def test_authenticate__external_user(self, login_service: LoginService) -> None: @with_db_context def test_get_jwt(self, login_service: LoginService) -> None: - # Prepare the user3 in the db - assert USER3.user is not None - user3 = User(id=USER3.user.id, name="Clark") - login_service.users.save(user3) - - # Attach a group to the user - group = Group(id="group", name="readers") - login_service.groups.save(group) - role = Role(type=RoleType.READER, identity=user3, group=group) - login_service.roles.save(role) - - # Prepare an LDAP user in the db - user_ldap = UserLdap(id=4, name="Jane") + # Create a bot for Joh Fredersen + joh_id = 4 + _param = get_user_param(login_service, user_id=joh_id, group_id="superman") + joh_bot = login_service.save_bot(BotCreateDTO(name="Maria", roles=[]), _param) + + # Update the user "Jane DOE" which is an LDAP user + jane_id = 60 + user_ldap = UserLdap( + id=jane_id, + name="Jane DOE", + external_id="j.doe", + firstname="Jane", + lastname="DOE", + ) login_service.users.save(user_ldap) - # Prepare a bot in the db (using the ID or user3) - bot = Bot(id=5, name="Maria", owner=user3.id, is_author=True) - login_service.users.save(bot) - # We can get a JWT for a user, an LDAP user, but not a bot - jwt_user = login_service.get_jwt(user3.id) + lois_id = 3 + jwt_user = login_service.get_jwt(lois_id) assert jwt_user is not None - assert jwt_user.id == user3.id - assert jwt_user.impersonator == user3.id + assert jwt_user.id == lois_id + assert jwt_user.impersonator == lois_id assert jwt_user.type == "users" - assert jwt_user.groups == [JWTGroup(id="group", name="readers", role=RoleType.READER)] + assert jwt_user.groups == [JWTGroup(id="superman", name="Superman", role=RoleType.READER)] jwt_user = login_service.get_jwt(user_ldap.id) assert jwt_user is not None @@ -637,367 +710,306 @@ def test_get_jwt(self, login_service: LoginService) -> None: assert jwt_user.type == "users_ldap" assert jwt_user.groups == [] - jwt_user = login_service.get_jwt(bot.id) + jwt_user = login_service.get_jwt(joh_bot.id) assert jwt_user is None @with_db_context - @pytest.mark.parametrize( - "param, expected", - [ - ( - SITE_ADMIN, - [ - {"id": "admin", "name": "admin"}, - {"id": "gr1", "name": "Adventure"}, - {"id": "gr2", "name": "Comedy"}, - ], - ), - ( - GROUP_ADMIN, - [ - {"id": "admin", "name": "admin"}, - {"id": "gr2", "name": "Comedy"}, - ], - ), - ( - USER3, - [ - {"id": "gr1", "name": "Adventure"}, - ], - ), - (BAD_PARAM, []), - ], - ) - def test_get_all_groups( - self, - login_service: LoginService, - param: RequestParameters, - expected: t.Sequence[t.Mapping[str, str]], - ) -> None: - # Prepare some groups in the db - group1 = Group(id="gr1", name="Adventure") - login_service.groups.save(group1) - group2 = Group(id="gr2", name="Comedy") - login_service.groups.save(group2) - - # The group admin is a reader in the group "gr2" - assert GROUP_ADMIN.user is not None - robin_hood = User(id=GROUP_ADMIN.user.id, name="Robin") - login_service.users.save(robin_hood) - role = Role(type=RoleType.READER, identity=robin_hood, group=group2) - login_service.roles.save(role) - - # The user3 is a reader in the group "gr1" - assert USER3.user is not None - indiana_johns = User(id=USER3.user.id, name="Indiana") - login_service.users.save(indiana_johns) - role = Role(type=RoleType.READER, identity=indiana_johns, group=group1) - login_service.roles.save(role) - - # Anybody except anonymous can get the list of groups - if expected: - # The site admin can see all groups - actual = login_service.get_all_groups(param) - assert [g.dict() for g in actual] == expected - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_all_groups(param) + def test_get_all_groups(self, login_service: LoginService) -> None: + # The site admin can get all groups + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_all_groups(_param) + assert [g.dict() for g in actual] == [ + {"id": "admin", "name": "X-Men"}, + {"id": "superman", "name": "Superman"}, + {"id": "metropolis", "name": "Metropolis"}, + ] + + # The group admin can its own groups + _param = get_user_param(login_service, user_id=2, group_id="superman") + actual = login_service.get_all_groups(_param) + assert [g.dict() for g in actual] == [{"id": "superman", "name": "Superman"}] + + # The user can get its own groups + _param = get_user_param(login_service, user_id=3, group_id="superman") + actual = login_service.get_all_groups(_param) + assert [g.dict() for g in actual] == [{"id": "superman", "name": "Superman"}] @with_db_context - @pytest.mark.parametrize( - "param, expected", - [ - ( - SITE_ADMIN, - [ - {"id": 0, "name": "Superman"}, - {"id": 1, "name": "John"}, - {"id": 2, "name": "Jane"}, - {"id": 3, "name": "Tarzan"}, - ], - ), - ( - GROUP_ADMIN, - [ - {"id": 1, "name": "John"}, - ], - ), - ( - USER3, - [ - {"id": 3, "name": "Tarzan"}, - ], - ), - (BAD_PARAM, []), - ], - ) - def test_get_all_users( - self, - login_service: LoginService, - param: RequestParameters, - expected: t.Sequence[t.Mapping[str, t.Union[int, str]]], - ) -> None: - # Insert some users in the db - user0 = User(id=0, name="Superman") - login_service.users.save(user0) - user1 = User(id=1, name="John") - login_service.users.save(user1) - user2 = User(id=2, name="Jane") - login_service.users.save(user2) - user3 = User(id=3, name="Tarzan") - login_service.users.save(user3) - - # user3 is a reader in the group "group" - group = Group(id="group", name="readers") - login_service.groups.save(group) - role = Role(type=RoleType.READER, identity=user3, group=group) - login_service.roles.save(role) - - # Anybody except anonymous can get the list of users - if expected: - actual = login_service.get_all_users(param) - assert [u.dict() for u in actual] == expected - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_all_users(param) + def test_get_all_users(self, login_service: LoginService) -> None: + # The site admin can get all users + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_all_users(_param) + assert [u.dict() for u in actual] == [ + {"id": 1, "name": "Professor Xavier"}, + {"id": 2, "name": "Clark Kent"}, + {"id": 3, "name": "Lois Lane"}, + {"id": 4, "name": "Joh Fredersen"}, + {"id": 5, "name": "Freder Fredersen"}, + ] + + # The group admin can get its own users, but also the users of the other groups + # note: I don't know why the group admin can get all users -- Laurent + _param = get_user_param(login_service, user_id=2, group_id="superman") + actual = login_service.get_all_users(_param) + assert [u.dict() for u in actual] == [ + {"id": 1, "name": "Professor Xavier"}, + {"id": 2, "name": "Clark Kent"}, + {"id": 3, "name": "Lois Lane"}, + {"id": 4, "name": "Joh Fredersen"}, + {"id": 5, "name": "Freder Fredersen"}, + ] + + # The user can get its own users + _param = get_user_param(login_service, user_id=3, group_id="superman") + actual = login_service.get_all_users(_param) + assert [u.dict() for u in actual] == [ + {"id": 2, "name": "Clark Kent"}, + {"id": 3, "name": "Lois Lane"}, + ] @with_db_context - @pytest.mark.parametrize( - "param, expected", - [ - (SITE_ADMIN, [5]), - (GROUP_ADMIN, []), - (USER3, []), - (BAD_PARAM, []), - ], - ) - def test_get_all_bots( - self, - login_service: LoginService, - param: RequestParameters, - expected: t.Sequence[int], - ) -> None: - # add a user, an LDAP user and a bot in the db - user = User(id=3, name="John") - login_service.users.save(user) - user_ldap = UserLdap(id=4, name="Jane") - login_service.users.save(user_ldap) - bot = Bot(id=5, name="my-app", owner=3, is_author=False) - login_service.users.save(bot) - - if expected: - actual = login_service.get_all_bots(param) - assert [b.id for b in actual] == expected - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_all_bots(param) + def test_get_all_bots(self, login_service: LoginService) -> None: + # Create a bot for Joh Fredersen + joh_id = 4 + _param = get_user_param(login_service, user_id=joh_id, group_id="superman") + joh_bot = login_service.save_bot(BotCreateDTO(name="Maria", roles=[]), _param) + + # The site admin can get all bots + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_all_bots(_param) + assert [b.to_dto().dict() for b in actual] == [ + {"id": joh_bot.id, "is_author": True, "name": "Maria", "owner": joh_id}, + ] + + # The group admin cannot access the list of bots + _param = get_user_param(login_service, user_id=2, group_id="superman") + with pytest.raises(Exception): + login_service.get_all_bots(_param) + + # The user cannot access the list of bots + _param = get_user_param(login_service, user_id=3, group_id="superman") + with pytest.raises(Exception): + login_service.get_all_bots(_param) @with_db_context - @pytest.mark.parametrize( - "param, expected", - [ - (SITE_ADMIN, [(3, "group")]), - (GROUP_ADMIN, [(3, "group")]), - (USER3, []), - (BAD_PARAM, []), - ], - ) - def test_get_all_roles_in_group( - self, - login_service: LoginService, - param: RequestParameters, - expected: t.Sequence[t.Tuple[int, str]], - ) -> None: - # Insert some users in the db - user0 = User(id=0, name="Superman") - login_service.users.save(user0) - user1 = User(id=1, name="John") - login_service.users.save(user1) - user2 = User(id=2, name="Jane") - login_service.users.save(user2) - user3 = User(id=3, name="Tarzan") - login_service.users.save(user3) - - # user3 is a reader in the group "group" - group = Group(id="group", name="readers") - login_service.groups.save(group) - role = Role(type=RoleType.READER, identity=user3, group=group) - login_service.roles.save(role) - - # The site admin and the group admin can get the list of roles in a group - if expected: - actual = login_service.get_all_roles_in_group("group", param) - assert [(r.identity_id, r.group_id) for r in actual] == expected - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_all_roles_in_group("group", param) + def test_get_all_roles_in_group(self, login_service: LoginService) -> None: + # The site admin can get all roles in a given group + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_all_roles_in_group("superman", _param) + assert [b.to_dto().dict() for b in actual] == [ + { + "group": {"id": "superman", "name": "Superman"}, + "identity": {"id": 2, "name": "Clark Kent"}, + "type": RoleType.ADMIN, + }, + { + "group": {"id": "superman", "name": "Superman"}, + "identity": {"id": 3, "name": "Lois Lane"}, + "type": RoleType.READER, + }, + ] + + # The group admin can get all roles his own group + _param = get_user_param(login_service, user_id=2, group_id="superman") + actual = login_service.get_all_roles_in_group("superman", _param) + assert [b.to_dto().dict() for b in actual] == [ + { + "group": {"id": "superman", "name": "Superman"}, + "identity": {"id": 2, "name": "Clark Kent"}, + "type": RoleType.ADMIN, + }, + { + "group": {"id": "superman", "name": "Superman"}, + "identity": {"id": 3, "name": "Lois Lane"}, + "type": RoleType.READER, + }, + ] + + # The user cannot access the list of roles + _param = get_user_param(login_service, user_id=3, group_id="superman") + with pytest.raises(Exception): + login_service.get_all_roles_in_group("superman", _param) @with_db_context - @pytest.mark.parametrize( - "param, can_delete", - [ - (SITE_ADMIN, True), - (GROUP_ADMIN, True), - (USER3, False), - (BAD_PARAM, False), - ], - ) - def test_delete_group(self, login_service: LoginService, param: RequestParameters, can_delete: bool) -> None: - # Insert a group in the db - group = Group(id="group", name="readers") - login_service.groups.save(group) - - # The site admin and the group admin can delete a group - if can_delete: - login_service.delete_group("group", param) - actual = login_service.groups.get("group") - assert actual is None - else: - with pytest.raises(UserHasNotPermissionError): - login_service.delete_group("group", param) - actual = login_service.groups.get("group") - assert actual is not None + def test_delete_group(self, login_service: LoginService) -> None: + # Create new groups for Lois Lane (3) and Freder Fredersen (5) + group1 = login_service.groups.save(Group(id="g1", name="Group I")) + group2 = login_service.groups.save(Group(id="g2", name="Group II")) + group3 = login_service.groups.save(Group(id="g3", name="Group III")) + + lois = t.cast(User, login_service.users.get(3)) # group admin + freder = t.cast(User, login_service.users.get(5)) # user + + login_service.roles.save(Role(type=RoleType.ADMIN, group=group1, identity=lois)) + login_service.roles.save(Role(type=RoleType.READER, group=group1, identity=freder)) + login_service.roles.save(Role(type=RoleType.ADMIN, group=group2, identity=lois)) + login_service.roles.save(Role(type=RoleType.WRITER, group=group2, identity=freder)) + login_service.roles.save(Role(type=RoleType.ADMIN, group=group3, identity=lois)) + login_service.roles.save(Role(type=RoleType.RUNNER, group=group3, identity=freder)) + + # The site admin can delete any group + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + login_service.delete_group("g1", _param) + assert login_service.groups.get(group1.id) is None + + # The group admin can delete his own group + _param = get_user_param(login_service, user_id=3, group_id="g2") + login_service.delete_group("g2", _param) + assert login_service.groups.get(group2.id) is None + + # The user cannot delete a group + _param = get_user_param(login_service, user_id=5, group_id="g3") + with pytest.raises(Exception): + login_service.delete_group("g3", _param) + assert login_service.groups.get(group3.id) is not None @with_db_context - @pytest.mark.parametrize( - "param, can_delete", - [ - (SITE_ADMIN, True), - (GROUP_ADMIN, False), - (USER3, False), - (BAD_PARAM, False), - ], - ) - def test_delete_user(self, login_service: LoginService, param: RequestParameters, can_delete: bool) -> None: - # Insert a user in the db which is an owner of a bot - user = User(id=3, name="John") - login_service.users.save(user) - bot = Bot(id=4, name="my-app", owner=3, is_author=False) - login_service.users.save(bot) - - # The site admin can delete the user - if can_delete: - login_service.delete_user(3, param) - actual = login_service.users.get(3) - assert actual is None - else: - with pytest.raises(UserHasNotPermissionError): - login_service.delete_user(3, param) - actual = login_service.users.get(3) - assert actual is not None + def test_delete_user(self, login_service: LoginService) -> None: + # Create Joh's bot + joh_id = 4 + _param = get_user_param(login_service, user_id=joh_id, group_id="metropolis") + joh_bot = login_service.save_bot(BotCreateDTO(name="Maria", roles=[]), _param) + + # The site admin can delete Fredersen (5) + freder_id = 5 + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + login_service.delete_user(freder_id, _param) + assert login_service.users.get(freder_id) is None + + # The group admin Joh can delete himself (4) + _param = get_user_param(login_service, user_id=joh_id, group_id="metropolis") + login_service.delete_user(joh_id, _param) + assert login_service.users.get(joh_id) is None + assert login_service.bots.get(joh_bot.id) is None + + # Lois Lane cannot delete Clark Kent (2) + lois_id = 3 + clark_id = 2 + _param = get_user_param(login_service, user_id=lois_id, group_id="superman") + with pytest.raises(Exception): + login_service.delete_user(clark_id, _param) + assert login_service.users.get(clark_id) is not None + + # Clark Kent cannot delete Lois Lane (3) + _param = get_user_param(login_service, user_id=clark_id, group_id="superman") + with pytest.raises(Exception): + login_service.delete_user(lois_id, _param) + assert login_service.users.get(lois_id) is not None @with_db_context - @pytest.mark.parametrize( - "param, can_delete", - [ - (SITE_ADMIN, True), - (GROUP_ADMIN, False), - (USER3, True), - (BAD_PARAM, False), - ], - ) - def test_delete_bot(self, login_service: LoginService, param: RequestParameters, can_delete: bool) -> None: - # Insert a user in the db which is an owner of a bot - user = User(id=3, name="John") - login_service.users.save(user) - bot = Bot(id=4, name="my-app", owner=3, is_author=False) - login_service.users.save(bot) - - # The site admin and the current owner can delete the bot - if can_delete: - login_service.delete_bot(4, param) - actual = login_service.bots.get(4) - assert actual is None - else: - with pytest.raises(UserHasNotPermissionError): - login_service.delete_bot(4, param) - actual = login_service.bots.get(4) - assert actual is not None + def test_delete_bot(self, login_service: LoginService) -> None: + # Create Joh's bot + joh_id = 4 + _param = get_user_param(login_service, user_id=joh_id, group_id="metropolis") + joh_bot = login_service.save_bot(BotCreateDTO(name="Maria", roles=[]), _param) + + # The site admin can delete the bot + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + login_service.delete_bot(joh_bot.id, _param) + assert login_service.bots.get(joh_bot.id) is None + + # Create Lois's bot + lois_id = 3 + _param = get_user_param(login_service, user_id=lois_id, group_id="superman") + lois_bot = login_service.save_bot(BotCreateDTO(name="Lois bot", roles=[]), _param) + + # The group admin cannot delete the bot + clark_id = 2 + _param = get_user_param(login_service, user_id=clark_id, group_id="superman") + with pytest.raises(Exception): + login_service.delete_bot(lois_bot.id, _param) + assert login_service.bots.get(lois_bot.id) is not None + + # Create Freder's bot + freder_id = 5 + _param = get_user_param(login_service, user_id=freder_id, group_id="metropolis") + freder_bot = login_service.save_bot(BotCreateDTO(name="Freder bot", roles=[]), _param) + + # Freder can delete his own bot + _param = get_user_param(login_service, user_id=freder_id, group_id="metropolis") + login_service.delete_bot(freder_bot.id, _param) + assert login_service.bots.get(freder_bot.id) is None + + # Freder cannot delete Lois's bot + _param = get_user_param(login_service, user_id=freder_id, group_id="metropolis") + with pytest.raises(Exception): + login_service.delete_bot(lois_bot.id, _param) + assert login_service.bots.get(lois_bot.id) is not None @with_db_context - @pytest.mark.parametrize( - "param, can_delete", - [ - (SITE_ADMIN, True), - (GROUP_ADMIN, True), - (USER3, False), - (BAD_PARAM, False), - ], - ) - def test_delete_role(self, login_service: LoginService, param: RequestParameters, can_delete: bool) -> None: - # Insert the user3 in the db - user = User(id=3, name="Tarzan") - login_service.users.save(user) - - # Insert a group in the db - group = Group(id="group", name="readers") - login_service.groups.save(group) - - # Insert a role in the db - role = Role(type=RoleType.READER, identity=user, group=group) - login_service.roles.save(role) - - # The site admin and the group admin can delete a role - if can_delete: - login_service.delete_role(3, "group", param) - actual = login_service.roles.get_all_by_group("group") - assert len(actual) == 0 - else: - with pytest.raises(UserHasNotPermissionError): - login_service.delete_role(3, "group", param) - actual = login_service.roles.get_all_by_group("group") - assert len(actual) == 1 + def test_delete_role(self, login_service: LoginService) -> None: + # Create a new group + group = login_service.groups.save(Group(id="g1", name="Group I")) + + # Create a new user + user = login_service.users.save(User(id=10, name="User 1")) + + # Create a new role + role = login_service.roles.save(Role(type=RoleType.ADMIN, group=group, identity=user)) + + # The site admin can delete any role + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + login_service.delete_role(role.identity.id, role.group.id, _param) + assert login_service.roles.get(role.identity.id, role.group.id) is None + + # Create a new role + role = login_service.roles.save(Role(type=RoleType.ADMIN, group=group, identity=user)) + + # The group admin can delete a role of his own group + _param = get_user_param(login_service, user_id=user.id, group_id="g1") + login_service.delete_role(role.identity.id, role.group.id, _param) + assert login_service.roles.get(role.identity.id, role.group.id) is None + + # Create a new role + role = login_service.roles.save(Role(type=RoleType.ADMIN, group=group, identity=user)) + + # The group admin cannot delete a role of another group + _param = get_user_param(login_service, user_id=2, group_id="superman") + with pytest.raises(Exception): + login_service.delete_role(role.identity.id, "g1", _param) + assert login_service.roles.get(role.identity.id, "g1") is not None + + # The user cannot delete a role + _param = get_user_param(login_service, user_id=1, group_id="g1") + with pytest.raises(Exception): + login_service.delete_role(role.identity.id, role.group.id, _param) + assert login_service.roles.get(role.identity.id, role.group.id) is not None @with_db_context - @pytest.mark.parametrize( - "param, can_delete", - [ - (SITE_ADMIN, True), - (GROUP_ADMIN, True), - (USER3, False), - (BAD_PARAM, False), - ], - ) - def test_delete_all_roles_from_user( - self, login_service: LoginService, param: RequestParameters, can_delete: bool - ) -> None: - # Insert the user3 in the db - assert USER3.user is not None - user = User(id=USER3.user.id, name="Tarzan") - login_service.users.save(user) - - # Insert a group in the db - group = Group(id="group", name="readers") - login_service.groups.save(group) - - # Insert a role in the db - role = Role(type=RoleType.READER, identity=user, group=group) - login_service.roles.save(role) - - # Insert the group admin in the db - assert GROUP_ADMIN.user is not None - group_admin = User(id=GROUP_ADMIN.user.id, name="John") - login_service.users.save(group_admin) - - # Insert another group in the db - group2 = Group(id="group2", name="readers") - login_service.groups.save(group2) - - # Insert a role in the db - role2 = Role(type=RoleType.READER, identity=group_admin, group=group2) - login_service.roles.save(role2) - - # The site admin and the group admin can delete a role - if can_delete: - login_service.delete_all_roles_from_user(3, param) - actual = login_service.roles.get_all_by_group("group") - assert len(actual) == 0 - actual = login_service.roles.get_all_by_group("group2") - assert len(actual) == 1 - else: - with pytest.raises(UserHasNotPermissionError): - login_service.delete_all_roles_from_user(3, param) - actual = login_service.roles.get_all_by_group("group") - assert len(actual) == 1 - actual = login_service.roles.get_all_by_group("group2") - assert len(actual) == 1 + def test_delete_all_roles_from_user(self, login_service: LoginService) -> None: + # Create a new group + group = login_service.groups.save(Group(id="g1", name="Group I")) + + # Create a new user + user = login_service.users.save(User(id=10, name="User 1")) + + # Create a new role + role = login_service.roles.save(Role(type=RoleType.ADMIN, group=group, identity=user)) + + # The site admin can delete any role + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + login_service.delete_all_roles_from_user(user.id, _param) + assert login_service.roles.get(role.identity.id, role.group.id) is None + + # Create a new role + role = login_service.roles.save(Role(type=RoleType.ADMIN, group=group, identity=user)) + + # The group admin can delete a role of his own group + _param = get_user_param(login_service, user_id=user.id, group_id="g1") + login_service.delete_all_roles_from_user(user.id, _param) + assert login_service.roles.get(role.identity.id, role.group.id) is None + + # Create a new role + role = login_service.roles.save(Role(type=RoleType.ADMIN, group=group, identity=user)) + + # The group admin cannot delete a role of another group + _param = get_user_param(login_service, user_id=2, group_id="superman") + with pytest.raises(Exception): + login_service.delete_all_roles_from_user(user.id, _param) + assert login_service.roles.get(role.identity.id, role.group.id) is not None + + # The user cannot delete a role + _param = get_user_param(login_service, user_id=1, group_id="g1") + with pytest.raises(Exception): + login_service.delete_all_roles_from_user(user.id, _param) + assert login_service.roles.get(role.identity.id, role.group.id) is not None From d8ca77c5c9d40e983dc83972e08c5cc26c5db864 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Tue, 5 Dec 2023 23:33:29 +0100 Subject: [PATCH 28/43] docs(db-init): improve docstring in database model --- antarest/core/tasks/model.py | 6 ++++++ antarest/login/model.py | 8 ++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/antarest/core/tasks/model.py b/antarest/core/tasks/model.py index cb92032445..1d7a9e1566 100644 --- a/antarest/core/tasks/model.py +++ b/antarest/core/tasks/model.py @@ -176,10 +176,16 @@ def __repr__(self) -> str: def cancel_orphan_tasks(engine: Engine, session_args: Mapping[str, bool]) -> None: """ + Cancel all tasks that are currently running or pending. + When the web application restarts, such as after a new deployment, any pending or running tasks may be lost. To mitigate this, it is preferable to set these tasks to a "FAILED" status. This ensures that users can easily identify the tasks that were affected by the restart and take appropriate actions, such as restarting the tasks manually. + + Args: + engine: The database engine (SQLAlchemy connection to SQLite or PostgreSQL). + session_args: The session arguments (SQLAlchemy session arguments). """ updated_values = { TaskJob.status: TaskStatus.FAILED.value, diff --git a/antarest/login/model.py b/antarest/login/model.py index f56230464a..5012a4995c 100644 --- a/antarest/login/model.py +++ b/antarest/login/model.py @@ -302,8 +302,12 @@ class CredentialsDTO(BaseModel): def init_admin_user(engine: Engine, session_args: t.Mapping[str, bool], admin_password: str) -> None: """ - When starting the app, the 'admin' group and 'admin' user are automatically created if they - do not already exist in the database. + Create the default admin user, group and role if they do not already exist in the database. + + Args: + engine: The database engine (SQLAlchemy connection to SQLite or PostgreSQL). + session_args: The session arguments (SQLAlchemy session arguments). + admin_password: The admin password extracted from the configuration file. """ make_session = sessionmaker(bind=engine, **session_args) with make_session() as session: From bce130a0fa2b5cd36b08b97d42b54848d268684e Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Wed, 6 Dec 2023 18:55:13 +0100 Subject: [PATCH 29/43] feat(db-init): remove deprecated method `TaskJobService._fix_running_status` The `_fix_running_status` method is replaced by `cancel_orphan_tasks` --- antarest/core/tasks/service.py | 89 ++++++++++++++-------------------- 1 file changed, 36 insertions(+), 53 deletions(-) diff --git a/antarest/core/tasks/service.py b/antarest/core/tasks/service.py index 28038e2d4c..0844489283 100644 --- a/antarest/core/tasks/service.py +++ b/antarest/core/tasks/service.py @@ -1,16 +1,15 @@ import datetime import logging import time +import typing as t from abc import ABC, abstractmethod from concurrent.futures import Future, ThreadPoolExecutor from http import HTTPStatus -from typing import Awaitable, Callable, Dict, List, Optional, Union from fastapi import HTTPException from antarest.core.config import Config from antarest.core.interfaces.eventbus import Event, EventChannelDirectory, EventType, IEventBus -from antarest.core.jwt import DEFAULT_ADMIN_USER from antarest.core.model import PermissionInfo, PublicMode from antarest.core.requests import MustBeAuthenticatedError, RequestParameters, UserHasNotPermissionError from antarest.core.tasks.model import ( @@ -31,8 +30,8 @@ logger = logging.getLogger(__name__) -TaskUpdateNotifier = Callable[[str], None] -Task = Callable[[TaskUpdateNotifier], TaskResult] +TaskUpdateNotifier = t.Callable[[str], None] +Task = t.Callable[[TaskUpdateNotifier], TaskResult] DEFAULT_AWAIT_MAX_TIMEOUT = 172800 # 48 hours """Default timeout for `await_task` in seconds.""" @@ -44,21 +43,21 @@ def add_worker_task( self, task_type: TaskType, task_queue: str, - task_args: Dict[str, Union[int, float, bool, str]], - name: Optional[str], - ref_id: Optional[str], + task_args: t.Dict[str, t.Union[int, float, bool, str]], + name: t.Optional[str], + ref_id: t.Optional[str], request_params: RequestParameters, - ) -> Optional[str]: + ) -> t.Optional[str]: raise NotImplementedError() @abstractmethod def add_task( self, action: Task, - name: Optional[str], - task_type: Optional[TaskType], - ref_id: Optional[str], - custom_event_messages: Optional[CustomTaskEventMessages], + name: t.Optional[str], + task_type: t.Optional[TaskType], + ref_id: t.Optional[str], + custom_event_messages: t.Optional[CustomTaskEventMessages], request_params: RequestParameters, ) -> str: raise NotImplementedError() @@ -73,7 +72,7 @@ def status_task( raise NotImplementedError() @abstractmethod - def list_tasks(self, task_filter: TaskListFilter, request_params: RequestParameters) -> List[TaskDTO]: + def list_tasks(self, task_filter: TaskListFilter, request_params: RequestParameters) -> t.List[TaskDTO]: raise NotImplementedError() @abstractmethod @@ -96,24 +95,22 @@ def __init__( self.config = config self.repo = repository self.event_bus = event_bus - self.tasks: Dict[str, Future[None]] = {} + self.tasks: t.Dict[str, Future[None]] = {} self.threadpool = ThreadPoolExecutor(max_workers=config.tasks.max_workers, thread_name_prefix="taskjob_") self.event_bus.add_listener(self.create_task_event_callback(), [EventType.TASK_CANCEL_REQUEST]) self.remote_workers = config.tasks.remote_workers - # set the status of previously running job to FAILED due to server restart - self._fix_running_status() def _create_worker_task( self, task_id: str, task_type: str, - task_args: Dict[str, Union[int, float, bool, str]], - ) -> Callable[[TaskUpdateNotifier], TaskResult]: - task_result_wrapper: List[TaskResult] = [] + task_args: t.Dict[str, t.Union[int, float, bool, str]], + ) -> t.Callable[[TaskUpdateNotifier], TaskResult]: + task_result_wrapper: t.List[TaskResult] = [] def _create_awaiter( - res_wrapper: List[TaskResult], - ) -> Callable[[Event], Awaitable[None]]: + res_wrapper: t.List[TaskResult], + ) -> t.Callable[[Event], t.Awaitable[None]]: async def _await_task_end(event: Event) -> None: task_event = WorkerTaskResult.parse_obj(event.payload) if task_event.task_id == task_id: @@ -155,11 +152,11 @@ def add_worker_task( self, task_type: TaskType, task_queue: str, - task_args: Dict[str, Union[int, float, bool, str]], - name: Optional[str], - ref_id: Optional[str], + task_args: t.Dict[str, t.Union[int, float, bool, str]], + name: t.Optional[str], + ref_id: t.Optional[str], request_params: RequestParameters, - ) -> Optional[str]: + ) -> t.Optional[str]: if not self.check_remote_worker_for_queue(task_queue): logger.warning(f"Failed to find configured remote worker for task queue {task_queue}") return None @@ -176,10 +173,10 @@ def add_worker_task( def add_task( self, action: Task, - name: Optional[str], - task_type: Optional[TaskType], - ref_id: Optional[str], - custom_event_messages: Optional[CustomTaskEventMessages], + name: t.Optional[str], + task_type: t.Optional[TaskType], + ref_id: t.Optional[str], + custom_event_messages: t.Optional[CustomTaskEventMessages], request_params: RequestParameters, ) -> str: task = self._create_task(name, task_type, ref_id, request_params) @@ -188,9 +185,9 @@ def add_task( def _create_task( self, - name: Optional[str], - task_type: Optional[TaskType], - ref_id: Optional[str], + name: t.Optional[str], + task_type: t.Optional[TaskType], + ref_id: t.Optional[str], request_params: RequestParameters, ) -> TaskJob: if not request_params.user: @@ -209,7 +206,7 @@ def _launch_task( self, action: Task, task: TaskJob, - custom_event_messages: Optional[CustomTaskEventMessages], + custom_event_messages: t.Optional[CustomTaskEventMessages], request_params: RequestParameters, ) -> None: if not request_params.user: @@ -230,7 +227,7 @@ def _launch_task( future = self.threadpool.submit(self._run_task, action, task.id, custom_event_messages) self.tasks[task.id] = future - def create_task_event_callback(self) -> Callable[[Event], Awaitable[None]]: + def create_task_event_callback(self) -> t.Callable[[Event], t.Awaitable[None]]: async def task_event_callback(event: Event) -> None: self._cancel_task(str(event.payload), dispatch=False) @@ -275,10 +272,10 @@ def status_task( detail=f"Failed to retrieve task {task_id} in db", ) - def list_tasks(self, task_filter: TaskListFilter, request_params: RequestParameters) -> List[TaskDTO]: + def list_tasks(self, task_filter: TaskListFilter, request_params: RequestParameters) -> t.List[TaskDTO]: return [task.to_dto() for task in self.list_db_tasks(task_filter, request_params)] - def list_db_tasks(self, task_filter: TaskListFilter, request_params: RequestParameters) -> List[TaskJob]: + def list_db_tasks(self, task_filter: TaskListFilter, request_params: RequestParameters) -> t.List[TaskJob]: if not request_params.user: raise MustBeAuthenticatedError() user = None if request_params.user.is_site_admin() else request_params.user.impersonator @@ -314,7 +311,7 @@ def _run_task( self, callback: Task, task_id: str, - custom_event_messages: Optional[CustomTaskEventMessages] = None, + custom_event_messages: t.Optional[CustomTaskEventMessages] = None, ) -> None: self.event_bus.push( Event( @@ -385,7 +382,7 @@ def _run_task( ) ) - def _task_logger(self, task_id: str) -> Callable[[str], None]: + def _task_logger(self, task_id: str) -> t.Callable[[str], None]: def log_msg(message: str) -> None: task = self.repo.get(task_id) if task: @@ -394,27 +391,13 @@ def log_msg(message: str) -> None: return log_msg - def _fix_running_status(self) -> None: - with db(): - previous_tasks = self.list_db_tasks( - TaskListFilter(status=[TaskStatus.RUNNING, TaskStatus.PENDING]), - request_params=RequestParameters(user=DEFAULT_ADMIN_USER), - ) - for task in previous_tasks: - self._update_task_status( - task.id, - TaskStatus.FAILED, - False, - "Task was interrupted due to server restart", - ) - def _update_task_status( self, task_id: str, status: TaskStatus, result: bool, message: str, - command_result: Optional[str] = None, + command_result: t.Optional[str] = None, ) -> None: task = self.repo.get_or_raise(task_id) task.status = status.value From e8026d063a1758aceeefbb58c38b4ca03bcedb79 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Thu, 7 Dec 2023 09:20:44 +0100 Subject: [PATCH 30/43] docs(db-model): add `__repr__` in the database model to debug --- antarest/matrixstore/model.py | 45 +++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/antarest/matrixstore/model.py b/antarest/matrixstore/model.py index 1f6e500c42..aa9a4a91a9 100644 --- a/antarest/matrixstore/model.py +++ b/antarest/matrixstore/model.py @@ -1,6 +1,6 @@ import datetime +import typing as t import uuid -from typing import Any, List, Union from pydantic import BaseModel from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Table # type: ignore @@ -29,7 +29,11 @@ class Matrix(Base): # type: ignore height: int = Column(Integer) created_at: datetime.datetime = Column(DateTime) - def __eq__(self, other: Any) -> bool: + def __repr__(self) -> str: # pragma: no cover + """Returns a string representation of the matrix.""" + return f"Matrix(id={self.id}, shape={(self.height, self.width)}, created_at={self.created_at})" + + def __eq__(self, other: t.Any) -> bool: if not isinstance(other, Matrix): return False @@ -50,9 +54,9 @@ class MatrixInfoDTO(BaseModel): class MatrixDataSetDTO(BaseModel): id: str name: str - matrices: List[MatrixInfoDTO] + matrices: t.List[MatrixInfoDTO] owner: UserInfo - groups: List[GroupDTO] + groups: t.List[GroupDTO] public: bool created_at: str updated_at: str @@ -85,7 +89,11 @@ class MatrixDataSetRelation(Base): # type: ignore name: str = Column(String, primary_key=True) matrix: Matrix = relationship(Matrix) - def __eq__(self, other: Any) -> bool: + def __repr__(self) -> str: # pragma: no cover + """Returns a string representation of the matrix.""" + return f"MatrixDataSetRelation(dataset_id={self.dataset_id}, matrix_id={self.matrix_id}, name={self.name})" + + def __eq__(self, other: t.Any) -> bool: if not isinstance(other, MatrixDataSetRelation): return False @@ -152,7 +160,18 @@ def to_dto(self) -> MatrixDataSetDTO: updated_at=str(self.updated_at), ) - def __eq__(self, other: Any) -> bool: + def __repr__(self) -> str: # pragma: no cover + """Returns a string representation of the matrix.""" + return ( + f"MatrixDataSet(id={self.id}," + f" name={self.name}," + f" owner_id={self.owner_id}," + f" public={self.public}," + f" created_at={self.created_at}," + f" updated_at={self.updated_at})" + ) + + def __eq__(self, other: t.Any) -> bool: if not isinstance(other, MatrixDataSet): return False @@ -181,9 +200,9 @@ def __eq__(self, other: Any) -> bool: class MatrixDTO(BaseModel): width: int height: int - index: List[str] - columns: List[str] - data: List[List[MatrixData]] + index: t.List[str] + columns: t.List[str] + data: t.List[t.List[MatrixData]] created_at: int = 0 id: str = "" @@ -198,12 +217,12 @@ class MatrixContent(BaseModel): columns: A list of columns indexes or names. """ - data: List[List[MatrixData]] - index: List[Union[int, str]] - columns: List[Union[int, str]] + data: t.List[t.List[MatrixData]] + index: t.List[t.Union[int, str]] + columns: t.List[t.Union[int, str]] class MatrixDataSetUpdateDTO(BaseModel): name: str - groups: List[str] + groups: t.List[str] public: bool From d5874d050f5db50cf0f6245a27325f4c930be630 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Sat, 9 Dec 2023 19:40:00 +0100 Subject: [PATCH 31/43] fix(db): correct the session retrieval in `StudyMetadataRepository` --- antarest/study/repository.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/antarest/study/repository.py b/antarest/study/repository.py index 94a0220e37..1a830c7428 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -53,18 +53,19 @@ def save( if update_modification_date: metadata.updated_at = datetime.datetime.utcnow() - metadata.groups = [db.session.merge(g) for g in metadata.groups] + session = self.session + metadata.groups = [session.merge(g) for g in metadata.groups] if metadata.owner: - metadata.owner = db.session.merge(metadata.owner) - db.session.add(metadata) - db.session.commit() + metadata.owner = session.merge(metadata.owner) + session.add(metadata) + session.commit() if update_in_listing: self._update_study_from_cache_listing(metadata) return metadata def refresh(self, metadata: Study) -> None: - db.session.refresh(metadata) + self.session.refresh(metadata) def get(self, id: str) -> t.Optional[Study]: """Get the study by ID or return `None` if not found in database.""" @@ -72,7 +73,7 @@ def get(self, id: str) -> t.Optional[Study]: # to check the permissions of the current user efficiently. study: Study = ( # fmt: off - db.session.query(Study) + self.session.query(Study) .options(joinedload(Study.owner)) .options(joinedload(Study.groups)) .get(id) @@ -85,7 +86,7 @@ def one(self, id: str) -> Study: # When we fetch a study, we also need to fetch the associated owner and groups # to check the permissions of the current user efficiently. study: Study = ( - db.session.query(Study) + self.session.query(Study) .options(joinedload(Study.owner)) .options(joinedload(Study.groups)) .filter_by(id=id) @@ -97,7 +98,7 @@ def get_list(self, study_id: t.List[str]) -> t.List[Study]: # When we fetch a study, we also need to fetch the associated owner and groups # to check the permissions of the current user efficiently. studies: t.List[Study] = ( - db.session.query(Study) + self.session.query(Study) .options(joinedload(Study.owner)) .options(joinedload(Study.groups)) .where(Study.id.in_(study_id)) @@ -106,16 +107,16 @@ def get_list(self, study_id: t.List[str]) -> t.List[Study]: return studies def get_additional_data(self, study_id: str) -> t.Optional[StudyAdditionalData]: - study: StudyAdditionalData = db.session.query(StudyAdditionalData).get(study_id) + study: StudyAdditionalData = self.session.query(StudyAdditionalData).get(study_id) return study def get_all(self) -> t.List[Study]: entity = with_polymorphic(Study, "*") - studies: t.List[Study] = db.session.query(entity).filter(RawStudy.missing.is_(None)).all() + studies: t.List[Study] = self.session.query(entity).filter(RawStudy.missing.is_(None)).all() return studies def get_all_raw(self, show_missing: bool = True) -> t.List[RawStudy]: - query = db.session.query(RawStudy) + query = self.session.query(RawStudy) if not show_missing: query = query.filter(RawStudy.missing.is_(None)) studies: t.List[RawStudy] = query.all() @@ -123,9 +124,10 @@ def get_all_raw(self, show_missing: bool = True) -> t.List[RawStudy]: def delete(self, id: str) -> None: logger.debug(f"Deleting study {id}") - u: Study = db.session.query(Study).get(id) - db.session.delete(u) - db.session.commit() + session = self.session + u: Study = session.query(Study).get(id) + session.delete(u) + session.commit() self._remove_study_from_cache_listing(id) def _remove_study_from_cache_listing(self, study_id: str) -> None: From d6e2ed831483b2e3c3c8665c0b3053458db49e3f Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Sun, 10 Dec 2023 15:29:03 +0100 Subject: [PATCH 32/43] feat(db-repo): change the constructor of the `TaskJobRepository` class to accept an optional session parameter --- antarest/core/tasks/repository.py | 66 +++++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 17 deletions(-) diff --git a/antarest/core/tasks/repository.py b/antarest/core/tasks/repository.py index 1994c55fab..294f63255b 100644 --- a/antarest/core/tasks/repository.py +++ b/antarest/core/tasks/repository.py @@ -1,9 +1,10 @@ import datetime +import typing as t from http import HTTPStatus from operator import and_ -from typing import Any, List, Optional from fastapi import HTTPException +from sqlalchemy.orm import Session # type: ignore from antarest.core.tasks.model import TaskJob, TaskListFilter, TaskStatus from antarest.core.utils.fastapi_sqlalchemy import db @@ -11,16 +12,45 @@ class TaskJobRepository: + """ + Database connector to manage Tasks/Jobs entities. + """ + + def __init__(self, session: t.Optional[Session] = None): + """ + Initialize the repository. + + Args: + session: Optional SQLAlchemy session to be used. + """ + self._session = session + + @property + def session(self) -> Session: + """ + Get the SQLAlchemy session for the repository. + + Returns: + SQLAlchemy session. + """ + if self._session is None: + # Get or create the session from a context variable (thread local variable) + return db.session + # Get the user-defined session + return self._session + def save(self, task: TaskJob) -> TaskJob: - task = db.session.merge(task) - db.session.add(task) - db.session.commit() + session = self.session + task = session.merge(task) + session.add(task) + session.commit() return task - def get(self, id: str) -> Optional[TaskJob]: - task: TaskJob = db.session.get(TaskJob, id) + def get(self, id: str) -> t.Optional[TaskJob]: + session = self.session + task: TaskJob = session.get(TaskJob, id) if task is not None: - db.session.refresh(task) + session.refresh(task) return task def get_or_raise(self, id: str) -> TaskJob: @@ -30,7 +60,7 @@ def get_or_raise(self, id: str) -> TaskJob: return task @staticmethod - def _combine_clauses(where_clauses: List[Any]) -> Any: + def _combine_clauses(where_clauses: t.List[t.Any]) -> t.Any: assert_this(len(where_clauses) > 0) if len(where_clauses) > 1: return and_( @@ -40,9 +70,9 @@ def _combine_clauses(where_clauses: List[Any]) -> Any: else: return where_clauses[0] - def list(self, filter: TaskListFilter, user: Optional[int] = None) -> List[TaskJob]: - query = db.session.query(TaskJob) - where_clauses: List[Any] = [] + def list(self, filter: TaskListFilter, user: t.Optional[int] = None) -> t.List[TaskJob]: + query = self.session.query(TaskJob) + where_clauses: t.List[t.Any] = [] if user: where_clauses.append(TaskJob.owner_id == user) if len(filter.status) > 0: @@ -74,19 +104,21 @@ def list(self, filter: TaskListFilter, user: Optional[int] = None) -> List[TaskJ elif len(where_clauses) == 1: query = query.where(*where_clauses) - tasks: List[TaskJob] = query.all() + tasks: t.List[TaskJob] = query.all() return tasks def delete(self, tid: str) -> None: - task = db.session.get(TaskJob, tid) + session = self.session + task = session.get(TaskJob, tid) if task: - db.session.delete(task) - db.session.commit() + session.delete(task) + session.commit() def update_timeout(self, task_id: str, timeout: int) -> None: """Update task status to TIMEOUT.""" - task: TaskJob = db.session.get(TaskJob, task_id) + session = self.session + task: TaskJob = session.get(TaskJob, task_id) task.status = TaskStatus.TIMEOUT task.result_msg = f"Task '{task_id}' timeout after {timeout} seconds" task.result_status = False - db.session.commit() + session.commit() From 965ac4bcb2f42b366528625a5487330639f3d5c5 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Mon, 11 Dec 2023 18:09:20 +0100 Subject: [PATCH 33/43] chore: correct typo in `exceptions.py` --- antarest/core/utils/fastapi_sqlalchemy/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/antarest/core/utils/fastapi_sqlalchemy/exceptions.py b/antarest/core/utils/fastapi_sqlalchemy/exceptions.py index 7e435ba286..ad1eccff2c 100644 --- a/antarest/core/utils/fastapi_sqlalchemy/exceptions.py +++ b/antarest/core/utils/fastapi_sqlalchemy/exceptions.py @@ -1,5 +1,5 @@ class MissingSessionError(Exception): - """Excetion raised for when the user tries to access a database session before it is created.""" + """Exception raised for when the user tries to access a database session before it is created.""" def __init__(self) -> None: msg = """ From ec276002f6627ceb16e8c165d1f39c08eeff117b Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Mon, 11 Dec 2023 18:14:10 +0100 Subject: [PATCH 34/43] test(fixture): use a `LocalEventBus` for `EventBusService` instead of a mock object --- tests/conftest_services.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/tests/conftest_services.py b/tests/conftest_services.py index 5afb53460b..59f562e241 100644 --- a/tests/conftest_services.py +++ b/tests/conftest_services.py @@ -1,12 +1,13 @@ """ This module provides various pytest fixtures for unit testing the AntaREST application. -Fixtures in this module are used to set up and provide instances of different classes and services required during testing. +Fixtures in this module are used to set up and provide instances of different classes +and services required during testing. """ import datetime +import typing as t import uuid from pathlib import Path -from typing import Dict, List, Optional, Union from unittest.mock import Mock import pytest @@ -18,6 +19,8 @@ from antarest.core.tasks.model import CustomTaskEventMessages, TaskDTO, TaskListFilter, TaskResult, TaskStatus, TaskType from antarest.core.tasks.service import ITaskService, Task from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware +from antarest.eventbus.business.local_eventbus import LocalEventBus +from antarest.eventbus.service import EventBusService from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.matrixstore.uri_resolver_service import UriResolverService @@ -51,26 +54,26 @@ class SynchTaskService(ITaskService): def __init__(self) -> None: - self._task_result: Optional[TaskResult] = None + self._task_result: t.Optional[TaskResult] = None def add_worker_task( self, task_type: TaskType, task_queue: str, - task_args: Dict[str, Union[int, float, bool, str]], - name: Optional[str], - ref_id: Optional[str], + task_args: t.Dict[str, t.Union[int, float, bool, str]], + name: t.Optional[str], + ref_id: t.Optional[str], request_params: RequestParameters, - ) -> Optional[str]: + ) -> t.Optional[str]: raise NotImplementedError() def add_task( self, action: Task, - name: Optional[str], - task_type: Optional[TaskType], - ref_id: Optional[str], - custom_event_messages: Optional[CustomTaskEventMessages], + name: t.Optional[str], + task_type: t.Optional[TaskType], + ref_id: t.Optional[str], + custom_event_messages: t.Optional[CustomTaskEventMessages], request_params: RequestParameters, ) -> str: self._task_result = action(lambda message: None) @@ -93,15 +96,15 @@ def status_task( logs=None, ) - def list_tasks(self, task_filter: TaskListFilter, request_params: RequestParameters) -> List[TaskDTO]: + def list_tasks(self, task_filter: TaskListFilter, request_params: RequestParameters) -> t.List[TaskDTO]: return [] - def await_task(self, task_id: str, timeout_sec: Optional[int] = None) -> None: + def await_task(self, task_id: str, timeout_sec: t.Optional[int] = None) -> None: pass @pytest.fixture(name="bucket_dir", scope="session") -def bucket_dir_fixture(tmp_path_factory) -> Path: +def bucket_dir_fixture(tmp_path_factory: t.Any) -> Path: """ Fixture that creates a session-level temporary directory named "matrix_store" for storing matrices. @@ -115,7 +118,7 @@ def bucket_dir_fixture(tmp_path_factory) -> Path: Returns: A Path object representing the created temporary directory for storing matrices. """ - return tmp_path_factory.mktemp("matrix_store", numbered=False) + return t.cast(Path, tmp_path_factory.mktemp("matrix_store")) @pytest.fixture(name="simple_matrix_service", scope="session") @@ -275,7 +278,7 @@ def event_bus_fixture() -> IEventBus: Returns: A Mock instance of the IEventBus class for event bus-related testing. """ - return Mock(spec=IEventBus) + return EventBusService(LocalEventBus()) @pytest.fixture(name="command_factory", scope="session") From 759e86b60cc17f4638c2f552df03c85948c224cf Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Mon, 11 Dec 2023 19:03:33 +0100 Subject: [PATCH 35/43] perf(service): improve performances by reducing the number of SQL requests in `TaskJobService` --- antarest/core/tasks/service.py | 125 ++++--- tests/core/test_tasks.py | 578 +++++++++++++-------------------- 2 files changed, 293 insertions(+), 410 deletions(-) diff --git a/antarest/core/tasks/service.py b/antarest/core/tasks/service.py index 0844489283..07832e7365 100644 --- a/antarest/core/tasks/service.py +++ b/antarest/core/tasks/service.py @@ -7,6 +7,7 @@ from http import HTTPStatus from fastapi import HTTPException +from sqlalchemy.orm import Session # type: ignore from antarest.core.config import Config from antarest.core.interfaces.eventbus import Event, EventChannelDirectory, EventType, IEventBus @@ -25,7 +26,6 @@ ) from antarest.core.tasks.repository import TaskJobRepository from antarest.core.utils.fastapi_sqlalchemy import db -from antarest.core.utils.utils import retry from antarest.worker.worker import WorkerTaskCommand, WorkerTaskResult logger = logging.getLogger(__name__) @@ -85,6 +85,26 @@ def noop_notifier(message: str) -> None: """This function is used in tasks when no notification is required.""" +class TaskJobLogRecorder: + """ + Callback used to register log messages in the TaskJob table. + + Args: + task_id: The task id. + session: The database session created in the same thread as the task thread. + """ + + def __init__(self, task_id: str, session: Session): + self.session = session + self.task_id = task_id + + def __call__(self, message: str) -> None: + task = self.session.query(TaskJob).get(self.task_id) + if task: + task.logs.append(TaskJobLog(message=message, task_id=self.task_id)) + db.session.commit() + + class TaskJobService(ITaskService): def __init__( self, @@ -294,18 +314,24 @@ def await_task(self, task_id: str, timeout_sec: int = DEFAULT_AWAIT_MAX_TIMEOUT) logger.warning(f"Task '{task_id}' not handled by this worker, will poll for task completion from db") end = time.time() + timeout_sec while time.time() < end: - with db(): - task = self.repo.get(task_id) - if task is None: - logger.error(f"Awaited task '{task_id}' was not found") - return - if TaskStatus(task.status).is_final(): - return + task_status = db.session.query(TaskJob.status).filter(TaskJob.id == task_id).scalar() + if task_status is None: + logger.error(f"Awaited task '{task_id}' was not found") + return + if TaskStatus(task_status).is_final(): + return logger.info("💤 Sleeping 2 seconds...") time.sleep(2) + logger.error(f"Timeout while awaiting task '{task_id}'") - with db(): - self.repo.update_timeout(task_id, timeout_sec) + db.session.query(TaskJob).filter(TaskJob.id == task_id).update( + { + TaskJob.status: TaskStatus.TIMEOUT.value, + TaskJob.result_msg: f"Task '{task_id}' timeout after {timeout_sec} seconds", + TaskJob.result_status: False, + } + ) + db.session.commit() def _run_task( self, @@ -313,6 +339,8 @@ def _run_task( task_id: str, custom_event_messages: t.Optional[CustomTaskEventMessages] = None, ) -> None: + # attention: this function is executed in a thread, not in the main process + self.event_bus.push( Event( type=EventType.TASK_RUNNING, @@ -329,22 +357,32 @@ def _run_task( logger.info(f"Starting task {task_id}") with db(): - task = retry(lambda: self.repo.get_or_raise(task_id)) - task.status = TaskStatus.RUNNING.value - self.repo.save(task) - logger.info(f"Task {task_id} set to RUNNING") + db.session.query(TaskJob).filter(TaskJob.id == task_id).update({TaskJob.status: TaskStatus.RUNNING.value}) + db.session.commit() + logger.info(f"Task {task_id} set to RUNNING") + try: with db(): - result = callback(self._task_logger(task_id)) - logger.info(f"Task {task_id} ended") + # We must use the DB session attached to the current thread + result = callback(TaskJobLogRecorder(task_id, session=db.session)) + + status = TaskStatus.COMPLETED if result.success else TaskStatus.FAILED + logger.info(f"Task {task_id} ended with status {status}") + with db(): - self._update_task_status( - task_id, - TaskStatus.COMPLETED if result.success else TaskStatus.FAILED, - result.success, - result.message, - result.return_value, + # Do not use the `timezone.utc` timezone to preserve a naive datetime. + completion_date = datetime.datetime.utcnow() if status.is_final() else None + db.session.query(TaskJob).filter(TaskJob.id == task_id).update( + { + TaskJob.status: status.value, + TaskJob.result_msg: result.message, + TaskJob.result_status: result.success, + TaskJob.result: result.return_value, + TaskJob.completion_date: completion_date, + } ) + db.session.commit() + event_type = {True: EventType.TASK_COMPLETED, False: EventType.TASK_FAILED}[result.success] event_msg = {True: "completed", False: "failed"}[result.success] self.event_bus.push( @@ -365,13 +403,19 @@ def _run_task( except Exception as exc: err_msg = f"Task {task_id} failed: Unhandled exception {exc}" logger.error(err_msg, exc_info=exc) + with db(): - self._update_task_status( - task_id, - TaskStatus.FAILED, - False, - f"{err_msg}\nSee the logs for detailed information and the error traceback.", + result_msg = f"{err_msg}\nSee the logs for detailed information and the error traceback." + db.session.query(TaskJob).filter(TaskJob.id == task_id).update( + { + TaskJob.status: TaskStatus.FAILED.value, + TaskJob.result_msg: result_msg, + TaskJob.result_status: False, + TaskJob.completion_date: datetime.datetime.utcnow(), + } ) + db.session.commit() + message = err_msg if custom_event_messages is None else custom_event_messages.end self.event_bus.push( Event( @@ -381,30 +425,3 @@ def _run_task( channel=EventChannelDirectory.TASK + task_id, ) ) - - def _task_logger(self, task_id: str) -> t.Callable[[str], None]: - def log_msg(message: str) -> None: - task = self.repo.get(task_id) - if task: - task.logs.append(TaskJobLog(message=message, task_id=task_id)) - self.repo.save(task) - - return log_msg - - def _update_task_status( - self, - task_id: str, - status: TaskStatus, - result: bool, - message: str, - command_result: t.Optional[str] = None, - ) -> None: - task = self.repo.get_or_raise(task_id) - task.status = status.value - task.result_msg = message - task.result_status = result - task.result = command_result - if status.is_final(): - # Do not use the `timezone.utc` timezone to preserve a naive datetime. - task.completion_date = datetime.datetime.utcnow() - self.repo.save(task) diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index 671a78f267..e187e4eb01 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -1,22 +1,22 @@ +import dataclasses import datetime import time +import typing as t from pathlib import Path -from typing import Callable, List -from unittest.mock import ANY, Mock, call +from unittest.mock import ANY, Mock import pytest -from sqlalchemy import create_engine +from sqlalchemy import create_engine # type: ignore from sqlalchemy.engine.base import Engine # type: ignore -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker # type: ignore from antarest.core.config import Config, RemoteWorkerConfig, TaskConfig -from antarest.core.interfaces.eventbus import Event, EventType, IEventBus +from antarest.core.interfaces.eventbus import EventType, IEventBus from antarest.core.jwt import DEFAULT_ADMIN_USER from antarest.core.model import PermissionInfo, PublicMode from antarest.core.persistence import Base from antarest.core.requests import RequestParameters, UserHasNotPermissionError from antarest.core.tasks.model import ( - TaskDTO, TaskJob, TaskJobLog, TaskListFilter, @@ -27,42 +27,57 @@ ) from antarest.core.tasks.repository import TaskJobRepository from antarest.core.tasks.service import TaskJobService -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db +from antarest.core.utils.fastapi_sqlalchemy import db from antarest.eventbus.business.local_eventbus import LocalEventBus from antarest.eventbus.service import EventBusService +from antarest.login.model import User +from antarest.study.model import RawStudy from antarest.utils import SESSION_ARGS from antarest.worker.worker import AbstractWorker, WorkerTaskCommand from tests.helpers import with_db_context -def test_service() -> None: - # sourcery skip: aware-datetime-for-utc - engine = create_engine("sqlite:///:memory:", echo=False) +@pytest.fixture(name="db_engine", autouse=True) +def db_engine_fixture(tmp_path: Path) -> t.Generator[Engine, None, None]: + """ + Fixture that creates an SQLite database in a temporary directory. + + When a function runs in a different thread than the main one and needs to use + the database, it uses the global `db` object. This object helps create a new + local session in the thread to connect to the SQLite database. + However, we can't use an in-memory SQLite database ("sqlite:///:memory:") because + it creates a new empty database each time. That's why we use a SQLite database stored on disk. + + Yields: + An instance of the created SQLite database engine. + """ + db_path = tmp_path / "db.sqlite" + db_url = f"sqlite:///{db_path}" + engine = create_engine(db_url, echo=False) + engine.execute("PRAGMA foreign_keys = ON") Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) + yield engine + engine.dispose() - repo_mock = Mock(spec=TaskJobRepository) - creation_date = datetime.datetime.now(datetime.timezone.utc) - task = TaskJob(id="a", name="b", status=2, creation_date=creation_date) - repo_mock.list.return_value = [task] - repo_mock.get_or_raise.return_value = task - service = TaskJobService(config=Config(), repository=repo_mock, event_bus=Mock()) - repo_mock.save.assert_called_with( - TaskJob( - id="a", - name="b", - status=4, - creation_date=creation_date, - result_status=False, - result_msg="Task was interrupted due to server restart", - completion_date=ANY, - ) - ) + +@with_db_context +def test_service(core_config: Config, event_bus: IEventBus) -> None: + engine = db.session.bind + task_job_repo = TaskJobRepository() + + # Prepare a TaskJob in the database + creation_date = datetime.datetime.utcnow() + running_task = TaskJob(id="a", name="b", status=TaskStatus.RUNNING.value, creation_date=creation_date) + task_job_repo.save(running_task) + + # Create a TaskJobService + service = TaskJobService(config=core_config, repository=task_job_repo, event_bus=event_bus) + + # Cancel pending and running tasks + cancel_orphan_tasks(engine=engine, session_args=SESSION_ARGS) + + # Test Case: list tasks + # ===================== tasks = service.list_tasks( TaskListFilter(), @@ -72,52 +87,37 @@ def test_service() -> None: assert tasks[0].status == TaskStatus.FAILED assert tasks[0].creation_date_utc == str(creation_date) - start = datetime.datetime.now(datetime.timezone.utc) - end = start + datetime.timedelta(seconds=1) - repo_mock.reset_mock() - repo_mock.get.return_value = TaskJob( - id="a", - completion_date=end, - name="Unnamed", - owner_id=1, - status=TaskStatus.COMPLETED.value, - result_status=True, - result_msg="OK", - creation_date=start, - ) + # Test Case: get task status + # ========================== + res = service.status_task("a", RequestParameters(user=DEFAULT_ADMIN_USER)) assert res is not None - assert res == TaskDTO( - id="a", - completion_date_utc=str(end), - creation_date_utc=str(start), - owner=1, - name="Unnamed", - result=TaskResult(success=True, message="OK"), - status=TaskStatus.COMPLETED, - ) + expected = { + "completion_date_utc": ANY, + "creation_date_utc": creation_date.isoformat(" "), + "id": "a", + "logs": None, + "name": "b", + "owner": None, + "ref_id": None, + "result": { + "message": "Task was interrupted due to server restart", + "return_value": None, + "success": False, + }, + "status": TaskStatus.FAILED, + "type": None, + } + assert res.dict() == expected + + # Test Case: add a task that fails and wait for it + # ================================================ # noinspection PyUnusedLocal - def action_fail(update_msg: Callable[[str], None]) -> TaskResult: - raise NotImplementedError() + def action_fail(update_msg: t.Callable[[str], None]) -> TaskResult: + raise Exception("this action failed") - def action_ok(update_msg: Callable[[str], None]) -> TaskResult: - update_msg("start") - update_msg("end") - return TaskResult(success=True, message="OK") - - repo_mock.reset_mock() - now = datetime.datetime.utcnow() - task = TaskJob( - name="failed action", - owner_id=1, - id="a", - creation_date=now, - status=TaskStatus.PENDING.value, - ) - repo_mock.save.side_effect = lambda x: task - repo_mock.get_or_raise.return_value = task - service.add_task( + failed_id = service.add_task( action_fail, "failed action", None, @@ -125,79 +125,27 @@ def action_ok(update_msg: Callable[[str], None]) -> TaskResult: None, RequestParameters(user=DEFAULT_ADMIN_USER), ) - service.await_task("a") - repo_mock.save.assert_has_calls( - [ - call( - TaskJob( - id=None, - logs=[], - owner_id=1, - creation_date=None, - completion_date=None, - name="failed action", - status=None, - result_msg=None, - result_status=None, - ) - ), - call( - TaskJob( - id="a", - logs=[], - owner_id=1, - creation_date=now, - completion_date=ANY, - name="failed action", - status=4, - result_msg=ANY, # "Task a failed: Unhandled exception [...]" - result_status=False, - ) - ), - call( - TaskJob( - id="a", - logs=[], - owner_id=1, - creation_date=now, - completion_date=ANY, - name="failed action", - status=4, - result_msg=ANY, # "Task a failed: Unhandled exception [...]" - result_status=False, - ) - ), - ] + service.await_task(failed_id, timeout_sec=2) + + failed_task = task_job_repo.get(failed_id) + assert failed_task is not None + assert failed_task.status == TaskStatus.FAILED.value + assert failed_task.result_status is False + assert failed_task.result_msg == ( + f"Task {failed_id} failed: Unhandled exception this action failed" + f"\nSee the logs for detailed information and the error traceback." ) + assert failed_task.completion_date is not None - repo_mock.reset_mock() - now = datetime.datetime.utcnow() - task = TaskJob( - name="Unnamed", - owner_id=1, - id="a", - creation_date=now, - status=TaskStatus.PENDING.value, - ) - repo_mock.save.side_effect = lambda x: task - repo_mock.get_or_raise.return_value = task - repo_mock.get.side_effect = [ - TaskJob( - name="Unnamed", - owner_id=1, - id="a", - creation_date=now, - status=TaskStatus.RUNNING.value, - ), - TaskJob( - name="Unnamed", - owner_id=1, - id="a", - creation_date=now, - status=TaskStatus.RUNNING.value, - ), - ] - service.add_task( + # Test Case: add a task that succeeds and wait for it + # =================================================== + + def action_ok(update_msg: t.Callable[[str], None]) -> TaskResult: + update_msg("start") + update_msg("end") + return TaskResult(success=True, message="OK") + + ok_id = service.add_task( action_ok, None, None, @@ -205,134 +153,46 @@ def action_ok(update_msg: Callable[[str], None]) -> TaskResult: None, request_params=RequestParameters(user=DEFAULT_ADMIN_USER), ) - service.await_task("a") - repo_mock.save.assert_has_calls( - [ - call(TaskJob(owner_id=1, name="Unnamed")), - # this is not called with that because the object is mutated, and mock seems to suck.. - # TaskJob( - # id="a", - # name="failed action", - # owner_id=1, - # status=TaskStatus.RUNNING.value, - # creation_date=now, - # ), - call( - TaskJob( - id="a", - completion_date=ANY, - name="Unnamed", - owner_id=1, - status=TaskStatus.COMPLETED.value, - result_status=True, - result_msg="OK", - creation_date=now, - ) - ), - call( - TaskJob( - name="Unnamed", - owner_id=1, - id="a", - creation_date=now, - status=TaskStatus.RUNNING.value, - logs=[TaskJobLog(message="start", task_id="a")], - ) - ), - call( - TaskJob( - name="Unnamed", - owner_id=1, - id="a", - creation_date=now, - status=TaskStatus.RUNNING.value, - logs=[TaskJobLog(message="end", task_id="a")], - ) - ), - call( - TaskJob( - id="a", - completion_date=ANY, - name="Unnamed", - owner_id=1, - status=TaskStatus.COMPLETED.value, - result_status=True, - result_msg="OK", - creation_date=now, - ) - ), - ] - ) + service.await_task(ok_id, timeout_sec=2) - repo_mock.get.reset_mock() - repo_mock.get.side_effect = [None] - service.await_task("elsewhere") - repo_mock.get.assert_called_with("elsewhere") + ok_task = task_job_repo.get(ok_id) + assert ok_task is not None + assert ok_task.status == TaskStatus.COMPLETED.value + assert ok_task.result_status is True + assert ok_task.result_msg == "OK" + assert ok_task.completion_date is not None + assert len(ok_task.logs) == 2 + assert ok_task.logs[0].message == "start" + assert ok_task.logs[1].message == "end" class DummyWorker(AbstractWorker): - def __init__(self, event_bus: IEventBus, accept: List[str], tmp_path: Path): + def __init__(self, event_bus: IEventBus, accept: t.List[str], tmp_path: Path): super().__init__("test", event_bus, accept) self.tmp_path = tmp_path def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: # simulate a "long" task ;-) time.sleep(0.01) - relative_path = task_info.task_args["file"] + relative_path = t.cast(str, task_info.task_args["file"]) (self.tmp_path / relative_path).touch() return TaskResult(success=True, message="") @with_db_context -def test_worker_tasks(tmp_path: Path): - repo_mock = Mock(spec=TaskJobRepository) - repo_mock.list.return_value = [] - event_bus = EventBusService(LocalEventBus()) - service = TaskJobService( - config=Config(tasks=TaskConfig(remote_workers=[RemoteWorkerConfig(name="test", queues=["test"])])), - repository=repo_mock, - event_bus=event_bus, - ) +def test_worker_tasks(tmp_path: Path, core_config: Config, event_bus: IEventBus) -> None: + # Create a TaskJobService + task_job_repo = TaskJobRepository() + task_config = TaskConfig(remote_workers=[RemoteWorkerConfig(name="test", queues=["test"])]) + config = dataclasses.replace(core_config, tasks=task_config) + service = TaskJobService(config=config, repository=task_job_repo, event_bus=event_bus) worker = DummyWorker(event_bus, ["test"], tmp_path) worker.start(threaded=True) file_to_create = "foo" - assert not (tmp_path / file_to_create).exists() - repo_mock.save.side_effect = [ - TaskJob( - id="taskid", - name="Unnamed", - owner_id=0, - type=TaskType.WORKER_TASK, - ref_id=None, - ), - TaskJob( - id="taskid", - name="Unnamed", - owner_id=0, - type=TaskType.WORKER_TASK, - ref_id=None, - status=TaskStatus.RUNNING, - ), - TaskJob( - id="taskid", - name="Unnamed", - owner_id=0, - type=TaskType.WORKER_TASK, - ref_id=None, - status=TaskStatus.COMPLETED, - ), - ] - repo_mock.get_or_raise.return_value = TaskJob( - id="taskid", - name="Unnamed", - owner_id=0, - type=TaskType.WORKER_TASK, - ref_id=None, - ) task_id = service.add_worker_task( TaskType.WORKER_TASK, "test", @@ -341,130 +201,136 @@ def test_worker_tasks(tmp_path: Path): None, request_params=RequestParameters(user=DEFAULT_ADMIN_USER), ) - service.await_task(task_id) + assert task_id is not None + service.await_task(task_id, timeout_sec=2) assert (tmp_path / file_to_create).exists() -def test_repository(): - # sourcery skip: aware-datetime-for-utc - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) +def test_repository(db_session: Session) -> None: + # Prepare two users in the database + user1_id = 9 + db_session.add(User(id=user1_id, name="John")) + user2_id = 10 + db_session.add(User(id=user2_id, name="Jane")) + db_session.commit() - with db(): - # sourcery skip: extract-method - task_repository = TaskJobRepository() - - new_task = TaskJob(name="foo", owner_id=0, type=TaskType.COPY) - second_task = TaskJob(owner_id=1, ref_id="a") - - now = datetime.datetime.utcnow() - new_task = task_repository.save(new_task) - assert task_repository.get(new_task.id) == new_task - assert new_task.status == TaskStatus.PENDING.value - assert new_task.owner_id == 0 - assert new_task.creation_date >= now - - second_task = task_repository.save(second_task) - - result = task_repository.list(TaskListFilter(type=[TaskType.COPY])) - assert len(result) == 1 - assert result[0].id == new_task.id - - result = task_repository.list(TaskListFilter(ref_id="a")) - assert len(result) == 1 - assert result[0].id == second_task.id - - result = task_repository.list(TaskListFilter(), user=1) - assert len(result) == 1 - assert result[0].id == second_task.id - - result = task_repository.list(TaskListFilter()) - assert len(result) == 2 - - result = task_repository.list(TaskListFilter(name="fo")) - assert len(result) == 1 - - result = task_repository.list(TaskListFilter(name="fo", status=[TaskStatus.RUNNING])) - assert len(result) == 0 - new_task.status = TaskStatus.RUNNING.value - task_repository.save(new_task) - result = task_repository.list(TaskListFilter(name="fo", status=[TaskStatus.RUNNING])) - assert len(result) == 1 - - new_task.completion_date = datetime.datetime.utcnow() - task_repository.save(new_task) - result = task_repository.list( - TaskListFilter( - name="fo", - from_completion_date_utc=(new_task.completion_date + datetime.timedelta(seconds=1)).timestamp(), - ) + # Create a RawStudy in the database + study_id = "e34fe4d5-5964-4ef2-9baf-fad66dadc512" + db_session.add(RawStudy(id="study_id", name="foo", version="860")) + db_session.commit() + + # Create a TaskJobService + task_job_repo = TaskJobRepository(db_session) + + new_task = TaskJob(name="foo", owner_id=user1_id, type=TaskType.COPY) + + now = datetime.datetime.utcnow() + new_task = task_job_repo.save(new_task) + assert task_job_repo.get(new_task.id) == new_task + assert new_task.status == TaskStatus.PENDING.value + assert new_task.owner_id == user1_id + assert new_task.creation_date >= now + + second_task = TaskJob(owner_id=user2_id, ref_id=study_id) + second_task = task_job_repo.save(second_task) + + result = task_job_repo.list(TaskListFilter(type=[TaskType.COPY])) + assert len(result) == 1 + assert result[0].id == new_task.id + + result = task_job_repo.list(TaskListFilter(ref_id=study_id)) + assert len(result) == 1 + assert result[0].id == second_task.id + + result = task_job_repo.list(TaskListFilter(), user=user2_id) + assert len(result) == 1 + assert result[0].id == second_task.id + + result = task_job_repo.list(TaskListFilter()) + assert len(result) == 2 + + result = task_job_repo.list(TaskListFilter(name="fo")) + assert len(result) == 1 + + result = task_job_repo.list(TaskListFilter(name="fo", status=[TaskStatus.RUNNING])) + assert len(result) == 0 + new_task.status = TaskStatus.RUNNING.value + task_job_repo.save(new_task) + result = task_job_repo.list(TaskListFilter(name="fo", status=[TaskStatus.RUNNING])) + assert len(result) == 1 + + new_task.completion_date = datetime.datetime.utcnow() + task_job_repo.save(new_task) + result = task_job_repo.list( + TaskListFilter( + name="fo", + from_completion_date_utc=(new_task.completion_date + datetime.timedelta(seconds=1)).timestamp(), ) - assert len(result) == 0 - result = task_repository.list( - TaskListFilter( - name="fo", - from_completion_date_utc=(new_task.completion_date - datetime.timedelta(seconds=1)).timestamp(), - ) + ) + assert len(result) == 0 + result = task_job_repo.list( + TaskListFilter( + name="fo", + from_completion_date_utc=(new_task.completion_date - datetime.timedelta(seconds=1)).timestamp(), ) - assert len(result) == 1 + ) + assert len(result) == 1 - new_task.logs.append(TaskJobLog(message="hello")) - new_task.logs.append(TaskJobLog(message="bar")) - task_repository.save(new_task) - new_task = task_repository.get(new_task.id) - assert len(new_task.logs) == 2 - assert new_task.logs[0].message == "hello" + new_task.logs.append(TaskJobLog(message="hello")) + new_task.logs.append(TaskJobLog(message="bar")) + task_job_repo.save(new_task) + assert new_task.id is not None + new_task = task_job_repo.get_or_raise(new_task.id) + assert len(new_task.logs) == 2 + assert new_task.logs[0].message == "hello" - assert len(db.session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 2 + assert len(db_session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 2 - task_repository.delete(new_task.id) - assert len(db.session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 0 - assert task_repository.get(new_task.id) is None + task_job_repo.delete(new_task.id) + assert len(db_session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 0 + assert task_job_repo.get(new_task.id) is None -def test_cancel(): - # sourcery skip: aware-datetime-for-utc - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) +@with_db_context +def test_cancel(core_config: Config, event_bus: IEventBus) -> None: + # Create a TaskJobService and add tasks + task_job_repo = TaskJobRepository() + task_job_repo.save(TaskJob(id="a")) + task_job_repo.save(TaskJob(id="b")) - repo_mock = Mock(spec=TaskJobRepository) - repo_mock.list.return_value = [] - service = TaskJobService(config=Config(), repository=repo_mock, event_bus=Mock()) + # Create a TaskJobService + service = TaskJobService(config=core_config, repository=task_job_repo, event_bus=event_bus) with pytest.raises(UserHasNotPermissionError): service.cancel_task("a", RequestParameters()) + # Test Case: cancel a task that is not in the service tasks map + # ============================================================= + service.cancel_task("b", RequestParameters(user=DEFAULT_ADMIN_USER), dispatch=True) - # noinspection PyUnresolvedReferences - service.event_bus.push.assert_called_with( - Event( - type=EventType.TASK_CANCEL_REQUEST, - payload="b", - permissions=PermissionInfo(public_mode=PublicMode.NONE), - ) - ) - creation_date = datetime.datetime.utcnow() - task = TaskJob(id="a", name="b", status=2, creation_date=creation_date) - repo_mock.list.return_value = [task] - repo_mock.get_or_raise.return_value = task - service.tasks["a"] = Mock() + # The event_bus fixture is actually a EventBusService with LocalEventBus backend + backend = t.cast(LocalEventBus, t.cast(EventBusService, event_bus).backend) + collected_events = backend.get_events() + + assert len(collected_events) == 1 + assert collected_events[0].type == EventType.TASK_CANCEL_REQUEST + assert collected_events[0].payload == "b" + assert collected_events[0].permissions == PermissionInfo(public_mode=PublicMode.NONE) + + # Test Case: cancel a task that is in the service tasks map + # ========================================================= + + service.tasks["a"] = Mock(cancel=Mock(return_value=None)) + service.cancel_task("a", RequestParameters(user=DEFAULT_ADMIN_USER), dispatch=True) - task.status = TaskStatus.CANCELLED.value - repo_mock.save.assert_called_with(task) + + collected_events = backend.get_events() + assert len(collected_events) == 1, "No event should have been emitted because the task is in the service map" + task_a = task_job_repo.get("a") + assert task_a is not None + assert task_a.status == TaskStatus.CANCELLED.value @pytest.mark.parametrize( @@ -483,7 +349,7 @@ def test_cancel_orphan_tasks( status: int, result_status: bool, result_msg: str, -): +) -> None: max_diff_seconds: int = 1 test_id: str = "2ea94758-9ea5-4015-a45f-b245a6ffc147" From c1ccffc9573fb879e92a395c3fb6f80896071ffd Mon Sep 17 00:00:00 2001 From: MartinBelthle <102529366+MartinBelthle@users.noreply.github.com> Date: Tue, 12 Dec 2023 13:16:02 +0100 Subject: [PATCH 36/43] fix(api-bc): avoid duplicates in Binding Constraints creation through REST API (#1858) This only concerns the back-end part --- antarest/core/exceptions.py | 5 +++ .../business/binding_constraint_management.py | 36 +++++++++++++++++ .../command/create_binding_constraint.py | 15 ++++--- antarest/study/web/study_data_blueprint.py | 23 ++++++++++- tests/integration/test_integration.py | 39 +++++++++++++++++++ 5 files changed, 111 insertions(+), 7 deletions(-) diff --git a/antarest/core/exceptions.py b/antarest/core/exceptions.py index a666394d8b..ab39c3a566 100644 --- a/antarest/core/exceptions.py +++ b/antarest/core/exceptions.py @@ -189,6 +189,11 @@ def __init__(self, message: str) -> None: super().__init__(HTTPStatus.NOT_FOUND, message) +class DuplicateConstraintName(HTTPException): + def __init__(self, message: str) -> None: + super().__init__(HTTPStatus.CONFLICT, message) + + class MissingDataError(HTTPException): def __init__(self, message: str) -> None: super().__init__(HTTPStatus.NOT_FOUND, message) diff --git a/antarest/study/business/binding_constraint_management.py b/antarest/study/business/binding_constraint_management.py index ca1f714750..7caeabd9ab 100644 --- a/antarest/study/business/binding_constraint_management.py +++ b/antarest/study/business/binding_constraint_management.py @@ -5,6 +5,7 @@ from antarest.core.exceptions import ( ConstraintAlreadyExistError, ConstraintIdNotFoundError, + DuplicateConstraintName, MissingDataError, NoBindingConstraintError, NoConstraintError, @@ -13,8 +14,13 @@ from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import Study from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency +from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id from antarest.study.storage.storage_service import StudyStorageService from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator +from antarest.study.storage.variantstudy.model.command.create_binding_constraint import ( + BindingConstraintProperties, + CreateBindingConstraint, +) from antarest.study.storage.variantstudy.model.command.update_binding_constraint import UpdateBindingConstraint @@ -40,6 +46,10 @@ class UpdateBindingConstProps(BaseModel): value: Any +class BindingConstraintPropertiesWithName(BindingConstraintProperties): + name: str + + class BindingConstraintDTO(BaseModel): id: str name: str @@ -153,6 +163,32 @@ def get_binding_constraint( binding_constraint.append(new_config) return binding_constraint + def create_binding_constraint( + self, + study: Study, + data: BindingConstraintPropertiesWithName, + ) -> None: + binding_constraints = self.get_binding_constraint(study, None) + existing_ids = [bd.id for bd in binding_constraints] # type: ignore + bd_id = transform_name_to_id(data.name) + if bd_id in existing_ids: + raise DuplicateConstraintName(f"A binding constraint with the same name already exists: {bd_id}.") + + file_study = self.storage_service.get_storage(study).get_raw(study) + command = CreateBindingConstraint( + name=bd_id, + enabled=data.enabled, + time_step=data.time_step, + operator=data.operator, + coeffs=data.coeffs, + values=data.values, + filter_year_by_year=data.filter_year_by_year, + filter_synthesis=data.filter_synthesis, + comments=data.comments or "", + command_context=self.storage_service.variant_study_service.command_factory.command_context, + ) + execute_or_add_commands(study, file_study, [command], self.storage_service) + def update_binding_constraint( self, study: Study, diff --git a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py index ed3125f34b..901294a73d 100644 --- a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py +++ b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union, cast import numpy as np -from pydantic import Field, validator +from pydantic import BaseModel, Field, validator from antarest.matrixstore.model import MatrixData from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency @@ -58,12 +58,9 @@ def check_matrix_values(time_step: BindingConstraintFrequency, values: MatrixTyp raise ValueError("Matrix values cannot contain NaN") -class AbstractBindingConstraintCommand(ICommand, metaclass=ABCMeta): - """ - Abstract class for binding constraint commands. - """ - +class BindingConstraintProperties(BaseModel): # todo: add the `name` attribute because it should also be updated + # It would lead to an API change as update_binding_constraint currently does not have it enabled: bool = True time_step: BindingConstraintFrequency operator: BindingConstraintOperator @@ -73,6 +70,12 @@ class AbstractBindingConstraintCommand(ICommand, metaclass=ABCMeta): filter_synthesis: Optional[str] = None comments: Optional[str] = None + +class AbstractBindingConstraintCommand(BindingConstraintProperties, ICommand, metaclass=ABCMeta): + """ + Abstract class for binding constraint commands. + """ + def to_dto(self) -> CommandDTO: args = { "enabled": self.enabled, diff --git a/antarest/study/web/study_data_blueprint.py b/antarest/study/web/study_data_blueprint.py index de9fbcecd1..440539a4ab 100644 --- a/antarest/study/web/study_data_blueprint.py +++ b/antarest/study/web/study_data_blueprint.py @@ -26,7 +26,11 @@ ) from antarest.study.business.areas.st_storage_management import * from antarest.study.business.areas.thermal_management import * -from antarest.study.business.binding_constraint_management import ConstraintTermDTO, UpdateBindingConstProps +from antarest.study.business.binding_constraint_management import ( + BindingConstraintPropertiesWithName, + ConstraintTermDTO, + UpdateBindingConstProps, +) from antarest.study.business.correlation_management import CorrelationFormFields, CorrelationManager, CorrelationMatrix from antarest.study.business.district_manager import DistrictCreationDTO, DistrictInfoDTO, DistrictUpdateDTO from antarest.study.business.general_management import GeneralFormFields @@ -857,6 +861,23 @@ def update_binding_constraint( study = study_service.check_study_access(uuid, StudyPermissionType.WRITE, params) return study_service.binding_constraint_manager.update_binding_constraint(study, binding_constraint_id, data) + @bp.post( + "/studies/{uuid}/bindingconstraints", + tags=[APITag.study_data], + summary="Create a binding constraint", + response_model=None, + ) + def create_binding_constraint( + uuid: str, data: BindingConstraintPropertiesWithName, current_user: JWTUser = Depends(auth.get_current_user) + ) -> None: + logger.info( + f"Creating a new binding constraint for study {uuid}", + extra={"user": current_user.id}, + ) + params = RequestParameters(user=current_user) + study = study_service.check_study_access(uuid, StudyPermissionType.READ, params) + return study_service.binding_constraint_manager.create_binding_constraint(study, data) + @bp.post( "/studies/{uuid}/bindingconstraints/{binding_constraint_id}/term", tags=[APITag.study_data], diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index 0938557239..1e9fd99caa 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -2192,6 +2192,45 @@ def test_binding_constraint_manager(client: TestClient, admin_access_token: str, assert res.status_code == 200 assert constraints is None + # Creates a binding constraint with the new API + res = client.post( + f"/v1/studies/{variant_id}/bindingconstraints", + json={ + "name": "binding_constraint_3", + "enabled": True, + "time_step": "hourly", + "operator": "less", + "coeffs": {}, + "comments": "New API", + }, + headers=admin_headers, + ) + assert res.status_code == 200 + + # Asserts that creating 2 binding constraints with the same name raises an Exception + res = client.post( + f"/v1/studies/{variant_id}/bindingconstraints", + json={ + "name": "binding_constraint_3", + "enabled": True, + "time_step": "hourly", + "operator": "less", + "coeffs": {}, + "comments": "New API", + }, + headers=admin_headers, + ) + assert res.status_code == 409 + assert res.json() == { + "description": "A binding constraint with the same name already exists: binding_constraint_3.", + "exception": "DuplicateConstraintName", + } + + # Asserts that only 3 binding constraint have been created + res = client.get(f"/v1/studies/{variant_id}/bindingconstraints", headers=admin_headers) + assert res.status_code == 200 + assert len(res.json()) == 3 + def test_import(client: TestClient, admin_access_token: str, study_id: str) -> None: admin_headers = {"Authorization": f"Bearer {admin_access_token}"} From 641bbade75aaa0e74f567d2e14f0b3eccdd3ed8c Mon Sep 17 00:00:00 2001 From: Hatim Dinia Date: Wed, 13 Dec 2023 13:46:47 +0100 Subject: [PATCH 37/43] fix(ui): update current area after window reload (#1862) --- .../App/Singlestudy/explore/Modelization/index.tsx | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/index.tsx b/webapp/src/components/App/Singlestudy/explore/Modelization/index.tsx index d298c94a31..06b1a4da83 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/index.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/index.tsx @@ -1,5 +1,5 @@ -import { useMemo } from "react"; -import { useNavigate, useOutletContext } from "react-router-dom"; +import { useEffect, useMemo } from "react"; +import { useNavigate, useOutletContext, useParams } from "react-router-dom"; import { Box } from "@mui/material"; import { useTranslation } from "react-i18next"; import { StudyMetadata } from "../../../../../common/types"; @@ -14,9 +14,16 @@ function Modelization() { const [t] = useTranslation(); const dispatch = useAppDispatch(); const navigate = useNavigate(); + const { areaId: paramAreaId } = useParams(); const areas = useAppSelector((state) => getAreas(state, study.id)); const areaId = useAppSelector(getCurrentAreaId); + useEffect(() => { + if (!areaId && paramAreaId) { + dispatch(setCurrentArea(paramAreaId)); + } + }, [paramAreaId, dispatch, areaId]); + const tabList = useMemo(() => { const basePath = `/studies/${study.id}/explore/modelization`; From f9f7b66c11bf24d5cf71c7b391884e45ea5f09cb Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Tue, 12 Dec 2023 13:27:15 +0100 Subject: [PATCH 38/43] build: prepare new bug fix release v2.16.1 (unreleased) --- antarest/__init__.py | 4 ++-- docs/CHANGELOG.md | 4 ++++ setup.py | 2 +- sonar-project.properties | 2 +- webapp/package-lock.json | 4 ++-- webapp/package.json | 2 +- 6 files changed, 11 insertions(+), 7 deletions(-) diff --git a/antarest/__init__.py b/antarest/__init__.py index ada981b5ca..ea7c2d6185 100644 --- a/antarest/__init__.py +++ b/antarest/__init__.py @@ -7,9 +7,9 @@ # Standard project metadata -__version__ = "2.16.0" +__version__ = "2.16.1" __author__ = "RTE, Antares Web Team" -__date__ = "2023-11-30" +__date__ = "2023-12-14" # noinspection SpellCheckingInspection __credits__ = "(c) Réseau de Transport de l’Électricité (RTE)" diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index f48e0941ca..27761802e6 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,6 +1,10 @@ Antares Web Changelog ===================== +v2.16.1 (2023-12-14) +-------------------- + + v2.16.0 (2023-11-30) -------------------- diff --git a/setup.py b/setup.py index 37074c1e33..1760ecac46 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="AntaREST", - version="2.16.0", + version="2.16.1", description="Antares Server", long_description=Path("README.md").read_text(encoding="utf-8"), long_description_content_type="text/markdown", diff --git a/sonar-project.properties b/sonar-project.properties index e19a4a82dc..69dd022476 100644 --- a/sonar-project.properties +++ b/sonar-project.properties @@ -6,5 +6,5 @@ sonar.exclusions=antarest/gui.py,antarest/main.py sonar.python.coverage.reportPaths=coverage.xml sonar.python.version=3.8 sonar.javascript.lcov.reportPaths=webapp/coverage/lcov.info -sonar.projectVersion=2.16.0 +sonar.projectVersion=2.16.1 sonar.coverage.exclusions=antarest/gui.py,antarest/main.py,antarest/singleton_services.py,antarest/worker/archive_worker_service.py,webapp/**/* \ No newline at end of file diff --git a/webapp/package-lock.json b/webapp/package-lock.json index ed11beab16..6ce088d97a 100644 --- a/webapp/package-lock.json +++ b/webapp/package-lock.json @@ -1,12 +1,12 @@ { "name": "antares-web", - "version": "2.16.0", + "version": "2.16.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "antares-web", - "version": "2.16.0", + "version": "2.16.1", "dependencies": { "@emotion/react": "11.11.1", "@emotion/styled": "11.11.0", diff --git a/webapp/package.json b/webapp/package.json index e113491672..6b687cb3ae 100644 --- a/webapp/package.json +++ b/webapp/package.json @@ -1,6 +1,6 @@ { "name": "antares-web", - "version": "2.16.0", + "version": "2.16.1", "private": true, "engines": { "node": "18.16.1" From d4cd6b5c591105edce49355dc3edfcfbbc76129f Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Tue, 12 Dec 2023 13:42:16 +0100 Subject: [PATCH 39/43] docs: update change log for v2.16.1 --- docs/CHANGELOG.md | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 27761802e6..7a4fc60527 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -4,6 +4,34 @@ Antares Web Changelog v2.16.1 (2023-12-14) -------------------- +### Features + +* **db-init:** separate database initialization from global database session [`#1837`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1837) +* **ui:** add manual submit on clusters form [`#1852`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1852) +* **ui-modelling:** add dynamic area selection on Areas tab click [`#1835`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1835) +* **ui-storages:** use percentage values instead of ratio values [`#1846`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1846) + + +### Bug Fixes + +* **bc:** correct the name and shape of the binding constraint matrices [`#1849`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1849) +* **bc:** avoid duplicates in Binding Constraints creation through REST API [`#1858`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1858) +* **ui-study:** fix the study card explore button visibility [`#1842`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1842) +* **ui-matrix:** prevent matrices float values to be converted [`#1850`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1850) +* **ui-matrix:** calculate the prepend index according to the existence of a time column [`#1856`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1856) +* **ui-output:** add the missing "ST Storages" option in the Display selector in results view [`#1855`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1855) + + +## Documentation + +* **config:** enhance application configuration documentation [`#1710`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1710) + + +### Chore + +* **deps:** upgrade material-react-table [`#1851`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1851) + + v2.16.0 (2023-11-30) -------------------- From ece239fd7799e872e1f638d4f4a24312e8210299 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Wed, 13 Dec 2023 17:05:02 +0100 Subject: [PATCH 40/43] chore: correct unit test `test_cancel` --- tests/core/test_tasks.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index e187e4eb01..cc730bf0ea 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -305,13 +305,16 @@ def test_cancel(core_config: Config, event_bus: IEventBus) -> None: with pytest.raises(UserHasNotPermissionError): service.cancel_task("a", RequestParameters()) + # The event_bus fixture is actually a EventBusService with LocalEventBus backend + backend = t.cast(LocalEventBus, t.cast(EventBusService, event_bus).backend) + # Test Case: cancel a task that is not in the service tasks map # ============================================================= + backend.clear_events() + service.cancel_task("b", RequestParameters(user=DEFAULT_ADMIN_USER), dispatch=True) - # The event_bus fixture is actually a EventBusService with LocalEventBus backend - backend = t.cast(LocalEventBus, t.cast(EventBusService, event_bus).backend) collected_events = backend.get_events() assert len(collected_events) == 1 @@ -324,10 +327,12 @@ def test_cancel(core_config: Config, event_bus: IEventBus) -> None: service.tasks["a"] = Mock(cancel=Mock(return_value=None)) + backend.clear_events() + service.cancel_task("a", RequestParameters(user=DEFAULT_ADMIN_USER), dispatch=True) collected_events = backend.get_events() - assert len(collected_events) == 1, "No event should have been emitted because the task is in the service map" + assert len(collected_events) == 0, "No event should have been emitted because the task is in the service map" task_a = task_job_repo.get("a") assert task_a is not None assert task_a.status == TaskStatus.CANCELLED.value From 9b0d27c2393bdf22c170f763af32290808cd8044 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE <43534797+laurent-laporte-pro@users.noreply.github.com> Date: Wed, 13 Dec 2023 17:28:15 +0100 Subject: [PATCH 41/43] perf(variant): improve performances and correct snapshot generation (#1854) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cette PR permet de corriger (en partie) le problème de lenteur de la génération des variants en évitant de générer le snapshot s'il existe déjà. En effet, lorsqu'un variant est modifié, son snapshot n'est plus à jour et il suffit d'appliquer les nouvelles commandes pour le mettre à jour. Dans de rares cas, par exemple si l'utilisateur modifie l'historique des commandes, le snapshot ne sera pas mis à jour correctement. Cette situation n'est pour l'instant pas gérée automatiquement. --- .../storage/variantstudy/model/dbmodel.py | 2 +- .../variantstudy/snapshot_generator.py | 163 +++++++----- .../variantstudy/variant_study_service.py | 2 +- .../variantstudy/model/test_dbmodel.py | 4 +- .../variantstudy/test_snapshot_generator.py | 248 +++++++++++++----- 5 files changed, 276 insertions(+), 143 deletions(-) diff --git a/antarest/study/storage/variantstudy/model/dbmodel.py b/antarest/study/storage/variantstudy/model/dbmodel.py index 3e547bce13..1a88a76853 100644 --- a/antarest/study/storage/variantstudy/model/dbmodel.py +++ b/antarest/study/storage/variantstudy/model/dbmodel.py @@ -99,7 +99,7 @@ def snapshot_dir(self) -> Path: """Get the path of the snapshot directory.""" return Path(self.path) / "snapshot" - def is_snapshot_recent(self) -> bool: + def is_snapshot_up_to_date(self) -> bool: """Check if the snapshot exists and is up-to-date.""" return ( (self.snapshot is not None) diff --git a/antarest/study/storage/variantstudy/snapshot_generator.py b/antarest/study/storage/variantstudy/snapshot_generator.py index f36632ea87..50972ae99a 100644 --- a/antarest/study/storage/variantstudy/snapshot_generator.py +++ b/antarest/study/storage/variantstudy/snapshot_generator.py @@ -4,7 +4,6 @@ import datetime import logging import shutil -import tempfile import typing as t from pathlib import Path @@ -51,8 +50,6 @@ def __init__( self.study_factory = study_factory self.patch_service = patch_service self.repository = repository - # Temporary directory used to generate the snapshot - self._tmp_dir: Path = Path() def generate_snapshot( self, @@ -75,32 +72,29 @@ def generate_snapshot( root_study, descendants = self._retrieve_descendants(variant_study_id) assert_permission_on_studies(jwt_user, [root_study, *descendants], StudyPermissionType.READ, raising=True) - ref_study, cmd_blocks = search_ref_study(root_study, descendants, from_scratch=from_scratch) + search_result = search_ref_study(root_study, descendants, from_scratch=from_scratch) - # We are going to generate the snapshot in a temporary directory which will be renamed - # at the end of the process. This prevents incomplete snapshots in case of error. + ref_study = search_result.ref_study + cmd_blocks = search_result.cmd_blocks - # Get snapshot directory and prepare a temporary directory next to it. + # Get snapshot directory variant_study = descendants[-1] snapshot_dir = variant_study.snapshot_dir - snapshot_dir.parent.mkdir(parents=True, exist_ok=True) - self._tmp_dir = Path(tempfile.mkdtemp(dir=snapshot_dir.parent, prefix=f"~{snapshot_dir.name}", suffix=".tmp")) + try: - logger.info(f"Exporting the reference study '{ref_study.id}' to '{self._tmp_dir.name}'...") - self._export_ref_study(ref_study) + if search_result.force_regenerate or not snapshot_dir.exists(): + logger.info(f"Exporting the reference study '{ref_study.id}' to '{snapshot_dir.name}'...") + shutil.rmtree(snapshot_dir, ignore_errors=True) + self._export_ref_study(snapshot_dir, ref_study) logger.info(f"Applying commands to the reference study '{ref_study.id}'...") - results = self._apply_commands(variant_study, ref_study, cmd_blocks) - - if (snapshot_dir / "user").exists(): - logger.info("Keeping previous unmanaged user config...") - shutil.copytree(snapshot_dir / "user", self._tmp_dir / "user", dirs_exist_ok=True) + results = self._apply_commands(snapshot_dir, variant_study, cmd_blocks) # The snapshot is generated, we also need to de-normalize the matrices. file_study = self.study_factory.create_from_fs( - self._tmp_dir, + snapshot_dir, study_id=variant_study_id, - output_path=self._tmp_dir / OUTPUT_RELATIVE_PATH, + output_path=snapshot_dir / OUTPUT_RELATIVE_PATH, use_cache=False, # Avoid saving the study config in the cache ) if denormalize: @@ -112,26 +106,20 @@ def generate_snapshot( variant_study.snapshot = VariantStudySnapshot( id=variant_study_id, created_at=datetime.datetime.utcnow(), - last_executed_command=cmd_blocks[-1].id if cmd_blocks else None, + last_executed_command=variant_study.commands[-1].id if variant_study.commands else None, ) logger.info(f"Reading additional data from files for study {file_study.config.study_id}") variant_study.additional_data = self._read_additional_data(file_study) self.repository.save(variant_study) - # Store the study config in the cache (with adjusted paths). - file_study.config.study_path = file_study.config.path = snapshot_dir - file_study.config.output_path = snapshot_dir / OUTPUT_RELATIVE_PATH self._update_cache(file_study) except Exception: - shutil.rmtree(self._tmp_dir, ignore_errors=True) + shutil.rmtree(snapshot_dir, ignore_errors=True) raise else: - # Rename the temporary directory to the final snapshot directory - shutil.rmtree(snapshot_dir, ignore_errors=True) - self._tmp_dir.rename(snapshot_dir) try: notifier(results.json()) except Exception as exc: @@ -149,12 +137,12 @@ def _retrieve_descendants(self, variant_study_id: str) -> t.Tuple[RawStudy, t.Se root_study = self.repository.one(descendant_ids[0]) return root_study, descendants - def _export_ref_study(self, ref_study: t.Union[RawStudy, VariantStudy]) -> None: - self._tmp_dir.rmdir() # remove the temporary directory for shutil.copytree + def _export_ref_study(self, snapshot_dir: Path, ref_study: t.Union[RawStudy, VariantStudy]) -> None: if isinstance(ref_study, VariantStudy): + snapshot_dir.parent.mkdir(parents=True, exist_ok=True) export_study_flat( ref_study.snapshot_dir, - self._tmp_dir, + snapshot_dir, self.study_factory, denormalize=False, # de-normalization is done at the end outputs=False, # do NOT export outputs @@ -162,7 +150,7 @@ def _export_ref_study(self, ref_study: t.Union[RawStudy, VariantStudy]) -> None: elif isinstance(ref_study, RawStudy): self.raw_study_service.export_study_flat( ref_study, - self._tmp_dir, + snapshot_dir, denormalize=False, # de-normalization is done at the end outputs=False, # do NOT export outputs ) @@ -171,15 +159,15 @@ def _export_ref_study(self, ref_study: t.Union[RawStudy, VariantStudy]) -> None: def _apply_commands( self, + snapshot_dir: Path, variant_study: VariantStudy, - ref_study: t.Union[RawStudy, VariantStudy], cmd_blocks: t.Sequence[CommandBlock], ) -> GenerationResultInfoDTO: commands = [self.command_factory.to_command(cb.to_dto()) for cb in cmd_blocks] generator = VariantCommandGenerator(self.study_factory) results = generator.generate( commands, - self._tmp_dir, + snapshot_dir, variant_study, delete_on_failure=False, # Not needed, because we are using a temporary directory notifier=None, @@ -208,12 +196,22 @@ def _update_cache(self, file_study: FileStudy) -> None: ) +class RefStudySearchResult(t.NamedTuple): + """ + Result of the search for the reference study. + """ + + ref_study: t.Union[RawStudy, VariantStudy] + cmd_blocks: t.Sequence[CommandBlock] + force_regenerate: bool = False + + def search_ref_study( root_study: t.Union[RawStudy, VariantStudy], descendants: t.Sequence[VariantStudy], *, from_scratch: bool = False, -) -> t.Tuple[t.Union[RawStudy, VariantStudy], t.Sequence[CommandBlock]]: +) -> RefStudySearchResult: """ Search for the reference study and the commands to use for snapshot generation. @@ -225,6 +223,9 @@ def search_ref_study( Returns: The reference study and the commands to use for snapshot generation. """ + if not descendants: + # Edge case where the list of studies is empty. + return RefStudySearchResult(ref_study=root_study, cmd_blocks=[], force_regenerate=True) # The reference study is the root study or a variant study with a valid snapshot ref_study: t.Union[RawStudy, VariantStudy] @@ -236,42 +237,68 @@ def search_ref_study( # In the case of a from scratch generation, the root study will be used as the reference study. # We need to retrieve all commands from the descendants of variants in order to apply them # on the reference study. - ref_study = root_study - cmd_blocks = [c for v in descendants for c in v.commands] + return RefStudySearchResult( + ref_study=root_study, + cmd_blocks=[c for v in descendants for c in v.commands], + force_regenerate=True, + ) - else: - # To generate the last variant of a descendant of variants, we must search for - # the most recent snapshot in order to use it as a reference study. - # If no snapshot is found, we use the root study as a reference study. - - snapshot_vars = [v for v in descendants if v.is_snapshot_recent()] - - if snapshot_vars: - # We use the most recent snapshot as a reference study - ref_study = max(snapshot_vars, key=lambda v: v.snapshot.created_at) - - # This variant's snapshot corresponds to the commands actually generated - # at the time of the snapshot. However, we need to retrieve the remaining commands, - # because the snapshot generation may be incomplete. - last_exec_cmd = ref_study.snapshot.last_executed_command # ID of the command - if not last_exec_cmd: - # It is unlikely that this case will occur, but it means that - # the snapshot is not correctly generated (corrupted database). - # It better to use all commands to force snapshot re-generation. - cmd_blocks = ref_study.commands[:] - else: - command_ids = [c.id for c in ref_study.commands] - last_exec_index = command_ids.index(last_exec_cmd) - cmd_blocks = ref_study.commands[last_exec_index + 1 :] - - # We need to add all commands from the descendants of variants - # starting at the first descendant of reference study. - index = descendants.index(ref_study) - cmd_blocks.extend([c for v in descendants[index + 1 :] for c in v.commands]) + # To reuse the snapshot of the current variant, the last executed command + # must be one of the commands of the current variant. + curr_variant = descendants[-1] + if curr_variant.snapshot: + last_exec_cmd = curr_variant.snapshot.last_executed_command + command_ids = [c.id for c in curr_variant.commands] + # If the variant has no command, we can reuse the snapshot if it is recent + if not last_exec_cmd and not command_ids and curr_variant.is_snapshot_up_to_date(): + return RefStudySearchResult( + ref_study=curr_variant, + cmd_blocks=[], + force_regenerate=False, + ) + elif last_exec_cmd and last_exec_cmd in command_ids: + # We can reuse the snapshot of the current variant + last_exec_index = command_ids.index(last_exec_cmd) + return RefStudySearchResult( + ref_study=curr_variant, + cmd_blocks=curr_variant.commands[last_exec_index + 1 :], + force_regenerate=False, + ) + # We cannot reuse the snapshot of the current variant + # To generate the last variant of a descendant of variants, we must search for + # the most recent snapshot in order to use it as a reference study. + # If no snapshot is found, we use the root study as a reference study. + + snapshot_vars = [v for v in descendants if v.is_snapshot_up_to_date()] + + if snapshot_vars: + # We use the most recent snapshot as a reference study + ref_study = max(snapshot_vars, key=lambda v: v.snapshot.created_at) + + # This variant's snapshot corresponds to the commands actually generated + # at the time of the snapshot. However, we need to retrieve the remaining commands, + # because the snapshot generation may be incomplete. + last_exec_cmd = ref_study.snapshot.last_executed_command # ID of the command + command_ids = [c.id for c in ref_study.commands] + if not last_exec_cmd or last_exec_cmd not in command_ids: + # The last executed command may be missing (probably caused by a bug) + # or may reference a removed command. + # This requires regenerating the snapshot from scratch, + # with all commands from the reference study. + cmd_blocks = ref_study.commands[:] else: - # We use the root study as a reference study - ref_study = root_study - cmd_blocks = [c for v in descendants for c in v.commands] + last_exec_index = command_ids.index(last_exec_cmd) + cmd_blocks = ref_study.commands[last_exec_index + 1 :] + + # We need to add all commands from the descendants of variants + # starting at the first descendant of reference study. + index = descendants.index(ref_study) + cmd_blocks.extend([c for v in descendants[index + 1 :] for c in v.commands]) + + else: + # We use the root study as a reference study + ref_study = root_study + cmd_blocks = [c for v in descendants for c in v.commands] - return ref_study, cmd_blocks + return RefStudySearchResult(ref_study=ref_study, cmd_blocks=cmd_blocks, force_regenerate=True) diff --git a/antarest/study/storage/variantstudy/variant_study_service.py b/antarest/study/storage/variantstudy/variant_study_service.py index de6fa1651c..f9d3eea0aa 100644 --- a/antarest/study/storage/variantstudy/variant_study_service.py +++ b/antarest/study/storage/variantstudy/variant_study_service.py @@ -654,7 +654,7 @@ def generate( if variant_study.parent_id is None: raise NoParentStudyError(variant_study_id) - return self.generate_task(variant_study, denormalize) + return self.generate_task(variant_study, denormalize, from_scratch=from_scratch) def generate_study_config( self, diff --git a/tests/study/storage/variantstudy/model/test_dbmodel.py b/tests/study/storage/variantstudy/model/test_dbmodel.py index 0bcd107518..0715dec535 100644 --- a/tests/study/storage/variantstudy/model/test_dbmodel.py +++ b/tests/study/storage/variantstudy/model/test_dbmodel.py @@ -215,7 +215,7 @@ def test_init__without_snapshot(self, db_session: Session, raw_study_id: str, us # check Variant-specific properties assert obj.snapshot_dir == Path(variant_study_path).joinpath("snapshot") - assert obj.is_snapshot_recent() is False + assert obj.is_snapshot_up_to_date() is False @pytest.mark.parametrize( "created_at, updated_at, study_antares_file, expected", @@ -294,4 +294,4 @@ def test_is_snapshot_recent( # Check the snapshot_uptodate() method obj: VariantStudy = db_session.query(VariantStudy).filter(VariantStudy.id == variant_id).one() - assert obj.is_snapshot_recent() == expected + assert obj.is_snapshot_up_to_date() == expected diff --git a/tests/study/storage/variantstudy/test_snapshot_generator.py b/tests/study/storage/variantstudy/test_snapshot_generator.py index 2247b21045..5e90b6ee06 100644 --- a/tests/study/storage/variantstudy/test_snapshot_generator.py +++ b/tests/study/storage/variantstudy/test_snapshot_generator.py @@ -25,6 +25,7 @@ from antarest.study.storage.variantstudy.model.model import CommandDTO, GenerationResultInfoDTO from antarest.study.storage.variantstudy.snapshot_generator import SnapshotGenerator, search_ref_study from antarest.study.storage.variantstudy.variant_study_service import VariantStudyService +from tests.db_statement_recorder import DBStatementRecorder from tests.helpers import with_db_context @@ -85,12 +86,12 @@ class TestSearchRefStudy: and corresponding to a variant with an up-to-date snapshot. - The case where the list of studies contains two variants with up-to-date snapshots and - where the first is older than the second. + where the first is older than the second, and a third variant without snapshot. We expect to have a reference study corresponding to the second variant and a list of commands for the second variant. - The case where the list of studies contains two variants with up-to-date snapshots and - where the first is more recent than the second. + where the first is more recent than the second, and a third variant without snapshot. We expect to have a reference study corresponding to the first variant and a list of commands for both variants in order. @@ -118,9 +119,10 @@ def test_search_ref_study__empty_descendants(self) -> None: """ root_study = Study(id=str(uuid.uuid4()), name="root") references: t.Sequence[VariantStudy] = [] - ref_study, cmd_blocks = search_ref_study(root_study, references) - assert ref_study == root_study - assert cmd_blocks == [] + search_result = search_ref_study(root_study, references) + assert search_result.ref_study == root_study + assert search_result.cmd_blocks == [] + assert search_result.force_regenerate is True def test_search_ref_study__from_scratch(self, tmp_path: Path) -> None: """ @@ -195,9 +197,10 @@ def test_search_ref_study__from_scratch(self, tmp_path: Path) -> None: # Check the variants references = [variant1, variant2, variant3] - ref_study, cmd_blocks = search_ref_study(root_study, references, from_scratch=True) - assert ref_study == root_study - assert cmd_blocks == [c for v in [variant1, variant2, variant3] for c in v.commands] + search_result = search_ref_study(root_study, references, from_scratch=True) + assert search_result.ref_study == root_study + assert search_result.cmd_blocks == [c for v in [variant1, variant2, variant3] for c in v.commands] + assert search_result.force_regenerate is True def test_search_ref_study__obsolete_snapshots(self, tmp_path: Path) -> None: """ @@ -205,12 +208,9 @@ def test_search_ref_study__obsolete_snapshots(self, tmp_path: Path) -> None: - either there is no snapshot, - or the snapshot's creation date is earlier than the variant's last modification date. Note: The situation where the "snapshot/study.antares" file does not exist is not considered. - We expect to have the root study and a list of `CommandBlock` for all variants. - - Given a list of descendants with some variants with obsolete snapshots, - When calling search_ref_study with the flag from_scratch=False, - Then the root study is returned as reference study, - and all commands of all variants are returned. + The third variant has no snapshot, and must be generated from scratch. + We expect to have a reference study corresponding to the root study + and the list of commands of all variants in order. """ root_study = Study(id=str(uuid.uuid4()), name="root") @@ -231,6 +231,14 @@ def test_search_ref_study__obsolete_snapshots(self, tmp_path: Path) -> None: datetime.datetime(year=2023, month=1, day=2), snapshot_created_at=datetime.datetime(year=2023, month=1, day=1), ) + # Variant 3 has no snapshot. + variant3 = _create_variant( + tmp_path, + "variant3", + variant2.id, + datetime.datetime(year=2023, month=1, day=3), + snapshot_created_at=None, + ) # Add some variant commands variant1.commands = [ @@ -252,25 +260,41 @@ def test_search_ref_study__obsolete_snapshots(self, tmp_path: Path) -> None: version=1, args='{"area_name": "DE", "cluster_name": "DE", "cluster_type": "thermal"}', ), + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant2.id, + index=1, + command="create_thermal_cluster", + version=1, + args='{"area_name": "IT", "cluster_name": "IT", "cluster_type": "gas"}', + ), ] variant2.snapshot.last_executed_command = variant2.commands[0].id + variant3.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant2.id, + index=0, + command="create_thermal_cluster", + version=1, + args='{"area_name": "BE", "cluster_name": "BE", "cluster_type": "oil"}', + ), + ] # Check the variants - references = [variant1, variant2] - ref_study, cmd_blocks = search_ref_study(root_study, references) - assert ref_study == root_study - assert cmd_blocks == [c for v in [variant1, variant2] for c in v.commands] + references = [variant1, variant2, variant3] + search_result = search_ref_study(root_study, references) + assert search_result.ref_study == root_study + assert search_result.cmd_blocks == [c for v in [variant1, variant2, variant3] for c in v.commands] + assert search_result.force_regenerate is True def test_search_ref_study__old_recent_snapshot(self, tmp_path: Path) -> None: """ Case where the list of studies contains a variant with up-to-date snapshots and where the first is older than the second. + The third variant has no snapshot, and must be generated from scratch. We expect to have a reference study corresponding to the second variant - and an empty list of commands, because the snapshot is already completely up-to-date. - - Given a list of descendants with some variants with up-to-date snapshots, - When calling search_ref_study with the flag from_scratch=False, - Then the second variant is returned as reference study, and no commands are returned. + and the list of commands of the third variant. """ root_study = Study(id=str(uuid.uuid4()), name="root") @@ -291,6 +315,14 @@ def test_search_ref_study__old_recent_snapshot(self, tmp_path: Path) -> None: datetime.datetime(year=2023, month=1, day=2), snapshot_created_at=datetime.datetime(year=2023, month=1, day=3), ) + # Variant 3 has no snapshot. + variant3 = _create_variant( + tmp_path, + "variant3", + variant2.id, + datetime.datetime(year=2023, month=1, day=3), + snapshot_created_at=None, + ) # Add some variant commands variant1.commands = [ @@ -315,24 +347,31 @@ def test_search_ref_study__old_recent_snapshot(self, tmp_path: Path) -> None: ), ] variant2.snapshot.last_executed_command = variant2.commands[0].id + variant3.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant2.id, + index=0, + command="create_thermal_cluster", + version=1, + args='{"area_name": "BE", "cluster_name": "BE", "cluster_type": "oil"}', + ), + ] # Check the variants - references = [variant1, variant2] - ref_study, cmd_blocks = search_ref_study(root_study, references) - assert ref_study == variant2 - assert cmd_blocks == [] + references = [variant1, variant2, variant3] + search_result = search_ref_study(root_study, references) + assert search_result.ref_study == variant2 + assert search_result.cmd_blocks == variant3.commands + assert search_result.force_regenerate is True def test_search_ref_study__recent_old_snapshot(self, tmp_path: Path) -> None: """ Case where the list of studies contains a variant with up-to-date snapshots and where the second is older than the first. + The third variant has no snapshot, and must be generated from scratch. We expect to have a reference study corresponding to the first variant - and the list of commands of the second variant, because the first is completely up-to-date. - - Given a list of descendants with some variants with up-to-date snapshots, - When calling search_ref_study with the flag from_scratch=False, - Then the first variant is returned as reference study, - and the commands of the second variant are returned. + and the list of commands of the second and third variants. """ root_study = Study(id=str(uuid.uuid4()), name="root") @@ -353,6 +392,14 @@ def test_search_ref_study__recent_old_snapshot(self, tmp_path: Path) -> None: datetime.datetime(year=2023, month=1, day=2), snapshot_created_at=datetime.datetime(year=2023, month=1, day=2), ) + # Variant 3 has no snapshot. + variant3 = _create_variant( + tmp_path, + "variant3", + variant2.id, + datetime.datetime(year=2023, month=1, day=3), + snapshot_created_at=None, + ) # Add some variant commands variant1.commands = [ @@ -377,12 +424,23 @@ def test_search_ref_study__recent_old_snapshot(self, tmp_path: Path) -> None: ), ] variant2.snapshot.last_executed_command = variant2.commands[0].id + variant3.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant2.id, + index=0, + command="create_thermal_cluster", + version=1, + args='{"area_name": "BE", "cluster_name": "BE", "cluster_type": "oil"}', + ), + ] # Check the variants - references = [variant1, variant2] - ref_study, cmd_blocks = search_ref_study(root_study, references) - assert ref_study == variant1 - assert cmd_blocks == variant2.commands + references = [variant1, variant2, variant3] + search_result = search_ref_study(root_study, references) + assert search_result.ref_study == variant1 + assert search_result.cmd_blocks == [c for v in [variant2, variant3] for c in v.commands] + assert search_result.force_regenerate is True def test_search_ref_study__one_variant_completely_uptodate(self, tmp_path: Path) -> None: """ @@ -438,9 +496,10 @@ def test_search_ref_study__one_variant_completely_uptodate(self, tmp_path: Path) # Check the variants references = [variant1] - ref_study, cmd_blocks = search_ref_study(root_study, references) - assert ref_study == variant1 - assert cmd_blocks == [] + search_result = search_ref_study(root_study, references) + assert search_result.ref_study == variant1 + assert search_result.cmd_blocks == [] + assert search_result.force_regenerate is False def test_search_ref_study__one_variant_partially_uptodate(self, tmp_path: Path) -> None: """ @@ -496,9 +555,10 @@ def test_search_ref_study__one_variant_partially_uptodate(self, tmp_path: Path) # Check the variants references = [variant1] - ref_study, cmd_blocks = search_ref_study(root_study, references) - assert ref_study == variant1 - assert cmd_blocks == variant1.commands[1:] + search_result = search_ref_study(root_study, references) + assert search_result.ref_study == variant1 + assert search_result.cmd_blocks == variant1.commands[1:] + assert search_result.force_regenerate is False def test_search_ref_study__missing_last_command(self, tmp_path: Path) -> None: """ @@ -550,9 +610,65 @@ def test_search_ref_study__missing_last_command(self, tmp_path: Path) -> None: # Check the variants references = [variant1] - ref_study, cmd_blocks = search_ref_study(root_study, references) - assert ref_study == variant1 - assert cmd_blocks == variant1.commands + search_result = search_ref_study(root_study, references) + assert search_result.ref_study == variant1 + assert search_result.cmd_blocks == variant1.commands + assert search_result.force_regenerate is True + + def test_search_ref_study__deleted_last_command(self, tmp_path: Path) -> None: + """ + Case where the list of studies contains a variant with an up-to-date snapshot, + but the last executed command is missing (removed). + We expect to have the list of all variant commands, so that the snapshot can be re-generated. + """ + root_study = Study(id=str(uuid.uuid4()), name="root") + + # Prepare some variants with snapshots: + variant1 = _create_variant( + tmp_path, + "variant1", + root_study.id, + datetime.datetime(year=2023, month=1, day=1), + snapshot_created_at=datetime.datetime(year=2023, month=1, day=2), + ) + + # Add some variant commands + variant1.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=0, + command="create_area", + version=1, + args='{"area_name": "DE"}', + ), + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=1, + command="create_thermal_cluster", + version=1, + args='{"area_name": "DE", "cluster_name": "DE", "cluster_type": "thermal"}', + ), + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=2, + command="update_thermal_cluster", + version=1, + args='{"area_name": "DE", "cluster_name": "DE", "capacity": 1500}', + ), + ] + + # The last executed command is missing. + variant1.snapshot.last_executed_command = str(uuid.uuid4()) + + # Check the variants + references = [variant1] + search_result = search_ref_study(root_study, references) + assert search_result.ref_study == variant1 + assert search_result.cmd_blocks == variant1.commands + assert search_result.force_regenerate is True class RegisterNotification: @@ -715,15 +831,9 @@ def test_generate__nominal_case( repository=variant_study_service.repository, ) - sql_statements = [] notifier = RegisterNotification() - @event.listens_for(db.session.bind, "before_cursor_execute") # type: ignore - def before_cursor_execute(conn, cursor, statement: str, parameters, context, executemany) -> None: - # note: add a breakpoint here to debug the SQL statements. - sql_statements.append(statement) - - try: + with DBStatementRecorder(db.session.bind) as db_recorder: results = generator.generate_snapshot( variant_study.id, jwt_user, @@ -731,8 +841,6 @@ def before_cursor_execute(conn, cursor, statement: str, parameters, context, exe from_scratch=False, notifier=notifier, ) - finally: - event.remove(db.session.bind, "before_cursor_execute", before_cursor_execute) # Check: the number of database queries is kept as low as possible. # We expect 5 queries: @@ -741,7 +849,7 @@ def before_cursor_execute(conn, cursor, statement: str, parameters, context, exe # - 1 query to fetch the list of variants with snapshot, commands, etc., # - 1 query to update the variant study additional_data, # - 1 query to insert the variant study snapshot. - assert len(sql_statements) == 5, "\n-------\n".join(sql_statements) + assert len(db_recorder.sql_statements) == 5, str(db_recorder) # Check: the variant generation must succeed. assert results == GenerationResultInfoDTO( @@ -826,11 +934,6 @@ def test_generate__with_user_dir( Test the generation of a variant study containing a user directory. We expect that the user directory is correctly preserved. """ - # Add a user directory to the variant study. - user_dir = Path(variant_study.snapshot_dir) / "user" - user_dir.mkdir(parents=True, exist_ok=True) - user_dir.joinpath("user_file.txt").touch() - generator = SnapshotGenerator( cache=variant_study_service.cache, raw_study_service=variant_study_service.raw_study_service, @@ -840,22 +943,25 @@ def test_generate__with_user_dir( repository=variant_study_service.repository, ) - results = generator.generate_snapshot( + # Generate the snapshot once + generator.generate_snapshot( variant_study.id, jwt_user, denormalize=False, from_scratch=False, ) - # Check the results - assert results == GenerationResultInfoDTO( - success=True, - details=[ - ("create_area", True, "Area 'North' created"), - ("create_area", True, "Area 'South' created"), - ("create_link", True, "Link between 'north' and 'south' created"), - ("create_cluster", True, "Thermal cluster 'gas_cluster' added to area 'south'."), - ], + # Add a user directory to the variant study. + user_dir = Path(variant_study.snapshot_dir) / "user" + user_dir.mkdir(parents=True, exist_ok=True) + user_dir.joinpath("user_file.txt").touch() + + # Generate the snapshot again + generator.generate_snapshot( + variant_study.id, + jwt_user, + denormalize=False, + from_scratch=False, ) # Check that the user directory is correctly preserved. From 74c872e6380ce2e0a47f690476964a32957bd2e7 Mon Sep 17 00:00:00 2001 From: MartinBelthle <102529366+MartinBelthle@users.noreply.github.com> Date: Wed, 13 Dec 2023 17:31:32 +0100 Subject: [PATCH 42/43] fix(upgrade): correction of study upgrade when upgrading from v8.2 to v8.6 (creation of MinGen) (#1861) --- .../study/storage/study_upgrader/__init__.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/antarest/study/storage/study_upgrader/__init__.py b/antarest/study/storage/study_upgrader/__init__.py index 6b96dc711b..1993b4a0c3 100644 --- a/antarest/study/storage/study_upgrader/__init__.py +++ b/antarest/study/storage/study_upgrader/__init__.py @@ -201,9 +201,10 @@ def _copies_only_necessary_files(files_to_upgrade: List[Path], study_path: Path, The list of files and folders that were really copied. It's the same as files_to_upgrade but without any children that has parents already in the list. """ - files_to_upgrade.append(Path("study.antares")) + files_to_copy = _filters_out_children_files(files_to_upgrade) + files_to_copy.append(Path("study.antares")) files_to_retrieve = [] - for path in files_to_upgrade: + for path in files_to_copy: entire_path = study_path / path if entire_path.is_dir(): if not (tmp_path / path).exists(): @@ -220,6 +221,22 @@ def _copies_only_necessary_files(files_to_upgrade: List[Path], study_path: Path, return files_to_retrieve +def _filters_out_children_files(files_to_upgrade: List[Path]) -> List[Path]: + """ + Filters out children paths of "input" if "input" is already in the list. + Args: + files_to_upgrade: List[Path]: List of the files and folders concerned by the upgrade. + Returns: + The list of files filtered + """ + is_input_in_files_to_upgrade = Path("input") in files_to_upgrade + if is_input_in_files_to_upgrade: + files_to_keep = [Path("input")] + files_to_keep.extend(path for path in files_to_upgrade if "input" not in path.parts) + return files_to_keep + return files_to_upgrade + + def _replace_safely_original_files(files_to_replace: List[Path], study_path: Path, tmp_path: Path) -> None: """ Replace files/folders of the study that should be upgraded by their copy already upgraded in the tmp directory. From bf040bac3633dadacca4f56bd7387ccf2d71f8aa Mon Sep 17 00:00:00 2001 From: Hatim Dinia Date: Wed, 13 Dec 2023 13:46:47 +0100 Subject: [PATCH 43/43] build: new bug fix release v2.16.1 (2023-12-14) --- docs/CHANGELOG.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 7a4fc60527..1e587bf6e3 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -6,22 +6,29 @@ v2.16.1 (2023-12-14) ### Features -* **db-init:** separate database initialization from global database session [`#1837`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1837) * **ui:** add manual submit on clusters form [`#1852`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1852) * **ui-modelling:** add dynamic area selection on Areas tab click [`#1835`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1835) * **ui-storages:** use percentage values instead of ratio values [`#1846`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1846) +* **upgrade:** correction of study upgrade when upgrading from v8.2 to v8.6 (creation of MinGen) [`#1861`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1861) ### Bug Fixes * **bc:** correct the name and shape of the binding constraint matrices [`#1849`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1849) * **bc:** avoid duplicates in Binding Constraints creation through REST API [`#1858`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1858) +* **ui:** update current area after window reload [`#1862`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1862) * **ui-study:** fix the study card explore button visibility [`#1842`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1842) * **ui-matrix:** prevent matrices float values to be converted [`#1850`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1850) * **ui-matrix:** calculate the prepend index according to the existence of a time column [`#1856`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1856) * **ui-output:** add the missing "ST Storages" option in the Display selector in results view [`#1855`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1855) +### Performance + +* **db-init:** separate database initialization from global database session [`#1837`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1837) +* **variant:** improve performances and correct snapshot generation [`#1854`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1854) + + ## Documentation * **config:** enhance application configuration documentation [`#1710`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1710)