Skip to content

Commit

Permalink
Fix LeastSquares for functions with more than two arguments (#1016)
Browse files Browse the repository at this point in the history
Closes #974 

- Fixes a bug in cost.LeastSquares which prevented the use with data
that has more than two dimensions.
- New unit test checks that cost.LeastSquares works with 3-dimensional
data

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
HDembinski and pre-commit-ci[bot] authored Aug 1, 2024
1 parent cbd136b commit 8022a33
Show file tree
Hide file tree
Showing 10 changed files with 160 additions and 52 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- run: uv pip install --system -v . pytest ${{ matrix.installs }}
# python -m pip install .[test] is not used here to test minimum (faster),
# the cov workflow runs all tests.
- run: python -m pytest
# pip install .[test] is not used here to test minimum (faster)
# cov workflow runs all tests
- run: uv pip install --system . pytest pytest-xdist ${{ matrix.installs }}
- run: python -m pytest -n 3
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ bench/*.svg
.project
.pydevproject
.settings
.coverage
.coverage*
.ipynb_checkpoints
.eggs
.pytest_cache
.mypy_cache
.ruff_cache
.nox

Untitled*.ipynb
Untitled*.py
Expand Down
49 changes: 28 additions & 21 deletions noxfile.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""
Noxfile for iMinuit.
Noxfile for iminuit.
Use `-R` to instantly reuse an existing environment and
to avoid rebuilding the binary.
Pass extra arguments to pytest after --
"""

import nox
import sys

sys.path.append(".")
import python_releases

nox.needs_version = ">=2024.3.2"
nox.options.default_venv_backend = "uv|virtualenv"
Expand All @@ -15,46 +18,50 @@
"COVERAGE_CORE": "sysmon", # faster coverage on Python 3.12
}

PYPROJECT = nox.project.load_toml("pyproject.toml")
MINIMUM_PYTHON = PYPROJECT["project"]["requires-python"].strip(">=")
LATEST_PYTHON = str(python_releases.latest())

nox.options.sessions = ["test", "mintest", "maxtest"]


@nox.session(reuse_venv=True)
@nox.session()
def test(session: nox.Session) -> None:
"""Run the unit and regular tests."""
"""Run all tests."""
session.install("-e.[test]")
session.run("pytest", *session.posargs, env=ENV)
session.run("pytest", "-n=auto", *session.posargs, env=ENV)


@nox.session(python="3.12", reuse_venv=True)
def maxtest(session: nox.Session) -> None:
"""Run the unit and regular tests."""
session.install("-e.", "scipy", "matplotlib", "pytest", "--pre")
session.run("pytest", *session.posargs, env=ENV)
@nox.session(python=MINIMUM_PYTHON, venv_backend="uv")
def mintest(session: nox.Session) -> None:
"""Run tests on the minimum python version."""
session.install("-e.", "--resolution=lowest-direct")
session.install("pytest", "pytest-xdist")
session.run("pytest", "-n=auto", *session.posargs)


@nox.session(python="3.9", venv_backend="uv")
def mintest(session: nox.Session) -> None:
@nox.session(python=LATEST_PYTHON)
def maxtest(session: nox.Session) -> None:
"""Run the unit and regular tests."""
session.install("-e.", "--resolution=lowest-direct")
session.install("pytest")
session.run("pytest", *session.posargs)
session.install("-e.", "scipy", "matplotlib", "pytest", "pytest-xdist", "--pre")
session.run("pytest", "-n=auto", *session.posargs, env=ENV)


@nox.session(python="pypy3.9", venv_backend="uv")
@nox.session(python="pypy3.9")
def pypy(session: nox.Session) -> None:
"""Run the unit and regular tests."""
session.install("-e.")
session.install("pytest")
session.run("pytest", *session.posargs)
session.install("pytest", "pytest-xdist")
session.run("pytest", "-n=auto", *session.posargs)


# Python-3.12 provides coverage info faster
@nox.session(python="3.12", reuse_venv=True)
@nox.session(python="3.12", venv_backend="uv")
def cov(session: nox.Session) -> None:
"""Run covage and place in 'htmlcov' directory."""
session.install("-e.[test,doc]")
session.run("coverage", "run", "-m", "pytest", env=ENV)
session.run("coverage", "html", "-d", "htmlcov")
session.run("coverage", "html", "-d", "build/htmlcov")
session.run("coverage", "report", "-m")


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ test = [
"numba; platform_python_implementation=='CPython'",
"numba-stats; platform_python_implementation=='CPython'",
"pytest",
"pytest-xdist",
"scipy",
"tabulate",
"boost_histogram",
Expand Down
63 changes: 63 additions & 0 deletions python_releases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Get the latest Python release which is online."""

import urllib.request
import re
from html.parser import HTMLParser
import gzip
from packaging.version import Version


class PythonVersionParser(HTMLParser):
"""Specialized HTMLParser to get Python version number."""

def __init__(self):
super().__init__()
self.versions = set()
self.found_version = False

def handle_starttag(self, tag, attrs):
"""Look for the right tag and store result in an attribute."""
if tag == "a":
for attr in attrs:
if attr[0] == "href" and "/downloads/release/python-" in attr[1]:
self.found_version = True
return

def handle_data(self, data):
"""Extract Python version from entry."""
if self.found_version:
self.found_version = False
match = re.search(r"Python (\d+\.\d+)", data)
if match:
self.versions.add(Version(match.group(1)))


def versions():
"""Get all Python release versions."""
req = urllib.request.Request("https://www.python.org/downloads/")
req.add_header("Accept-Encoding", "gzip")

with urllib.request.urlopen(req) as response:
raw = response.read()
if response.info().get("Content-Encoding") == "gzip":
raw = gzip.decompress(raw)
html = raw.decode("utf-8")

parser = PythonVersionParser()
parser.feed(html)

return parser.versions


def latest():
"""Return version of latest Python release."""
return max(versions())


def main():
"""Print all discovered release versions."""
print(" ".join(str(x) for x in sorted(versions())))


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion src/iminuit/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2195,7 +2195,7 @@ def __init__(
y = _norm(y)
assert x.ndim >= 1 # guaranteed by _norm

self._ndim = x.ndim
self._ndim = x.shape[0] if x.ndim > 1 else 1
self._model = model
self._model_grad = grad
self.loss = loss
Expand Down
16 changes: 9 additions & 7 deletions src/iminuit/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
These are used by mypy and similar tools.
"""

from typing import Protocol, Optional, List, Union, runtime_checkable, NamedTuple
from typing import (
Protocol,
Optional,
List,
Union,
runtime_checkable,
NamedTuple,
Annotated,
)
from numpy.typing import NDArray
import numpy as np
import dataclasses
import sys

if sys.version_info < (3, 9):
from typing_extensions import Annotated # noqa pragma: no cover
else:
from typing import Annotated # noqa pragma: no cover


# Key for ValueView, ErrorView, etc.
Expand Down
8 changes: 3 additions & 5 deletions src/iminuit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,15 @@
Collection,
Sequence,
TypeVar,
Annotated,
get_args,
get_origin,
)
import abc
from time import monotonic
import warnings
import sys

if sys.version_info < (3, 9):
from typing_extensions import Annotated, get_args, get_origin # pragma: no cover
else:
from typing import Annotated, get_args, get_origin # pragma: no cover

T = TypeVar("T")

__all__ = (
Expand Down
61 changes: 48 additions & 13 deletions tests/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,34 +1066,69 @@ def model(x, a, b):


def test_LeastSquares_2D():
x = np.array([1.0, 2.0, 3.0])
y = np.array([4.0, 5.0, 6.0])
z = 1.5 * x + 0.2 * y
ze = 1.5

def model(xy, a, b):
x, y = xy
return a * x + b * y

c = LeastSquares((x, y), z, ze, model, grad=numerical_model_gradient(model))
x = np.array([1.0, 2.0, 3.0])
y = np.array([4.0, 5.0, 6.0])
f = model((x, y), 1.5, 0.2)
fe = 1.5

c = LeastSquares((x, y), f, fe, model, grad=numerical_model_gradient(model))
assert c.ndata == 3

ref = numerical_cost_gradient(c)
assert_allclose(c.grad(1, 2), ref(1, 2))

assert_equal(c.x, (x, y))
assert_equal(c.y, z)
assert_equal(c.yerror, ze)
assert_equal(c.y, f)
assert_equal(c.yerror, fe)
assert_allclose(c(1.5, 0.2), 0.0)
assert_allclose(c(2.5, 0.2), np.sum(((z - 2.5 * x - 0.2 * y) / ze) ** 2))
assert_allclose(c(1.5, 1.2), np.sum(((z - 1.5 * x - 1.2 * y) / ze) ** 2))
assert_allclose(c(2.5, 0.2), np.sum(((f - 2.5 * x - 0.2 * y) / fe) ** 2))
assert_allclose(c(1.5, 1.2), np.sum(((f - 1.5 * x - 1.2 * y) / fe) ** 2))

c.y = 2 * z
assert_equal(c.y, 2 * z)
c.y = 2 * f
assert_equal(c.y, 2 * f)
c.x = (y, x)
assert_equal(c.x, (y, x))


def test_LeastSquares_3D():
def model(xyz, a, b):
x, y, z = xyz
return a * x + b * y + a * b * z

x = np.array([1.0, 2.0, 3.0, 4.0])
y = np.array([4.0, 5.0, 6.0, 7.0])
z = np.array([7.0, 8.0, 9.0, 10.0])

f = model((x, y, z), 1.5, 0.2)
fe = 1.5

c = LeastSquares((x, y, z), f, fe, model, grad=numerical_model_gradient(model))
assert c.ndata == 4

ref = numerical_cost_gradient(c)
assert_allclose(c.grad(1, 2), ref(1, 2))

assert_equal(c.x, (x, y, z))
assert_equal(c.y, f)
assert_equal(c.yerror, fe)
assert_allclose(c(1.5, 0.2), 0.0)
assert_allclose(
c(2.5, 0.2), np.sum(((f - 2.5 * x - 0.2 * y - 2.5 * 0.2 * z) / fe) ** 2)
)
assert_allclose(
c(1.5, 1.2), np.sum(((f - 1.5 * x - 1.2 * y - 1.5 * 1.2 * z) / fe) ** 2)
)

c.y = 2 * z
assert_equal(c.y, 2 * z)
c.x = (y, x, z)
assert_equal(c.x, (y, x, z))


def test_LeastSquares_bad_input():
with pytest.raises(ValueError, match="shape mismatch"):
LeastSquares([1, 2], [], [1], lambda x, a: 0)
Expand Down Expand Up @@ -1208,7 +1243,7 @@ def line(x, par):
def test_LeastSquares_visualize_2D():
pytest.importorskip("matplotlib")

c = LeastSquares([[1, 2]], [[2, 3]], 0.1, line)
c = LeastSquares([[1, 2], [2, 3]], [1, 2], 0.1, line)

with pytest.raises(ValueError, match="not implemented for multi-dimensional"):
c.visualize((1, 2))
Expand Down
1 change: 1 addition & 0 deletions tests/test_without_ipywidgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


def test_interactive():
pytest.importorskip("matplotlib")
import iminuit

cost = LeastSquares([1.1, 2.2], [3.3, 4.4], 1, lambda x, a: a * x)
Expand Down

0 comments on commit 8022a33

Please sign in to comment.