diff --git a/pygmt/base_plotting.py b/pygmt/base_plotting.py index a3f682bebf9..6f97a9803c5 100644 --- a/pygmt/base_plotting.py +++ b/pygmt/base_plotting.py @@ -10,11 +10,12 @@ from .exceptions import GMTError, GMTInvalidInput from .helpers import ( build_arg_string, - dummy_context, data_kind, + dummy_context, fmt_docstring, - use_alias, kwargs_to_strings, + tempfile_from_buffer, + use_alias, ) @@ -799,10 +800,11 @@ def legend(self, spec=None, position="JTR+jTR+o0.2c", box="+gwhite+p1p", **kwarg Parameters ---------- - spec : None or str - Either None (default) for using the automatically generated legend - specification file, or a filename pointing to the legend - specification file. + spec : None or str or io.StringIO + Set to None (default) for using the automatically generated legend + specification file. Alternatively, pass in a filename or an + io.StringIO in-memory stream buffer pointing to the legend + specification text. {J} {R} position : str @@ -827,13 +829,17 @@ def legend(self, spec=None, position="JTR+jTR+o0.2c", box="+gwhite+p1p", **kwarg with Session() as lib: if spec is None: - specfile = "" + file_context = dummy_context("") elif data_kind(spec) == "file": - specfile = spec + file_context = dummy_context(spec) + elif data_kind(spec) == "buffer": + file_context = tempfile_from_buffer(spec) else: - raise GMTInvalidInput("Unrecognized data type: {}".format(type(spec))) - arg_str = " ".join([specfile, build_arg_string(kwargs)]) - lib.call_module("legend", arg_str) + raise GMTInvalidInput(f"Unrecognized data type: {type(spec)}") + + with file_context as fname: + arg_str = " ".join([fname, build_arg_string(kwargs)]) + lib.call_module("legend", arg_str) @fmt_docstring @use_alias( diff --git a/pygmt/helpers/__init__.py b/pygmt/helpers/__init__.py index b8a6958816d..95dc48078bb 100644 --- a/pygmt/helpers/__init__.py +++ b/pygmt/helpers/__init__.py @@ -2,7 +2,7 @@ Functions, classes, decorators, and context managers to help wrap GMT modules. """ from .decorators import fmt_docstring, use_alias, kwargs_to_strings -from .tempfile import GMTTempFile, unique_name +from .tempfile import GMTTempFile, tempfile_from_buffer, unique_name from .utils import ( data_kind, dummy_context, diff --git a/pygmt/helpers/tempfile.py b/pygmt/helpers/tempfile.py index a17293eb460..df8e85112e8 100644 --- a/pygmt/helpers/tempfile.py +++ b/pygmt/helpers/tempfile.py @@ -2,7 +2,9 @@ Utilities for dealing with temporary file management. """ import os +import shutil import uuid +from contextlib import contextmanager from tempfile import NamedTemporaryFile import numpy as np @@ -105,3 +107,49 @@ def loadtxt(self, **kwargs): """ return np.loadtxt(self.name, **kwargs) + + +@contextmanager +def tempfile_from_buffer(buf): + """ + Store an io.StringIO buffer stream inside a temporary text file. + + Use the temporary file name to pass in data in your string buffer to a GMT + module. + + Context manager (use in a ``with`` block). Yields the temporary file name + that you can pass as an argument to a GMT module call. Closes the + temporary file upon exit of the ``with`` block. + + Parameters + ---------- + buf : io.StringIO + The in-memory text stream buffer that will be included in the temporary + file. + + Yields + ------ + fname : str + The name of temporary file. Pass this as a file name argument to a GMT + module. + + Examples + -------- + + >>> import io + >>> from pygmt.helpers import tempfile_from_buffer + >>> from pygmt import info + >>> data = np.arange(0, 6, 0.5).reshape((4, 3)) + >>> buf = io.StringIO() + >>> np.savetxt(fname=buf, X=data, fmt="%.1f") + >>> with tempfile_from_buffer(buf=buf) as fname: + ... result = info(fname, per_column=True) + ... print(result.strip()) + 0 4.5 0.5 5 1 5.5 + """ + with GMTTempFile() as tmpfile: + buf.seek(0) # Change stream position back to start + with open(file=tmpfile.name, mode="w") as fdst: + shutil.copyfileobj(fsrc=buf, fdst=fdst) + + yield tmpfile.name diff --git a/pygmt/helpers/utils.py b/pygmt/helpers/utils.py index 5b4bb731baa..5217a7e15fe 100644 --- a/pygmt/helpers/utils.py +++ b/pygmt/helpers/utils.py @@ -1,9 +1,10 @@ """ Utilities and common tasks for wrapping the GMT modules. """ -import sys +import io import shutil import subprocess +import sys import webbrowser from collections.abc import Iterable from contextlib import contextmanager @@ -20,6 +21,7 @@ def data_kind(data, x=None, y=None, z=None): Possible types: * a file name provided as 'data' + * an io.StringIO in-memory stream provided as 'data' * an xarray.DataArray provided as 'data' * a matrix provided as 'data' * 1D arrays x and y (and z, optionally) @@ -29,8 +31,8 @@ def data_kind(data, x=None, y=None, z=None): Parameters ---------- - data : str, xarray.DataArray, 2d array, or None - Data file name, xarray.DataArray or numpy array. + data : str, io.StringIO, xarray.DataArray, 2d array, or None + Data file name, io.StringIO, xarray.DataArray or numpy array. x/y : 1d arrays or None x and y columns as numpy arrays. z : 1d array or None @@ -40,11 +42,13 @@ def data_kind(data, x=None, y=None, z=None): Returns ------- kind : str - One of: ``'file'``, ``'grid'``, ``'matrix'``, ``'vectors'``. + One of: ``'file'``, ``'buffer'``, ``'grid'``, ``'matrix'``, + ``'vectors'``. Examples -------- + >>> import io >>> import numpy as np >>> import xarray as xr >>> data_kind(data=None, x=np.array([1, 2, 3]), y=np.array([4, 5, 6])) @@ -53,6 +57,8 @@ def data_kind(data, x=None, y=None, z=None): 'matrix' >>> data_kind(data='my-data-file.txt', x=None, y=None) 'file' + >>> data_kind(data=io.StringIO("sometext"), x=None, y=None) + 'buffer' >>> data_kind(data=xr.DataArray(np.random.rand(4, 3))) 'grid' @@ -62,10 +68,12 @@ def data_kind(data, x=None, y=None, z=None): if data is not None and (x is not None or y is not None or z is not None): raise GMTInvalidInput("Too much data. Use either data or x and y.") if data is None and (x is None or y is None): - raise GMTInvalidInput("Must provided both x and y.") + raise GMTInvalidInput("Must provide both x and y.") if isinstance(data, str): kind = "file" + elif isinstance(data, io.StringIO): + kind = "buffer" elif isinstance(data, xr.DataArray): kind = "grid" elif data is not None: diff --git a/pygmt/tests/test_legend.py b/pygmt/tests/test_legend.py index 1fa98d6733a..5f33c618302 100644 --- a/pygmt/tests/test_legend.py +++ b/pygmt/tests/test_legend.py @@ -1,6 +1,8 @@ """ Tests for legend """ +import io + import pytest from .. import Figure @@ -45,7 +47,7 @@ def test_legend_default_position(): @pytest.mark.xfail( - reason="Baseline image not updated to use earth relief grid in GMT 6.1.0", + reason="Baseline image not updated to use earth relief grid in GMT 6.1.0" ) @pytest.mark.mpl_image_compare def test_legend_entries(): @@ -72,8 +74,9 @@ def test_legend_entries(): return fig -@pytest.mark.mpl_image_compare -def test_legend_specfile(): +@pytest.mark.parametrize("usebuffer", [True, False]) +@pytest.mark.mpl_image_compare(filename="test_legend_specfile.png") +def test_legend_specfile(usebuffer): """ Test specfile functionality. """ @@ -113,7 +116,10 @@ def test_legend_specfile(): fig = Figure() fig.basemap(projection="x6i", region=[0, 1, 0, 1], frame=True) - fig.legend(specfile.name, position="JTM+jCM+w5i") + + spec = io.StringIO(specfile_contents) if usebuffer else specfile.name + + fig.legend(spec=spec, position="JTM+jCM+w5i") return fig