Skip to content

Commit aaa0af9

Browse files
committed
feat: pass db connection for global as param
otherwise API clients need to set special environment variables for this library
1 parent f7c48f4 commit aaa0af9

File tree

4 files changed

+77
-95
lines changed

4 files changed

+77
-95
lines changed

soil_id/db.py

Lines changed: 43 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def save_soilgrids_output(plot_id, model_version, soilgrids_blob):
160160
conn.close()
161161

162162

163-
def get_hwsd2_profile_data(conn, hwsd2_mu_select):
163+
def get_hwsd2_profile_data(connection, hwsd2_mu_select):
164164
"""
165165
Retrieve HWSD v2 data based on selected hwsd2 (map unit) values.
166166
This version reuses an existing connection.
@@ -177,7 +177,7 @@ def get_hwsd2_profile_data(conn, hwsd2_mu_select):
177177
return pd.DataFrame()
178178

179179
try:
180-
with conn.cursor() as cur:
180+
with connection.cursor() as cur:
181181
# Create placeholders for the SQL IN clause
182182
placeholders = ", ".join(["%s"] * len(hwsd2_mu_select))
183183
sql_query = f"""
@@ -220,7 +220,7 @@ def get_hwsd2_profile_data(conn, hwsd2_mu_select):
220220
return pd.DataFrame()
221221

222222

223-
def extract_hwsd2_data(lon, lat, buffer_dist, table_name):
223+
def extract_hwsd2_data(connection, lon, lat, buffer_dist, table_name):
224224
"""
225225
Fetches HWSD soil data from a PostGIS table within a given buffer around a point,
226226
performing distance and intersection calculations directly on geographic coordinates.
@@ -234,26 +234,24 @@ def extract_hwsd2_data(lon, lat, buffer_dist, table_name):
234234
Returns:
235235
DataFrame: Merged data from hwsdv2 and hwsdv2_data.
236236
"""
237-
# Use a single connection for both queries.
238-
with get_datastore_connection() as conn:
239-
# Compute the buffer polygon (in WKT) around the problem point.
240-
# Here, we use the geography type to compute a buffer in meters,
241-
# then cast it back to geometry in EPSG:4326.
242-
buffer_query = """
243-
WITH buffer AS (
244-
SELECT ST_AsText(
245-
ST_Buffer(
246-
ST_SetSRID(ST_Point(%s, %s), 4326)::geography,
247-
%s
248-
)::geometry
249-
) AS wkt
250-
)
251-
SELECT wkt FROM buffer;
252-
"""
253-
with conn.cursor() as cur:
254-
cur.execute(buffer_query, (lon, lat, buffer_dist))
255-
buffer_wkt = cur.fetchone()[0]
256-
print("Buffer WKT:", buffer_wkt)
237+
# Compute the buffer polygon (in WKT) around the problem point.
238+
# Here, we use the geography type to compute a buffer in meters,
239+
# then cast it back to geometry in EPSG:4326.
240+
buffer_query = """
241+
WITH buffer AS (
242+
SELECT ST_AsText(
243+
ST_Buffer(
244+
ST_SetSRID(ST_Point(%s, %s), 4326)::geography,
245+
%s
246+
)::geometry
247+
) AS wkt
248+
)
249+
SELECT wkt FROM buffer;
250+
"""
251+
with connection.cursor() as cur:
252+
cur.execute(buffer_query, (lon, lat, buffer_dist))
253+
buffer_wkt = cur.fetchone()[0]
254+
print("Buffer WKT:", buffer_wkt)
257255

258256
# Build the main query that uses the computed buffer.
259257
# Distance is computed by casting geometries to geography,
@@ -344,7 +342,7 @@ def extract_hwsd2_data(lon, lat, buffer_dist, table_name):
344342
# """
345343

346344
# Use GeoPandas to execute the main query and load results into a GeoDataFrame.
347-
hwsd = gpd.read_postgis(main_query, conn, geom_col="geom")
345+
hwsd = gpd.read_postgis(main_query, connection, geom_col="geom")
348346
print("Main query returned", len(hwsd), "rows.")
349347

350348
# Remove the geometry column (if not needed) from this dataset.
@@ -354,7 +352,7 @@ def extract_hwsd2_data(lon, lat, buffer_dist, table_name):
354352
hwsd2_mu_select = hwsd["hwsd2"].tolist()
355353

356354
# Call get_hwsd2_profile_data using the same connection.
357-
hwsd_data = get_hwsd2_profile_data(conn, hwsd2_mu_select)
355+
hwsd_data = get_hwsd2_profile_data(connection, hwsd2_mu_select)
358356

359357
# Merge the two datasets.
360358
merged = pd.merge(hwsd_data, hwsd, on="hwsd2", how="left").drop_duplicates()
@@ -365,46 +363,37 @@ def extract_hwsd2_data(lon, lat, buffer_dist, table_name):
365363

366364

367365
# Function to fetch data from a PostgreSQL table
368-
def fetch_table_from_db(table_name):
369-
conn = None
366+
def fetch_table_from_db(connection, table_name):
370367
try:
371-
conn = get_datastore_connection()
372-
cur = conn.cursor()
373-
374-
query = f"SELECT * FROM {table_name} ORDER BY id ASC;"
375-
cur.execute(query)
376-
rows = cur.fetchall()
368+
with connection.cursor() as cur:
369+
query = f"SELECT * FROM {table_name} ORDER BY id ASC;"
370+
cur.execute(query)
371+
rows = cur.fetchall()
377372

378-
return rows
373+
return rows
379374

380375
except Exception as err:
381376
logging.error(f"Error querying PostgreSQL: {err}")
382377
return None
383378

384-
finally:
385-
if conn:
386-
conn.close()
387-
388379

389-
def get_WRB_descriptions(WRB_Comp_List):
380+
def get_WRB_descriptions(connection, WRB_Comp_List):
390381
"""
391382
Retrieve WRB descriptions based on provided WRB component list.
392383
"""
393-
conn = None
394384
try:
395-
conn = get_datastore_connection()
396-
cur = conn.cursor()
397-
398-
# Create placeholders for the SQL IN clause
399-
placeholders = ", ".join(["%s"] * len(WRB_Comp_List))
400-
sql = f"""SELECT WRB_tax, Description_en, Management_en, Description_es, Management_es,
401-
Description_ks, Management_ks, Description_fr, Management_fr
402-
FROM wrb_fao90_desc
403-
WHERE WRB_tax IN ({placeholders})"""
385+
with connection.cursor() as cur:
404386

405-
# Execute the query with the parameters
406-
cur.execute(sql, tuple(WRB_Comp_List))
407-
results = cur.fetchall()
387+
# Create placeholders for the SQL IN clause
388+
placeholders = ", ".join(["%s"] * len(WRB_Comp_List))
389+
sql = f"""SELECT WRB_tax, Description_en, Management_en, Description_es, Management_es,
390+
Description_ks, Management_ks, Description_fr, Management_fr
391+
FROM wrb_fao90_desc
392+
WHERE WRB_tax IN ({placeholders})"""
393+
394+
# Execute the query with the parameters
395+
cur.execute(sql, tuple(WRB_Comp_List))
396+
results = cur.fetchall()
408397

409398
# Convert the results to a pandas DataFrame
410399
data = pd.DataFrame(
@@ -423,18 +412,13 @@ def get_WRB_descriptions(WRB_Comp_List):
423412
)
424413

425414
return data
426-
427415
except Exception as err:
428416
logging.error(f"Error querying PostgreSQL: {err}")
429417
return None
430418

431-
finally:
432-
if conn:
433-
conn.close()
434-
435419

436420
# global only
437-
def getSG_descriptions(WRB_Comp_List):
421+
def getSG_descriptions(connection, WRB_Comp_List):
438422
"""
439423
Fetch WRB descriptions from a PostgreSQL database using wrb2006_to_fao90
440424
and wrb_fao90_desc tables. Returns a pandas DataFrame with columns:
@@ -447,13 +431,10 @@ def getSG_descriptions(WRB_Comp_List):
447431
pandas.DataFrame or None if an error occurs.
448432
"""
449433

450-
conn = None
451434
try:
452-
# 1. Get a connection to your datastore (replace with your actual function):
453-
conn = get_datastore_connection()
454435

455436
def execute_query(query, params):
456-
with conn.cursor() as cur:
437+
with connection.cursor() as cur:
457438
# Execute the query with the parameters
458439
cur.execute(query, params)
459440
return cur.fetchall()
@@ -524,7 +505,3 @@ def execute_query(query, params):
524505
except Exception as err:
525506
logging.error(f"Error querying PostgreSQL: {err}")
526507
return None
527-
528-
finally:
529-
if conn:
530-
conn.close()

soil_id/global_soil.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,7 @@
2727
from scipy.stats import norm
2828

2929
from .color import calculate_deltaE2000
30-
from .db import (
31-
extract_hwsd2_data,
32-
fetch_table_from_db,
33-
get_WRB_descriptions,
34-
getSG_descriptions,
35-
)
30+
from .db import extract_hwsd2_data, fetch_table_from_db, get_WRB_descriptions, getSG_descriptions
3631
from .services import get_soilgrids_classification_data, get_soilgrids_property_data
3732
from .utils import (
3833
adjust_depth_interval,
@@ -77,10 +72,11 @@ class SoilListOutputData:
7772
##################################################################################################
7873
# getSoilLocationBasedGlobal #
7974
##################################################################################################
80-
def list_soils_global(lon, lat):
75+
def list_soils_global(connection, lon, lat):
8176
# Extract HWSD2 Data
8277
try:
8378
hwsd2_data = extract_hwsd2_data(
79+
connection,
8480
lon,
8581
lat,
8682
table_name="hwsdv2",
@@ -444,6 +440,7 @@ def list_soils_global(lon, lat):
444440

445441
# Merge component descriptions
446442
WRB_Comp_Desc = get_WRB_descriptions(
443+
connection,
447444
mucompdata_cond_prob["compname_grp"].drop_duplicates().tolist()
448445
)
449446

@@ -582,6 +579,7 @@ def convert_to_serializable(obj):
582579
# rankPredictionGlobal #
583580
##############################################################################################
584581
def rank_soils_global(
582+
connection,
585583
lon,
586584
lat,
587585
list_output_data: SoilListOutputData,
@@ -949,7 +947,7 @@ def rank_soils_global(
949947
ysf = []
950948

951949
# Load color distribution data from NormDist2 (FAO90) table
952-
rows = fetch_table_from_db("NormDist2")
950+
rows = fetch_table_from_db(connection, "NormDist2")
953951
row_id = 0
954952
for row in rows:
955953
# row is a tuple; iterate over its values.

soil_id/tests/global/generate_bulk_test_results.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import traceback
2121

2222
import pandas
23-
23+
from soil_id.db import get_datastore_connection
2424
from soil_id.global_soil import list_soils_global, rank_soils_global
2525

2626
test_data_df = pandas.read_csv(
@@ -36,7 +36,10 @@
3636

3737
print(f"logging results to {result_file_name}")
3838
# buffering=1 is line buffering
39-
with open(result_file_name, "w", buffering=1) as result_file:
39+
with (
40+
open(result_file_name, "w", buffering=1) as result_file,
41+
get_datastore_connection() as connection,
42+
):
4043
result_agg = {}
4144

4245
for pedon_key, pedon in pedons:
@@ -54,9 +57,11 @@
5457
else:
5558
start_time = time.perf_counter()
5659
try:
57-
list_result = list_soils_global(lat=lat, lon=lon)
60+
61+
list_result = list_soils_global(connection=connection, lat=lat, lon=lon)
5862

5963
result_record["rank_result"] = rank_soils_global(
64+
connection=connection,
6065
lat=lat,
6166
lon=lon,
6267
list_output_data=list_result,

soil_id/tests/global/test_global.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import time
1818
import pytest
1919

20+
from soil_id.db import get_datastore_connection
2021
from soil_id.global_soil import list_soils_global, rank_soils_global, sg_list
2122

2223
test_locations = [
@@ -28,21 +29,22 @@
2829

2930
@pytest.mark.skip
3031
def test_soil_location():
31-
for item in test_locations:
32-
logging.info(f"Testing {item['lon']}, {item['lat']}")
33-
start_time = time.perf_counter()
34-
list_soils_result = list_soils_global(item["lon"], item["lat"])
35-
logging.info(f"...time: {(time.perf_counter() - start_time):.2f}s")
36-
rank_soils_global(
37-
item["lon"],
38-
item["lat"],
39-
list_output_data=list_soils_result,
40-
soilHorizon=["Loam"],
41-
topDepth=[15],
42-
bottomDepth=[45],
43-
rfvDepth=[20],
44-
lab_Color=[[41.23035939, 3.623018224, 13.27654356]],
45-
bedrock=None,
46-
cracks=None,
47-
)
48-
sg_list(item["lon"], item["lat"])
32+
with get_datastore_connection() as connection:
33+
for item in test_locations:
34+
logging.info(f"Testing {item['lon']}, {item['lat']}")
35+
start_time = time.perf_counter()
36+
list_soils_result = list_soils_global(connection, item["lon"], item["lat"])
37+
logging.info(f"...time: {(time.perf_counter()-start_time):.2f}s")
38+
rank_soils_result = rank_soils_global(
39+
connection,
40+
item["lon"],
41+
item["lat"],
42+
list_output_data=list_soils_result,
43+
soilHorizon=["Loam"],
44+
horizonDepth=[15],
45+
rfvDepth=[20],
46+
lab_Color=[[41.23035939, 3.623018224, 13.27654356]],
47+
bedrock=None,
48+
cracks=None,
49+
)
50+
sg_soils_result = sg_list(item["lon"], item["lat"])

0 commit comments

Comments
 (0)