Skip to content

Commit

Permalink
Enable construction of non-strict SequencePass (#1337)
Browse files Browse the repository at this point in the history
  • Loading branch information
cqc-alec authored Apr 11, 2024
1 parent 6110d7b commit 46961c7
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 16 deletions.
13 changes: 9 additions & 4 deletions pytket/binders/passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,15 @@ PYBIND11_MODULE(passes, m) {
py::class_<SequencePass, std::shared_ptr<SequencePass>, BasePass>(
m, "SequencePass", "A sequence of compilation passes.")
.def(
py::init<const py::tket_custom::SequenceVec<PassPtr> &>(),
"Construct from a list of compilation passes arranged in "
"order of application.",
py::arg("pass_list"))
py::init<const py::tket_custom::SequenceVec<PassPtr> &, bool>(),
"Construct from a list of compilation passes arranged in order of "
"application."
"\n\n:param pass_list: sequence of passes"
"\n:param strict: if True (the default), check that all "
"postconditions and preconditions of the passes in the sequence are "
"compatible and raise an exception if not."
"\n:return: a pass that applies the sequence",
py::arg("pass_list"), py::arg("strict") = true)
.def("__str__", [](const BasePass &) { return "<tket::SequencePass>"; })
.def(
"get_sequence", &SequencePass::get_sequence,
Expand Down
2 changes: 1 addition & 1 deletion pytket/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def package(self):
cmake.install()

def requirements(self):
self.requires("tket/1.2.113@tket/stable")
self.requires("tket/1.2.114@tket/stable")
self.requires("tklog/0.3.3@tket/stable")
self.requires("tkrng/0.3.3@tket/stable")
self.requires("tkassert/0.3.4@tket/stable")
Expand Down
2 changes: 2 additions & 0 deletions pytket/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ Features:
``CliffordPushThroughMeasures`` that optimises Clifford subcircuits
before end of circuit measurement gates.
* Add ``OpType.GPI``, ``OpType.GPI2`` and ``OpType.AAMS``.
* Allow construction of ``SequencePass`` without predicate checks, by means of
new ``strict`` argument to the constructor (defaulting to ``True``).

Fixes:

Expand Down
6 changes: 5 additions & 1 deletion pytket/pytket/_tket/passes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,13 @@ class SequencePass(BasePass):
"""
A sequence of compilation passes.
"""
def __init__(self, pass_list: typing.Sequence[BasePass]) -> None:
def __init__(self, pass_list: typing.Sequence[BasePass], strict: bool = True) -> None:
"""
Construct from a list of compilation passes arranged in order of application.
:param pass_list: sequence of passes
:param strict: if True (the default), check that all postconditions and preconditions of the passes in the sequence are compatible and raise an exception if not.
:return: a pass that applies the sequence
"""
def __str__(self: BasePass) -> str:
...
Expand Down
13 changes: 12 additions & 1 deletion pytket/tests/predicates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sympy

import pytest
from pytket import logging
from pytket.circuit import (
Circuit,
Expand Down Expand Up @@ -73,6 +73,7 @@
CliffordPushThroughMeasures,
CliffordSimp,
SynthesiseOQC,
ZXGraphlikeOptimisation,
)
from pytket.predicates import (
GateSetPredicate,
Expand Down Expand Up @@ -142,6 +143,16 @@ def test_compilerpass_seq() -> None:
assert seq.apply(cu2)


def test_compilerpass_seq_nonstrict() -> None:
passlist = [RebaseTket(), ZXGraphlikeOptimisation()]
with pytest.raises(RuntimeError):
_ = SequencePass(passlist)
seq = SequencePass(passlist, strict=False)
circ = Circuit(2)
seq.apply(circ)
assert np.allclose(circ.get_unitary(), np.eye(4, 4, dtype=complex))


def test_rebase_pass_generation() -> None:
cx = Circuit(2)
cx.CX(0, 1)
Expand Down
2 changes: 1 addition & 1 deletion tket/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class TketConan(ConanFile):
name = "tket"
version = "1.2.113"
version = "1.2.114"
package_type = "library"
license = "Apache 2"
homepage = "https://github.com/CQCL/tket"
Expand Down
4 changes: 2 additions & 2 deletions tket/include/tket/Predicates/CompilerPass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class BasePass {
void update_cache(const CompilationUnit& c_unit, SafetyMode safe_mode) const;
static PassConditions match_passes(const PassPtr& lhs, const PassPtr& rhs);
static PassConditions match_passes(
const PassConditions& lhs, const PassConditions& rhs);
const PassConditions& lhs, const PassConditions& rhs, bool strict = true);
};

/* Basic Pass that all combinators can be used on */
Expand Down Expand Up @@ -192,7 +192,7 @@ class StandardPass : public BasePass {
class SequencePass : public BasePass {
public:
SequencePass() {}
explicit SequencePass(const std::vector<PassPtr>& ptvec);
explicit SequencePass(const std::vector<PassPtr>& ptvec, bool strict = true);
bool apply(
CompilationUnit& c_unit, SafetyMode safe_mode = SafetyMode::Default,
const PassCallback& before_apply = trivial_callback,
Expand Down
10 changes: 5 additions & 5 deletions tket/src/Predicates/CompilerPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,13 @@ static PredicateClassGuarantees match_class_guarantees(
}

PassConditions BasePass::match_passes(
const PassConditions& lhs, const PassConditions& rhs) {
const PassConditions& lhs, const PassConditions& rhs, bool strict) {
PredicatePtrMap new_precons = lhs.first;
for (const TypePredicatePair& precon : rhs.first) {
PredicatePtrMap::const_iterator data_guar_iter =
lhs.second.specific_postcons_.find(precon.first);
if (data_guar_iter == lhs.second.specific_postcons_.end()) {
if (get_guarantee(precon.first, lhs) == Guarantee::Clear) {
if (strict && get_guarantee(precon.first, lhs) == Guarantee::Clear) {
throw IncompatibleCompilerPasses(precon.first);
} else {
PredicatePtrMap::iterator new_pre_it = new_precons.find(precon.first);
Expand All @@ -156,7 +156,7 @@ PassConditions BasePass::match_passes(
}
}
} else {
if (!data_guar_iter->second->implies(*precon.second)) {
if (strict && !data_guar_iter->second->implies(*precon.second)) {
throw IncompatibleCompilerPasses(precon.first);
}
}
Expand Down Expand Up @@ -227,14 +227,14 @@ PassPtr operator>>(const PassPtr& lhs, const PassPtr& rhs) {
return sequence;
}

SequencePass::SequencePass(const std::vector<PassPtr>& ptvec) {
SequencePass::SequencePass(const std::vector<PassPtr>& ptvec, bool strict) {
if (ptvec.size() == 0)
throw std::logic_error("Cannot generate CompilerPass from empty list");
std::vector<PassPtr>::const_iterator iter = ptvec.begin();
PassConditions conditions = (*iter)->get_conditions();
for (++iter; iter != ptvec.end(); ++iter) {
const PassConditions next_cons = (*iter)->get_conditions();
conditions = match_passes(conditions, next_cons);
conditions = match_passes(conditions, next_cons, strict);
}
this->precons_ = conditions.first;
this->postcons_ = conditions.second;
Expand Down
8 changes: 7 additions & 1 deletion tket/test/src/test_CompilerPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ SCENARIO("Construct sequence pass") {
}
}

SCENARIO("Construct invalid sequence passes from vector") {
SCENARIO("Construct sequence pass that is invalid in strict mode") {
std::vector<PassPtr> invalid_pass_to_combo{
SynthesiseOQC(), SynthesiseUMD(), SynthesiseTK()};
for (const PassPtr& pass : invalid_pass_to_combo) {
Expand All @@ -459,6 +459,12 @@ SCENARIO("Construct invalid sequence passes from vector") {
ppm, Transforms::id, pc, nlohmann::json{});
passes.push_back(compass);
REQUIRE_THROWS_AS((void)SequencePass(passes), IncompatibleCompilerPasses);
GIVEN("A circuit compilable with non-strict SequencePass") {
PassPtr sequence = std::make_shared<SequencePass>(passes, false);
Circuit circ(2);
CompilationUnit cu(circ);
REQUIRE_NOTHROW(sequence->apply(cu));
}
}
}

Expand Down

0 comments on commit 46961c7

Please sign in to comment.