diff --git a/.github/workflows/ci_tests.yaml b/.github/workflows/ci_tests.yaml index 02c200c77ec..4fe9dba07ea 100644 --- a/.github/workflows/ci_tests.yaml +++ b/.github/workflows/ci_tests.yaml @@ -166,7 +166,7 @@ jobs: # Run the regular tests - name: Run tests - run: make test PYTEST_EXTRA="-r P -n auto --reruns 2" + run: make test PYTEST_EXTRA="-r P --reruns 2" # Upload diff images on test failure - name: Upload diff images if any test fails diff --git a/pygmt/__init__.py b/pygmt/__init__.py index f6d1040851f..79d887d50a1 100644 --- a/pygmt/__init__.py +++ b/pygmt/__init__.py @@ -26,7 +26,6 @@ from pygmt.accessors import GMTDataArrayAccessor from pygmt.figure import Figure, set_display from pygmt.io import load_dataarray -from pygmt.session_management import begin as _begin from pygmt.session_management import end as _end from pygmt.src import ( binstats, @@ -66,7 +65,5 @@ xyz2grd, ) -# Start our global modern mode session -_begin() # Tell Python to run _end when shutting down _atexit.register(_end) diff --git a/pygmt/_state.py b/pygmt/_state.py new file mode 100644 index 00000000000..3586d759a54 --- /dev/null +++ b/pygmt/_state.py @@ -0,0 +1,9 @@ +""" +Private dictionary to keep tracking of current PyGMT state. + +The feature is only meant for internal use by PyGMT and is experimental! +""" + +_STATE = { + "session_name": None, +} diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 376d441746a..865c87932be 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -8,6 +8,7 @@ import contextlib import ctypes as ctp import io +import os import pathlib import sys import warnings @@ -17,6 +18,7 @@ import numpy as np import pandas as pd import xarray as xr +from pygmt._state import _STATE from pygmt.clib.conversion import ( array_to_datetime, as_c_contiguous, @@ -208,7 +210,20 @@ def __enter__(self): Calls :meth:`pygmt.clib.Session.create`. """ + _init_cli_session = False + # This is the first time a Session object is created. + if _STATE["session_name"] is None: + # Set GMT_SESSION_NAME to the current process id. + _STATE["session_name"] = os.environ["GMT_SESSION_NAME"] = str(os.getpid()) + # Need to initialize the GMT CLI session. + _init_cli_session = True self.create("pygmt-session") + + if _init_cli_session: + self.call_module("begin", args=["pygmt-session"]) + self.call_module(module="set", args=["GMT_COMPATIBILITY=6"]) + del _init_cli_session + return self def __exit__(self, exc_type, exc_value, traceback): diff --git a/pygmt/session_management.py b/pygmt/session_management.py index 87055bb44e8..d3234cbda63 100644 --- a/pygmt/session_management.py +++ b/pygmt/session_management.py @@ -2,11 +2,8 @@ Modern mode session management modules. """ -import os -import sys - +from pygmt._state import _STATE from pygmt.clib import Session -from pygmt.helpers import unique_name def begin(): @@ -17,10 +14,6 @@ def begin(): Only meant to be used once for creating the global session. """ - # On Windows, need to set GMT_SESSION_NAME to a unique value - if sys.platform == "win32": - os.environ["GMT_SESSION_NAME"] = unique_name() - prefix = "pygmt-session" with Session() as lib: lib.call_module(module="begin", args=[prefix]) @@ -39,3 +32,5 @@ def end(): """ with Session() as lib: lib.call_module(module="end", args=[]) + + _STATE["session_name"] = None # Reset the sesion name to None diff --git a/pygmt/tests/test_multiprocessing.py b/pygmt/tests/test_multiprocessing.py new file mode 100644 index 00000000000..bfb3a3d3983 --- /dev/null +++ b/pygmt/tests/test_multiprocessing.py @@ -0,0 +1,85 @@ +""" +Test multiprocessing support. +""" + +import multiprocessing as mp +from importlib import reload +from pathlib import Path + +import numpy.testing as npt +import pygmt + + +def _func(figname): + """ + A wrapper function for testing multiprocessing support. + """ + fig = pygmt.Figure() + fig.basemap(region=[10, 70, -3, 8], projection="X8c/6c", frame="afg") + fig.savefig(figname) + + +def test_multiprocessing(): + """ + Test multiprocessing support for plotting figures. + """ + prefix = "test_session_multiprocessing" + with mp.Pool(2) as p: + p.map(_func, [f"{prefix}-1.png", f"{prefix}-2.png"]) + Path(f"{prefix}-1.png").unlink() + Path(f"{prefix}-2.png").unlink() + + +def _func_datacut(dataset): + """ + A wrapper function for testing multiprocessing support. + """ + xrgrid = pygmt.grdcut(dataset, region=[-10, 10, -5, 5]) + return xrgrid + + +def test_multiprocessing_data_processing(): + """ + Test multiprocessing support for data processing. + """ + with mp.Pool(2) as p: + grids = p.map(_func_datacut, ["@earth_relief_01d_g", "@moon_relief_01d_g"]) + assert len(grids) == 2 + # The Earth relief dataset + assert grids[0].shape == (11, 21) + npt.assert_allclose(grids[0].min(), -5118.0, atol=0.5) + npt.assert_allclose(grids[0].max(), 680.5, atol=0.5) + # The Moon relief dataset + assert grids[1].shape == (11, 21) + npt.assert_allclose(grids[1].min(), -1122.0, atol=0.5) + npt.assert_allclose(grids[1].max(), 943.0, atol=0.5) + + +def _func_reload(figname): + """ + A wrapper for running PyGMT scripts with multiprocessing. + + Before the official multiprocessing support in PyGMT, we have to reload the + PyGMT library. Workaround from + https://github.com/GenericMappingTools/pygmt/issues/217#issuecomment-754774875. + + This test makes sure that the old workaround still works. + """ + import pygmt + + reload(pygmt) + fig = pygmt.Figure() + fig.basemap(region=[10, 70, -3, 8], projection="X8c/6c", frame="afg") + fig.savefig(figname) + + +def test_multiprocessing_reload(): + """ + Make sure that multiprocessing is supported if pygmt is re-imported. + """ + + prefix = "test_session_multiprocessing" + with mp.Pool(2) as p: + p.map(_func_reload, [f"{prefix}-1.png", f"{prefix}-2.png"]) + Path(f"{prefix}-1.png").unlink() + Path(f"{prefix}-2.png").unlink() diff --git a/pygmt/tests/test_session_management.py b/pygmt/tests/test_session_management.py index d949f1a51c0..1a811343b90 100644 --- a/pygmt/tests/test_session_management.py +++ b/pygmt/tests/test_session_management.py @@ -2,8 +2,6 @@ Test the session management modules. """ -import multiprocessing as mp -from importlib import reload from pathlib import Path import pytest @@ -36,10 +34,8 @@ def test_gmt_compat_6_is_applied(capsys): """ end() # Kill the global session try: - # Generate a gmt.conf file in the current directory - # with GMT_COMPATIBILITY = 5 - with Session() as lib: - lib.call_module("gmtset", ["GMT_COMPATIBILITY=5"]) + # Generate a gmt.conf file in the current directory with GMT_COMPATIBILITY = 5 + Path("gmt.conf").write_text("GMT_COMPATIBILITY = 5", encoding="utf-8") begin() with Session() as lib: lib.call_module("basemap", ["-R10/70/-3/8", "-JX4i/3i", "-Ba"]) @@ -60,29 +56,3 @@ def test_gmt_compat_6_is_applied(capsys): # Make sure no global "gmt.conf" in the current directory assert not Path("gmt.conf").exists() begin() # Restart the global session - - -def _gmt_func_wrapper(figname): - """ - A wrapper for running PyGMT scripts with multiprocessing. - - Currently, we have to import pygmt and reload it in each process. Workaround from - https://github.com/GenericMappingTools/pygmt/issues/217#issuecomment-754774875. - """ - import pygmt - - reload(pygmt) - fig = pygmt.Figure() - fig.basemap(region=[10, 70, -3, 8], projection="X8c/6c", frame="afg") - fig.savefig(figname) - - -def test_session_multiprocessing(): - """ - Make sure that multiprocessing is supported if pygmt is re-imported. - """ - prefix = "test_session_multiprocessing" - with mp.Pool(2) as p: - p.map(_gmt_func_wrapper, [f"{prefix}-1.png", f"{prefix}-2.png"]) - Path(f"{prefix}-1.png").unlink() - Path(f"{prefix}-2.png").unlink()