Skip to content

Commit

Permalink
Notebook with experimental newton implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Jul 18, 2024
1 parent ad27dc7 commit 767820e
Showing 1 changed file with 181 additions and 0 deletions.
181 changes: 181 additions & 0 deletions experiment-newton.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "ff134dc2-ad8c-41b9-a8da-8cc7b5352b9d",
"metadata": {},
"outputs": [],
"source": [
"from typing import Callable\n",
"import pytensor\n",
"import pytensor.tensor as pt\n",
"from scipy import linalg\n",
"from pytensor.scan.utils import until\n",
"from functools import partial"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "759d75fe-6b86-42a5-a9d3-96af6de75053",
"metadata": {},
"outputs": [],
"source": [
"def _newton_step(func, x, args):\n",
" f_x = func(x, *args)\n",
" jac = pt.jacobian(f_x, x)\n",
"\n",
" # TODO It would be nice to return the factored matrix for the pullback\n",
" # TODO Handle errors of the factorization\n",
" grad = pt.linalg.solve(jac, f_x, assume_a=\"sym\")\n",
"\n",
" return f_x, x - grad, grad, jac\n",
"\n",
"def _check_convergence(f_x, x, new_x, grad, tol):\n",
" # TODO What convergence criterion? Norm of grad etc...\n",
" converged = pt.lt(pt.linalg.norm(f_x, ord=1), tol)\n",
" return converged\n",
"\n",
"def _scan_step(x, n_steps, *args, func, tol):\n",
" f_x, new_x, grad, jac = _newton_step(func, x, args)\n",
" is_converged = _check_convergence(f_x, x, new_x, grad, tol)\n",
" return (new_x, n_steps + 1, jac), until(is_converged)\n",
"\n",
"def root(\n",
" func: Callable,\n",
" x0: pt.TensorVariable, # rank 1\n",
" args: tuple[pt.Variable, ...],\n",
" max_iter: int = 113,\n",
" tol: float = 1e-8,\n",
") -> tuple[\n",
" pt.TensorVariable, dict,\n",
"]:\n",
" root_func = partial(\n",
" _scan_step,\n",
" func=func,\n",
" tol=tol,\n",
" )\n",
"\n",
" outputs, updates = pytensor.scan(\n",
" root_func,\n",
" outputs_info=[x0, pt.constant(0, dtype=\"int64\"), None],\n",
" non_sequences=args,\n",
" n_steps=max_iter,\n",
" strict=True,\n",
" )\n",
"\n",
" x_trace, n_steps_trace, jac_trace = outputs\n",
" assert not updates\n",
"\n",
" return x_trace[-1], {\"n_steps\": n_steps_trace[-1], \"jac\": jac_trace[-1]}\n",
"\n",
"\n",
"def minimize(cost: Callable, x0: pt.TensorVariable, args):\n",
" def func(x):\n",
" return pt.grad(cost(x), x)\n",
"\n",
" return root(func, x0, args)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "21304789-4eab-49de-9db7-a5bb327712b2",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b031e81a-c615-4af5-b2d9-897ee46f15dc",
"metadata": {},
"outputs": [],
"source": [
"x0 = pt.tensor(\"x0\", shape=(3,))\n",
"#x0 = pt.full((3,), [2., 2., 2.])\n",
"#x0 = x0.copy()\n",
"\n",
"mu = pt.tensor(\"mu\", shape=())\n",
"\n",
"def func(x, mu):\n",
" cost = pt.sum((x ** 2 - mu) ** 2)\n",
" return pt.grad(cost, x)\n",
"\n",
"\n",
"x_root, stats = root(func, x0, args=[mu], tol=1e-8)\n",
"\n",
"(x_root_dmu,) = pt.grad(x_root[0], [mu])\n",
"\n",
"f_x = func(x_root, mu)\n",
"dfunc_dmu = pt.jacobian(f_x, mu, consider_constant=[x_root])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "0d54e9a4-89ed-4670-b069-ea58bb4e85e5",
"metadata": {},
"outputs": [],
"source": [
"func = pytensor.function([x0, mu], [x_root, stats[\"n_steps\"], stats[\"jac\"], dfunc_dmu])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "07747b6d-71ca-4bc3-9546-45e3122890d4",
"metadata": {},
"outputs": [],
"source": [
"x_root, n_steps, jac, dfunc_dmu_val = func(np.ones(3) * 3, np.full((), 5.))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2bf94004-465e-4c04-a23a-971c43b637a7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.2236068, 0.2236068, 0.2236068])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Dervivative of x_root with respect to mu\n",
"-linalg.solve(jac, dfunc_dmu_val, assume_a=\"sym\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "dev-cuda",
"language": "python",
"name": "dev-cuda"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 767820e

Please sign in to comment.