From 367f2cab0d98ff5bbf005a19a17f1698d428eb79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ante=20Pu=C5=A1i=C4=87?= Date: Tue, 18 Apr 2023 18:22:47 +0200 Subject: [PATCH] Make large dependencies optional (#235) * Initial commit * Add [all] option * Add extras to workflow file * Fix workflow * Test workflows on the original state * Test workflows (poetry.lock consistent with pyproject.toml) * Readd changes with fixed lock file * Update lock * Fix poetry.lock * Add review suggestions * Make spelling consistent * Mark tests requiring extras & modify imports if extras not installed * FIx formatting * Comply with flake8/black * Remove torch as independent extra * Rename extra * Re-add torch (for users who will manually install dgl) * Add detailed test grouping * Fix import * Fix import in utilities.py * Switch to optionally importing torch (in graph translators) * Fix typo * Add import checks * Improve import/export tests * Remove unnecessary test mark * Fix optional dependency tests * Rename optional import parameters & torch extra --- .github/workflows/build-and-test.yml | 4 +- gqlalchemy/exceptions.py | 9 +++ .../export/graph_transporter.py | 20 ++++-- .../importing/graph_importer.py | 19 +++++- .../transformations/importing/loaders.py | 20 +++++- .../translators/dgl_translator.py | 2 +- .../transformations/translators/translator.py | 9 ++- gqlalchemy/utilities.py | 30 +++++++-- poetry.lock | 61 ++++++------------- pyproject.toml | 28 ++++++--- pytest.ini | 6 ++ tests/transformations/export/test_export.py | 34 ++++++++--- .../transformations/importing/test_import.py | 34 ++++++++--- tests/transformations/loaders/test_loaders.py | 20 ++++++ .../translators/test_dgl_transformations.py | 16 +++-- .../translators/test_nx_transformations.py | 3 +- .../translators/test_pyg_transformations.py | 17 ++++-- 17 files changed, 230 insertions(+), 102 deletions(-) diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index f5086edd..f888958f 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -49,7 +49,7 @@ jobs: docker run -p 7474:7474 -p 7688:7687 -d -v $HOME/neo4j/data:/data -v $HOME/neo4j/logs:/logs -v $HOME/neo4j/import:/var/lib/neo4j/import -v $HOME/neo4j/plugins:/plugins --env NEO4J_AUTH=neo4j/test neo4j:4.4.7 - name: Test project run: | - poetry install + poetry install --all-extras poe install-pyg-cpu poetry run pytest -vvv -m "not slow and not ubuntu and not docker" - name: Use the Upload Artifact GitHub Action @@ -114,7 +114,7 @@ jobs: poetry-version: ${{ env.POETRY_VERSION }} - name: Test project run: | - poetry install + poetry install --all-extras poe install-pyg-cpu poetry run pytest -vvv -m "not slow and not ubuntu and not docker" - name: Save Memgraph Logs diff --git a/gqlalchemy/exceptions.py b/gqlalchemy/exceptions.py index e77f8001..0b8b6724 100644 --- a/gqlalchemy/exceptions.py +++ b/gqlalchemy/exceptions.py @@ -43,6 +43,10 @@ SQLitePropertyDatabase("path-to-sqlite-db", db) """ +MISSING_OPTIONAL_DEPENDENCY = """ +No module named '{dependency_name}' +""" + MISSING_ORDER = """ The second argument of the tuple must be order: ASC, ASCENDING, DESC or DESCENDING. """ @@ -199,6 +203,11 @@ def __init__(self, path): self.message = FILE_NOT_FOUND.format(path=path) +def raise_if_not_imported(dependency, dependency_name): + if not dependency: + raise ModuleNotFoundError(MISSING_OPTIONAL_DEPENDENCY.format(dependency_name=dependency_name)) + + def database_error_handler(func): def inner_function(*args, **kwargs): try: diff --git a/gqlalchemy/transformations/export/graph_transporter.py b/gqlalchemy/transformations/export/graph_transporter.py index 7b0626ea..73c13e6e 100644 --- a/gqlalchemy/transformations/export/graph_transporter.py +++ b/gqlalchemy/transformations/export/graph_transporter.py @@ -12,12 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from gqlalchemy.exceptions import raise_if_not_imported +import gqlalchemy.memgraph_constants as mg_consts from gqlalchemy.transformations.export.transporter import Transporter -from gqlalchemy.transformations.translators.dgl_translator import DGLTranslator -from gqlalchemy.transformations.translators.nx_translator import NxTranslator -from gqlalchemy.transformations.translators.pyg_translator import PyGTranslator from gqlalchemy.transformations.graph_type import GraphType -import gqlalchemy.memgraph_constants as mg_consts + +try: + from gqlalchemy.transformations.translators.dgl_translator import DGLTranslator +except ModuleNotFoundError: + DGLTranslator = None + +from gqlalchemy.transformations.translators.nx_translator import NxTranslator + +try: + from gqlalchemy.transformations.translators.pyg_translator import PyGTranslator +except ModuleNotFoundError: + PyGTranslator = None class GraphTransporter(Transporter): @@ -47,8 +57,10 @@ def __init__( super().__init__() self.graph_type = graph_type.upper() if self.graph_type == GraphType.DGL.name: + raise_if_not_imported(dependency=DGLTranslator, dependency_name="dgl") self.translator = DGLTranslator(host, port, username, password, encrypted, client_name, lazy) elif self.graph_type == GraphType.PYG.name: + raise_if_not_imported(dependency=PyGTranslator, dependency_name="torch_geometric") self.translator = PyGTranslator(host, port, username, password, encrypted, client_name, lazy) elif self.graph_type == GraphType.NX.name: self.translator = NxTranslator(host, port, username, password, encrypted, client_name, lazy) diff --git a/gqlalchemy/transformations/importing/graph_importer.py b/gqlalchemy/transformations/importing/graph_importer.py index 733c87a7..12123f47 100644 --- a/gqlalchemy/transformations/importing/graph_importer.py +++ b/gqlalchemy/transformations/importing/graph_importer.py @@ -15,11 +15,22 @@ from gqlalchemy import Memgraph from gqlalchemy.transformations.graph_type import GraphType from gqlalchemy.transformations.importing.importer import Importer -from gqlalchemy.transformations.translators.dgl_translator import DGLTranslator -from gqlalchemy.transformations.translators.nx_translator import NxTranslator -from gqlalchemy.transformations.translators.pyg_translator import PyGTranslator + +from gqlalchemy.exceptions import raise_if_not_imported import gqlalchemy.memgraph_constants as mg_consts +try: + from gqlalchemy.transformations.translators.dgl_translator import DGLTranslator +except ModuleNotFoundError: + DGLTranslator = None + +from gqlalchemy.transformations.translators.nx_translator import NxTranslator + +try: + from gqlalchemy.transformations.translators.pyg_translator import PyGTranslator +except ModuleNotFoundError: + PyGTranslator = None + class GraphImporter(Importer): """Imports dgl, pyg or networkx graph representations to Memgraph. @@ -45,8 +56,10 @@ def __init__( super().__init__() self.graph_type = graph_type.upper() if self.graph_type == GraphType.DGL.name: + raise_if_not_imported(dependency=DGLTranslator, dependency_name="dgl") self.translator = DGLTranslator(host, port, username, password, encrypted, client_name, lazy) elif self.graph_type == GraphType.PYG.name: + raise_if_not_imported(dependency=PyGTranslator, dependency_name="torch_geometric") self.translator = PyGTranslator(host, port, username, password, encrypted, client_name, lazy) elif self.graph_type == GraphType.NX.name: self.translator = NxTranslator(host, port, username, password, encrypted, client_name, lazy) diff --git a/gqlalchemy/transformations/importing/loaders.py b/gqlalchemy/transformations/importing/loaders.py index 82d9d042..f4a4f0ff 100644 --- a/gqlalchemy/transformations/importing/loaders.py +++ b/gqlalchemy/transformations/importing/loaders.py @@ -20,8 +20,15 @@ from typing import List, Dict, Any, Optional, Union import adlfs -import pyarrow.dataset as ds -from pyarrow import fs + +try: + import pyarrow.dataset as ds +except ModuleNotFoundError: + ds = None +try: + from pyarrow import fs +except ModuleNotFoundError: + fs = None from dacite import from_dict from gqlalchemy import Memgraph @@ -223,6 +230,9 @@ def __init__(self, bucket_name: str, **kwargs): if S3_SECRET_KEY not in kwargs: raise KeyError(f"{S3_SECRET_KEY} is needed to connect to S3 storage") + if fs is None: + raise ModuleNotFoundError("No module named 'pyarrow'") + super().__init__(fs=fs.S3FileSystem(**kwargs)) self._bucket_name = bucket_name @@ -278,6 +288,9 @@ def __init__(self, path: str) -> None: Args: path: path to the local storage location. """ + if fs is None: + raise ModuleNotFoundError("No module named 'pyarrow'") + super().__init__(fs=fs.LocalFileSystem()) self._path = path @@ -361,6 +374,9 @@ def load_data( source = self._file_system_handler.get_path(f"{collection_name}.{self._file_extension}") print("Loading data from " + ("cross " if is_cross_table else "") + f"table {source}...") + if ds is None: + raise ModuleNotFoundError("No module named 'pyarrow'") + dataset = ds.dataset(source=source, format=self._file_extension, filesystem=self._file_system_handler.fs) for batch in dataset.to_batches( diff --git a/gqlalchemy/transformations/translators/dgl_translator.py b/gqlalchemy/transformations/translators/dgl_translator.py index c9440beb..c88da98e 100644 --- a/gqlalchemy/transformations/translators/dgl_translator.py +++ b/gqlalchemy/transformations/translators/dgl_translator.py @@ -35,7 +35,7 @@ class DGLTranslator(Translator): """Performs conversion from cypher queries to the DGL graph representation. DGL assigns to each edge a unique integer, called the edge ID, based on the order in which it was added to the graph. In DGL, all the edges are directed, and an edge (u,v) indicates that the direction goes from node u to node v. Only features of numerical types (e.g., float, double, and int) are allowed. They can be scalars, vectors or multi-dimensional - tensors (DQL requirement). Each node feature has a unique name and each edge feature has a unique name. The features of nodes and edges can have + tensors (DGL requirement). Each node feature has a unique name and each edge feature has a unique name. The features of nodes and edges can have the same name. A feature is created via tensor assignment, which assigns a feature to each node/edge in the graph. The leading dimension of that tensor must be equal to the number of nodes/edges in the graph. You cannot assign a feature to a subset of the nodes/edges in the graph. Features of the same name must have the same dimensionality and data type. diff --git a/gqlalchemy/transformations/translators/translator.py b/gqlalchemy/transformations/translators/translator.py index c3dba4ee..8bb43227 100644 --- a/gqlalchemy/transformations/translators/translator.py +++ b/gqlalchemy/transformations/translators/translator.py @@ -17,8 +17,12 @@ from collections import defaultdict from numbers import Number -import torch +try: + import torch +except ModuleNotFoundError: + torch = None +from gqlalchemy.exceptions import raise_if_not_imported from gqlalchemy.transformations.constants import LABELS_CONCAT, DEFAULT_NODE_LABEL, DEFAULT_EDGE_TYPE from gqlalchemy.memgraph_constants import ( MG_HOST, @@ -35,7 +39,6 @@ class Translator(ABC): - # Lambda function to concat list of labels merge_labels: Callable[[Set[str]], str] = ( lambda labels, default_node_label: LABELS_CONCAT.join([label for label in sorted(labels)]) @@ -104,6 +107,8 @@ def validate_features(cls, features: List, expected_num: int): Returns: None if features cannot be set or tensor of same features. """ + raise_if_not_imported(dependency=torch, dependency_name="torch") + if len(features) != expected_num: return None try: diff --git a/gqlalchemy/utilities.py b/gqlalchemy/utilities.py index 2826f01a..b58e1891 100644 --- a/gqlalchemy/utilities.py +++ b/gqlalchemy/utilities.py @@ -13,14 +13,21 @@ # limitations under the License. from abc import ABC, abstractmethod -import math -import numpy as np -import torch - from datetime import datetime, date, time, timedelta from enum import Enum +import inspect +import math from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np + +try: + import torch +except ModuleNotFoundError: + torch = None + +from gqlalchemy.exceptions import raise_if_not_imported + class DatetimeKeywords(Enum): DURATION = "duration" @@ -67,13 +74,26 @@ def _format_timedelta(duration: timedelta) -> str: return f"P{days}DT{hours}H{minutes}M{remainder_sec}S" +def _is_torch_tensor(value): + for cls in inspect.getmro(type(value)): + try: + if cls.__module__ == "torch" and cls.__name__ == "Tensor": + return True + except Exception: + pass + return False + + def to_cypher_value(value: Any, config: NetworkXCypherConfig = None) -> str: """Converts value to a valid Cypher type.""" if config is None: config = NetworkXCypherConfig() value_type = type(value) - if isinstance(value, torch.Tensor): + + if _is_torch_tensor(value): + raise_if_not_imported(dependency=torch, dependency_name="torch") + if value.squeeze().size() == 1: return value.squeeze().item() else: diff --git a/poetry.lock b/poetry.lock index f309e2cd..a60498be 100644 --- a/poetry.lock +++ b/poetry.lock @@ -566,7 +566,7 @@ name = "dgl" version = "0.9.1" description = "Deep Graph Library" category = "main" -optional = false +optional = true python-versions = "*" files = [ {file = "dgl-0.9.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7b9848c2c643d7b4d4c1969d05e511aff776b101c69573defd3382107b25a5c9"}, @@ -1065,7 +1065,7 @@ name = "nvidia-cublas-cu11" version = "11.10.3.66" description = "CUBLAS native runtime libraries" category = "main" -optional = false +optional = true python-versions = ">=3" files = [ {file = "nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl", hash = "sha256:d32e4d75f94ddfb93ea0a5dda08389bcc65d8916a25cb9f37ac89edaeed3bded"}, @@ -1081,7 +1081,7 @@ name = "nvidia-cuda-nvrtc-cu11" version = "11.7.99" description = "NVRTC native runtime libraries" category = "main" -optional = false +optional = true python-versions = ">=3" files = [ {file = "nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:9f1562822ea264b7e34ed5930567e89242d266448e936b85bc97a3370feabb03"}, @@ -1098,7 +1098,7 @@ name = "nvidia-cuda-runtime-cu11" version = "11.7.99" description = "CUDA Runtime native Libraries" category = "main" -optional = false +optional = true python-versions = ">=3" files = [ {file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:cc768314ae58d2641f07eac350f40f99dcb35719c4faff4bc458a7cd2b119e31"}, @@ -1114,7 +1114,7 @@ name = "nvidia-cudnn-cu11" version = "8.5.0.96" description = "cuDNN runtime libraries" category = "main" -optional = false +optional = true python-versions = ">=3" files = [ {file = "nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:402f40adfc6f418f9dae9ab402e773cfed9beae52333f6d86ae3107a1b9527e7"}, @@ -1157,18 +1157,6 @@ files = [ [package.dependencies] pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" -[[package]] -name = "pastel" -version = "0.2.1" -description = "Bring colors to your terminal." -category = "main" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -files = [ - {file = "pastel-0.2.1-py2.py3-none-any.whl", hash = "sha256:4349225fcdf6c2bb34d483e523475de5bb04a5c10ef711263452cb37d7dd4364"}, - {file = "pastel-0.2.1.tar.gz", hash = "sha256:e6581ac04e973cac858828c6202c1e1e81fee1dc7de7683f3e1ffe0bfd8a573d"}, -] - [[package]] name = "pathspec" version = "0.10.1" @@ -1213,25 +1201,6 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] -[[package]] -name = "poethepoet" -version = "0.18.1" -description = "A task runner that works well with poetry." -category = "main" -optional = false -python-versions = ">=3.7" -files = [ - {file = "poethepoet-0.18.1-py3-none-any.whl", hash = "sha256:e85727bf6f4a10bf6c1a43026bdeb40df689bea3c4682d03cbe531cabc8f2ba6"}, - {file = "poethepoet-0.18.1.tar.gz", hash = "sha256:5f3566b14c2f5dccdfbc3bb26f0096006b38dc0b9c74bd4f8dd1eba7b0e29f6a"}, -] - -[package.dependencies] -pastel = ">=0.2.1,<0.3.0" -tomli = ">=1.2.2" - -[package.extras] -poetry-plugin = ["poetry (>=1.0,<2.0)"] - [[package]] name = "portalocker" version = "2.5.1" @@ -1334,7 +1303,7 @@ name = "pyarrow" version = "9.0.0" description = "Python library for Apache Arrow" category = "main" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "pyarrow-9.0.0-cp310-cp310-macosx_10_13_universal2.whl", hash = "sha256:767cafb14278165ad539a2918c14c1b73cf20689747c21375c38e3fe62884902"}, @@ -1751,7 +1720,7 @@ name = "scipy" version = "1.9.3" description = "Fundamental algorithms for scientific computing in Python" category = "main" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "scipy-1.9.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1884b66a54887e21addf9c16fb588720a8309a57b2e258ae1c7986d4444d3bc0"}, @@ -1830,7 +1799,7 @@ files = [ name = "tomli" version = "2.0.1" description = "A lil' TOML parser" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1843,7 +1812,7 @@ name = "torch" version = "1.13.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" category = "main" -optional = false +optional = true python-versions = ">=3.7.0" files = [ {file = "torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:fd12043868a34a8da7d490bf6db66991108b00ffbeecb034228bfcbbd4197143"}, @@ -1884,7 +1853,7 @@ name = "tqdm" version = "4.64.1" description = "Fast, Extensible Progress Meter" category = "main" -optional = false +optional = true python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" files = [ {file = "tqdm-4.64.1-py2.py3-none-any.whl", hash = "sha256:6fee160d6ffcd1b1c68c65f14c829c22832bc401726335ce92c52d395944a6a1"}, @@ -1972,7 +1941,7 @@ name = "wheel" version = "0.38.4" description = "A built-package format for Python" category = "main" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "wheel-0.38.4-py3-none-any.whl", hash = "sha256:b60533f3f5d530e971d6737ca6d58681ee434818fab630c83a734bb10c083ce8"}, @@ -2055,7 +2024,13 @@ files = [ idna = ">=2.0" multidict = ">=4.0" +[extras] +all = ["pyarrow", "torch", "dgl"] +arrow = ["pyarrow"] +dgl = ["torch", "dgl"] +torch-pyg = ["torch"] + [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "12af775519a25ecce98b684891f8ddcafbc7363ce0397129e9735fc40a50f440" +content-hash = "8a9ee84456c0584f19fee82239169cfa4abb9f7b8d3c20b2de78d8f89751b712" diff --git a/pyproject.toml b/pyproject.toml index 49143846..3dc6a1be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,14 +1,15 @@ [tool.poetry] name = "GQLAlchemy" -version = "1.4.0" +version = "1.4.1" -description = "GQLAlchemy is library developed with purpose of assisting writing and running queries on Memgraph." +description = "GQLAlchemy is a library developed to assist with writing and running queries in Memgraph." repository = "https://github.com/memgraph/gqlalchemy" authors = [ - "Bruno Sacaric ", - "Josip Mrden ", - "Katarina Supe ", - "Andi Skrgat ", + "Bruno Sacaric ", + "Josip Mrden ", + "Katarina Supe ", + "Andi Skrgat ", + "Ante Pusic ", ] license = "Apache-2.0" readme = "README.md" @@ -37,14 +38,21 @@ pymgclient = "1.3.1" networkx = "^2.5.1" pydantic = "^1.8.2" psutil = "^5.9.0" -pyarrow = "^9.0.0" dacite = "^1.6.0" adlfs = "^2022.2.0" neo4j = "^4.4.3" -docker = "^5.0.3" -torch = "^1.13.1" numpy = "^1.24.1" -dgl = "^0.9.1" +docker = "^5.0.3" + +pyarrow = { version = "^9.0.0", optional = true } +torch = { version = "^1.13.1", optional = true } +dgl = { version = "^0.9.1", optional = true } + +[tool.poetry.extras] +arrow = ["pyarrow"] +dgl = ["torch", "dgl"] +all = ["pyarrow", "torch", "dgl"] +torch_pyg = ["torch"] [tool.poetry.group.dev.dependencies] black = "^22.3.0" diff --git a/pytest.ini b/pytest.ini index cf89760f..19332480 100644 --- a/pytest.ini +++ b/pytest.ini @@ -9,3 +9,9 @@ python_files = test_*.py testpaths = tests gqlalchemy markers = slow: slow tests + extras: tests using the optional dependencies + pyg: tests using the PyG (PyTorch Geometric) extra + dgl: tests using the DGL (Deep Graph Library) extra + arrow: tests using the arrow (PyArrow) extra + ubuntu: slow tests + docker: slow tests diff --git a/tests/transformations/export/test_export.py b/tests/transformations/export/test_export.py index 97f1b3ee..fc81f7fc 100644 --- a/tests/transformations/export/test_export.py +++ b/tests/transformations/export/test_export.py @@ -1,15 +1,31 @@ import pytest from gqlalchemy.transformations.export.graph_transporter import GraphTransporter -from gqlalchemy.transformations.translators.dgl_translator import DGLTranslator -from gqlalchemy.transformations.translators.pyg_translator import PyGTranslator -from gqlalchemy.transformations.translators.nx_translator import NxTranslator -@pytest.mark.parametrize( - "graph_type, translator_type", [("DGL", DGLTranslator), ("pyG", PyGTranslator), ("Nx", NxTranslator)] -) -def test_export_selection_strategy(graph_type, translator_type): - transporter = GraphTransporter(graph_type) - assert isinstance(transporter.translator, translator_type) +@pytest.mark.extras +@pytest.mark.dgl +def test_export_dgl(): + DGLTranslator = pytest.importorskip("gqlalchemy.transformations.translators.dgl_translator.DGLTranslator") + + transporter = GraphTransporter(graph_type="DGL") + assert isinstance(transporter.translator, DGLTranslator) + transporter.export() # even with empty graph we should check that something doesn't fail + + +@pytest.mark.extras +@pytest.mark.pyg +def test_export_pyg(): + PyGTranslator = pytest.importorskip("gqlalchemy.transformations.translators.pyg_translator.PyGTranslator") + + transporter = GraphTransporter(graph_type="pyG") + assert isinstance(transporter.translator, PyGTranslator) + transporter.export() # even with empty graph we should check that something doesn't fail + + +def test_export_nx(): + from gqlalchemy.transformations.translators.nx_translator import NxTranslator + + transporter = GraphTransporter(graph_type="Nx") + assert isinstance(transporter.translator, NxTranslator) transporter.export() # even with empty graph we should check that something doesn't fail diff --git a/tests/transformations/importing/test_import.py b/tests/transformations/importing/test_import.py index d717d7bb..3ae1f5da 100644 --- a/tests/transformations/importing/test_import.py +++ b/tests/transformations/importing/test_import.py @@ -15,15 +15,31 @@ import pytest from gqlalchemy.transformations.importing.graph_importer import GraphImporter -from gqlalchemy.transformations.translators.dgl_translator import DGLTranslator -from gqlalchemy.transformations.translators.pyg_translator import PyGTranslator -from gqlalchemy.transformations.translators.nx_translator import NxTranslator -@pytest.mark.parametrize( - "graph_type, translator_type", [("DGL", DGLTranslator), ("pyG", PyGTranslator), ("Nx", NxTranslator)] -) -def test_import_selection_strategy(graph_type, translator_type): - importer = GraphImporter(graph_type) - assert isinstance(importer.translator, translator_type) +@pytest.mark.extras +@pytest.mark.dgl +def test_import_dgl(): + DGLTranslator = pytest.importorskip("gqlalchemy.transformations.translators.dgl_translator.DGLTranslator") + + importer = GraphImporter(graph_type="DGL") + assert isinstance(importer.translator, DGLTranslator) + importer.translate(None) # it should fail safely no matter what + + +@pytest.mark.extras +@pytest.mark.pyg +def test_import_pyg(): + PyGTranslator = pytest.importorskip("gqlalchemy.transformations.translators.pyg_translator.PyGTranslator") + + importer = GraphImporter(graph_type="pyG") + assert isinstance(importer.translator, PyGTranslator) + importer.translate(None) # it should fail safely no matter what + + +def test_import_nx(): + from gqlalchemy.transformations.translators.nx_translator import NxTranslator + + importer = GraphImporter(graph_type="Nx") + assert isinstance(importer.translator, NxTranslator) importer.translate(None) # it should fail safely no matter what diff --git a/tests/transformations/loaders/test_loaders.py b/tests/transformations/loaders/test_loaders.py index 873a42bf..6d8e64cc 100644 --- a/tests/transformations/loaders/test_loaders.py +++ b/tests/transformations/loaders/test_loaders.py @@ -75,8 +75,13 @@ def test_custom_data_loader(dummy_loader): assert dummy_loader.num == 42 +@pytest.mark.extras +@pytest.mark.arrow def test_local_table_to_graph_importer_parquet(memgraph): """e2e test, using Local File System to import into memgraph, tests available file extensions""" + + _ = pytest.importorskip("pyarrow") + my_configuration = { "indices": {"example": ["name"]}, "name_mappings": {"example": {"label": "PERSON"}}, @@ -86,8 +91,13 @@ def test_local_table_to_graph_importer_parquet(memgraph): importer.translate(drop_database_on_start=True) +@pytest.mark.extras +@pytest.mark.arrow def test_local_table_to_graph_importer_csv(memgraph): """e2e test, using Local File System to import into memgraph, tests available file extensions""" + + _ = pytest.importorskip("pyarrow") + my_configuration = { "indices": {"example": ["name"]}, "name_mappings": {"example": {"label": "PERSON"}}, @@ -97,8 +107,13 @@ def test_local_table_to_graph_importer_csv(memgraph): importer.translate(drop_database_on_start=True) +@pytest.mark.extras +@pytest.mark.arrow def test_local_table_to_graph_importer_orc(memgraph): """e2e test, using Local File System to import into memgraph, tests available file extensions""" + + _ = pytest.importorskip("pyarrow") + if platform.system() == "Windows": with pytest.raises(ValueError): ORCLocalFileSystemImporter(path="", data_configuration=None) @@ -112,8 +127,13 @@ def test_local_table_to_graph_importer_orc(memgraph): importer.translate(drop_database_on_start=True) +@pytest.mark.extras +@pytest.mark.arrow def test_local_table_to_graph_importer_feather(memgraph): """e2e test, using Local File System to import into memgraph, tests available file extensions""" + + _ = pytest.importorskip("pyarrow") + my_configuration = { "indices": {"example": ["name"]}, "name_mappings": {"example": {"label": "PERSON"}}, diff --git a/tests/transformations/translators/test_dgl_transformations.py b/tests/transformations/translators/test_dgl_transformations.py index d9367946..11daa625 100644 --- a/tests/transformations/translators/test_dgl_transformations.py +++ b/tests/transformations/translators/test_dgl_transformations.py @@ -12,22 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any, Set from numbers import Number +import pytest +from typing import Dict, Any, Set import numpy as np -import dgl -from dgl.data import TUDataset -import torch + from gqlalchemy import Match from gqlalchemy.models import Node, Relationship -from gqlalchemy.transformations.translators.dgl_translator import DGLTranslator + from gqlalchemy.transformations.translators.translator import Translator from gqlalchemy.transformations.constants import DGL_ID, DEFAULT_NODE_LABEL, DEFAULT_EDGE_TYPE from gqlalchemy.utilities import to_cypher_value from tests.transformations.common import execute_queries +dgl = pytest.importorskip("dgl") +TUDataset = pytest.importorskip("dgl.data.TUDataset") +torch = pytest.importorskip("torch") +DGLTranslator = pytest.importorskip("gqlalchemy.transformations.translators.dgl_translator.DGLTranslator") + +pytestmark = [pytest.mark.extras, pytest.mark.dgl] + ########## # UTILS ########## diff --git a/tests/transformations/translators/test_nx_transformations.py b/tests/transformations/translators/test_nx_transformations.py index 3f81a105..9cb112a8 100644 --- a/tests/transformations/translators/test_nx_transformations.py +++ b/tests/transformations/translators/test_nx_transformations.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import networkx as nx import pytest +import networkx as nx + from gqlalchemy.transformations.translators.nx_translator import ( NxTranslator, NoNetworkXConfigException, diff --git a/tests/transformations/translators/test_pyg_transformations.py b/tests/transformations/translators/test_pyg_transformations.py index eb81e9a5..40cf7359 100644 --- a/tests/transformations/translators/test_pyg_transformations.py +++ b/tests/transformations/translators/test_pyg_transformations.py @@ -12,21 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any, Set from numbers import Number - -from torch_geometric.data import Data, HeteroData -from torch_geometric.datasets import FakeDataset, FakeHeteroDataset -import torch +import pytest +from typing import Dict, Any, Set from gqlalchemy import Match from gqlalchemy.models import Node, Relationship -from gqlalchemy.transformations.translators.pyg_translator import PyGTranslator from gqlalchemy.transformations.translators.translator import Translator from gqlalchemy.transformations.constants import PYG_ID, DEFAULT_NODE_LABEL, DEFAULT_EDGE_TYPE from gqlalchemy.utilities import to_cypher_value from tests.transformations.common import execute_queries +PyGTranslator = pytest.importorskip("gqlalchemy.transformations.translators.pyg_translator.PyGTranslator") +Data = pytest.importorskip("torch_geometric.data.Data") +HeteroData = pytest.importorskip("torch_geometric.data.HeteroData") +FakeDataset = pytest.importorskip("torch_geometric.datasets.FakeDataset") +FakeHeteroDataset = pytest.importorskip("torch_geometric.datasets.FakeHeteroDataset") +torch = pytest.importorskip("torch") + +pytestmark = [pytest.mark.extras, pytest.mark.pyg] + # TODO: test number of properties that were converted ##########