Skip to content

Commit

Permalink
switch to np.narray types
Browse files Browse the repository at this point in the history
  • Loading branch information
amystamile-usgs committed Apr 23, 2024
1 parent 9eccae5 commit 4f8c09b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 48 deletions.
51 changes: 14 additions & 37 deletions knoten/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,18 @@
from typing import NamedTuple

class Point(NamedTuple):
x: np.double
y: np.double
z: np.double
x: np.ndarray
y: np.ndarray
z: np.ndarray

class LatLon(NamedTuple):
lat: np.double
lon: np.double

lat: np.ndarray
lon: np.ndarray
# np.narray
class Sphere(NamedTuple):
lat: np.double
lon: np.double
radius: np.double

class Matrix(NamedTuple):
vec_a: Point
vec_b: Point
vec_c: Point
lat: np.ndarray
lon: np.ndarray
radius: np.ndarray

def sep_angle(a_vec, b_vec):
"""
Expand All @@ -31,7 +26,7 @@ def sep_angle(a_vec, b_vec):
Returns
-------
: float
: np.ndarray
"""
dot_prod = a_vec.x * b_vec.x + a_vec.y * b_vec.y + a_vec.z * b_vec.z
dot_prod /= magnitude(a_vec) * magnitude(b_vec)
Expand All @@ -49,7 +44,7 @@ def magnitude(vec):
Returns
-------
: float
: np.ndarray
"""
return np.sqrt(vec.x * vec.x + vec.y * vec.y + vec.z * vec.z)

Expand All @@ -63,7 +58,7 @@ def distance(start, stop):
Returns
-------
: float
: np.ndarray
"""
diff = Point(stop.x - start.x, stop.y - start.y, stop.z - start.z)

Expand Down Expand Up @@ -133,7 +128,7 @@ def ground_azimuth(ground_pt, sub_pt):
Returns
-------
: float
: np.ndarray
"""
if (ground_pt.lat >= 0.0):
a = (90.0 - sub_pt.lat) * np.pi / 180.0
Expand Down Expand Up @@ -266,32 +261,14 @@ def scale_vector(vec, scalar):
----------
vec : Point object (x, y, z)
scalar : float
scalar : np.ndarray
Returns
-------
: Point object (x, y, z)
"""
return Point(vec.x * scalar, vec.y * scalar, vec.z * scalar)

def matrix_vec_product(mat, vec):
"""
Parameters
----------
mat : Matrix object (vec_a, vec_b, vec_c)
vec : Point object (x, y, z)
Returns
-------
: Point object (x, y, z)
"""
x = mat.vec_a.x * vec.x + mat.vec_a.y * vec.y + mat.vec_a.z * vec.z
y = mat.vec_b.x * vec.x + mat.vec_b.y * vec.y + mat.vec_b.z * vec.z
z = mat.vec_c.x * vec.x + mat.vec_c.y * vec.y + mat.vec_c.z * vec.z

return Point(x, y, z)

def reproject(record, semi_major, semi_minor, source_proj, dest_proj, **kwargs):
"""
Thin wrapper around PyProj's Transform() function to transform 1 or more three-dimensional
Expand Down
12 changes: 1 addition & 11 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,4 @@ def test_scale_vector():
vec = utils.Point(1.0, 2.0, -3.0)
scalar = 3.0
result = utils.Point(3.0, 6.0, -9.0)
np.testing.assert_array_equal(utils.scale_vector(vec, scalar), result)

def test_matrix_vec_product():
vec_a = utils.Point(0.0, 1.0, 0.0)
vec_b = utils.Point(-1.0, 0.0, 0.0)
vec_c = utils.Point(0.0, 0.0, 1.0)
mat = utils.Matrix(vec_a, vec_b, vec_c)
vec = utils.Point(1.0, 2.0, 3.0)

result = utils.Point(2.0, -1.0, 3.0)
np.testing.assert_array_equal(result, utils.matrix_vec_product(mat, vec))
np.testing.assert_array_equal(utils.scale_vector(vec, scalar), result)

0 comments on commit 4f8c09b

Please sign in to comment.