From 46961c786bbbb454e3d1c328d102d68d36674e44 Mon Sep 17 00:00:00 2001 From: Alec Edgington <54802828+cqc-alec@users.noreply.github.com> Date: Thu, 11 Apr 2024 15:14:29 +0100 Subject: [PATCH] Enable construction of non-strict `SequencePass` (#1337) --- pytket/binders/passes.cpp | 13 +++++++++---- pytket/conanfile.py | 2 +- pytket/docs/changelog.rst | 2 ++ pytket/pytket/_tket/passes.pyi | 6 +++++- pytket/tests/predicates_test.py | 13 ++++++++++++- tket/conanfile.py | 2 +- tket/include/tket/Predicates/CompilerPass.hpp | 4 ++-- tket/src/Predicates/CompilerPass.cpp | 10 +++++----- tket/test/src/test_CompilerPass.cpp | 8 +++++++- 9 files changed, 44 insertions(+), 16 deletions(-) diff --git a/pytket/binders/passes.cpp b/pytket/binders/passes.cpp index d298f2c54c..22f07ff06f 100644 --- a/pytket/binders/passes.cpp +++ b/pytket/binders/passes.cpp @@ -288,10 +288,15 @@ PYBIND11_MODULE(passes, m) { py::class_, BasePass>( m, "SequencePass", "A sequence of compilation passes.") .def( - py::init &>(), - "Construct from a list of compilation passes arranged in " - "order of application.", - py::arg("pass_list")) + py::init &, bool>(), + "Construct from a list of compilation passes arranged in order of " + "application." + "\n\n:param pass_list: sequence of passes" + "\n:param strict: if True (the default), check that all " + "postconditions and preconditions of the passes in the sequence are " + "compatible and raise an exception if not." + "\n:return: a pass that applies the sequence", + py::arg("pass_list"), py::arg("strict") = true) .def("__str__", [](const BasePass &) { return ""; }) .def( "get_sequence", &SequencePass::get_sequence, diff --git a/pytket/conanfile.py b/pytket/conanfile.py index 4dfa604439..aa2270098e 100644 --- a/pytket/conanfile.py +++ b/pytket/conanfile.py @@ -32,7 +32,7 @@ def package(self): cmake.install() def requirements(self): - self.requires("tket/1.2.113@tket/stable") + self.requires("tket/1.2.114@tket/stable") self.requires("tklog/0.3.3@tket/stable") self.requires("tkrng/0.3.3@tket/stable") self.requires("tkassert/0.3.4@tket/stable") diff --git a/pytket/docs/changelog.rst b/pytket/docs/changelog.rst index 94d7b45066..479f9d21dc 100644 --- a/pytket/docs/changelog.rst +++ b/pytket/docs/changelog.rst @@ -16,6 +16,8 @@ Features: ``CliffordPushThroughMeasures`` that optimises Clifford subcircuits before end of circuit measurement gates. * Add ``OpType.GPI``, ``OpType.GPI2`` and ``OpType.AAMS``. +* Allow construction of ``SequencePass`` without predicate checks, by means of + new ``strict`` argument to the constructor (defaulting to ``True``). Fixes: diff --git a/pytket/pytket/_tket/passes.pyi b/pytket/pytket/_tket/passes.pyi index edb47a34fa..22bc51873e 100644 --- a/pytket/pytket/_tket/passes.pyi +++ b/pytket/pytket/_tket/passes.pyi @@ -192,9 +192,13 @@ class SequencePass(BasePass): """ A sequence of compilation passes. """ - def __init__(self, pass_list: typing.Sequence[BasePass]) -> None: + def __init__(self, pass_list: typing.Sequence[BasePass], strict: bool = True) -> None: """ Construct from a list of compilation passes arranged in order of application. + + :param pass_list: sequence of passes + :param strict: if True (the default), check that all postconditions and preconditions of the passes in the sequence are compatible and raise an exception if not. + :return: a pass that applies the sequence """ def __str__(self: BasePass) -> str: ... diff --git a/pytket/tests/predicates_test.py b/pytket/tests/predicates_test.py index 70a18ec681..2afc1d625b 100644 --- a/pytket/tests/predicates_test.py +++ b/pytket/tests/predicates_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import sympy - +import pytest from pytket import logging from pytket.circuit import ( Circuit, @@ -73,6 +73,7 @@ CliffordPushThroughMeasures, CliffordSimp, SynthesiseOQC, + ZXGraphlikeOptimisation, ) from pytket.predicates import ( GateSetPredicate, @@ -142,6 +143,16 @@ def test_compilerpass_seq() -> None: assert seq.apply(cu2) +def test_compilerpass_seq_nonstrict() -> None: + passlist = [RebaseTket(), ZXGraphlikeOptimisation()] + with pytest.raises(RuntimeError): + _ = SequencePass(passlist) + seq = SequencePass(passlist, strict=False) + circ = Circuit(2) + seq.apply(circ) + assert np.allclose(circ.get_unitary(), np.eye(4, 4, dtype=complex)) + + def test_rebase_pass_generation() -> None: cx = Circuit(2) cx.CX(0, 1) diff --git a/tket/conanfile.py b/tket/conanfile.py index 8ace2c499c..fdec96f461 100644 --- a/tket/conanfile.py +++ b/tket/conanfile.py @@ -23,7 +23,7 @@ class TketConan(ConanFile): name = "tket" - version = "1.2.113" + version = "1.2.114" package_type = "library" license = "Apache 2" homepage = "https://github.com/CQCL/tket" diff --git a/tket/include/tket/Predicates/CompilerPass.hpp b/tket/include/tket/Predicates/CompilerPass.hpp index 0b53496938..26ac7def6c 100644 --- a/tket/include/tket/Predicates/CompilerPass.hpp +++ b/tket/include/tket/Predicates/CompilerPass.hpp @@ -156,7 +156,7 @@ class BasePass { void update_cache(const CompilationUnit& c_unit, SafetyMode safe_mode) const; static PassConditions match_passes(const PassPtr& lhs, const PassPtr& rhs); static PassConditions match_passes( - const PassConditions& lhs, const PassConditions& rhs); + const PassConditions& lhs, const PassConditions& rhs, bool strict = true); }; /* Basic Pass that all combinators can be used on */ @@ -192,7 +192,7 @@ class StandardPass : public BasePass { class SequencePass : public BasePass { public: SequencePass() {} - explicit SequencePass(const std::vector& ptvec); + explicit SequencePass(const std::vector& ptvec, bool strict = true); bool apply( CompilationUnit& c_unit, SafetyMode safe_mode = SafetyMode::Default, const PassCallback& before_apply = trivial_callback, diff --git a/tket/src/Predicates/CompilerPass.cpp b/tket/src/Predicates/CompilerPass.cpp index eefcd71888..e6d3522f9d 100644 --- a/tket/src/Predicates/CompilerPass.cpp +++ b/tket/src/Predicates/CompilerPass.cpp @@ -137,13 +137,13 @@ static PredicateClassGuarantees match_class_guarantees( } PassConditions BasePass::match_passes( - const PassConditions& lhs, const PassConditions& rhs) { + const PassConditions& lhs, const PassConditions& rhs, bool strict) { PredicatePtrMap new_precons = lhs.first; for (const TypePredicatePair& precon : rhs.first) { PredicatePtrMap::const_iterator data_guar_iter = lhs.second.specific_postcons_.find(precon.first); if (data_guar_iter == lhs.second.specific_postcons_.end()) { - if (get_guarantee(precon.first, lhs) == Guarantee::Clear) { + if (strict && get_guarantee(precon.first, lhs) == Guarantee::Clear) { throw IncompatibleCompilerPasses(precon.first); } else { PredicatePtrMap::iterator new_pre_it = new_precons.find(precon.first); @@ -156,7 +156,7 @@ PassConditions BasePass::match_passes( } } } else { - if (!data_guar_iter->second->implies(*precon.second)) { + if (strict && !data_guar_iter->second->implies(*precon.second)) { throw IncompatibleCompilerPasses(precon.first); } } @@ -227,14 +227,14 @@ PassPtr operator>>(const PassPtr& lhs, const PassPtr& rhs) { return sequence; } -SequencePass::SequencePass(const std::vector& ptvec) { +SequencePass::SequencePass(const std::vector& ptvec, bool strict) { if (ptvec.size() == 0) throw std::logic_error("Cannot generate CompilerPass from empty list"); std::vector::const_iterator iter = ptvec.begin(); PassConditions conditions = (*iter)->get_conditions(); for (++iter; iter != ptvec.end(); ++iter) { const PassConditions next_cons = (*iter)->get_conditions(); - conditions = match_passes(conditions, next_cons); + conditions = match_passes(conditions, next_cons, strict); } this->precons_ = conditions.first; this->postcons_ = conditions.second; diff --git a/tket/test/src/test_CompilerPass.cpp b/tket/test/src/test_CompilerPass.cpp index ab2ec20f4e..cb07c2e5fb 100644 --- a/tket/test/src/test_CompilerPass.cpp +++ b/tket/test/src/test_CompilerPass.cpp @@ -446,7 +446,7 @@ SCENARIO("Construct sequence pass") { } } -SCENARIO("Construct invalid sequence passes from vector") { +SCENARIO("Construct sequence pass that is invalid in strict mode") { std::vector invalid_pass_to_combo{ SynthesiseOQC(), SynthesiseUMD(), SynthesiseTK()}; for (const PassPtr& pass : invalid_pass_to_combo) { @@ -459,6 +459,12 @@ SCENARIO("Construct invalid sequence passes from vector") { ppm, Transforms::id, pc, nlohmann::json{}); passes.push_back(compass); REQUIRE_THROWS_AS((void)SequencePass(passes), IncompatibleCompilerPasses); + GIVEN("A circuit compilable with non-strict SequencePass") { + PassPtr sequence = std::make_shared(passes, false); + Circuit circ(2); + CompilationUnit cu(circ); + REQUIRE_NOTHROW(sequence->apply(cu)); + } } }