Skip to content

Commit d36ca72

Browse files
committed
feat: pass db connection for global as param
otherwise API clients need to set special environment variables for this library
1 parent e86c5df commit d36ca72

File tree

4 files changed

+129
-151
lines changed

4 files changed

+129
-151
lines changed

soil_id/db.py

+97-121
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121

2222
# Third-party libraries
2323
import psycopg
24-
from shapely import wkb
25-
from shapely.geometry import Point
2624

2725
# local libraries
2826
import soil_id.config
27+
from shapely import wkb
28+
from shapely.geometry import Point
2929

3030

3131
def get_datastore_connection():
@@ -162,7 +162,7 @@ def save_soilgrids_output(plot_id, model_version, soilgrids_blob):
162162
conn.close()
163163

164164

165-
def get_hwsd2_profile_data(conn, hwsd2_mu_select):
165+
def get_hwsd2_profile_data(connection, hwsd2_mu_select):
166166
"""
167167
Retrieve HWSD v2 data based on selected hwsd2 (map unit) values.
168168
This version reuses an existing connection.
@@ -179,7 +179,7 @@ def get_hwsd2_profile_data(conn, hwsd2_mu_select):
179179
return pd.DataFrame()
180180

181181
try:
182-
with conn.cursor() as cur:
182+
with connection.cursor() as cur:
183183
# Create placeholders for the SQL IN clause
184184
placeholders = ", ".join(["%s"] * len(hwsd2_mu_select))
185185
sql_query = f"""
@@ -222,7 +222,7 @@ def get_hwsd2_profile_data(conn, hwsd2_mu_select):
222222
return pd.DataFrame()
223223

224224

225-
def extract_hwsd2_data(lon, lat, buffer_dist, table_name):
225+
def extract_hwsd2_data(connection, lon, lat, buffer_dist, table_name):
226226
"""
227227
Fetches HWSD soil data from a PostGIS table within a given buffer around a point,
228228
performing distance and intersection calculations directly on geographic coordinates.
@@ -236,133 +236,121 @@ def extract_hwsd2_data(lon, lat, buffer_dist, table_name):
236236
Returns:
237237
DataFrame: Merged data from hwsdv2 and hwsdv2_data.
238238
"""
239-
# Use a single connection for both queries.
240-
with get_datastore_connection() as conn:
241-
# Compute the buffer polygon (in WKT) around the problem point.
242-
# Here, we use the geography type to compute a buffer in meters,
243-
# then cast it back to geometry in EPSG:4326.
244-
buffer_query = """
245-
WITH buffer AS (
246-
SELECT ST_AsText(
247-
ST_Buffer(
248-
ST_SetSRID(ST_Point(%s, %s), 4326)::geography,
249-
%s
250-
)::geometry
251-
) AS wkt
252-
)
253-
SELECT wkt FROM buffer;
254-
"""
255-
with conn.cursor() as cur:
256-
cur.execute(buffer_query, (lon, lat, buffer_dist))
257-
buffer_wkt = cur.fetchone()[0]
258-
print("Buffer WKT:", buffer_wkt)
259-
260-
# Build the main query that uses the computed buffer.
261-
# Distance is computed by casting geometries to geography,
262-
# which returns the geodesic distance in meters.
263-
main_query = f"""
264-
WITH
265-
-- Step 1: Get the polygon that contains the point
266-
point_poly AS (
267-
SELECT ST_MakeValid(geom) AS geom
268-
FROM {table_name}
269-
WHERE ST_Intersects(
270-
geom,
271-
ST_SetSRID(ST_Point({lon}, {lat}), 4326)
272-
)
273-
),
274-
275-
-- Step 2: Get polygons that intersect the buffer
276-
valid_geom AS (
277-
SELECT
278-
hwsd2,
279-
ST_MakeValid(geom) AS geom
280-
FROM {table_name}
281-
WHERE geom && ST_GeomFromText('{buffer_wkt}', 4326)
282-
AND ST_Intersects(geom, ST_GeomFromText('{buffer_wkt}', 4326))
283-
)
284-
285-
-- Step 3: Filter to those that either contain the point or border the point's polygon
286-
SELECT
287-
vg.hwsd2,
288-
ST_AsEWKB(vg.geom) AS geom,
289-
ST_Distance(
290-
vg.geom::geography,
291-
ST_SetSRID(ST_Point({lon}, {lat}), 4326)::geography
292-
) AS distance,
293-
ST_Intersects(
294-
vg.geom,
295-
ST_SetSRID(ST_Point({lon}, {lat}), 4326)
296-
) AS pt_intersect
297-
FROM valid_geom vg, point_poly pp
298-
WHERE
299-
ST_Intersects(vg.geom, ST_SetSRID(ST_Point({lon}, {lat}), 4326))
300-
OR ST_Intersects(vg.geom, pp.geom);
301-
"""
239+
# Compute the buffer polygon (in WKT) around the problem point.
240+
# Here, we use the geography type to compute a buffer in meters,
241+
# then cast it back to geometry in EPSG:4326.
242+
buffer_query = """
243+
WITH buffer AS (
244+
SELECT ST_AsText(
245+
ST_Buffer(
246+
ST_SetSRID(ST_Point(%s, %s), 4326)::geography,
247+
%s
248+
)::geometry
249+
) AS wkt
250+
)
251+
SELECT wkt FROM buffer;
252+
"""
253+
with connection.cursor() as cur:
254+
cur.execute(buffer_query, (lon, lat, buffer_dist))
255+
buffer_wkt = cur.fetchone()[0]
256+
print("Buffer WKT:", buffer_wkt)
257+
258+
# Build the main query that uses the computed buffer.
259+
# Distance is computed by casting geometries to geography,
260+
# which returns the geodesic distance in meters.
261+
main_query = f"""
262+
WITH
263+
-- Step 1: Get the polygon that contains the point
264+
point_poly AS (
265+
SELECT ST_MakeValid(geom) AS geom
266+
FROM {table_name}
267+
WHERE ST_Intersects(
268+
geom,
269+
ST_SetSRID(ST_Point({lon}, {lat}), 4326)
270+
)
271+
),
272+
273+
-- Step 2: Get polygons that intersect the buffer
274+
valid_geom AS (
275+
SELECT
276+
hwsd2,
277+
ST_MakeValid(geom) AS geom
278+
FROM {table_name}
279+
WHERE geom && ST_GeomFromText('{buffer_wkt}', 4326)
280+
AND ST_Intersects(geom, ST_GeomFromText('{buffer_wkt}', 4326))
281+
)
302282
283+
-- Step 3: Filter to those that either contain the point or border the point's polygon
284+
SELECT
285+
vg.hwsd2,
286+
ST_AsEWKB(vg.geom) AS geom,
287+
ST_Distance(
288+
vg.geom::geography,
289+
ST_SetSRID(ST_Point({lon}, {lat}), 4326)::geography
290+
) AS distance,
291+
ST_Intersects(
292+
vg.geom,
293+
ST_SetSRID(ST_Point({lon}, {lat}), 4326)
294+
) AS pt_intersect
295+
FROM valid_geom vg, point_poly pp
296+
WHERE
297+
ST_Intersects(vg.geom, ST_SetSRID(ST_Point({lon}, {lat}), 4326))
298+
OR ST_Intersects(vg.geom, pp.geom);
299+
"""
303300

304-
# Use GeoPandas to execute the main query and load results into a GeoDataFrame.
305-
hwsd = gpd.read_postgis(main_query, conn, geom_col="geom")
306-
print("Main query returned", len(hwsd), "rows.")
301+
# Use GeoPandas to execute the main query and load results into a GeoDataFrame.
302+
hwsd = gpd.read_postgis(main_query, connection, geom_col="geom")
303+
print("Main query returned", len(hwsd), "rows.")
307304

308-
# Remove the geometry column (if not needed) from this dataset.
309-
hwsd = hwsd.drop(columns=["geom"])
305+
# Remove the geometry column (if not needed) from this dataset.
306+
hwsd = hwsd.drop(columns=["geom"])
310307

311-
# Get the list of hwsd2 identifiers.
312-
hwsd2_mu_select = hwsd["hwsd2"].tolist()
308+
# Get the list of hwsd2 identifiers.
309+
hwsd2_mu_select = hwsd["hwsd2"].tolist()
313310

314-
# Call get_hwsd2_profile_data using the same connection.
315-
hwsd_data = get_hwsd2_profile_data(conn, hwsd2_mu_select)
311+
# Call get_hwsd2_profile_data using the same connection.
312+
hwsd_data = get_hwsd2_profile_data(connection, hwsd2_mu_select)
316313

317-
# Merge the two datasets.
318-
merged = pd.merge(hwsd_data, hwsd, on="hwsd2", how="left").drop_duplicates()
319-
return merged
314+
# Merge the two datasets.
315+
merged = pd.merge(hwsd_data, hwsd, on="hwsd2", how="left").drop_duplicates()
316+
return merged
320317

321318

322319
# global
323320

324321

325322
# Function to fetch data from a PostgreSQL table
326-
def fetch_table_from_db(table_name):
327-
conn = None
323+
def fetch_table_from_db(connection, table_name):
328324
try:
329-
conn = get_datastore_connection()
330-
cur = conn.cursor()
331-
332-
query = f"SELECT * FROM {table_name} ORDER BY id ASC;"
333-
cur.execute(query)
334-
rows = cur.fetchall()
325+
with connection.cursor() as cur:
326+
query = f"SELECT * FROM {table_name} ORDER BY id ASC;"
327+
cur.execute(query)
328+
rows = cur.fetchall()
335329

336-
return rows
330+
return rows
337331

338332
except Exception as err:
339333
logging.error(f"Error querying PostgreSQL: {err}")
340334
return None
341335

342-
finally:
343-
if conn:
344-
conn.close()
345336

346-
347-
def get_WRB_descriptions(WRB_Comp_List):
337+
def get_WRB_descriptions(connection, WRB_Comp_List):
348338
"""
349339
Retrieve WRB descriptions based on provided WRB component list.
350340
"""
351-
conn = None
352341
try:
353-
conn = get_datastore_connection()
354-
cur = conn.cursor()
342+
with connection.cursor() as cur:
355343

356-
# Create placeholders for the SQL IN clause
357-
placeholders = ", ".join(["%s"] * len(WRB_Comp_List))
358-
sql = f"""SELECT WRB_tax, Description_en, Management_en, Description_es, Management_es,
359-
Description_ks, Management_ks, Description_fr, Management_fr
360-
FROM wrb_fao90_desc
361-
WHERE WRB_tax IN ({placeholders})"""
362-
363-
# Execute the query with the parameters
364-
cur.execute(sql, tuple(WRB_Comp_List))
365-
results = cur.fetchall()
344+
# Create placeholders for the SQL IN clause
345+
placeholders = ", ".join(["%s"] * len(WRB_Comp_List))
346+
sql = f"""SELECT WRB_tax, Description_en, Management_en, Description_es, Management_es,
347+
Description_ks, Management_ks, Description_fr, Management_fr
348+
FROM wrb_fao90_desc
349+
WHERE WRB_tax IN ({placeholders})"""
350+
351+
# Execute the query with the parameters
352+
cur.execute(sql, tuple(WRB_Comp_List))
353+
results = cur.fetchall()
366354

367355
# Convert the results to a pandas DataFrame
368356
data = pd.DataFrame(
@@ -381,18 +369,13 @@ def get_WRB_descriptions(WRB_Comp_List):
381369
)
382370

383371
return data
384-
385372
except Exception as err:
386373
logging.error(f"Error querying PostgreSQL: {err}")
387374
return None
388375

389-
finally:
390-
if conn:
391-
conn.close()
392-
393376

394377
# global only
395-
def getSG_descriptions(WRB_Comp_List):
378+
def getSG_descriptions(connection, WRB_Comp_List):
396379
"""
397380
Fetch WRB descriptions from a PostgreSQL database using wrb2006_to_fao90
398381
and wrb_fao90_desc tables. Returns a pandas DataFrame with columns:
@@ -405,13 +388,10 @@ def getSG_descriptions(WRB_Comp_List):
405388
pandas.DataFrame or None if an error occurs.
406389
"""
407390

408-
conn = None
409391
try:
410-
# 1. Get a connection to your datastore (replace with your actual function):
411-
conn = get_datastore_connection()
412392

413393
def execute_query(query, params):
414-
with conn.cursor() as cur:
394+
with connection.cursor() as cur:
415395
# Execute the query with the parameters
416396
cur.execute(query, params)
417397
return cur.fetchall()
@@ -482,7 +462,3 @@ def execute_query(query, params):
482462
except Exception as err:
483463
logging.error(f"Error querying PostgreSQL: {err}")
484464
return None
485-
486-
finally:
487-
if conn:
488-
conn.close()

soil_id/global_soil.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,7 @@
2727
from scipy.stats import norm
2828

2929
from .color import calculate_deltaE2000
30-
from .db import (
31-
extract_hwsd2_data,
32-
fetch_table_from_db,
33-
get_WRB_descriptions,
34-
getSG_descriptions,
35-
)
30+
from .db import extract_hwsd2_data, fetch_table_from_db, get_WRB_descriptions, getSG_descriptions
3631
from .services import get_soilgrids_classification_data, get_soilgrids_property_data
3732
from .utils import (
3833
adjust_depth_interval,
@@ -77,10 +72,11 @@ class SoilListOutputData:
7772
##################################################################################################
7873
# getSoilLocationBasedGlobal #
7974
##################################################################################################
80-
def list_soils_global(lon, lat):
75+
def list_soils_global(connection, lon, lat):
8176
# Extract HWSD2 Data
8277
try:
8378
hwsd2_data = extract_hwsd2_data(
79+
connection,
8480
lon,
8581
lat,
8682
table_name="hwsdv2",
@@ -580,6 +576,7 @@ def convert_to_serializable(obj):
580576
# rankPredictionGlobal #
581577
##############################################################################################
582578
def rank_soils_global(
579+
connection,
583580
lon,
584581
lat,
585582
list_output_data: SoilListOutputData,
@@ -949,7 +946,7 @@ def rank_soils_global(
949946
ysf = []
950947

951948
# Load color distribution data from NormDist2 (FAO90) table
952-
rows = fetch_table_from_db("NormDist2")
949+
rows = fetch_table_from_db(connection, "NormDist2")
953950
row_id = 0
954951
for row in rows:
955952
# row is a tuple; iterate over its values.

soil_id/tests/global/generate_bulk_test_results.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import traceback
2121

2222
import pandas
23-
23+
from soil_id.db import get_datastore_connection
2424
from soil_id.global_soil import list_soils_global, rank_soils_global
2525

2626
test_data_df = pandas.read_csv(
@@ -36,7 +36,10 @@
3636

3737
print(f"logging results to {result_file_name}")
3838
# buffering=1 is line buffering
39-
with open(result_file_name, "w", buffering=1) as result_file:
39+
with (
40+
open(result_file_name, "w", buffering=1) as result_file,
41+
get_datastore_connection() as connection,
42+
):
4043
result_agg = {}
4144

4245
for pedon_key, pedon in pedons:
@@ -55,9 +58,10 @@
5558
start_time = time.perf_counter()
5659
try:
5760

58-
list_result = list_soils_global(lat=lat, lon=lon)
61+
list_result = list_soils_global(connection=connection, lat=lat, lon=lon)
5962

6063
result_record["rank_result"] = rank_soils_global(
64+
connection=connection,
6165
lat=lat,
6266
lon=lon,
6367
list_output_data=list_result,

0 commit comments

Comments
 (0)