Skip to content

Commit

Permalink
Make large dependencies optional (#235)
Browse files Browse the repository at this point in the history
* Initial commit

* Add [all] option

* Add extras to workflow file

* Fix workflow

* Test workflows on the original state

* Test workflows (poetry.lock consistent with pyproject.toml)

* Readd changes with fixed lock file

* Update lock

* Fix poetry.lock

* Add review suggestions

* Make spelling consistent

* Mark tests requiring extras & modify imports if extras not installed

* FIx formatting

* Comply with flake8/black

* Remove torch as independent extra

* Rename extra

* Re-add torch (for users who will manually install dgl)

* Add detailed test grouping

* Fix import

* Fix import in utilities.py

* Switch to optionally importing torch (in graph translators)

* Fix typo

* Add import checks

* Improve import/export tests

* Remove unnecessary test mark

* Fix optional dependency tests

* Rename optional import parameters & torch extra
  • Loading branch information
antepusic authored Apr 18, 2023
1 parent cde5a5c commit 367f2ca
Show file tree
Hide file tree
Showing 17 changed files with 230 additions and 102 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
docker run -p 7474:7474 -p 7688:7687 -d -v $HOME/neo4j/data:/data -v $HOME/neo4j/logs:/logs -v $HOME/neo4j/import:/var/lib/neo4j/import -v $HOME/neo4j/plugins:/plugins --env NEO4J_AUTH=neo4j/test neo4j:4.4.7
- name: Test project
run: |
poetry install
poetry install --all-extras
poe install-pyg-cpu
poetry run pytest -vvv -m "not slow and not ubuntu and not docker"
- name: Use the Upload Artifact GitHub Action
Expand Down Expand Up @@ -114,7 +114,7 @@ jobs:
poetry-version: ${{ env.POETRY_VERSION }}
- name: Test project
run: |
poetry install
poetry install --all-extras
poe install-pyg-cpu
poetry run pytest -vvv -m "not slow and not ubuntu and not docker"
- name: Save Memgraph Logs
Expand Down
9 changes: 9 additions & 0 deletions gqlalchemy/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
SQLitePropertyDatabase("path-to-sqlite-db", db)
"""

MISSING_OPTIONAL_DEPENDENCY = """
No module named '{dependency_name}'
"""

MISSING_ORDER = """
The second argument of the tuple must be order: ASC, ASCENDING, DESC or DESCENDING.
"""
Expand Down Expand Up @@ -199,6 +203,11 @@ def __init__(self, path):
self.message = FILE_NOT_FOUND.format(path=path)


def raise_if_not_imported(dependency, dependency_name):
if not dependency:
raise ModuleNotFoundError(MISSING_OPTIONAL_DEPENDENCY.format(dependency_name=dependency_name))


def database_error_handler(func):
def inner_function(*args, **kwargs):
try:
Expand Down
20 changes: 16 additions & 4 deletions gqlalchemy/transformations/export/graph_transporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from gqlalchemy.exceptions import raise_if_not_imported
import gqlalchemy.memgraph_constants as mg_consts
from gqlalchemy.transformations.export.transporter import Transporter
from gqlalchemy.transformations.translators.dgl_translator import DGLTranslator
from gqlalchemy.transformations.translators.nx_translator import NxTranslator
from gqlalchemy.transformations.translators.pyg_translator import PyGTranslator
from gqlalchemy.transformations.graph_type import GraphType
import gqlalchemy.memgraph_constants as mg_consts

try:
from gqlalchemy.transformations.translators.dgl_translator import DGLTranslator
except ModuleNotFoundError:
DGLTranslator = None

from gqlalchemy.transformations.translators.nx_translator import NxTranslator

try:
from gqlalchemy.transformations.translators.pyg_translator import PyGTranslator
except ModuleNotFoundError:
PyGTranslator = None


class GraphTransporter(Transporter):
Expand Down Expand Up @@ -47,8 +57,10 @@ def __init__(
super().__init__()
self.graph_type = graph_type.upper()
if self.graph_type == GraphType.DGL.name:
raise_if_not_imported(dependency=DGLTranslator, dependency_name="dgl")
self.translator = DGLTranslator(host, port, username, password, encrypted, client_name, lazy)
elif self.graph_type == GraphType.PYG.name:
raise_if_not_imported(dependency=PyGTranslator, dependency_name="torch_geometric")
self.translator = PyGTranslator(host, port, username, password, encrypted, client_name, lazy)
elif self.graph_type == GraphType.NX.name:
self.translator = NxTranslator(host, port, username, password, encrypted, client_name, lazy)
Expand Down
19 changes: 16 additions & 3 deletions gqlalchemy/transformations/importing/graph_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,22 @@
from gqlalchemy import Memgraph
from gqlalchemy.transformations.graph_type import GraphType
from gqlalchemy.transformations.importing.importer import Importer
from gqlalchemy.transformations.translators.dgl_translator import DGLTranslator
from gqlalchemy.transformations.translators.nx_translator import NxTranslator
from gqlalchemy.transformations.translators.pyg_translator import PyGTranslator

from gqlalchemy.exceptions import raise_if_not_imported
import gqlalchemy.memgraph_constants as mg_consts

try:
from gqlalchemy.transformations.translators.dgl_translator import DGLTranslator
except ModuleNotFoundError:
DGLTranslator = None

from gqlalchemy.transformations.translators.nx_translator import NxTranslator

try:
from gqlalchemy.transformations.translators.pyg_translator import PyGTranslator
except ModuleNotFoundError:
PyGTranslator = None


class GraphImporter(Importer):
"""Imports dgl, pyg or networkx graph representations to Memgraph.
Expand All @@ -45,8 +56,10 @@ def __init__(
super().__init__()
self.graph_type = graph_type.upper()
if self.graph_type == GraphType.DGL.name:
raise_if_not_imported(dependency=DGLTranslator, dependency_name="dgl")
self.translator = DGLTranslator(host, port, username, password, encrypted, client_name, lazy)
elif self.graph_type == GraphType.PYG.name:
raise_if_not_imported(dependency=PyGTranslator, dependency_name="torch_geometric")
self.translator = PyGTranslator(host, port, username, password, encrypted, client_name, lazy)
elif self.graph_type == GraphType.NX.name:
self.translator = NxTranslator(host, port, username, password, encrypted, client_name, lazy)
Expand Down
20 changes: 18 additions & 2 deletions gqlalchemy/transformations/importing/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,15 @@
from typing import List, Dict, Any, Optional, Union

import adlfs
import pyarrow.dataset as ds
from pyarrow import fs

try:
import pyarrow.dataset as ds
except ModuleNotFoundError:
ds = None
try:
from pyarrow import fs
except ModuleNotFoundError:
fs = None
from dacite import from_dict

from gqlalchemy import Memgraph
Expand Down Expand Up @@ -223,6 +230,9 @@ def __init__(self, bucket_name: str, **kwargs):
if S3_SECRET_KEY not in kwargs:
raise KeyError(f"{S3_SECRET_KEY} is needed to connect to S3 storage")

if fs is None:
raise ModuleNotFoundError("No module named 'pyarrow'")

super().__init__(fs=fs.S3FileSystem(**kwargs))
self._bucket_name = bucket_name

Expand Down Expand Up @@ -278,6 +288,9 @@ def __init__(self, path: str) -> None:
Args:
path: path to the local storage location.
"""
if fs is None:
raise ModuleNotFoundError("No module named 'pyarrow'")

super().__init__(fs=fs.LocalFileSystem())
self._path = path

Expand Down Expand Up @@ -361,6 +374,9 @@ def load_data(
source = self._file_system_handler.get_path(f"{collection_name}.{self._file_extension}")
print("Loading data from " + ("cross " if is_cross_table else "") + f"table {source}...")

if ds is None:
raise ModuleNotFoundError("No module named 'pyarrow'")

dataset = ds.dataset(source=source, format=self._file_extension, filesystem=self._file_system_handler.fs)

for batch in dataset.to_batches(
Expand Down
2 changes: 1 addition & 1 deletion gqlalchemy/transformations/translators/dgl_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class DGLTranslator(Translator):
"""Performs conversion from cypher queries to the DGL graph representation. DGL assigns to each edge a unique integer, called the edge ID,
based on the order in which it was added to the graph. In DGL, all the edges are directed, and an edge (u,v) indicates that the direction goes
from node u to node v. Only features of numerical types (e.g., float, double, and int) are allowed. They can be scalars, vectors or multi-dimensional
tensors (DQL requirement). Each node feature has a unique name and each edge feature has a unique name. The features of nodes and edges can have
tensors (DGL requirement). Each node feature has a unique name and each edge feature has a unique name. The features of nodes and edges can have
the same name. A feature is created via tensor assignment, which assigns a feature to each node/edge in the graph. The leading dimension of that
tensor must be equal to the number of nodes/edges in the graph. You cannot assign a feature to a subset of the nodes/edges in the graph. Features of the
same name must have the same dimensionality and data type.
Expand Down
9 changes: 7 additions & 2 deletions gqlalchemy/transformations/translators/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
from collections import defaultdict
from numbers import Number

import torch
try:
import torch
except ModuleNotFoundError:
torch = None

from gqlalchemy.exceptions import raise_if_not_imported
from gqlalchemy.transformations.constants import LABELS_CONCAT, DEFAULT_NODE_LABEL, DEFAULT_EDGE_TYPE
from gqlalchemy.memgraph_constants import (
MG_HOST,
Expand All @@ -35,7 +39,6 @@


class Translator(ABC):

# Lambda function to concat list of labels
merge_labels: Callable[[Set[str]], str] = (
lambda labels, default_node_label: LABELS_CONCAT.join([label for label in sorted(labels)])
Expand Down Expand Up @@ -104,6 +107,8 @@ def validate_features(cls, features: List, expected_num: int):
Returns:
None if features cannot be set or tensor of same features.
"""
raise_if_not_imported(dependency=torch, dependency_name="torch")

if len(features) != expected_num:
return None
try:
Expand Down
30 changes: 25 additions & 5 deletions gqlalchemy/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,21 @@
# limitations under the License.

from abc import ABC, abstractmethod
import math
import numpy as np
import torch

from datetime import datetime, date, time, timedelta
from enum import Enum
import inspect
import math
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np

try:
import torch
except ModuleNotFoundError:
torch = None

from gqlalchemy.exceptions import raise_if_not_imported


class DatetimeKeywords(Enum):
DURATION = "duration"
Expand Down Expand Up @@ -67,13 +74,26 @@ def _format_timedelta(duration: timedelta) -> str:
return f"P{days}DT{hours}H{minutes}M{remainder_sec}S"


def _is_torch_tensor(value):
for cls in inspect.getmro(type(value)):
try:
if cls.__module__ == "torch" and cls.__name__ == "Tensor":
return True
except Exception:
pass
return False


def to_cypher_value(value: Any, config: NetworkXCypherConfig = None) -> str:
"""Converts value to a valid Cypher type."""
if config is None:
config = NetworkXCypherConfig()

value_type = type(value)
if isinstance(value, torch.Tensor):

if _is_torch_tensor(value):
raise_if_not_imported(dependency=torch, dependency_name="torch")

if value.squeeze().size() == 1:
return value.squeeze().item()
else:
Expand Down
Loading

0 comments on commit 367f2ca

Please sign in to comment.