Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add-ml-model-shape #205

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions examples/ml_pipeline.ipynb

Large diffs are not rendered by default.

72 changes: 62 additions & 10 deletions geoengine/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from pathlib import Path
import tempfile
from dataclasses import dataclass
import geoengine_openapi_client.models
from onnx import TypeProto, TensorProto, ModelProto
from onnx.helper import tensor_dtype_to_string
from geoengine_openapi_client.models import MlModelMetadata, MlModel, RasterDataType
from geoengine_openapi_client.models import MlModelMetadata, MlModel, RasterDataType, TensorShape3D
import geoengine_openapi_client
from geoengine.auth import get_session
from geoengine.datasets import UploadId
Expand All @@ -33,7 +34,8 @@ def register_ml_model(onnx_model: ModelProto,
onnx_model,
input_type=model_config.metadata.input_type,
output_type=model_config.metadata.output_type,
num_input_bands=model_config.metadata.num_input_bands,
input_shape=model_config.metadata.input_shape,
out_shape=model_config.metadata.output_shape
)

session = get_session()
Expand All @@ -58,10 +60,12 @@ def register_ml_model(onnx_model: ModelProto,
ml_api.add_ml_model(model, _request_timeout=register_timeout)


# pylint: disable=too-many-branches,too-many-statements
def validate_model_config(onnx_model: ModelProto, *,
input_type: RasterDataType,
output_type: RasterDataType,
num_input_bands: int):
input_shape: TensorShape3D,
out_shape: TensorShape3D):
'''Validates the model config. Raises an exception if the model config is invalid'''

def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix: 'str'):
Expand All @@ -80,18 +84,66 @@ def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix:
raise InputException('Models with multiple inputs are not supported')
check_data_type(model_inputs[0].type, input_type, 'input')

dims = model_inputs[0].type.tensor_type.shape.dim
if len(dims) != 2:
raise InputException('Only 2D input tensors are supported')
if not dims[1].dim_value:
raise InputException('Dimension 1 of the input tensor must have a length')
if dims[1].dim_value != num_input_bands:
raise InputException(f'Model input has {dims[1].dim_value} bands, but {num_input_bands} bands are expected')
dim = model_inputs[0].type.tensor_type.shape.dim

if len(dim) == 2:
if not dim[1].dim_value:
raise InputException('Dimension 1 of a 1D input tensor must have a length')
if dim[1].dim_value != input_shape.attributes:
raise InputException(f'Model input has {dim[1].dim_value} bands, but {input_shape.attributes} are expected')
elif len(dim) == 4:
if not dim[1].dim_value:
raise InputException('Dimension 1 of the a 3D input tensor must have a length')
if not dim[2].dim_value:
raise InputException('Dimension 2 of the a 3D input tensor must have a length')
if not dim[3].dim_value:
raise InputException('Dimension 3 of the a 3D input tensor must have a length')
if dim[1].dim_value != input_shape.y:
raise InputException(f'Model input has {dim[1].dim_value} y size, but {input_shape.y} are expected')
if dim[2].dim_value != input_shape.x:
raise InputException(f'Model input has {dim[2].dim_value} x size, but {input_shape.x} are expected')
if dim[3].dim_value != input_shape.attributes:
raise InputException(f'Model input has {dim[3].dim_value} bands, but {input_shape.attributes} are expected')
else:
raise InputException('Only 1D and 3D input tensors are supported')

if len(model_outputs) < 1:
raise InputException('Models with no outputs are not supported')
check_data_type(model_outputs[0].type, output_type, 'output')

dim = model_outputs[0].type.tensor_type.shape.dim
if len(dim) == 1:
pass # this is a happens if there is only a single out? so shape would be [-1]
elif len(dim) == 2:
if not dim[1].dim_value:
raise InputException('Dimension 1 of a 1D input tensor must have a length')
if dim[1].dim_value != 1:
raise InputException(f'Model output has {dim[1].dim_value} bands, but {out_shape.attributes} are expected')
elif len(dim) == 3:
if not dim[1].dim_value:
raise InputException('Dimension 1 of a 3D input tensor must have a length')
if not dim[2].dim_value:
raise InputException('Dimension 2 of a 3D input tensor must have a length')
if dim[1].dim_value != out_shape.y:
raise InputException(f'Model output has {dim[1].dim_value} y size, but {out_shape.y} are expected')
if dim[2].dim_value != out_shape.x:
raise InputException(f'Model output has {dim[2].dim_value} x size, but {out_shape.x} are expected')
elif len(dim) == 4:
if not dim[1].dim_value:
raise InputException('Dimension 1 of the a 3D input tensor must have a length')
if not dim[2].dim_value:
raise InputException('Dimension 2 of the a 3D input tensor must have a length')
if not dim[3].dim_value:
raise InputException('Dimension 3 of the a 3D input tensor must have a length')
if dim[1].dim_value != out_shape.y:
raise InputException(f'Model output has {dim[1].dim_value} y size, but {out_shape.y} are expected')
if dim[2].dim_value != out_shape.x:
raise InputException(f'Model output has {dim[2].dim_value} x size, but {out_shape.x} are expected')
if dim[3].dim_value != out_shape.attributes:
raise InputException(f'Model output has {dim[3].dim_value} bands, but {out_shape.attributes} are expected')
else:
raise InputException('Only 1D and 3D output tensors are supported')


RASTER_TYPE_TO_ONNX_TYPE = {
RasterDataType.F32: TensorProto.FLOAT,
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package_dir =
packages = find:
python_requires = >=3.9
install_requires =
geoengine-openapi-client == 0.0.17
geoengine-openapi-client @ git+https://github.com/geo-engine/openapi-client@ml-model-input-output-shape#subdirectory=python
geopandas >=0.9,<0.15
matplotlib >=3.5,<3.8
numpy >=1.21,<2
Expand Down
29 changes: 21 additions & 8 deletions tests/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sklearn.ensemble import RandomForestClassifier
from skl2onnx import to_onnx
import numpy as np
from geoengine_openapi_client.models import MlModelMetadata, RasterDataType
from geoengine_openapi_client.models import MlModelMetadata, RasterDataType, TensorShape3D
import geoengine as ge
from . import UrllibMocker

Expand Down Expand Up @@ -47,8 +47,17 @@ def test_uploading_onnx_model(self):
"metadata": {
"fileName": "model.onnx",
"inputType": "F32",
"numInputBands": 2,
"outputType": "I64"
"outputType": "I64",
"inputShape": {
"y": 1,
"x": 1,
"attributes": 2
},
"outputShape": {
"y": 1,
"x": 1,
"attributes": 1
}
},
"name": "foo",
"upload": upload_id
Expand All @@ -64,8 +73,9 @@ def test_uploading_onnx_model(self):
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F32,
num_input_bands=2,
output_type=RasterDataType.I64,
input_shape=TensorShape3D(y=1, x=1, attributes=2),
output_shape=TensorShape3D(y=1, x=1, attributes=1)
),
display_name="Decision Tree",
description="A simple decision tree model",
Expand All @@ -80,16 +90,17 @@ def test_uploading_onnx_model(self):
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F32,
num_input_bands=4,
output_type=RasterDataType.I64,
input_shape=TensorShape3D(y=1, x=1, attributes=4),
output_shape=TensorShape3D(y=1, x=1, attributes=1)
),
display_name="Decision Tree",
description="A simple decision tree model",
)
)
self.assertEqual(
str(exception.exception),
'Model input has 2 bands, but 4 bands are expected'
'Model input has 2 bands, but 4 are expected'
)

with self.assertRaises(ge.InputException) as exception:
Expand All @@ -100,8 +111,9 @@ def test_uploading_onnx_model(self):
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F64,
num_input_bands=2,
output_type=RasterDataType.I64,
input_shape=TensorShape3D(y=1, x=1, attributes=2),
output_shape=TensorShape3D(y=1, x=1, attributes=1)
),
display_name="Decision Tree",
description="A simple decision tree model",
Expand All @@ -120,8 +132,9 @@ def test_uploading_onnx_model(self):
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F32,
num_input_bands=2,
output_type=RasterDataType.I32,
input_shape=TensorShape3D(y=1, x=1, attributes=2),
output_shape=TensorShape3D(y=1, x=1, attributes=1)
),
display_name="Decision Tree",
description="A simple decision tree model",
Expand Down
Loading