Skip to content

Commit 8bcc844

Browse files
committed
feat: pass db connection for global as param
otherwise API clients need to set special environment variables for this library
1 parent ac0a581 commit 8bcc844

File tree

4 files changed

+117
-138
lines changed

4 files changed

+117
-138
lines changed

soil_id/db.py

+85-108
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121

2222
# Third-party libraries
2323
import psycopg
24-
from shapely import wkb
25-
from shapely.geometry import Point
2624

2725
# local libraries
2826
import soil_id.config
27+
from shapely import wkb
28+
from shapely.geometry import Point
2929

3030

3131
def get_datastore_connection():
@@ -162,7 +162,7 @@ def save_soilgrids_output(plot_id, model_version, soilgrids_blob):
162162
conn.close()
163163

164164

165-
def get_hwsd2_profile_data(conn, hwsd2_mu_select):
165+
def get_hwsd2_profile_data(connection, hwsd2_mu_select):
166166
"""
167167
Retrieve HWSD v2 data based on selected hwsd2 (map unit) values.
168168
This version reuses an existing connection.
@@ -179,7 +179,7 @@ def get_hwsd2_profile_data(conn, hwsd2_mu_select):
179179
return pd.DataFrame()
180180

181181
try:
182-
with conn.cursor() as cur:
182+
with connection.cursor() as cur:
183183
# Create placeholders for the SQL IN clause
184184
placeholders = ", ".join(["%s"] * len(hwsd2_mu_select))
185185
sql_query = f"""
@@ -222,7 +222,7 @@ def get_hwsd2_profile_data(conn, hwsd2_mu_select):
222222
return pd.DataFrame()
223223

224224

225-
def extract_hwsd2_data(lon, lat, buffer_dist, table_name):
225+
def extract_hwsd2_data(connection, lon, lat, buffer_dist, table_name):
226226
"""
227227
Fetches HWSD soil data from a PostGIS table within a given buffer around a point,
228228
performing distance and intersection calculations directly on geographic coordinates.
@@ -236,119 +236,108 @@ def extract_hwsd2_data(lon, lat, buffer_dist, table_name):
236236
Returns:
237237
DataFrame: Merged data from hwsdv2 and hwsdv2_data.
238238
"""
239-
# Use a single connection for both queries.
240-
with get_datastore_connection() as conn:
241-
# Compute the buffer polygon (in WKT) around the problem point.
242-
# Here, we use the geography type to compute a buffer in meters,
243-
# then cast it back to geometry in EPSG:4326.
244-
buffer_query = """
245-
WITH buffer AS (
246-
SELECT ST_AsText(
247-
ST_Buffer(
248-
ST_SetSRID(ST_Point(%s, %s), 4326)::geography,
249-
%s
250-
)::geometry
251-
) AS wkt
252-
)
253-
SELECT wkt FROM buffer;
254-
"""
255-
with conn.cursor() as cur:
256-
cur.execute(buffer_query, (lon, lat, buffer_dist))
257-
buffer_wkt = cur.fetchone()[0]
258-
print("Buffer WKT:", buffer_wkt)
259-
260-
# Build the main query that uses the computed buffer.
261-
# Distance is computed by casting geometries to geography,
262-
# which returns the geodesic distance in meters.
263-
main_query = f"""
264-
WITH valid_geom AS (
265-
SELECT
266-
hwsd2,
267-
ST_MakeValid(geom) AS geom
268-
FROM {table_name}
269-
WHERE geom && ST_GeomFromText('{buffer_wkt}', 4326)
270-
AND ST_Intersects(geom, ST_GeomFromText('{buffer_wkt}', 4326))
271-
)
272-
SELECT
273-
hwsd2,
274-
ST_AsEWKB(geom) AS geom,
275-
ST_Distance(
276-
geom::geography,
277-
ST_SetSRID(ST_Point({lon}, {lat}), 4326)::geography
278-
) AS distance,
279-
ST_Intersects(
280-
geom,
281-
ST_SetSRID(ST_Point({lon}, {lat}), 4326)
282-
) AS pt_intersect
283-
FROM valid_geom
284-
WHERE ST_Intersects(
285-
geom,
286-
ST_SetSRID(ST_Point({lon}, {lat}), 4326)
287-
);
288-
"""
239+
# Compute the buffer polygon (in WKT) around the problem point.
240+
# Here, we use the geography type to compute a buffer in meters,
241+
# then cast it back to geometry in EPSG:4326.
242+
buffer_query = """
243+
WITH buffer AS (
244+
SELECT ST_AsText(
245+
ST_Buffer(
246+
ST_SetSRID(ST_Point(%s, %s), 4326)::geography,
247+
%s
248+
)::geometry
249+
) AS wkt
250+
)
251+
SELECT wkt FROM buffer;
252+
"""
253+
with connection.cursor() as cur:
254+
cur.execute(buffer_query, (lon, lat, buffer_dist))
255+
buffer_wkt = cur.fetchone()[0]
256+
print("Buffer WKT:", buffer_wkt)
257+
258+
# Build the main query that uses the computed buffer.
259+
# Distance is computed by casting geometries to geography,
260+
# which returns the geodesic distance in meters.
261+
main_query = f"""
262+
WITH valid_geom AS (
263+
SELECT
264+
hwsd2,
265+
ST_MakeValid(geom) AS geom
266+
FROM {table_name}
267+
WHERE geom && ST_GeomFromText('{buffer_wkt}', 4326)
268+
AND ST_Intersects(geom, ST_GeomFromText('{buffer_wkt}', 4326))
269+
)
270+
SELECT
271+
hwsd2,
272+
ST_AsEWKB(geom) AS geom,
273+
ST_Distance(
274+
geom::geography,
275+
ST_SetSRID(ST_Point({lon}, {lat}), 4326)::geography
276+
) AS distance,
277+
ST_Intersects(
278+
geom,
279+
ST_SetSRID(ST_Point({lon}, {lat}), 4326)
280+
) AS pt_intersect
281+
FROM valid_geom
282+
WHERE ST_Intersects(
283+
geom,
284+
ST_SetSRID(ST_Point({lon}, {lat}), 4326)
285+
);
286+
"""
289287

290-
# Use GeoPandas to execute the main query and load results into a GeoDataFrame.
291-
hwsd = gpd.read_postgis(main_query, conn, geom_col="geom")
292-
print("Main query returned", len(hwsd), "rows.")
288+
# Use GeoPandas to execute the main query and load results into a GeoDataFrame.
289+
hwsd = gpd.read_postgis(main_query, connection, geom_col="geom")
290+
print("Main query returned", len(hwsd), "rows.")
293291

294-
# Remove the geometry column (if not needed) from this dataset.
295-
hwsd = hwsd.drop(columns=["geom"])
292+
# Remove the geometry column (if not needed) from this dataset.
293+
hwsd = hwsd.drop(columns=["geom"])
296294

297-
# Get the list of hwsd2 identifiers.
298-
hwsd2_mu_select = hwsd["hwsd2"].tolist()
295+
# Get the list of hwsd2 identifiers.
296+
hwsd2_mu_select = hwsd["hwsd2"].tolist()
299297

300-
# Call get_hwsd2_profile_data using the same connection.
301-
hwsd_data = get_hwsd2_profile_data(conn, hwsd2_mu_select)
298+
# Call get_hwsd2_profile_data using the same connection.
299+
hwsd_data = get_hwsd2_profile_data(connection, hwsd2_mu_select)
302300

303-
# Merge the two datasets.
304-
merged = pd.merge(hwsd_data, hwsd, on="hwsd2", how="left").drop_duplicates()
305-
return merged
301+
# Merge the two datasets.
302+
merged = pd.merge(hwsd_data, hwsd, on="hwsd2", how="left").drop_duplicates()
303+
return merged
306304

307305

308306
# global
309307

310308

311309
# Function to fetch data from a PostgreSQL table
312-
def fetch_table_from_db(table_name):
313-
conn = None
310+
def fetch_table_from_db(connection, table_name):
314311
try:
315-
conn = get_datastore_connection()
316-
cur = conn.cursor()
317-
318-
query = f"SELECT * FROM {table_name} ORDER BY id ASC;"
319-
cur.execute(query)
320-
rows = cur.fetchall()
312+
with connection.cursor() as cur:
313+
query = f"SELECT * FROM {table_name} ORDER BY id ASC;"
314+
cur.execute(query)
315+
rows = cur.fetchall()
321316

322-
return rows
317+
return rows
323318

324319
except Exception as err:
325320
logging.error(f"Error querying PostgreSQL: {err}")
326321
return None
327322

328-
finally:
329-
if conn:
330-
conn.close()
331-
332323

333-
def get_WRB_descriptions(WRB_Comp_List):
324+
def get_WRB_descriptions(connection, WRB_Comp_List):
334325
"""
335326
Retrieve WRB descriptions based on provided WRB component list.
336327
"""
337-
conn = None
338328
try:
339-
conn = get_datastore_connection()
340-
cur = conn.cursor()
341-
342-
# Create placeholders for the SQL IN clause
343-
placeholders = ", ".join(["%s"] * len(WRB_Comp_List))
344-
sql = f"""SELECT WRB_tax, Description_en, Management_en, Description_es, Management_es,
345-
Description_ks, Management_ks, Description_fr, Management_fr
346-
FROM wrb_fao90_desc
347-
WHERE WRB_tax IN ({placeholders})"""
329+
with connection.cursor() as cur:
348330

349-
# Execute the query with the parameters
350-
cur.execute(sql, tuple(WRB_Comp_List))
351-
results = cur.fetchall()
331+
# Create placeholders for the SQL IN clause
332+
placeholders = ", ".join(["%s"] * len(WRB_Comp_List))
333+
sql = f"""SELECT WRB_tax, Description_en, Management_en, Description_es, Management_es,
334+
Description_ks, Management_ks, Description_fr, Management_fr
335+
FROM wrb_fao90_desc
336+
WHERE WRB_tax IN ({placeholders})"""
337+
338+
# Execute the query with the parameters
339+
cur.execute(sql, tuple(WRB_Comp_List))
340+
results = cur.fetchall()
352341

353342
# Convert the results to a pandas DataFrame
354343
data = pd.DataFrame(
@@ -367,18 +356,13 @@ def get_WRB_descriptions(WRB_Comp_List):
367356
)
368357

369358
return data
370-
371359
except Exception as err:
372360
logging.error(f"Error querying PostgreSQL: {err}")
373361
return None
374362

375-
finally:
376-
if conn:
377-
conn.close()
378-
379363

380364
# global only
381-
def getSG_descriptions(WRB_Comp_List):
365+
def getSG_descriptions(connection, WRB_Comp_List):
382366
"""
383367
Fetch WRB descriptions from a PostgreSQL database using wrb2006_to_fao90
384368
and wrb_fao90_desc tables. Returns a pandas DataFrame with columns:
@@ -391,13 +375,10 @@ def getSG_descriptions(WRB_Comp_List):
391375
pandas.DataFrame or None if an error occurs.
392376
"""
393377

394-
conn = None
395378
try:
396-
# 1. Get a connection to your datastore (replace with your actual function):
397-
conn = get_datastore_connection()
398379

399380
def execute_query(query, params):
400-
with conn.cursor() as cur:
381+
with connection.cursor() as cur:
401382
# Execute the query with the parameters
402383
cur.execute(query, params)
403384
return cur.fetchall()
@@ -468,7 +449,3 @@ def execute_query(query, params):
468449
except Exception as err:
469450
logging.error(f"Error querying PostgreSQL: {err}")
470451
return None
471-
472-
finally:
473-
if conn:
474-
conn.close()

soil_id/global_soil.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,7 @@
2727
from scipy.stats import norm
2828

2929
from .color import calculate_deltaE2000
30-
from .db import (
31-
extract_hwsd2_data,
32-
fetch_table_from_db,
33-
get_WRB_descriptions,
34-
getSG_descriptions,
35-
)
30+
from .db import extract_hwsd2_data, fetch_table_from_db, get_WRB_descriptions, getSG_descriptions
3631
from .services import get_soilgrids_classification_data, get_soilgrids_property_data
3732
from .utils import (
3833
adjust_depth_interval,
@@ -77,10 +72,11 @@ class SoilListOutputData:
7772
##################################################################################################
7873
# getSoilLocationBasedGlobal #
7974
##################################################################################################
80-
def list_soils_global(lon, lat):
75+
def list_soils_global(connection, lon, lat):
8176
# Extract HWSD2 Data
8277
try:
8378
hwsd2_data = extract_hwsd2_data(
79+
connection,
8480
lon,
8581
lat,
8682
table_name="hwsdv2",
@@ -558,6 +554,7 @@ def convert_to_serializable(obj):
558554
# rankPredictionGlobal #
559555
##############################################################################################
560556
def rank_soils_global(
557+
connection,
561558
lon,
562559
lat,
563560
list_output_data: SoilListOutputData,
@@ -927,7 +924,7 @@ def rank_soils_global(
927924
ysf = []
928925

929926
# Load color distribution data from NormDist2 (FAO90) table
930-
rows = fetch_table_from_db("NormDist2")
927+
rows = fetch_table_from_db(connection, "NormDist2")
931928
row_id = 0
932929
for row in rows:
933930
# row is a tuple; iterate over its values.

soil_id/tests/global/generate_bulk_test_results.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import traceback
2121

2222
import pandas
23-
23+
from soil_id.db import get_datastore_connection
2424
from soil_id.global_soil import list_soils_global, rank_soils_global
2525

2626
test_data_df = pandas.read_csv(
@@ -36,7 +36,10 @@
3636

3737
print(f"logging results to {result_file_name}")
3838
# buffering=1 is line buffering
39-
with open(result_file_name, "w", buffering=1) as result_file:
39+
with (
40+
open(result_file_name, "w", buffering=1) as result_file,
41+
get_datastore_connection() as connection,
42+
):
4043
result_agg = {}
4144

4245
for pedon_key, pedon in pedons:
@@ -55,9 +58,10 @@
5558
start_time = time.perf_counter()
5659
try:
5760

58-
list_result = list_soils_global(lat=lat, lon=lon)
61+
list_result = list_soils_global(connection=connection, lat=lat, lon=lon)
5962

6063
result_record["rank_result"] = rank_soils_global(
64+
connection=connection,
6165
lat=lat,
6266
lon=lon,
6367
list_output_data=list_result,

0 commit comments

Comments
 (0)