Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Figure.plot & Figure.plot3d: Move common codes into _common.py #3461

Merged
merged 3 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions pygmt/src/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Common functions used in multiple PyGMT functions/methods.
"""

from pathlib import Path
from typing import Any

from pygmt.src.which import which


def _data_geometry_is_point(data: Any, kind: str) -> bool:
"""
Check if the geometry of the input data is Point or MultiPoint.

The inptu data can be a GeoJSON object or a OGR_GMT file.

This function is used in ``Figure.plot`` and ``Figure.plot3d``.

Parameters
----------
data
The data being plotted.
kind
The data kind.

Returns
-------
bool
``True`` if the geometry is Point/MultiPoint, ``False`` otherwise.
"""
if kind == "geojson" and data.geom_type.isin(["Point", "MultiPoint"]).all():
return True
if kind == "file" and str(data).endswith(".gmt"): # OGR_GMT file
try:
with Path(which(data)).open(encoding="utf-8") as file:
line = file.readline()
if "@GMULTIPOINT" in line or "@GPOINT" in line:
return True
except FileNotFoundError:
pass

Check warning on line 40 in pygmt/src/_common.py

View check run for this annotation

Codecov / codecov/patch

pygmt/src/_common.py#L39-L40

Added lines #L39 - L40 were not covered by tests
return False
21 changes: 4 additions & 17 deletions pygmt/src/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
plot - Plot in two dimensions.
"""

from pathlib import Path

from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
Expand All @@ -14,7 +12,7 @@
kwargs_to_strings,
use_alias,
)
from pygmt.src.which import which
from pygmt.src._common import _data_geometry_is_point


@fmt_docstring
Expand Down Expand Up @@ -50,9 +48,7 @@
w="wrap",
)
@kwargs_to_strings(R="sequence", c="sequence_comma", i="sequence_comma", p="sequence")
def plot( # noqa: PLR0912
self, data=None, x=None, y=None, size=None, direction=None, **kwargs
):
def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs):
r"""
Plot lines, polygons, and symbols in 2-D.

Expand Down Expand Up @@ -242,17 +238,8 @@ def plot( # noqa: PLR0912
raise GMTInvalidInput(f"'{name}' can't be 1-D array if 'data' is used.")

# Set the default style if data has a geometry of Point or MultiPoint
if kwargs.get("S") is None:
if kind == "geojson" and data.geom_type.isin(["Point", "MultiPoint"]).all():
kwargs["S"] = "s0.2c"
elif kind == "file" and str(data).endswith(".gmt"): # OGR_GMT file
try:
with Path(which(data)).open(encoding="utf-8") as file:
line = file.readline()
if "@GMULTIPOINT" in line or "@GPOINT" in line:
kwargs["S"] = "s0.2c"
except FileNotFoundError:
pass
if kwargs.get("S") is None and _data_geometry_is_point(data, kind):
kwargs["S"] = "s0.2c"

with Session() as lib:
with lib.virtualfile_in(
Expand Down
19 changes: 4 additions & 15 deletions pygmt/src/plot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
plot3d - Plot in three dimensions.
"""

from pathlib import Path

from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
Expand All @@ -14,7 +12,7 @@
kwargs_to_strings,
use_alias,
)
from pygmt.src.which import which
from pygmt.src._common import _data_geometry_is_point


@fmt_docstring
Expand Down Expand Up @@ -51,7 +49,7 @@
w="wrap",
)
@kwargs_to_strings(R="sequence", c="sequence_comma", i="sequence_comma", p="sequence")
def plot3d( # noqa: PLR0912
def plot3d(
self, data=None, x=None, y=None, z=None, size=None, direction=None, **kwargs
):
r"""
Expand Down Expand Up @@ -218,17 +216,8 @@ def plot3d( # noqa: PLR0912
raise GMTInvalidInput(f"'{name}' can't be 1-D array if 'data' is used.")

# Set the default style if data has a geometry of Point or MultiPoint
if kwargs.get("S") is None:
if kind == "geojson" and data.geom_type.isin(["Point", "MultiPoint"]).all():
kwargs["S"] = "u0.2c"
elif kind == "file" and str(data).endswith(".gmt"): # OGR_GMT file
try:
with Path(which(data)).open(encoding="utf-8") as file:
line = file.readline()
if "@GMULTIPOINT" in line or "@GPOINT" in line:
kwargs["S"] = "u0.2c"
except FileNotFoundError:
pass
if kwargs.get("S") is None and _data_geometry_is_point(data, kind):
kwargs["S"] = "u0.2c"

with Session() as lib:
with lib.virtualfile_in(
Expand Down
Loading