diff --git a/experiment-newton.ipynb b/experiment-newton.ipynb new file mode 100644 index 0000000000..863b6e97b2 --- /dev/null +++ b/experiment-newton.ipynb @@ -0,0 +1,181 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "ff134dc2-ad8c-41b9-a8da-8cc7b5352b9d", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Callable\n", + "import pytensor\n", + "import pytensor.tensor as pt\n", + "from scipy import linalg\n", + "from pytensor.scan.utils import until\n", + "from functools import partial" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "759d75fe-6b86-42a5-a9d3-96af6de75053", + "metadata": {}, + "outputs": [], + "source": [ + "def _newton_step(func, x, args):\n", + " f_x = func(x, *args)\n", + " jac = pt.jacobian(f_x, x)\n", + "\n", + " # TODO It would be nice to return the factored matrix for the pullback\n", + " # TODO Handle errors of the factorization\n", + " grad = pt.linalg.solve(jac, f_x, assume_a=\"sym\")\n", + "\n", + " return f_x, x - grad, grad, jac\n", + "\n", + "def _check_convergence(f_x, x, new_x, grad, tol):\n", + " # TODO What convergence criterion? Norm of grad etc...\n", + " converged = pt.lt(pt.linalg.norm(f_x, ord=1), tol)\n", + " return converged\n", + "\n", + "def _scan_step(x, n_steps, *args, func, tol):\n", + " f_x, new_x, grad, jac = _newton_step(func, x, args)\n", + " is_converged = _check_convergence(f_x, x, new_x, grad, tol)\n", + " return (new_x, n_steps + 1, jac), until(is_converged)\n", + "\n", + "def root(\n", + " func: Callable,\n", + " x0: pt.TensorVariable, # rank 1\n", + " args: tuple[pt.Variable, ...],\n", + " max_iter: int = 113,\n", + " tol: float = 1e-8,\n", + ") -> tuple[\n", + " pt.TensorVariable, dict,\n", + "]:\n", + " root_func = partial(\n", + " _scan_step,\n", + " func=func,\n", + " tol=tol,\n", + " )\n", + "\n", + " outputs, updates = pytensor.scan(\n", + " root_func,\n", + " outputs_info=[x0, pt.constant(0, dtype=\"int64\"), None],\n", + " non_sequences=args,\n", + " n_steps=max_iter,\n", + " strict=True,\n", + " )\n", + "\n", + " x_trace, n_steps_trace, jac_trace = outputs\n", + " assert not updates\n", + "\n", + " return x_trace[-1], {\"n_steps\": n_steps_trace[-1], \"jac\": jac_trace[-1]}\n", + "\n", + "\n", + "def minimize(cost: Callable, x0: pt.TensorVariable, args):\n", + " def func(x):\n", + " return pt.grad(cost(x), x)\n", + "\n", + " return root(func, x0, args)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "21304789-4eab-49de-9db7-a5bb327712b2", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b031e81a-c615-4af5-b2d9-897ee46f15dc", + "metadata": {}, + "outputs": [], + "source": [ + "x0 = pt.tensor(\"x0\", shape=(3,))\n", + "#x0 = pt.full((3,), [2., 2., 2.])\n", + "#x0 = x0.copy()\n", + "\n", + "mu = pt.tensor(\"mu\", shape=())\n", + "\n", + "def func(x, mu):\n", + " cost = pt.sum((x ** 2 - mu) ** 2)\n", + " return pt.grad(cost, x)\n", + "\n", + "\n", + "x_root, stats = root(func, x0, args=[mu], tol=1e-8)\n", + "\n", + "(x_root_dmu,) = pt.grad(x_root[0], [mu])\n", + "\n", + "f_x = func(x_root, mu)\n", + "dfunc_dmu = pt.jacobian(f_x, mu, consider_constant=[x_root])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0d54e9a4-89ed-4670-b069-ea58bb4e85e5", + "metadata": {}, + "outputs": [], + "source": [ + "func = pytensor.function([x0, mu], [x_root, stats[\"n_steps\"], stats[\"jac\"], dfunc_dmu])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "07747b6d-71ca-4bc3-9546-45e3122890d4", + "metadata": {}, + "outputs": [], + "source": [ + "x_root, n_steps, jac, dfunc_dmu_val = func(np.ones(3) * 3, np.full((), 5.))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2bf94004-465e-4c04-a23a-971c43b637a7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.2236068, 0.2236068, 0.2236068])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Dervivative of x_root with respect to mu\n", + "-linalg.solve(jac, dfunc_dmu_val, assume_a=\"sym\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dev-cuda", + "language": "python", + "name": "dev-cuda" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}