From b3d29bc6e23ca6efc9c2febccdad99d6625224a0 Mon Sep 17 00:00:00 2001 From: lkct Date: Fri, 27 Sep 2024 17:28:23 +0100 Subject: [PATCH 1/3] minor fixes to poly and diff --- cirkit/backend/torch/parameters/nodes.py | 10 ++++-- cirkit/backend/torch/semiring.py | 2 +- cirkit/symbolic/functional.py | 17 +++++---- cirkit/symbolic/layers.py | 44 +++++++++++++++++++----- cirkit/symbolic/operators.py | 3 +- cirkit/symbolic/parameters.py | 10 ++++-- 6 files changed, 63 insertions(+), 23 deletions(-) diff --git a/cirkit/backend/torch/parameters/nodes.py b/cirkit/backend/torch/parameters/nodes.py index c87c366e..ff1bae9c 100644 --- a/cirkit/backend/torch/parameters/nodes.py +++ b/cirkit/backend/torch/parameters/nodes.py @@ -811,10 +811,16 @@ def __init__(self, in_shape: tuple[int, ...], *, num_folds: int = 1, order: int def shape(self) -> tuple[int, ...]: # if dp1>order, i.e., deg>=order, then diff, else const 0. return ( - self.in_shapes[0][0], - self.in_shapes[0][1] - self.order if self.in_shapes[0][1] > self.order else 1, + self.in_shapes[0][0], # dim Ko + self.in_shapes[0][1] - self.order + if self.in_shapes[0][1] > self.order + else 1, # dim dp1 ) + @property + def config(self) -> dict[str, Any]: + return {**super().config, "order": self.order} + @classmethod def _diff_once(cls, x: Tensor) -> Tensor: degp1 = x.shape[-1] # x shape (F, K, dp1). diff --git a/cirkit/backend/torch/semiring.py b/cirkit/backend/torch/semiring.py index 58fa904b..11d37c4f 100644 --- a/cirkit/backend/torch/semiring.py +++ b/cirkit/backend/torch/semiring.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence from typing import ClassVar, TypeVar, cast -from typing_extensions import TypeVarTuple, Unpack, final +from typing_extensions import TypeVarTuple, Unpack, final # FUTURE: in typing from 3.11 import torch from torch import Tensor diff --git a/cirkit/symbolic/functional.py b/cirkit/symbolic/functional.py index b193732d..1b0efef1 100644 --- a/cirkit/symbolic/functional.py +++ b/cirkit/symbolic/functional.py @@ -306,19 +306,19 @@ def differentiate( if order <= 0: raise ValueError("The order of differentiation must be positive.") - # Use the registry in the current context, if not specified otherwise + # Use the registry in the current context, if not specified otherwise. if registry is None: registry = OPERATOR_REGISTRY.get() - # Mapping the symbolic circuit layers with blocks of circuit layers + # Mapping the symbolic circuit layers with blocks of circuit layers. layers_to_blocks: dict[Layer, list[CircuitBlock]] = {} - # For each new circuit block, keep track of its inputs + # For each new circuit block, keep track of its inputs. in_blocks: dict[CircuitBlock, Sequence[CircuitBlock]] = {} for sl in sc.topological_ordering(): # "diff_blocks: List[CircuitBlock]" is the diff of sl wrt each variable and channel in order - # and then at the end we append a copy of sl + # and then at the end we append a copy of sl. sl_params = {name: p.ref() for name, p in sl.params.items()} if isinstance(sl, InputLayer): @@ -348,6 +348,7 @@ def differentiate( # The layers are the same for all diffs of a SumLayer. We retrieve (num_vars * num_chs) # from the length of one input blocks. var_ch = len(layers_to_blocks[sc.layer_inputs(sl)[0]][:-1]) + # TODO: make a shortcut for the copy idiom? diff_blocks = [ CircuitBlock.from_layer(type(sl)(**sl.config, **sl_params)) for _ in range(var_ch) ] @@ -433,13 +434,15 @@ def differentiate( # Save all the blocks including a copy of sl at [-1] as the diff layers of sl. layers_to_blocks[sl] = diff_blocks - # Construct the integral symbolic circuit and set the integration operation metadata + # Construct the differential symbolic circuit and set the differentiation operation metadata. return Circuit.from_operation( sc.scope, sc.num_channels, - sum(layers_to_blocks.values(), []), + list(itertools.chain.from_iterable(layers_to_blocks.values())), in_blocks, # TODO: in_blocks uses Sequence, and Sequence should work. - sum((layers_to_blocks[sl] for sl in sc.outputs), []), + itertools.chain.from_iterable( + layers_to_blocks[sl] for sl in sc.outputs + ), # TODO: Iterable should work operation=CircuitOperation( operator=CircuitOperator.DIFFERENTIATION, operands=(sc,), diff --git a/cirkit/symbolic/layers.py b/cirkit/symbolic/layers.py index d9cd7135..a76a1abc 100644 --- a/cirkit/symbolic/layers.py +++ b/cirkit/symbolic/layers.py @@ -2,7 +2,7 @@ from enum import IntEnum, auto from typing import Any, cast -from cirkit.symbolic.initializers import NormalInitializer +from cirkit.symbolic.initializers import Initializer, NormalInitializer from cirkit.symbolic.parameters import ( Parameter, ParameterFactory, @@ -86,6 +86,37 @@ def params(self) -> dict[str, Parameter]: """ return {} + # TODO: apply to other layers + @staticmethod + def _make_param( + param: Parameter | None, + param_factory: ParameterFactory | None, + shape: tuple[int, ...], + default_factory: ParameterFactory = lambda shape: Parameter.from_leaf( + TensorParameter(*shape, initializer=NormalInitializer()) + ), + ) -> Parameter: + """Make a parameter from the optional parameter object or factory or default. + + Args: + param (Optional[Parameter]): The optional parameter provided to a layer. + param_factory (Optional[ParameterFactory]): The optional parameter factory provided to + a layer. + shape (Tuple[int, ...]): The shape of the parameter. + default_factory (ParameterFactory, optional): The factory to use when falling back to + default. Defaults to a new leaf TensorParameter with NormalInitializer. + + Returns: + Parameter: The parameter. + """ + if param is not None: + return param + + if param_factory is not None: + return param_factory(shape) + + return default_factory(shape) + class InputLayer(Layer): """The symbolic input layer class.""" @@ -318,13 +349,7 @@ def __init__( raise ValueError("The Polynomial layer encodes a univariate distribution") super().__init__(scope, num_output_units, num_channels) self.degree = degree - if coeff is None: - if coeff_factory is None: - coeff = Parameter.from_leaf( - TensorParameter(*self._coeff_shape, initializer=NormalInitializer()) - ) - else: - coeff = coeff_factory(self._coeff_shape) + coeff = self._make_param(coeff, coeff_factory, self._coeff_shape) if coeff.shape != self._coeff_shape: raise ValueError(f"Expected parameter shape {self._coeff_shape}, found {coeff.shape}") self.coeff = coeff @@ -334,7 +359,8 @@ def _coeff_shape(self) -> tuple[int, ...]: return self.num_output_units, self.degree + 1 @property - def config(self) -> dict: + def config(self) -> dict[str, Any]: + # FUTURE: use | operator in 3.9 return {**super().config, "degree": self.degree} @property diff --git a/cirkit/symbolic/operators.py b/cirkit/symbolic/operators.py index 0ef72569..8d19649d 100644 --- a/cirkit/symbolic/operators.py +++ b/cirkit/symbolic/operators.py @@ -177,9 +177,8 @@ def multiply_polynomial_layers(sl1: PolynomialLayer, sl2: PolynomialLayer) -> Ci f"but found '{sl1.num_channels}' and '{sl2.num_channels}'" ) - shape1, shape2 = sl1.coeff.shape, sl2.coeff.shape coeff = Parameter.from_binary( - PolynomialProduct(shape1, shape2), + PolynomialProduct(sl1.coeff.shape, sl2.coeff.shape), sl1.coeff.ref(), sl2.coeff.ref(), ) diff --git a/cirkit/symbolic/parameters.py b/cirkit/symbolic/parameters.py index c8a25105..d57de174 100644 --- a/cirkit/symbolic/parameters.py +++ b/cirkit/symbolic/parameters.py @@ -572,6 +572,12 @@ def __init__(self, in_shape: tuple[int, ...], *, order: int = 1): def shape(self) -> tuple[int, ...]: # if dp1>order, i.e., deg>=order, then diff, else const 0. return ( - self.in_shapes[0][0], - self.in_shapes[0][1] - self.order if self.in_shapes[0][1] > self.order else 1, + self.in_shapes[0][0], # dim Ko + self.in_shapes[0][1] - self.order + if self.in_shapes[0][1] > self.order + else 1, # dim dp1 ) + + @property + def config(self) -> dict[str, Any]: + return {**super().config, "order": self.order} From ce1ccaf6ddd502d92ddb6224b2c847da33fb77d0 Mon Sep 17 00:00:00 2001 From: lkct Date: Wed, 18 Sep 2024 18:53:26 +0100 Subject: [PATCH 2/3] notebook for adding poly layer --- notebooks/how-to-add-an-input-layer.ipynb | 606 ++++++++++++++++++++++ 1 file changed, 606 insertions(+) create mode 100644 notebooks/how-to-add-an-input-layer.ipynb diff --git a/notebooks/how-to-add-an-input-layer.ipynb b/notebooks/how-to-add-an-input-layer.ipynb new file mode 100644 index 00000000..c16a7601 --- /dev/null +++ b/notebooks/how-to-add-an-input-layer.ipynb @@ -0,0 +1,606 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How-To: Add an Input Layer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook introduces how to add a custom input layer. (To add sum/product layers, please open an issue to discuss).\n", + "\n", + "- For users: The code may be added to anywhere in your project, just make sure you have proper imports.\n", + "- For developers: Please look at comments for each code block to decide where to add the code pieces." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A new layer requires a symbolic definition and the implementation(s) corresponding to the backend(s)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this notebook, we will illustrate the process with `MyPolynomialLayer` and its `torch` backend, which is a replicate of `PolynomialLayer` in the library." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Symbolic" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the symbolic part, we will have to:\n", + "- Add the definition of the layer;\n", + "- Decide the operators supported by this layer;\n", + " - Identify the parameter operations required by the operators.\n", + "\n", + "All the above will not involve any actual tensors, just the configs and shapes." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can add as many operators as the layer supports, but for illustrative purposes, here we only illustrate with multiplication.\n", + "\n", + "For operators the layer does not support, just leave it out and it will be properly handled." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Layer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For layer definition, we should include any configs it needs, along with any parameter it includes. The parameter can be constructed from an optionally provided parameter or a factory, or falls back to default which is a new parameter with normal initialization (init may be changed by additional args to `_make_param`)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The basic set of methods to define for a layer consist of:\n", + "- `__init__`: A must in most cases. Defines how to instantiate this layer.\n", + "- `_{param}_shape`: One for each parameter, if any. Specifies the shape of the parameter.\n", + "- `config`: Must-have if `__init__` accepts any args other than `scope`, `num_output_units`, `num_channels` and the params. Should be appended with any configs of the layer to `super().config`.\n", + "- `params`: Must-have if the layer has any parameters. Includes all params in a dict." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[cirkit/symbolic/layers.py](../cirkit/symbolic/layers.py)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any\n", + "\n", + "from cirkit.symbolic.layers import InputLayer\n", + "from cirkit.symbolic.parameters import Parameter, ParameterFactory\n", + "from cirkit.utils.scope import Scope\n", + "\n", + "\n", + "class MyPolynomialLayer(InputLayer):\n", + " def __init__(\n", + " self,\n", + " scope: Scope,\n", + " num_output_units: int,\n", + " num_channels: int,\n", + " *,\n", + " degree: int,\n", + " coeff: Parameter | None = None,\n", + " coeff_factory: ParameterFactory | None = None,\n", + " ):\n", + " if len(scope) != 1:\n", + " raise ValueError(\"The Polynomial layer encodes a univariate distribution\")\n", + " if num_channels != 1:\n", + " raise ValueError(\"The Polynomial layer encodes a univariate distribution\")\n", + " super().__init__(scope, num_output_units, num_channels)\n", + " self.degree = degree\n", + " coeff = self._make_param(coeff, coeff_factory, self._coeff_shape)\n", + " if coeff.shape != self._coeff_shape:\n", + " raise ValueError(f\"Expected parameter shape {self._coeff_shape}, found {coeff.shape}\")\n", + " self.coeff = coeff\n", + "\n", + " @property\n", + " def _coeff_shape(self) -> tuple[int, ...]:\n", + " return self.num_output_units, self.degree + 1\n", + "\n", + " @property\n", + " def config(self) -> dict[str, Any]:\n", + " return {**super().config, \"degree\": self.degree}\n", + "\n", + " @property\n", + " def params(self) -> dict[str, Parameter]:\n", + " return {\"coeff\": self.coeff}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Parameter Operation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After deciding which operator(s) we want to support, we must define the parameter operations the operator(s) need(s).\n", + "\n", + "Since we are only looking at multiplication here, and the layer only has one parameter `coeff`, we only need to define one patameter operation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As multiplication is a binary operator, we can inherit from `BinaryParameterOp` to make the best use of existing infrastructure. Alternatively, a more general `ParameterOp` class may be inherited.\n", + "\n", + "The mininum definition should include the `shape` property which defines the output shape of this parameter operation.\n", + "\n", + "Optionally, `__init__` can be redefined with customized instantiation behaviour, and `config` should include any additional args of `__init__`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[cirkit/symbolic/parameters.py](../cirkit/symbolic/parameters.py)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from cirkit.symbolic.parameters import BinaryParameterOp\n", + "\n", + "\n", + "class MyPolynomialProduct(BinaryParameterOp):\n", + " @property\n", + " def shape(self) -> tuple[int, ...]:\n", + " return (\n", + " self.in_shapes[0][0] * self.in_shapes[1][0], # dim Ko\n", + " self.in_shapes[0][1] + self.in_shapes[1][1] - 1, # dim deg+1\n", + " )\n", + "\n", + " # -------- unnecessary in this case, directly use inherited --------\n", + "\n", + " # def __init__(self, in_shape1: tuple[int, ...], in_shape2: tuple[int, ...]):\n", + " # super().__init__(in_shape1, in_shape2)\n", + "\n", + " # @property\n", + " # def config(self) -> dict[str, Any]:\n", + " # return {}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Layer Operator" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After the layer and param op have been defined, we can then define how an operator act on the layer by defining a rule function and registering it to the rules registry." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to share the underlying parameters across the operations, `param.ref()` should be passed to build the new parameter from the operators.\n", + "\n", + "And then, the resulting new layer (or can be layers, if needed) should be wrapped in a `CircuitBlock` for return." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[cirkit/symbolic/operators.py](../cirkit/symbolic/operators.py)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from cirkit.symbolic.circuit import CircuitBlock\n", + "from cirkit.symbolic.layers import LayerOperator\n", + "from cirkit.symbolic.operators import DEFAULT_OPERATOR_RULES\n", + "from cirkit.symbolic.parameters import Parameter\n", + "\n", + "\n", + "def multiply_mypolynomial_layers(sl1: MyPolynomialLayer, sl2: MyPolynomialLayer) -> CircuitBlock:\n", + " if sl1.scope != sl2.scope:\n", + " raise ValueError(\n", + " f\"Expected Polynomial layers to have the same scope,\"\n", + " f\" but found '{sl1.scope}' and '{sl2.scope}'\"\n", + " )\n", + " if sl1.num_channels != sl2.num_channels:\n", + " raise ValueError(\n", + " f\"Expected Polynomial layers to have the number of channels,\"\n", + " f\"but found '{sl1.num_channels}' and '{sl2.num_channels}'\"\n", + " )\n", + "\n", + " coeff = Parameter.from_binary(\n", + " MyPolynomialProduct(sl1.coeff.shape, sl2.coeff.shape),\n", + " sl1.coeff.ref(),\n", + " sl2.coeff.ref(),\n", + " )\n", + "\n", + " sl = MyPolynomialLayer(\n", + " sl1.scope,\n", + " sl1.num_output_units * sl2.num_output_units,\n", + " num_channels=sl1.num_channels,\n", + " degree=sl1.degree + sl2.degree,\n", + " coeff=coeff,\n", + " )\n", + " return CircuitBlock.from_layer(sl)\n", + "\n", + "\n", + "DEFAULT_OPERATOR_RULES[LayerOperator.MULTIPLICATION].append(multiply_mypolynomial_layers)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Implementation with Backend" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the backend implementation, we will have to:\n", + "- Implement the actual computation for the layer and operator(s);\n", + "- Specify the rule that maps the implementation above with the symbolic layer/operator(s)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "What has been provided in the symbolic part should has a corresponding implmentation with the backend, although the rules are actually what handles whether and how the symbolic representation is translated." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `torch` Implementation - Layer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The layer will take in the actual `Tensor` for parameters and input, and should calculate the output `Tensor` in its `forward` function, as in the common practice of `torch`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The basic set of methods to implement for a layer consist of:\n", + "- `__init__`: A must in most cases. Defines how to instantiate this layer. Note that `num_folds` is not specified here but handled automatically in the pipeline, while `num_variables` is implicitly provided as `scope_idx.shape[-1]`.\n", + "- `_valid_{param}_shape`: Not requied but recommended. Checks if the parameter has correct shape and folding.\n", + "- `fold_settings`: A must in most cases. Contains a shape that helps to decide which layers can be folded (same shape can be stacked). Should be appended with any extra shap (from non-default args as in `config`) that may affect folding, but no need to duplicate.\n", + "- `config`: Must-have if `__init__` accepts any args other than `scope_idx`, `num_output_units`, `num_channels`, `semiring` and the params. Should append any configs of the layer to `super().config`.\n", + "- `params`: Must-have if the layer has any parameters. Includes all params in a dict.\n", + "- `forward`: Must-have in all cases. Defines the actually computation of this layer. It must follow the protocol:\n", + " - The input is the value that the circuit receives, sliced to the corresponding scope, with shape `(fold, channel, batch, variable)`. Note that for simplictiy layers always accept only one batch dimension, while multi batch dim is handled at the circuit level;\n", + " - The output is the value in the space defined by the specified semi-ring, with shape `(fold, batch, output_unit)`.\n", + "- `integrate`: TODO: why we need it? what's the protocol?" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[cirkit/backend/torch/layers/input.py](../cirkit/backend/torch/layers/input.py)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from torch import Tensor\n", + "\n", + "from cirkit.backend.torch.layers.input import TorchInputLayer, polyval\n", + "from cirkit.backend.torch.parameters.parameter import TorchParameter\n", + "from cirkit.backend.torch.semiring import Semiring, SumProductSemiring\n", + "\n", + "# The same implementation of the imported polyval.\n", + "\n", + "# def polyval(coeff: Tensor, x: Tensor) -> Tensor:\n", + "# \"\"\"Evaluate polynomial given coefficients and point, with the shape for PolynomialLayer.\n", + "\n", + "# Args:\n", + "# coeff (Tensor): The coefficients of the polynomial, shape (F, Ko, deg+1).\n", + "# x (Tensor): The point of the variable, shape (F, H, B, Ki), where H=Ki=1.\n", + "\n", + "# Returns:\n", + "# Tensor: The value of the polymonial, shape (F, B, Ko).\n", + "# \"\"\"\n", + "# x = x.squeeze(dim=1) # shape (F, H=1, B, Ki=1) -> (F, B, 1).\n", + "# y = x.new_zeros(*x.shape[:-1], coeff.shape[-2]) # shape (F, B, Ko).\n", + "\n", + "# for a_n in reversed(coeff.unbind(dim=2)): # Reverse iterator of the degree axis, shape (F, Ko).\n", + "# # a_n shape (F, Ko) -> (F, 1, Ko).\n", + "# y = torch.addcmul(a_n.unsqueeze(dim=1), x, y) # y = a_n + x * y, by Horner's method.\n", + "# return y # shape (F, B, Ko).\n", + "\n", + "\n", + "class TorchMyPolynomialLayer(TorchInputLayer):\n", + " def __init__(\n", + " self,\n", + " scope_idx: Tensor,\n", + " num_output_units: int,\n", + " *,\n", + " num_channels: int = 1,\n", + " degree: int,\n", + " coeff: TorchParameter,\n", + " semiring: Semiring | None = None,\n", + " ) -> None:\n", + " num_variables = scope_idx.shape[-1]\n", + " if num_variables != 1:\n", + " raise ValueError(\"The Polynomial layer encodes a univariate distribution\")\n", + " if num_channels != 1:\n", + " raise ValueError(\"The Polynomial layer encodes a univariate distribution\")\n", + " super().__init__(\n", + " scope_idx,\n", + " num_output_units,\n", + " num_channels=num_channels,\n", + " semiring=semiring,\n", + " )\n", + " self.degree = degree\n", + " if not self._valid_parameters_shape(coeff):\n", + " raise ValueError(\"The number of folds and shape of 'coeff' must match the layer's\")\n", + " self.coeff = coeff\n", + "\n", + " def _valid_parameters_shape(self, p: TorchParameter) -> bool:\n", + " if p.num_folds != self.num_folds:\n", + " return False\n", + " return p.shape == (self.num_output_units, self.degree + 1)\n", + "\n", + " @property\n", + " def fold_settings(self) -> tuple[Any, ...]:\n", + " return *super().fold_settings, self.degree + 1\n", + "\n", + " @property\n", + " def config(self) -> dict[str, Any]:\n", + " return {**super().config, \"degree\": self.degree}\n", + "\n", + " @property\n", + " def params(self) -> dict[str, TorchParameter]:\n", + " return {\"coeff\": self.coeff}\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " \"\"\"Run forward pass.\n", + "\n", + " Args:\n", + " x (Tensor): The input to this layer, shape (F, H=C, B, Ki=D).\n", + "\n", + " Returns:\n", + " Tensor: The output of this layer, shape (F, B, Ko).\n", + " \"\"\"\n", + " coeff = self.coeff() # shape (F, Ko, dp1)\n", + " return self.semiring.map_from(polyval(coeff, x), SumProductSemiring)\n", + "\n", + " def integrate(self) -> Tensor:\n", + " raise TypeError(\"Cannot integrate a PolynomialLayer\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `torch` Implementation - Operator" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `torch` version of operators also provides a `TorchBinaryParameterOp` for easier implementation, with `TorchParameterOp` for more customization.\n", + "\n", + "The minimal implementation can include only the `shape` of output parameter, and the `forward` that transforms the input parameter(s) to the output.\n", + "\n", + "And optionally, `__init__` may be defined to contain a customized instantiation, with `config` containing additional args and `fold_settings` contaning any additional shapes." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[cirkit/backend/torch/parameters/nodes.py](../cirkit/backend/torch/parameters/nodes.py)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "from cirkit.backend.torch.parameters.nodes import TorchBinaryParameterOp\n", + "\n", + "\n", + "class TorchMyPolynomialProduct(TorchBinaryParameterOp):\n", + " @property\n", + " def shape(self) -> tuple[int, ...]:\n", + " return (\n", + " self.in_shapes[0][0] * self.in_shapes[1][0], # dim K\n", + " self.in_shapes[0][1] + self.in_shapes[1][1] - 1, # dim dp1\n", + " )\n", + "\n", + " def forward(self, coeff1: Tensor, coeff2: Tensor) -> Tensor:\n", + " if coeff1.is_complex() or coeff2.is_complex():\n", + " fft = torch.fft.fft\n", + " ifft = torch.fft.ifft\n", + " else:\n", + " fft = torch.fft.rfft\n", + " ifft = torch.fft.irfft\n", + "\n", + " degp1 = coeff1.shape[-1] + coeff2.shape[-1] - 1 # deg1p1 + deg2p1 - 1 = (deg1 + deg2) + 1.\n", + "\n", + " spec1 = fft(coeff1, n=degp1, dim=-1) # shape (F, K1, dp1).\n", + " spec2 = fft(coeff2, n=degp1, dim=-1) # shape (F, K2, dp1).\n", + "\n", + " # shape (F, K1, 1, dp1), (F, 1, K2, dp1) -> (F, K1, K2, dp1) -> (F, K1*K2, dp1).\n", + " spec = torch.flatten(\n", + " spec1.unsqueeze(dim=2) * spec2.unsqueeze(dim=1), start_dim=1, end_dim=2\n", + " )\n", + "\n", + " return ifft(spec, n=degp1, dim=-1) # shape (F, K1*K2, dp1).\n", + "\n", + " # -------- unnecessary in this case, directly use inherited --------\n", + "\n", + " # def __init__(\n", + " # self, in_shape1: tuple[int, ...], in_shape2: tuple[int, ...], *, num_folds: int = 1\n", + " # ) -> None:\n", + " # super().__init__(in_shape1, in_shape2, num_folds=num_folds)\n", + "\n", + " # @property\n", + " # def fold_settings(self) -> tuple[Any, ...]:\n", + " # return super().fold_settings\n", + "\n", + " # @property\n", + " # def config(self) -> dict[str, Any]:\n", + " # return {}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Rules" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we need to register the mapping between the torch implementations with their symbolic conterparts. It should be simple to define in most cases.\n", + "\n", + "Note that each backend has its own registry instead of one large dict for everything." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[cirkit/backend/torch/rules/layers.py](../cirkit/backend/torch/rules/layers.py)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from cirkit.backend.torch.compiler import TorchCompiler\n", + "from cirkit.backend.torch.rules.layers import DEFAULT_LAYER_COMPILATION_RULES\n", + "\n", + "\n", + "def compile_polynomial_layer(\n", + " compiler: TorchCompiler, sl: MyPolynomialLayer\n", + ") -> TorchMyPolynomialLayer:\n", + " coeff = compiler.compile_parameter(sl.coeff)\n", + " return TorchMyPolynomialLayer(\n", + " torch.tensor(tuple(sl.scope)),\n", + " sl.num_output_units,\n", + " num_channels=sl.num_channels,\n", + " degree=sl.degree,\n", + " coeff=coeff,\n", + " semiring=compiler.semiring,\n", + " )\n", + "\n", + "\n", + "DEFAULT_LAYER_COMPILATION_RULES.update({MyPolynomialLayer: compile_polynomial_layer})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[cirkit/backend/torch/rules/parameters.py](../cirkit/backend/torch/rules/parameters.py)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from cirkit.backend.torch.rules.parameters import DEFAULT_PARAMETER_COMPILATION_RULES\n", + "\n", + "\n", + "def compile_polynomial_product(\n", + " compiler: TorchCompiler, p: MyPolynomialProduct\n", + ") -> TorchMyPolynomialProduct:\n", + " return TorchMyPolynomialProduct(*p.in_shapes)\n", + "\n", + "\n", + "DEFAULT_PARAMETER_COMPILATION_RULES.update({MyPolynomialProduct: compile_polynomial_product})" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cirkit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 00c07f061325293da37a2d93b93db5b02f49d1bb Mon Sep 17 00:00:00 2001 From: lkct Date: Fri, 27 Sep 2024 18:57:20 +0100 Subject: [PATCH 3/3] notebook for adding diff op --- notebooks/how-to-add-an-operator.ipynb | 589 +++++++++++++++++++++++++ 1 file changed, 589 insertions(+) create mode 100644 notebooks/how-to-add-an-operator.ipynb diff --git a/notebooks/how-to-add-an-operator.ipynb b/notebooks/how-to-add-an-operator.ipynb new file mode 100644 index 00000000..52a6e68c --- /dev/null +++ b/notebooks/how-to-add-an-operator.ipynb @@ -0,0 +1,589 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How-To: Add an Operator" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook introduces how to add a custom input layer.\n", + "\n", + "- For users: There's no way to do it as the end-user. Currently it's designed that modifications to the library code is required.\n", + "- For developers: Please look at comments for each code block to decide where to add/modify the code pieces." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A new operator requires symbolic definition(s) for each layer supporting it and the implementation(s) corresponding to the backend(s)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this notebook, we will illustrate the process with `MYDIFFERENTIATION` and its `torch` backend, which is a replicate of `DIFFERENTIATION` in the library." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Enums" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First of all, we need to register the new operator in two `Enum` class: `CircuitOperator` and `LayerOperator`.\n", + "\n", + "This can only be done by modifying the corresponding classes in the library, but not possible in the user code." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In rare cases that an operator on circuit does not involve any transformation of the layers (e.g. `CircuitOperator.CONCATENATE`), it may skip the `LayerOperator` class and the implementation for the operator on the layers." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[cirkit/symbolic/circuit.py](../cirkit/symbolic/circuit.py)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from enum import IntEnum, auto\n", + "\n", + "\n", + "class CircuitOperator(IntEnum):\n", + " ... # Any existing enum values.\n", + " MYDIFFERENTIATION = auto()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[cirkit/symbolic/layers.py](../cirkit/symbolic/layers.py)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class LayerOperator(IntEnum):\n", + " ... # Any existing enum values.\n", + " MYDIFFERENTIATION = auto()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Symbolic" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the symbolic part, we will have to:\n", + "- Decide the layers that supports this operator;\n", + " - Identify the parameter operations required by the operator;\n", + "- Define the process to operate a circuit with its symbolic representation.\n", + "\n", + "All the above will not involve any actual tensors, just the configs and shapes." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can add as many layers as the operator supports, but for illustrative purposes, here we only illustrate with `PolynomialLayer` in the library.\n", + "\n", + "For layers that do not support the operator, just leave it out and it will be properly handled." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Parameter Operation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After deciding which layer(s) we want to support, we must define the parameter operations the layer(s) need(s).\n", + "\n", + "Since we are only looking at `PolynomialLayer` here, and the layer only has one parameter `coeff`, we only need to define one patameter operation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As multiplication is a unary operator, we can inherit from `UnaryParameterOp` to make the best use of existing infrastructure. Alternatively, a more general `ParameterOp` class may be inherited.\n", + "\n", + "The mininum definition should include the `shape` property which defines the output shape of this parameter operation.\n", + "\n", + "In this case, we also need `__init__` as differentiation also needs an additional argument `order`, and `config` should also be overriden to include `order`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[cirkit/symbolic/parameters.py](../cirkit/symbolic/parameters.py)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any\n", + "\n", + "from cirkit.symbolic.parameters import UnaryParameterOp\n", + "\n", + "\n", + "class PolynomialMyDifferential(UnaryParameterOp):\n", + " def __init__(self, in_shape: tuple[int, ...], *, order: int = 1):\n", + " if order <= 0:\n", + " raise ValueError(\"The order of differentiation must be positive.\")\n", + " super().__init__(in_shape)\n", + " self.order = order\n", + "\n", + " @property\n", + " def shape(self) -> tuple[int, ...]:\n", + " # if dp1>order, i.e., deg>=order, then diff, else const 0.\n", + " return (\n", + " self.in_shapes[0][0], # dim Ko\n", + " self.in_shapes[0][1] - self.order\n", + " if self.in_shapes[0][1] > self.order\n", + " else 1, # dim dp1\n", + " )\n", + "\n", + " @property\n", + " def config(self) -> dict[str, Any]:\n", + " return {**super().config, \"order\": self.order}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Layer Operator" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After the param op has been defined, we can then define how an operator act on the layer by defining a rule function and registering it to the rules registry." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to share the underlying parameters across the operations, `param.ref()` should be passed to build the new parameter from the operators.\n", + "\n", + "And then, the resulting new layer (or can be layers, if needed) should be wrapped in a `CircuitBlock` for return." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[cirkit/symbolic/operators.py](../cirkit/symbolic/operators.py)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from cirkit.symbolic.circuit import CircuitBlock\n", + "from cirkit.symbolic.layers import PolynomialLayer\n", + "from cirkit.symbolic.operators import DEFAULT_OPERATOR_RULES\n", + "from cirkit.symbolic.parameters import Parameter\n", + "\n", + "\n", + "def my_differentiate_polynomial_layer(\n", + " sl: PolynomialLayer, *, var_idx: int, ch_idx: int, order: int = 1\n", + ") -> CircuitBlock:\n", + " # PolynomialLayer is constructed univariate, but we still take the 2 idx for unified interface\n", + " assert (var_idx, ch_idx) == (0, 0), \"This should not happen\"\n", + " if order <= 0:\n", + " raise ValueError(\"The order of differentiation must be positive.\")\n", + " coeff = Parameter.from_unary(\n", + " PolynomialMyDifferential(sl.coeff.shape, order=order), sl.coeff.ref()\n", + " )\n", + " sl = PolynomialLayer(\n", + " sl.scope, sl.num_output_units, sl.num_channels, degree=coeff.shape[-1] - 1, coeff=coeff\n", + " )\n", + " return CircuitBlock.from_layer(sl)\n", + "\n", + "\n", + "DEFAULT_OPERATOR_RULES[LayerOperator.MYDIFFERENTIATION].append(my_differentiate_polynomial_layer)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Operation on Symbolic Circuit" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To implement the operator in a symbolic way, we need to define a function that takes in the symbolic circuit(s) and a optional custom registry, along with any other args the operator needs. The function should return the resulting circuit after applying the operator, with proper parameter sharing." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the following code, we omit many algorithmic details, but focus on points to note with coding. Please read the comments for how it's expected to be defined." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[cirkit/symbolic/functional.py](../cirkit/symbolic/functional.py)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import itertools\n", + "from collections.abc import Sequence\n", + "\n", + "from cirkit.symbolic.circuit import Circuit, CircuitOperation, StructuralPropertyError\n", + "from cirkit.symbolic.layers import InputLayer, Layer, SumLayer\n", + "from cirkit.symbolic.registry import OPERATOR_REGISTRY, OperatorRegistry\n", + "\n", + "\n", + "# sc (or multiple if n-ary op) and registry is common interface, order is specific to diff.\n", + "def my_differentiate(\n", + " sc: Circuit, registry: OperatorRegistry | None = None, *, order: int = 1\n", + ") -> Circuit:\n", + " # Sanity checks of args.\n", + " if not sc.is_smooth or not sc.is_decomposable:\n", + " raise StructuralPropertyError(\n", + " \"Only smooth and decomposable circuits can be efficiently differentiated.\"\n", + " )\n", + " if order <= 0:\n", + " raise ValueError(\"The order of differentiation must be positive.\")\n", + "\n", + " # Use the registry in the current context, if not specified otherwise.\n", + " if registry is None:\n", + " registry = OPERATOR_REGISTRY.get()\n", + "\n", + " # Keep a mapping from the layers in the input circuit to the blocks in the output circuit.\n", + " # Depending on the algorithm, another form of mapping may be used.\n", + " layers_to_blocks: dict[Layer, list[CircuitBlock]] = {}\n", + "\n", + " # The directed edges connecting the blocks in the output circuit, as a mapping from each block\n", + " # to its inputs. This must be defined in this form to construct the output circuit.\n", + " in_blocks: dict[CircuitBlock, Sequence[CircuitBlock]] = {}\n", + "\n", + " # Iterate all the symbolic layers in the input circuit in the topological order.\n", + " for sl in sc.topological_ordering():\n", + " # For an InputLayer, e.g. PolynomialLayer, the rule should exist in the registry.\n", + " if isinstance(sl, InputLayer):\n", + " # Retrieve the differentiation rule from the registry.\n", + " func = registry.retrieve_rule(LayerOperator.MYDIFFERENTIATION, type(sl))\n", + " # Get the differential using the rule.\n", + " diff_blocks = [func(sl, var_idx=0, ch_idx=0, order=order)]\n", + "\n", + " # Save the blocks as corresponding to the current symbolic layer.\n", + " layers_to_blocks[sl] = diff_blocks\n", + "\n", + " # Update to in_blocks can be omitted: blocks not exist will be treated as no input.\n", + " # in_blocks[diff_blocks[0]] = []\n", + "\n", + " # For a SumLayer, the original connectivity and params are copied in the differential.\n", + " elif isinstance(sl, SumLayer):\n", + " # An idiom to make a copy of a layer and keep the params shared.\n", + " diff_blocks = [\n", + " CircuitBlock.from_layer(\n", + " type(sl)(**sl.config, **{name: p.ref() for name, p in sl.params.items()})\n", + " )\n", + " ]\n", + " # Each item corresponds to the item in diff_blocks, meaning its inputs are the blocks\n", + " # corresponding to the inputs of the current layer.\n", + " diff_in_blocks = [\n", + " [layers_to_blocks[sl_in][0] for sl_in in sc.layer_inputs(sl)],\n", + " ]\n", + "\n", + " # Save the blocks as corresponding to the current symbolic layer.\n", + " layers_to_blocks[sl] = diff_blocks\n", + "\n", + " # Update in_blocks with the blocks and the corresponding inputs.\n", + " in_blocks.update(zip(diff_blocks, diff_in_blocks))\n", + "\n", + " # There can be other cases processed based on need.\n", + " else:\n", + " pass\n", + "\n", + " # End of `for sl in sc.topological_ordering():`\n", + "\n", + " # Construct the differential symbolic circuit and set the differentiation operation metadata.\n", + " return Circuit.from_operation(\n", + " sc.scope, # The scope is the same, or may change based on the algorithm.\n", + " sc.num_channels, # Channels shall not change in most cases.\n", + " itertools.chain.from_iterable(layers_to_blocks.values()), # All the blocks constructed.\n", + " in_blocks, # The edges as recoded.\n", + " itertools.chain.from_iterable(layers_to_blocks[sl] for sl in sc.outputs), # Idiom.\n", + " operation=CircuitOperation( # Metadata of the operation.\n", + " operator=CircuitOperator.MYDIFFERENTIATION, # The Enum value for the op.\n", + " operands=(sc,), # The operands, anything passed into this operator function.\n", + " metadata=dict(order=order), # Any additional args of the operator.\n", + " ),\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Implementation with Backend" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the backend implementation, we will have to:\n", + "- Implement the actual computation for the layer and operator(s);\n", + "- Specify the rule that maps the implementation above with the symbolic layer/operator(s)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "What has been provided in the symbolic part should has a corresponding implmentation with the backend, although the rules are actually what handles whether and how the symbolic representation is translated." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `torch` Implementation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `torch` version of operators also provides a `TorchUnaryParameterOp` for easier implementation, with `TorchParameterOp` for more customization.\n", + "\n", + "The minimal implementation can include only the `shape` of output parameter, and the `forward` that transforms the input parameter(s) to the output.\n", + "\n", + "In this case, we also need `__init__` as differentiation also needs an additional argument `order`, and `config` should also be overriden to include `order`.\n", + "\n", + "And optionally, `fold_settings` can be provided to contain any additional shapes that affect folding (here the default is enough as the input shape decides everything)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[cirkit/backend/torch/parameters/nodes.py](../cirkit/backend/torch/parameters/nodes.py)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import Tensor\n", + "\n", + "from cirkit.backend.torch.parameters.nodes import TorchUnaryParameterOp\n", + "\n", + "\n", + "class TorchPolynomialMyDifferential(TorchUnaryParameterOp):\n", + " def __init__(self, in_shape: tuple[int, ...], *, num_folds: int = 1, order: int = 1) -> None:\n", + " if order <= 0:\n", + " raise ValueError(\"The order of differentiation must be positive.\")\n", + " super().__init__(in_shape, num_folds=num_folds)\n", + " self.order = order\n", + "\n", + " @property\n", + " def shape(self) -> tuple[int, ...]:\n", + " # if dp1>order, i.e., deg>=order, then diff, else const 0.\n", + " return (\n", + " self.in_shapes[0][0], # dim Ko\n", + " self.in_shapes[0][1] - self.order\n", + " if self.in_shapes[0][1] > self.order\n", + " else 1, # dim dp1\n", + " )\n", + "\n", + " @property\n", + " def config(self) -> dict[str, Any]:\n", + " return {**super().config, \"order\": self.order}\n", + "\n", + " def forward(self, coeff: Tensor) -> Tensor:\n", + " if coeff.shape[-1] <= self.order:\n", + " return torch.zeros_like(coeff[..., :1]) # shape (F, K, 1).\n", + "\n", + " for _ in range(self.order):\n", + " degp1 = coeff.shape[-1] # shape (F, K, dp1).\n", + " arange = torch.arange(1, degp1).to(coeff) # shape (deg,).\n", + " coeff = coeff[..., 1:] * arange # a_n x^n -> n a_n x^(n-1), with a_0 disappeared.\n", + "\n", + " return coeff # shape (F, K, dp1-ord).\n", + "\n", + " # -------- unnecessary in this case, directly use inherited --------\n", + "\n", + " # @property\n", + " # def fold_settings(self) -> Tuple[Any, ...]:\n", + " # return super().fold_settings" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Rules" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we need to register the mapping between the torch implementations with their symbolic conterparts. It should be simple to define in most cases.\n", + "\n", + "Note that each backend has its own registry instead of one large dict for everything." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[cirkit/backend/torch/rules/parameters.py](../cirkit/backend/torch/rules/parameters.py)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from cirkit.backend.torch.compiler import TorchCompiler\n", + "from cirkit.backend.torch.rules.parameters import DEFAULT_PARAMETER_COMPILATION_RULES\n", + "\n", + "\n", + "def compile_polynomial_my_differential(\n", + " compiler: \"TorchCompiler\", p: PolynomialMyDifferential\n", + ") -> TorchPolynomialMyDifferential:\n", + " return TorchPolynomialMyDifferential(*p.in_shapes, order=p.order)\n", + "\n", + "\n", + "DEFAULT_PARAMETER_COMPILATION_RULES.update(\n", + " {PolynomialMyDifferential: compile_polynomial_my_differential}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pipeline-level Convenience Method" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After the above definition, the new operator should be available to use as `cirkit.symbolic.functional.my_differential`. However for convenience, we can also add the following to the `PipelineContext` class.\n", + "\n", + "Note that this must be done by modifying the class in the library instead of in the user code." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[cirkit/pipeline.py](../cirkit/pipeline.py)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from contextlib import AbstractContextManager\n", + "\n", + "import cirkit.symbolic.functional as SF\n", + "from cirkit.backend.compiler import CompiledCircuit\n", + "\n", + "\n", + "class PipelineContext(AbstractContextManager):\n", + " ... # Any existing defs\n", + "\n", + " def differentiate(self, cc: CompiledCircuit, *, order: int = 1) -> CompiledCircuit:\n", + " if not self._compiler.has_symbolic(cc):\n", + " raise ValueError(\"The given compiled circuit is not known in this pipeline\")\n", + " if order <= 0:\n", + " raise ValueError(\"The order of differentiation must be positive.\")\n", + " sc = self._compiler.get_symbolic_circuit(cc)\n", + " diff_sc = SF.differentiate(sc, registry=self._op_registry, order=order)\n", + " return self.compile(diff_sc)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cirkit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}