Skip to content

Commit

Permalink
Merge pull request #2 from peterrrock2/dev
Browse files Browse the repository at this point in the history
Add validation to ogr2ogr queries to prevent command line overflow
  • Loading branch information
peterrrock2 authored Sep 26, 2024
2 parents 239f05a + 40d78ae commit a977a8b
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 23 deletions.
2 changes: 1 addition & 1 deletion gerrydb_meta/api/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from gerrydb_meta.render import view_to_gpkg
from gerrydb_meta.scopes import ScopeManager

log = logging.getLogger()
log = logging.getLogger("uvicorn")

router = APIRouter()
CHUNK_SIZE = 32 * 1024 * 1024 # for gzipping rendered views
Expand Down
27 changes: 27 additions & 0 deletions gerrydb_meta/api/view_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
from gerrydb_meta.api.deps import get_db, get_obj_meta, get_scopes
from gerrydb_meta.scopes import ScopeManager

import logging


log = logging.getLogger("uvicorn")


MAX_VIEW_TEMPLATE_COLUMNS = 200


class ViewTemplateApi(NamespacedObjectApi):
def _check_public(self, scopes: ScopeManager) -> None:
Expand Down Expand Up @@ -67,6 +75,25 @@ def create_route(
resolved_objs = from_resource_paths(
paths=obj_in.members, db=db, scopes=scopes, follow_refs=False
)
total_objs = 0
for item in resolved_objs:
if isinstance(item, models.ColumnSet):
total_objs += len(item.columns)
else:
total_objs += 1

if total_objs > MAX_VIEW_TEMPLATE_COLUMNS:
log.error(
f"Cannot create view template with more than {MAX_VIEW_TEMPLATE_COLUMNS} columns."
)
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=(
f"Maximum view template column count exceeded. Found {total_objs} columns. "
f"Maximum is {MAX_VIEW_TEMPLATE_COLUMNS}."
),
)

template_obj, etag = self.crud.create(
db=db,
obj_in=obj_in,
Expand Down
4 changes: 4 additions & 0 deletions gerrydb_meta/crud/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
PatchSchemaType = TypeVar("PatchSchemaType", bound=BaseModel)

# These characters are most likely to appear in the resource_id part of
# a path (typically the last segment). Exclusion of these characters
# prevents ogr2ogr fails and helps protect against malicious code injection.
INVALID_PATH_SUBSTRINGS = set(
{
"..",
" ",
";",
}
)

Expand Down
26 changes: 26 additions & 0 deletions gerrydb_meta/crud/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class ViewRenderContext:
graph_edges: Sequence | None
geo_meta: dict[int, models.ObjectMeta]
geo_meta_ids: dict[str, int] # by path
geo_valid_from_dates: dict[str, datetime]

# Bulk queries for `ogr2ogr`.
geo_query: str
Expand Down Expand Up @@ -440,6 +441,7 @@ def render(self, db: Session, *, view: models.View) -> ViewRenderContext:

plans, plan_labels, plan_assignments = self._plans(db, view)
geo_meta_ids, geo_meta = self._geo_meta(db, view)
geo_valid_from_dates = self._geo_valid_dates(db, view)

return ViewRenderContext(
view=view,
Expand All @@ -451,6 +453,7 @@ def render(self, db: Session, *, view: models.View) -> ViewRenderContext:
graph_edges=self._graph_edges(db, view),
geo_meta=geo_meta,
geo_meta_ids=geo_meta_ids,
geo_valid_from_dates=geo_valid_from_dates,
# Query generation: substitute in literals and remove the
# ST_AsBinary() calls added by GeoAlchemy2.
geo_query=re.sub(
Expand Down Expand Up @@ -506,6 +509,29 @@ def _geo_meta(

return geo_meta_ids, distinct_meta

def _geo_valid_dates(self, db: Session, view: models.View) -> dict[str, datetime]:
"""Gets the valid dates for each geometry.
Returns:
A dictionary mapping geometry IDs to valid dates.
"""

query = (
select(models.Geography.path, models.GeoVersion.valid_from)
.join(
models.GeoSetMember,
models.Geography.geo_id == models.GeoSetMember.geo_id,
)
.join(
models.GeoVersion, models.Geography.geo_id == models.GeoVersion.geo_id
)
.where(models.GeoSetMember.set_version_id == view.set_version_id)
)

result = db.execute(query)

return {row.path: row.valid_from for row in result}

def _plans(
self, db: Session, view: models.View
) -> tuple[list[models.Plan], list[str], Sequence | None]:
Expand Down
115 changes: 93 additions & 22 deletions gerrydb_meta/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tempfile
import uuid
from pathlib import Path
import os, sys, shlex

import orjson as json

Expand Down Expand Up @@ -205,6 +206,51 @@ def _init_gpkg_plans_extension(
conn.commit()


def __get_arg_max() -> int:
"""
Retrieve the system's ARG_MAX value.
Returns:
int: The maximum length of the arguments to the exec functions in bytes.
Returns None if the value cannot be determined.
"""
if hasattr(os, "sysconf"):
if "SC_ARG_MAX" in os.sysconf_names:
try:
arg_max = os.sysconf("SC_ARG_MAX")
if arg_max > 0:
return arg_max
except (ValueError, OSError) as e:
print(f"Warning: Unable to retrieve ARG_MAX using os.sysconf: {e}")
raise e

if sys.platform.startswith("win"):
raise RuntimeError("This function cannot be run in a Windows environment.")

# Fallback Unix-like systems where SC_ARG_MAX is not available.
# Uses common default value (Linux typically has 2,097,152 bytes).
return 2097152


def __validate_query(query: str) -> bool:
"""
Ensures that the query is does not exceed the maximum allowable
length of queries made to the terminal. This is generally governed by
the ARG_MAX environment variable.
Args:
query: The query to be validated.
Raises:
RuntimeError: If the query is too long.
"""
query_utf8 = query.encode("utf-8")
max_query_len = __get_arg_max()

if len(query_utf8) > max_query_len:
raise RuntimeError("The length of the geoquery passed to ogr2ogr is too long. ")


def view_to_gpkg(context: ViewRenderContext, db_config: str) -> tuple[uuid.UUID, Path]:
"""Renders a view (with metadata) to a GeoPackage."""
render_uuid = uuid.uuid4()
Expand All @@ -229,16 +275,22 @@ def view_to_gpkg(context: ViewRenderContext, db_config: str) -> tuple[uuid.UUID,
*proj_args,
]

subprocess_command_list = [
"ogr2ogr",
*base_args,
"-sql",
context.geo_query,
"-nln",
geo_layer_name,
]

subprocess_command = shlex.join(subprocess_command_list)

__validate_query(subprocess_command)

try:
subprocess.run(
[
"ogr2ogr",
*base_args,
"-sql",
context.geo_query,
"-nln",
geo_layer_name,
],
subprocess_command_list,
check=True,
capture_output=True,
)
Expand All @@ -262,26 +314,33 @@ def view_to_gpkg(context: ViewRenderContext, db_config: str) -> tuple[uuid.UUID,
raise RenderError(
"Failed to render view: geographic layer not found in GeoPackage.",
) from ex

if geo_row_count != context.view.num_geos:
# Validate inner joins.
raise RenderError(
f"Failed to render view: expected {context.view.num_geos} geographies "
f"in layer, got {geo_row_count} geographies."
)

subprocess_command_list = [
"ogr2ogr",
*base_args,
"-update",
"-sql",
context.internal_point_query,
"-nln",
internal_point_layer_name,
"-nlt",
"POINT",
]

subprocess_command = shlex.join(subprocess_command_list)

__validate_query(subprocess_command)

try:
subprocess.run(
[
"ogr2ogr",
*base_args,
"-update",
"-sql",
context.internal_point_query,
"-nln",
internal_point_layer_name,
"-nlt",
"POINT",
],
subprocess_command_list,
check=True,
capture_output=True,
)
Expand Down Expand Up @@ -349,11 +408,23 @@ def view_to_gpkg(context: ViewRenderContext, db_config: str) -> tuple[uuid.UUID,
)
db_meta_id_to_gpkg_meta_id[db_id] = cur.lastrowid

geo_attrs_dict = {}

assert (
context.geo_meta_ids.keys() == context.geo_valid_from_dates.keys()
), "Geographic metadata IDs and valid dates must be aligned."

for path in context.geo_meta_ids.keys():
geo_attrs_dict[path] = (
context.geo_meta_ids[path],
context.geo_valid_from_dates[path],
)

conn.executemany(
"INSERT INTO gerrydb_geo_attrs (path, meta_id) VALUES (?, ?)",
"INSERT INTO gerrydb_geo_attrs (path, meta_id, valid_from) VALUES (?, ?, ?)",
(
(path, db_meta_id_to_gpkg_meta_id[db_id])
for path, db_id in context.geo_meta_ids.items()
(path, db_meta_id_to_gpkg_meta_id[db_id], valid_from)
for path, (db_id, valid_from) in geo_attrs_dict.items()
),
)

Expand Down

0 comments on commit a977a8b

Please sign in to comment.