diff --git a/doc/api/index.rst b/doc/api/index.rst index cff460ce2ff..1c80d8d163f 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -317,5 +317,6 @@ Low level access (these are mostly used by the :mod:`pygmt.clib` package): clib.Session.get_libgmt_func clib.Session.virtualfile_from_data clib.Session.virtualfile_from_grid + clib.Session.virtualfile_from_stringio clib.Session.virtualfile_from_matrix clib.Session.virtualfile_from_vectors diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index a0c11373d68..cdfc3bc4963 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -7,6 +7,7 @@ import contextlib import ctypes as ctp +import io import pathlib import sys import warnings @@ -60,6 +61,7 @@ "GMT_IS_PLP", # items could be any one of POINT, LINE, or POLY "GMT_IS_SURFACE", # items are 2-D grid "GMT_IS_VOLUME", # items are 3-D grid + "GMT_IS_TEXT", # Text strings which triggers ASCII text reading ] METHODS = [ @@ -70,6 +72,11 @@ DIRECTIONS = ["GMT_IN", "GMT_OUT"] MODES = ["GMT_CONTAINER_ONLY", "GMT_IS_OUTPUT"] +MODE_MODIFIERS = [ + "GMT_GRID_IS_CARTESIAN", + "GMT_GRID_IS_GEO", + "GMT_WITH_STRINGS", +] REGISTRATIONS = ["GMT_GRID_PIXEL_REG", "GMT_GRID_NODE_REG"] @@ -728,7 +735,7 @@ def create_data( mode_int = self._parse_constant( mode, valid=MODES, - valid_modifiers=["GMT_GRID_IS_CARTESIAN", "GMT_GRID_IS_GEO"], + valid_modifiers=MODE_MODIFIERS, ) geometry_int = self._parse_constant(geometry, valid=GEOMETRIES) registration_int = self._parse_constant(registration, valid=REGISTRATIONS) @@ -1603,6 +1610,100 @@ def virtualfile_from_grid(self, grid): with self.open_virtualfile(*args) as vfile: yield vfile + @contextlib.contextmanager + def virtualfile_from_stringio(self, stringio: io.StringIO): + r""" + Store a :class:`io.StringIO` object in a virtual file. + + Store the contents of a :class:`io.StringIO` object in a GMT_DATASET container + and create a virtual file to pass to a GMT module. + + For simplicity, currently we make following assumptions in the StringIO object + + - ``"#"`` indicates a comment line. + - ``">"`` indicates a segment header. + + Parameters + ---------- + stringio + The :class:`io.StringIO` object containing the data to be stored in the + virtual file. + + Yields + ------ + fname + The name of the virtual file. + + Examples + -------- + >>> import io + >>> from pygmt.clib import Session + >>> # A StringIO object containing legend specifications + >>> stringio = io.StringIO( + ... "# Comment\n" + ... "H 24p Legend\n" + ... "N 2\n" + ... "S 0.1i c 0.15i p300/12 0.25p 0.3i My circle\n" + ... ) + >>> with Session() as lib: + ... with lib.virtualfile_from_stringio(stringio) as fin: + ... lib.virtualfile_to_dataset(vfname=fin, output_type="pandas") + 0 + 0 H 24p Legend + 1 N 2 + 2 S 0.1i c 0.15i p300/12 0.25p 0.3i My circle + """ + # Parse the io.StringIO object. + segments = [] + current_segment = {"header": "", "data": []} + for line in stringio.getvalue().splitlines(): + if line.startswith("#"): # Skip comments + continue + if line.startswith(">"): # Segment header + if current_segment["data"]: # If we have data, start a new segment + segments.append(current_segment) + current_segment = {"header": "", "data": []} + current_segment["header"] = line.strip(">").lstrip() + else: + current_segment["data"].append(line) # type: ignore[attr-defined] + if current_segment["data"]: # Add the last segment if it has data + segments.append(current_segment) + + # One table with one or more segments. + # n_rows is the maximum number of rows/records for all segments. + # n_columns is the number of numeric data columns, so it's 0 here. + n_tables = 1 + n_segments = len(segments) + n_rows = max(len(segment["data"]) for segment in segments) + n_columns = 0 + + # Create the GMT_DATASET container + family, geometry = "GMT_IS_DATASET", "GMT_IS_TEXT" + dataset = self.create_data( + family, + geometry, + mode="GMT_CONTAINER_ONLY|GMT_WITH_STRINGS", + dim=[n_tables, n_segments, n_rows, n_columns], + ) + dataset = ctp.cast(dataset, ctp.POINTER(_GMT_DATASET)) + table = dataset.contents.table[0].contents + for i, segment in enumerate(segments): + seg = table.segment[i].contents + if segment["header"]: + seg.header = segment["header"].encode() # type: ignore[attr-defined] + seg.text = strings_to_ctypes_array(segment["data"]) + + with self.open_virtualfile(family, geometry, "GMT_IN", dataset) as vfile: + try: + yield vfile + finally: + # Must set the pointers to None to avoid double freeing the memory. + # Maybe upstream bug. + for i in range(n_segments): + seg = table.segment[i].contents + seg.header = None + seg.text = None + def virtualfile_in( # noqa: PLR0912 self, check_kind=None, @@ -1696,6 +1797,7 @@ def virtualfile_in( # noqa: PLR0912 "geojson": tempfile_from_geojson, "grid": self.virtualfile_from_grid, "image": tempfile_from_image, + "stringio": self.virtualfile_from_stringio, # Note: virtualfile_from_matrix is not used because a matrix can be # converted to vectors instead, and using vectors allows for better # handling of string type inputs (e.g. for datetime data types) @@ -1704,7 +1806,7 @@ def virtualfile_in( # noqa: PLR0912 }[kind] # Ensure the data is an iterable (Python list or tuple) - if kind in {"geojson", "grid", "image", "file", "arg"}: + if kind in {"geojson", "grid", "image", "file", "arg", "stringio"}: if kind == "image" and data.dtype != "uint8": msg = ( f"Input image has dtype: {data.dtype} which is unsupported, " diff --git a/pygmt/helpers/utils.py b/pygmt/helpers/utils.py index 33ef60b4c98..6585bb7566b 100644 --- a/pygmt/helpers/utils.py +++ b/pygmt/helpers/utils.py @@ -2,6 +2,7 @@ Utilities and common tasks for wrapping the GMT modules. """ +import io import os import pathlib import shutil @@ -188,8 +189,10 @@ def _check_encoding( def data_kind( data: Any = None, required: bool = True -) -> Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"]: - """ +) -> Literal[ + "arg", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors" +]: + r""" Check the kind of data that is provided to a module. The ``data`` argument can be in any type, but only following types are supported: @@ -222,6 +225,7 @@ def data_kind( >>> import numpy as np >>> import xarray as xr >>> import pathlib + >>> import io >>> data_kind(data=None) 'vectors' >>> data_kind(data=np.arange(10).reshape((5, 2))) @@ -240,8 +244,12 @@ def data_kind( 'grid' >>> data_kind(data=xr.DataArray(np.random.rand(3, 4, 5))) 'image' + >>> data_kind(data=io.StringIO("TEXT1\nTEXT23\n")) + 'stringio' """ - kind: Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"] + kind: Literal[ + "arg", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors" + ] if isinstance(data, str | pathlib.PurePath) or ( isinstance(data, list | tuple) and all(isinstance(_file, str | pathlib.PurePath) for _file in data) @@ -250,6 +258,8 @@ def data_kind( kind = "file" elif isinstance(data, bool | int | float) or (data is None and not required): kind = "arg" + elif isinstance(data, io.StringIO): + kind = "stringio" elif isinstance(data, xr.DataArray): kind = "image" if len(data.dims) == 3 else "grid" elif hasattr(data, "__geo_interface__"): diff --git a/pygmt/tests/test_clib_virtualfiles.py b/pygmt/tests/test_clib_virtualfiles.py index b8b5ee0500d..2a966de7c05 100644 --- a/pygmt/tests/test_clib_virtualfiles.py +++ b/pygmt/tests/test_clib_virtualfiles.py @@ -2,6 +2,7 @@ Test the C API functions related to virtual files. """ +import io from importlib.util import find_spec from itertools import product from pathlib import Path @@ -407,3 +408,106 @@ def test_inquire_virtualfile(): ]: with lib.open_virtualfile(family, geometry, "GMT_OUT", None) as vfile: assert lib.inquire_virtualfile(vfile) == lib[family] + + +class TestVirtualfileFromStringIO: + """ + Test the virtualfile_from_stringio method. + """ + + def _stringio_to_dataset(self, data: io.StringIO): + """ + A helper function for check the virtualfile_from_stringio method. + + The function does the following: + + 1. Creates a virtual file from the input StringIO object. + 2. Pass the virtual file to the ``read`` module, which reads the virtual file + and writes it to another virtual file. + 3. Reads the output virtual file as a GMT_DATASET object. + 4. Extracts the header and the trailing text from the dataset and returns it as + a string. + """ + with clib.Session() as lib: + with ( + lib.virtualfile_from_stringio(data) as vintbl, + lib.virtualfile_out(kind="dataset") as vouttbl, + ): + lib.call_module("read", args=[vintbl, vouttbl, "-Td"]) + ds = lib.read_virtualfile(vouttbl, kind="dataset").contents + + output = [] + table = ds.table[0].contents + for segment in table.segment[: table.n_segments]: + seg = segment.contents + output.append(f"> {seg.header.decode()}" if seg.header else ">") + output.extend(np.char.decode(seg.text[: seg.n_rows])) + return "\n".join(output) + "\n" + + def test_virtualfile_from_stringio(self): + """ + Test the virtualfile_from_stringio method. + """ + data = io.StringIO( + "# Comment\n" + "H 24p Legend\n" + "N 2\n" + "S 0.1i c 0.15i p300/12 0.25p 0.3i My circle\n" + ) + expected = ( + ">\n" + "H 24p Legend\n" + "N 2\n" + "S 0.1i c 0.15i p300/12 0.25p 0.3i My circle\n" + ) + assert self._stringio_to_dataset(data) == expected + + def test_one_segment(self): + """ + Test the virtualfile_from_stringio method with one segment. + """ + data = io.StringIO( + "# Comment\n" + "> Segment 1\n" + "1 2 3 ABC\n" + "4 5 DE\n" + "6 7 8 9 FGHIJK LMN OPQ\n" + "RSTUVWXYZ\n" + ) + expected = ( + "> Segment 1\n" + "1 2 3 ABC\n" + "4 5 DE\n" + "6 7 8 9 FGHIJK LMN OPQ\n" + "RSTUVWXYZ\n" + ) + assert self._stringio_to_dataset(data) == expected + + def test_multiple_segments(self): + """ + Test the virtualfile_from_stringio method with multiple segments. + """ + data = io.StringIO( + "# Comment line 1\n" + "# Comment line 2\n" + "> Segment 1\n" + "1 2 3 ABC\n" + "4 5 DE\n" + "6 7 8 9 FG\n" + "# Comment line 3\n" + "> Segment 2\n" + "1 2 3 ABC\n" + "4 5 DE\n" + "6 7 8 9 FG\n" + ) + expected = ( + "> Segment 1\n" + "1 2 3 ABC\n" + "4 5 DE\n" + "6 7 8 9 FG\n" + "> Segment 2\n" + "1 2 3 ABC\n" + "4 5 DE\n" + "6 7 8 9 FG\n" + ) + assert self._stringio_to_dataset(data) == expected