|
| 1 | +# Copyright 2025 Camptocamp SA |
| 2 | +# License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl) |
| 3 | + |
| 4 | +import json |
| 5 | +import logging |
| 6 | + |
| 7 | +from psycopg2 import sql |
| 8 | +from psycopg2.extras import execute_values |
| 9 | + |
| 10 | +_logger = logging.getLogger(__name__) |
| 11 | + |
| 12 | + |
| 13 | +def _get_impacted_table_names(cr): |
| 14 | + query = """ |
| 15 | + SELECT TABLE_NAME |
| 16 | + FROM INFORMATION_SCHEMA.COLUMNS |
| 17 | + WHERE COLUMN_NAME = 'server_env_defaults'; |
| 18 | + """ |
| 19 | + cr.execute(query) |
| 20 | + return [row[0] for row in cr.fetchall()] |
| 21 | + |
| 22 | + |
| 23 | +def _get_values_to_fix(cr, table): |
| 24 | + query = sql.SQL( |
| 25 | + """ |
| 26 | + SELECT id, server_env_defaults |
| 27 | + FROM {table} |
| 28 | + WHERE server_env_defaults |
| 29 | + IS NOT NULL; |
| 30 | + """ |
| 31 | + ) |
| 32 | + formatted_query = query.format(table=sql.Identifier(table)) |
| 33 | + cr.execute(formatted_query) |
| 34 | + return cr.fetchall() |
| 35 | + |
| 36 | + |
| 37 | +def _get_fixed_values(cr, values): |
| 38 | + new_values = [] |
| 39 | + for _id, defaults in values: |
| 40 | + # defaults are string dicts |
| 41 | + defaults = json.loads(defaults) |
| 42 | + new_defaults = {} |
| 43 | + for key, value in defaults.items(): |
| 44 | + # Only fix keys that weren't already fixed |
| 45 | + # Makes this idempotent. |
| 46 | + if key.endswith("_env_default") and not key.startswith("x_"): |
| 47 | + new_defaults[f"x_{key}"] = value |
| 48 | + else: |
| 49 | + new_defaults[key] = value |
| 50 | + # dump dict in a string |
| 51 | + new_values.append((_id, json.dumps(new_defaults))) |
| 52 | + return new_values |
| 53 | + |
| 54 | + |
| 55 | +def _apply_new_values(cr, table, values): |
| 56 | + # Doing the formatting in 2 steps, this is on purpose. |
| 57 | + query = f""" |
| 58 | + UPDATE {table} |
| 59 | + SET server_env_defaults = c.server_env_defaults |
| 60 | + FROM (VALUES %s) |
| 61 | + AS c(id, server_env_defaults) |
| 62 | + WHERE {table}.id = c.id |
| 63 | + """ |
| 64 | + execute_values(cr, query, values) |
| 65 | + |
| 66 | + |
| 67 | +def fix_server_env_defaults(cr): |
| 68 | + for table_name in _get_impacted_table_names(cr): |
| 69 | + _logger.info(f"Fixing server_env_defaults on '{table_name}'") |
| 70 | + old_values = _get_values_to_fix(cr, table_name) |
| 71 | + new_values = _get_fixed_values(cr, old_values) |
| 72 | + _apply_new_values(cr, table_name, new_values) |
| 73 | + |
| 74 | + |
| 75 | +def migrate(cr, version): |
| 76 | + fix_server_env_defaults(cr) |
0 commit comments