From 0921609a30f431f414d2a000f75b7d920e53280f Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Fri, 27 Jun 2025 14:09:11 +0800 Subject: [PATCH 1/8] sampler --- map2loop/mapdata.py | 57 ------------------------------ map2loop/project.py | 32 +++++++---------- map2loop/sampler.py | 22 ++++++++---- map2loop/thickness_calculator.py | 7 ++-- map2loop/utils.py | 60 ++++++++++++++++++++++++++++++++ 5 files changed, 93 insertions(+), 85 deletions(-) diff --git a/map2loop/mapdata.py b/map2loop/mapdata.py index 4137af27..df5f6804 100644 --- a/map2loop/mapdata.py +++ b/map2loop/mapdata.py @@ -1448,63 +1448,6 @@ def get_value_from_raster(self, datatype: Datatype, x, y): val = data.ReadAsArray(px, py, 1, 1)[0][0] return val - @beartype.beartype - def __value_from_raster(self, inv_geotransform, data, x: float, y: float): - """ - Get the value from a raster dataset at the specified point - - Args: - inv_geotransform (gdal.GeoTransform): - The inverse of the data's geotransform - data (numpy.array): - The raster data - x (float): - The easting coordinate of the value - y (float): - The northing coordinate of the value - - Returns: - float or int: The value at the point specified - """ - px = int(inv_geotransform[0] + inv_geotransform[1] * x + inv_geotransform[2] * y) - py = int(inv_geotransform[3] + inv_geotransform[4] * x + inv_geotransform[5] * y) - # Clamp values to the edges of raster if past boundary, similiar to GL_CLIP - px = max(px, 0) - px = min(px, data.shape[0] - 1) - py = max(py, 0) - py = min(py, data.shape[1] - 1) - return data[px][py] - - @beartype.beartype - def get_value_from_raster_df(self, datatype: Datatype, df: pandas.DataFrame): - """ - Add a 'Z' column to a dataframe with the heights from the 'X' and 'Y' coordinates - - Args: - datatype (Datatype): - The datatype of the raster map to retrieve from - df (pandas.DataFrame): - The original dataframe with 'X' and 'Y' columns - - Returns: - pandas.DataFrame: The modified dataframe - """ - if len(df) <= 0: - df["Z"] = [] - return df - data = self.get_map_data(datatype) - if data is None: - logger.warning("Cannot get value from data as data is not loaded") - return None - - inv_geotransform = gdal.InvGeoTransform(data.GetGeoTransform()) - data_array = numpy.array(data.GetRasterBand(1).ReadAsArray().T) - - df["Z"] = df.apply( - lambda row: self.__value_from_raster(inv_geotransform, data_array, row["X"], row["Y"]), - axis=1, - ) - return df @beartype.beartype def extract_all_contacts(self, save_contacts=True): diff --git a/map2loop/project.py b/map2loop/project.py index d9cfbb83..ec7260e5 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -1,6 +1,6 @@ # internal imports from map2loop.fault_orientation import FaultOrientationNearest -from .utils import hex_to_rgb +from .utils import hex_to_rgb, set_z_values_from_raster_df from .m2l_enums import VerboseLevel, ErrorState, Datatype from .mapdata import MapData from .sampler import Sampler, SamplerDecimator, SamplerSpacing @@ -503,26 +503,20 @@ def sample_map_data(self): """ Use the samplers to extract points along polylines or unit boundaries """ - logger.info( - f"Sampling geology map data using {self.samplers[Datatype.GEOLOGY].sampler_label}" - ) - self.geology_samples = self.samplers[Datatype.GEOLOGY].sample( - self.map_data.get_map_data(Datatype.GEOLOGY), self.map_data - ) - logger.info( - f"Sampling structure map data using {self.samplers[Datatype.STRUCTURE].sampler_label}" - ) - self.structure_samples = self.samplers[Datatype.STRUCTURE].sample( - self.map_data.get_map_data(Datatype.STRUCTURE), self.map_data - ) + geology_data = self.map_data.get_map_data(Datatype.GEOLOGY) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + + logger.info(f"Sampling geology map data using {self.samplers[Datatype.GEOLOGY].sampler_label}") + self.geology_samples = self.samplers[Datatype.GEOLOGY].sample(geology_data) + + logger.info(f"Sampling structure map data using {self.samplers[Datatype.STRUCTURE].sampler_label}") + self.structure_samples = self.samplers[Datatype.STRUCTURE].sample(self.map_data.get_map_data(Datatype.STRUCTURE), dtm_data, geology_data) + logger.info(f"Sampling fault map data using {self.samplers[Datatype.FAULT].sampler_label}") - self.fault_samples = self.samplers[Datatype.FAULT].sample( - self.map_data.get_map_data(Datatype.FAULT), self.map_data - ) + self.fault_samples = self.samplers[Datatype.FAULT].sample(self.map_data.get_map_data(Datatype.FAULT)) + logger.info(f"Sampling fold map data using {self.samplers[Datatype.FOLD].sampler_label}") - self.fold_samples = self.samplers[Datatype.FOLD].sample( - self.map_data.get_map_data(Datatype.FOLD), self.map_data - ) + self.fold_samples = self.samplers[Datatype.FOLD].sample(self.map_data.get_map_data(Datatype.FOLD)) def extract_geology_contacts(self): """ diff --git a/map2loop/sampler.py b/map2loop/sampler.py index 01600566..10aa51b9 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -1,6 +1,7 @@ # internal imports from .m2l_enums import Datatype from .mapdata import MapData +from .utils import set_z_values_from_raster_df # external imports from abc import ABC, abstractmethod @@ -10,6 +11,7 @@ import shapely import numpy from typing import Optional +from .utils import set_z_values_from_raster_df class Sampler(ABC): @@ -38,7 +40,7 @@ def type(self): @beartype.beartype @abstractmethod def sample( - self, spatial_data: geopandas.GeoDataFrame, map_data: Optional[MapData] = None + self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[geopandas.GeoDataFrame] = None, geology_data: Optional[geopandas.GeoDataFrame] = None ) -> pandas.DataFrame: """ Execute sampling method (abstract method) @@ -73,7 +75,7 @@ def __init__(self, decimation: int = 1): @beartype.beartype def sample( - self, spatial_data: geopandas.GeoDataFrame, map_data: Optional[MapData] = None + self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None ) -> pandas.DataFrame: """ Execute sample method takes full point data, samples the data and returns the decimated points @@ -87,10 +89,16 @@ def sample( data = spatial_data.copy() data["X"] = data.geometry.x data["Y"] = data.geometry.y - data["Z"] = map_data.get_value_from_raster_df(Datatype.DTM, data)["Z"] - data["layerID"] = geopandas.sjoin( - data, map_data.get_map_data(Datatype.GEOLOGY), how='left' - )['index_right'] + if dtm_data is not None: + data["Z"] = set_z_values_from_raster_df(dtm_data, data)["Z"] + else: + data["Z"] = None + if geology_data is not None: + data["layerID"] = geopandas.sjoin( + data, geology_data, how='left' + )['index_right'] + else: + data["layerID"] = None data.reset_index(drop=True, inplace=True) return pandas.DataFrame(data[:: self.decimation].drop(columns="geometry")) @@ -118,7 +126,7 @@ def __init__(self, spacing: float = 50.0): @beartype.beartype def sample( - self, spatial_data: geopandas.GeoDataFrame, map_data: Optional[MapData] = None + self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[geopandas.GeoDataFrame] = None, geology_data: Optional[geopandas.GeoDataFrame] = None ) -> pandas.DataFrame: """ Execute sample method takes full point data, samples the data and returns the sampled points diff --git a/map2loop/thickness_calculator.py b/map2loop/thickness_calculator.py index d7a9aad1..3da0ad40 100644 --- a/map2loop/thickness_calculator.py +++ b/map2loop/thickness_calculator.py @@ -5,6 +5,7 @@ calculate_endpoints, multiline_to_line, find_segment_strike_from_pt, + set_z_values_from_raster_df ) from .m2l_enums import Datatype from .interpolators import DipDipDirectionInterpolator @@ -271,7 +272,8 @@ def compute( # set the crs of the contacts to the crs of the units contacts = contacts.set_crs(crs=basal_contacts.crs) # get the elevation Z of the contacts - contacts = map_data.get_value_from_raster_df(Datatype.DTM, contacts) + dtm_data = map_data.get_map_data(Datatype.DTM) + contacts = set_z_values_from_raster_df(dtm_data, contacts) # update the geometry of the contact points to include the Z value contacts["geometry"] = contacts.apply( lambda row: shapely.geometry.Point(row.geometry.x, row.geometry.y, row["Z"]), axis=1 @@ -299,7 +301,8 @@ def compute( # set the crs of the interpolated orientations to the crs of the units interpolated_orientations = interpolated_orientations.set_crs(crs=basal_contacts.crs) # get the elevation Z of the interpolated points - interpolated = map_data.get_value_from_raster_df(Datatype.DTM, interpolated_orientations) + dtm_data = map_data.get_map_data(Datatype.DTM) + interpolated = set_z_values_from_raster_df(dtm_data, interpolated_orientations) # update the geometry of the interpolated points to include the Z value interpolated["geometry"] = interpolated.apply( lambda row: shapely.geometry.Point(row.geometry.x, row.geometry.y, row["Z"]), axis=1 diff --git a/map2loop/utils.py b/map2loop/utils.py index c3ed7795..55e2e7b2 100644 --- a/map2loop/utils.py +++ b/map2loop/utils.py @@ -7,6 +7,7 @@ import pandas import re import json +from osgeo import gdal from .logging import getLogger logger = getLogger(__name__) @@ -528,3 +529,62 @@ def update_from_legacy_file( json.dump(parsed_data, f, indent=4) return file_map + +@beartype.beartype +def value_from_raster(inv_geotransform, data, x: float, y: float): + """ + Get the value from a raster dataset at the specified point + + Args: + inv_geotransform (gdal.GeoTransform): + The inverse of the data's geotransform + data (numpy.array): + The raster data + x (float): + The easting coordinate of the value + y (float): + The northing coordinate of the value + + Returns: + float or int: The value at the point specified + """ + px = int(inv_geotransform[0] + inv_geotransform[1] * x + inv_geotransform[2] * y) + py = int(inv_geotransform[3] + inv_geotransform[4] * x + inv_geotransform[5] * y) + # Clamp values to the edges of raster if past boundary, similiar to GL_CLIP + px = max(px, 0) + px = min(px, data.shape[0] - 1) + py = max(py, 0) + py = min(py, data.shape[1] - 1) + return data[px][py] + +@beartype.beartype +def set_z_values_from_raster_df(dtm_data: gdal.Dataset, df: pandas.DataFrame): + """ + Add a 'Z' column to a dataframe with the heights from the 'X' and 'Y' coordinates + + Args: + dtm_data (gdal.Dataset): + Dtm data from raster map + df (pandas.DataFrame): + The original dataframe with 'X' and 'Y' columns + + Returns: + pandas.DataFrame: The modified dataframe + """ + if len(df) <= 0: + df["Z"] = [] + return df + + if dtm_data is None: + logger.warning("Cannot get value from data as data is not loaded") + return None + + inv_geotransform = gdal.InvGeoTransform(dtm_data.GetGeoTransform()) + data_array = numpy.array(dtm_data.GetRasterBand(1).ReadAsArray().T) + + df["Z"] = df.apply( + lambda row: value_from_raster(inv_geotransform, data_array, row["X"], row["Y"]), + axis=1, + ) + + return df \ No newline at end of file From 6b0249d3bc9091bd6a1e2461636ce80e7db4b77f Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Fri, 27 Jun 2025 14:18:46 +0800 Subject: [PATCH 2/8] fix extract_geology_contacts --- map2loop/project.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/map2loop/project.py b/map2loop/project.py index ec7260e5..9e8189cd 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -526,11 +526,9 @@ def extract_geology_contacts(self): self.map_data.extract_basal_contacts(self.stratigraphic_column.column) # sample the contacts - self.map_data.sampled_contacts = self.samplers[Datatype.GEOLOGY].sample( - self.map_data.basal_contacts - ) - - self.map_data.get_value_from_raster_df(Datatype.DTM, self.map_data.sampled_contacts) + self.map_data.sampled_contacts = self.samplers[Datatype.GEOLOGY].sample(self.map_data.basal_contacts) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + set_z_values_from_raster_df(dtm_data, self.map_data.sampled_contacts) def calculate_stratigraphic_order(self, take_best=False): """ From e3ae1c3996b2dcd93ef566dfa1dd5ba0b76247e1 Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Fri, 27 Jun 2025 14:37:19 +0800 Subject: [PATCH 3/8] fix calculate_fault_orientations and summarise_fault_data --- map2loop/project.py | 6 ++++-- map2loop/sampler.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/map2loop/project.py b/map2loop/project.py index 9e8189cd..7e39bced 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -706,7 +706,8 @@ def calculate_fault_orientations(self): self.map_data.get_map_data(Datatype.FAULT_ORIENTATION), self.map_data, ) - self.map_data.get_value_from_raster_df(Datatype.DTM, self.fault_orientations) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + set_z_values_from_raster_df(dtm_data, self.fault_orientations) else: logger.warning( "No fault orientation data found, skipping fault orientation calculation" @@ -731,7 +732,8 @@ def summarise_fault_data(self): """ Use the fault shapefile to make a summary of each fault by name """ - self.map_data.get_value_from_raster_df(Datatype.DTM, self.fault_samples) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + set_z_values_from_raster_df(dtm_data, self.fault_samples) self.deformation_history.summarise_data(self.fault_samples) self.deformation_history.faults = self.throw_calculator.compute( diff --git a/map2loop/sampler.py b/map2loop/sampler.py index 10aa51b9..43db952e 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -11,7 +11,7 @@ import shapely import numpy from typing import Optional -from .utils import set_z_values_from_raster_df +from osgeo import gdal class Sampler(ABC): From cc315d7e5d5191a64ab07a2477a66b0b9fd47569 Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Fri, 27 Jun 2025 15:38:57 +0800 Subject: [PATCH 4/8] fix dtm data type --- map2loop/sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/map2loop/sampler.py b/map2loop/sampler.py index 43db952e..e8e1fe51 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -40,7 +40,7 @@ def type(self): @beartype.beartype @abstractmethod def sample( - self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[geopandas.GeoDataFrame] = None, geology_data: Optional[geopandas.GeoDataFrame] = None + self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None ) -> pandas.DataFrame: """ Execute sampling method (abstract method) @@ -126,7 +126,7 @@ def __init__(self, spacing: float = 50.0): @beartype.beartype def sample( - self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[geopandas.GeoDataFrame] = None, geology_data: Optional[geopandas.GeoDataFrame] = None + self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None ) -> pandas.DataFrame: """ Execute sample method takes full point data, samples the data and returns the sampled points From abd8b2c791dcacbfcb612b8120ecfb4cd446a6c4 Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Fri, 27 Jun 2025 15:49:20 +0800 Subject: [PATCH 5/8] fix get_value_from_raster import --- map2loop/sorter.py | 6 +++--- map2loop/thickness_calculator.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/map2loop/sorter.py b/map2loop/sorter.py index da4dab76..656cc4c9 100644 --- a/map2loop/sorter.py +++ b/map2loop/sorter.py @@ -3,7 +3,7 @@ import pandas import numpy as np import math -from .mapdata import MapData +from .mapdata import MapData, get_value_from_raster from typing import Union from .logging import getLogger @@ -434,9 +434,9 @@ def sort( continue # Get heights for intersection point and start of ray - height = map_data.get_value_from_raster(Datatype.DTM, start.x, start.y) + height = get_value_from_raster(Datatype.DTM, start.x, start.y) first_intersect_point = Point(start.x, start.y, height) - height = map_data.get_value_from_raster( + height = get_value_from_raster( Datatype.DTM, second_intersect_point.x, second_intersect_point.y ) second_intersect_point = Point(second_intersect_point.x, start.y, height) diff --git a/map2loop/thickness_calculator.py b/map2loop/thickness_calculator.py index 3da0ad40..d6992b90 100644 --- a/map2loop/thickness_calculator.py +++ b/map2loop/thickness_calculator.py @@ -9,7 +9,7 @@ ) from .m2l_enums import Datatype from .interpolators import DipDipDirectionInterpolator -from .mapdata import MapData +from .mapdata import MapData, get_value_from_raster from .logging import getLogger logger = getLogger(__name__) @@ -358,13 +358,13 @@ def compute( p1[0] = numpy.asarray(short_line[0].coords[0][0]) p1[1] = numpy.asarray(short_line[0].coords[0][1]) # get the elevation Z of the end point p1 - p1[2] = map_data.get_value_from_raster(Datatype.DTM, p1[0], p1[1]) + p1[2] = get_value_from_raster(Datatype.DTM, p1[0], p1[1]) # create array to store xyz coordinates of the end point p2 p2 = numpy.zeros(3) p2[0] = numpy.asarray(short_line[0].coords[-1][0]) p2[1] = numpy.asarray(short_line[0].coords[-1][1]) # get the elevation Z of the end point p2 - p2[2] = map_data.get_value_from_raster(Datatype.DTM, p2[0], p2[1]) + p2[2] = get_value_from_raster(Datatype.DTM, p2[0], p2[1]) # calculate the length of the shortest line line_length = scipy.spatial.distance.euclidean(p1, p2) # find the indices of the points that are within 5% of the length of the shortest line From 3a908df183190db14033bd2916d523652f20a321 Mon Sep 17 00:00:00 2001 From: noellehmcheng <143368485+noellehmcheng@users.noreply.github.com> Date: Fri, 27 Jun 2025 08:01:45 +0000 Subject: [PATCH 6/8] style: style fixes by ruff and autoformatting by black --- map2loop/sampler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/map2loop/sampler.py b/map2loop/sampler.py index e8e1fe51..04baed4a 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -1,6 +1,4 @@ # internal imports -from .m2l_enums import Datatype -from .mapdata import MapData from .utils import set_z_values_from_raster_df # external imports From aa6f0e246cfac1a690bb115b89acfe5d881dd922 Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Fri, 27 Jun 2025 16:35:41 +0800 Subject: [PATCH 7/8] revert get_value_from_raster --- map2loop/sorter.py | 6 +++--- map2loop/thickness_calculator.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/map2loop/sorter.py b/map2loop/sorter.py index 656cc4c9..da4dab76 100644 --- a/map2loop/sorter.py +++ b/map2loop/sorter.py @@ -3,7 +3,7 @@ import pandas import numpy as np import math -from .mapdata import MapData, get_value_from_raster +from .mapdata import MapData from typing import Union from .logging import getLogger @@ -434,9 +434,9 @@ def sort( continue # Get heights for intersection point and start of ray - height = get_value_from_raster(Datatype.DTM, start.x, start.y) + height = map_data.get_value_from_raster(Datatype.DTM, start.x, start.y) first_intersect_point = Point(start.x, start.y, height) - height = get_value_from_raster( + height = map_data.get_value_from_raster( Datatype.DTM, second_intersect_point.x, second_intersect_point.y ) second_intersect_point = Point(second_intersect_point.x, start.y, height) diff --git a/map2loop/thickness_calculator.py b/map2loop/thickness_calculator.py index d6992b90..3da0ad40 100644 --- a/map2loop/thickness_calculator.py +++ b/map2loop/thickness_calculator.py @@ -9,7 +9,7 @@ ) from .m2l_enums import Datatype from .interpolators import DipDipDirectionInterpolator -from .mapdata import MapData, get_value_from_raster +from .mapdata import MapData from .logging import getLogger logger = getLogger(__name__) @@ -358,13 +358,13 @@ def compute( p1[0] = numpy.asarray(short_line[0].coords[0][0]) p1[1] = numpy.asarray(short_line[0].coords[0][1]) # get the elevation Z of the end point p1 - p1[2] = get_value_from_raster(Datatype.DTM, p1[0], p1[1]) + p1[2] = map_data.get_value_from_raster(Datatype.DTM, p1[0], p1[1]) # create array to store xyz coordinates of the end point p2 p2 = numpy.zeros(3) p2[0] = numpy.asarray(short_line[0].coords[-1][0]) p2[1] = numpy.asarray(short_line[0].coords[-1][1]) # get the elevation Z of the end point p2 - p2[2] = get_value_from_raster(Datatype.DTM, p2[0], p2[1]) + p2[2] = map_data.get_value_from_raster(Datatype.DTM, p2[0], p2[1]) # calculate the length of the shortest line line_length = scipy.spatial.distance.euclidean(p1, p2) # find the indices of the points that are within 5% of the length of the shortest line From c9a93feb31f526a990f9853ad7b3cbad5db0bbb9 Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Tue, 1 Jul 2025 12:43:00 +0800 Subject: [PATCH 8/8] move dtm and geology data parameters from sample function to constructor --- map2loop/project.py | 4 +++- map2loop/sampler.py | 34 ++++++++++++++++++++++++---------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/map2loop/project.py b/map2loop/project.py index 7e39bced..632d41cd 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -510,7 +510,9 @@ def sample_map_data(self): self.geology_samples = self.samplers[Datatype.GEOLOGY].sample(geology_data) logger.info(f"Sampling structure map data using {self.samplers[Datatype.STRUCTURE].sampler_label}") - self.structure_samples = self.samplers[Datatype.STRUCTURE].sample(self.map_data.get_map_data(Datatype.STRUCTURE), dtm_data, geology_data) + self.samplers[Datatype.STRUCTURE].dtm_data = dtm_data + self.samplers[Datatype.STRUCTURE].geology_data = geology_data + self.structure_samples = self.samplers[Datatype.STRUCTURE].sample(self.map_data.get_map_data(Datatype.STRUCTURE)) logger.info(f"Sampling fault map data using {self.samplers[Datatype.FAULT].sampler_label}") self.fault_samples = self.samplers[Datatype.FAULT].sample(self.map_data.get_map_data(Datatype.FAULT)) diff --git a/map2loop/sampler.py b/map2loop/sampler.py index 04baed4a..b4c7835c 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -20,11 +20,13 @@ class Sampler(ABC): ABC (ABC): Derived from Abstract Base Class """ - def __init__(self): + def __init__(self, dtm_data=None, geology_data=None): """ Initialiser of for Sampler """ self.sampler_label = "SamplerBaseClass" + self.dtm_data = dtm_data + self.geology_data = geology_data def type(self): """ @@ -38,7 +40,7 @@ def type(self): @beartype.beartype @abstractmethod def sample( - self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None + self, spatial_data: geopandas.GeoDataFrame ) -> pandas.DataFrame: """ Execute sampling method (abstract method) @@ -60,20 +62,24 @@ class SamplerDecimator(Sampler): """ @beartype.beartype - def __init__(self, decimation: int = 1): + def __init__(self, decimation: int = 1, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None): """ Initialiser for decimator sampler Args: decimation (int, optional): stride of the points to sample. Defaults to 1. + dtm_data (Optional[gdal.Dataset], optional): digital terrain map data. Defaults to None. + geology_data (Optional[geopandas.GeoDataFrame], optional): geology data. Defaults to None. """ + super().__init__(dtm_data, geology_data) self.sampler_label = "SamplerDecimator" decimation = max(decimation, 1) self.decimation = decimation + @beartype.beartype def sample( - self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None + self, spatial_data: geopandas.GeoDataFrame ) -> pandas.DataFrame: """ Execute sample method takes full point data, samples the data and returns the decimated points @@ -87,13 +93,17 @@ def sample( data = spatial_data.copy() data["X"] = data.geometry.x data["Y"] = data.geometry.y - if dtm_data is not None: - data["Z"] = set_z_values_from_raster_df(dtm_data, data)["Z"] + if self.dtm_data is not None: + result = set_z_values_from_raster_df(self.dtm_data, data) + if result is not None: + data["Z"] = result["Z"] + else: + data["Z"] = None else: data["Z"] = None - if geology_data is not None: + if self.geology_data is not None: data["layerID"] = geopandas.sjoin( - data, geology_data, how='left' + data, self.geology_data, how='left' )['index_right'] else: data["layerID"] = None @@ -111,20 +121,24 @@ class SamplerSpacing(Sampler): """ @beartype.beartype - def __init__(self, spacing: float = 50.0): + def __init__(self, spacing: float = 50.0, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None): """ Initialiser for spacing sampler Args: spacing (float, optional): The distance between samples. Defaults to 50.0. + dtm_data (Optional[gdal.Dataset], optional): digital terrain map data. Defaults to None. + geology_data (Optional[geopandas.GeoDataFrame], optional): geology data. Defaults to None. """ + super().__init__(dtm_data, geology_data) self.sampler_label = "SamplerSpacing" spacing = max(spacing, 1.0) self.spacing = spacing + @beartype.beartype def sample( - self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None + self, spatial_data: geopandas.GeoDataFrame ) -> pandas.DataFrame: """ Execute sample method takes full point data, samples the data and returns the sampled points