diff --git a/openeo_processes_dask/process_implementations/__init__.py b/openeo_processes_dask/process_implementations/__init__.py index ad5176d..5dd1333 100644 --- a/openeo_processes_dask/process_implementations/__init__.py +++ b/openeo_processes_dask/process_implementations/__init__.py @@ -8,6 +8,7 @@ from .inspect import * from .logic import * from .math import * +from .text import * try: from .ml import * diff --git a/openeo_processes_dask/process_implementations/text.py b/openeo_processes_dask/process_implementations/text.py new file mode 100644 index 0000000..0e665fd --- /dev/null +++ b/openeo_processes_dask/process_implementations/text.py @@ -0,0 +1,50 @@ +from typing import Any, Optional + + +def text_begins(data: str, pattern: str, case_sensitive: Optional[bool] = True) -> str: + if data: + if case_sensitive: + return data.startswith(pattern) + else: + return data.lower().startswith(pattern.lower()) + else: + return None + + +def text_contains( + data: str, pattern: str, case_sensitive: Optional[bool] = True +) -> str: + if data: + if case_sensitive: + return pattern in data + else: + return pattern.lower() in data.lower() + else: + return None + + +def text_ends(data: str, pattern: str, case_sensitive: Optional[bool] = True) -> str: + if data: + if case_sensitive: + return data.endswith(pattern) + else: + return data.lower().endswith(pattern.lower()) + else: + return None + + +def text_concat(data: list[Any], separator: Any) -> str: + string = "" + for elem in data: + if isinstance(elem, bool) or elem is None: + string += str(elem).lower() + else: + string += str(elem) + if isinstance(separator, bool) or separator is None: + string += str(separator).lower() + else: + string += str(separator) + if separator == "": + return string + else: + return string[: -len(str(separator))] diff --git a/tests/test_text.py b/tests/test_text.py new file mode 100644 index 0000000..7130faf --- /dev/null +++ b/tests/test_text.py @@ -0,0 +1,75 @@ +import pytest + +from openeo_processes_dask.process_implementations.text import * + + +@pytest.mark.parametrize( + "string,expected,pattern,case_sensitive", + [ + ("Lorem ipsum dolor sit amet", False, "amet", True), + ("Lorem ipsum dolor sit amet", True, "Lorem", True), + ("Lorem ipsum dolor sit amet", False, "lorem", True), + ("Lorem ipsum dolor sit amet", True, "lorem", False), + ("Ä", True, "ä", False), + (None, "nan", "null", True), + ], +) +def test_text_begins(string, expected, pattern, case_sensitive): + result = text_begins(string, pattern, case_sensitive) + if isinstance(expected, str) and "nan" == expected: + assert result is None + else: + assert result == expected + + +@pytest.mark.parametrize( + "string,expected,pattern,case_sensitive", + [ + ("Lorem ipsum dolor sit amet", True, "amet", True), + ("Lorem ipsum dolor sit amet", False, "Lorem", True), + ("Lorem ipsum dolor sit amet", False, "AMET", True), + ("Lorem ipsum dolor sit amet", True, "AMET", False), + ("Ä", True, "ä", False), + (None, "nan", "null", True), + ], +) +def test_text_ends(string, expected, pattern, case_sensitive): + result = text_ends(string, pattern, case_sensitive) + if isinstance(expected, str) and "nan" == expected: + assert result is None + else: + assert result == expected + + +@pytest.mark.parametrize( + "string,expected,pattern,case_sensitive", + [ + ("Lorem ipsum dolor sit amet", False, "openEO", True), + ("Lorem ipsum dolor sit amet", True, "ipsum dolor", True), + ("Lorem ipsum dolor sit amet", False, "Ipsum Dolor", True), + ("Lorem ipsum dolor sit amet", True, "SIT", False), + ("ÄÖÜ", True, "ö", False), + (None, "nan", "null", True), + ], +) +def test_text_contains(string, expected, pattern, case_sensitive): + result = text_contains(string, pattern, case_sensitive) + if isinstance(expected, str) and "nan" == expected: + assert result is None + else: + assert result == expected + + +@pytest.mark.parametrize( + "data,expected,separator", + [ + (["Hello", "World"], "Hello World", " "), + ([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], "1234567890", ""), + ([None, True, False, 1, -1.5, "ß"], "none\ntrue\nfalse\n1\n-1.5\nß", "\n"), + ([2, 0], "210", 1), + ([], "", ""), + ], +) +def test_text_contains(data, expected, separator): + result = text_concat(data, separator) + assert result == expected