Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: implement quantum problem set filter #278

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ oryx-build-commands.txt
prof/
tags
TAGS
*.swp
redeboer marked this conversation as resolved.
Show resolved Hide resolved

# Virtual environments
*venv/
Expand Down
32 changes: 32 additions & 0 deletions src/qrules/quantum_numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,29 @@ class EdgeQuantumNumbers:
EdgeQuantumNumbers.g_parity,
]

# for accessing the keys of the dicts in EdgeSettings
EdgeQuantumNumberTypes = Union[
type[EdgeQuantumNumbers.pid],
type[EdgeQuantumNumbers.mass],
type[EdgeQuantumNumbers.width],
type[EdgeQuantumNumbers.spin_magnitude],
type[EdgeQuantumNumbers.spin_projection],
type[EdgeQuantumNumbers.charge],
type[EdgeQuantumNumbers.isospin_magnitude],
type[EdgeQuantumNumbers.isospin_projection],
type[EdgeQuantumNumbers.strangeness],
type[EdgeQuantumNumbers.charmness],
type[EdgeQuantumNumbers.bottomness],
type[EdgeQuantumNumbers.topness],
type[EdgeQuantumNumbers.baryon_number],
type[EdgeQuantumNumbers.electron_lepton_number],
type[EdgeQuantumNumbers.muon_lepton_number],
type[EdgeQuantumNumbers.tau_lepton_number],
type[EdgeQuantumNumbers.parity],
type[EdgeQuantumNumbers.c_parity],
type[EdgeQuantumNumbers.g_parity],
]
redeboer marked this conversation as resolved.
Show resolved Hide resolved


@frozen(init=False)
class NodeQuantumNumbers:
Expand Down Expand Up @@ -158,6 +181,15 @@ class NodeQuantumNumbers:
]
"""Type hint for quantum numbers of interaction nodes."""

# for accessing the keys of the dicts in NodeSettings
NodeQuantumNumberTypes = Union[
type[NodeQuantumNumbers.l_magnitude],
type[NodeQuantumNumbers.l_projection],
type[NodeQuantumNumbers.s_magnitude],
type[NodeQuantumNumbers.s_projection],
type[NodeQuantumNumbers.parity_prefactor],
]


def _to_optional_float(optional_float: float | None) -> float | None:
if optional_float is None:
Expand Down
6 changes: 4 additions & 2 deletions src/qrules/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
from qrules.quantum_numbers import (
EdgeQuantumNumber,
EdgeQuantumNumbers,
EdgeQuantumNumberTypes,
NodeQuantumNumber,
NodeQuantumNumberTypes,
)
from qrules.topology import MutableTransition, Topology

Expand All @@ -47,7 +49,7 @@ class EdgeSettings:

conservation_rules: set[GraphElementRule] = field(factory=set)
rule_priorities: dict[GraphElementRule, int] = field(factory=dict)
qn_domains: dict[Any, list] = field(factory=dict)
qn_domains: dict[EdgeQuantumNumberTypes, list] = field(factory=dict)


@implement_pretty_repr
Expand All @@ -67,7 +69,7 @@ class NodeSettings:

conservation_rules: set[Rule] = field(factory=set)
rule_priorities: dict[Rule, int] = field(factory=dict)
qn_domains: dict[Any, list] = field(factory=dict)
qn_domains: dict[NodeQuantumNumberTypes, list] = field(factory=dict)
interaction_strength: float = 1.0


Expand Down
160 changes: 160 additions & 0 deletions tests/unit/test_solving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Iterable

import attrs
import pytest

import qrules.particle
import qrules.quantum_numbers
import qrules.system_control
import qrules.transition
from qrules.conservation_rules import (
GraphElementRule,
c_parity_conservation,
parity_conservation,
spin_magnitude_conservation,
spin_validity,
)
from qrules.quantum_numbers import (
EdgeQuantumNumbers,
EdgeQuantumNumberTypes,
NodeQuantumNumbers,
NodeQuantumNumberTypes,
)
from qrules.solving import CSPSolver, EdgeSettings, NodeSettings, QNProblemSet
from qrules.topology import MutableTransition

if TYPE_CHECKING:
from qrules.argument_handling import Rule


def test_solve(
all_particles: qrules.particle.ParticleCollection,
quantum_number_problem_set: QNProblemSet,
) -> None:
solver = CSPSolver(all_particles)
result = solver.find_solutions(quantum_number_problem_set)
assert len(result.solutions) == 19


def test_solve_with_filtered_quantum_number_problem_set(
all_particles: qrules.particle.ParticleCollection,
quantum_number_problem_set: QNProblemSet,
) -> None:
solver = CSPSolver(all_particles)
redeboer marked this conversation as resolved.
Show resolved Hide resolved
new_quantum_number_problem_set = filter_quantum_number_problem_set(
quantum_number_problem_set,
edge_rules={spin_validity},
node_rules={
spin_magnitude_conservation,
parity_conservation,
c_parity_conservation,
},
edge_properties_and_domains={
EdgeQuantumNumbers.pid, # had to be added for c_parity_conservation to work
EdgeQuantumNumbers.spin_magnitude,
# EdgeQuantumNumbers.spin_projection, # can be left out to reduce the number of solutions
EdgeQuantumNumbers.parity,
EdgeQuantumNumbers.c_parity,
},
node_properties_and_domains=(
NodeQuantumNumbers.l_magnitude,
NodeQuantumNumbers.s_magnitude,
),
)
result = solver.find_solutions(new_quantum_number_problem_set)

assert len(result.solutions) != 0
redeboer marked this conversation as resolved.
Show resolved Hide resolved


def filter_quantum_number_problem_set(
quantum_number_problem_set: QNProblemSet,
edge_rules: set[GraphElementRule],
node_rules: set[Rule],
edge_properties_and_domains: Iterable[EdgeQuantumNumberTypes],
node_properties_and_domains: Iterable[NodeQuantumNumberTypes],
) -> QNProblemSet:
old_edge_settings = quantum_number_problem_set.solving_settings.states
old_node_settings = quantum_number_problem_set.solving_settings.interactions
old_edge_properties = quantum_number_problem_set.initial_facts.states
old_node_properties = quantum_number_problem_set.initial_facts.interactions
new_edge_settings = {
edge_id: EdgeSettings(
conservation_rules=edge_rules,
rule_priorities=edge_setting.rule_priorities,
qn_domains=({
key: val
for key, val in edge_setting.qn_domains.items()
if key in set(edge_properties_and_domains)
}),
)
for edge_id, edge_setting in old_edge_settings.items()
}
new_node_settings = {
node_id: NodeSettings(
conservation_rules=node_rules,
rule_priorities=node_setting.rule_priorities,
qn_domains=({
key: val
for key, val in node_setting.qn_domains.items()
if key in set(node_properties_and_domains)
}),
)
for node_id, node_setting in old_node_settings.items()
}
new_combined_settings = MutableTransition(
topology=quantum_number_problem_set.solving_settings.topology,
states=new_edge_settings,
interactions=new_node_settings,
)
new_edge_properties = {
edge_id: {
edge_quantum_number: scalar
for edge_quantum_number, scalar in graph_edge_property_map.items()
if edge_quantum_number in edge_properties_and_domains
}
for edge_id, graph_edge_property_map in old_edge_properties.items()
}
new_node_properties = {
node_id: {
node_quantum_number: scalar
for node_quantum_number, scalar in graph_node_property_map.items()
if node_quantum_number in node_properties_and_domains
}
for node_id, graph_node_property_map in old_node_properties.items()
}
new_combined_properties = MutableTransition(
topology=quantum_number_problem_set.initial_facts.topology,
states=new_edge_properties,
interactions=new_node_properties,
)
return attrs.evolve(
quantum_number_problem_set,
solving_settings=new_combined_settings,
initial_facts=new_combined_properties,
)


@pytest.fixture(scope="session")
def all_particles():
return [
qrules.system_control.create_edge_properties(part)
for part in qrules.particle.load_pdg()
]


@pytest.fixture(scope="session")
def quantum_number_problem_set() -> QNProblemSet:
stm = qrules.StateTransitionManager(
initial_state=["psi(2S)"],
final_state=["gamma", "eta", "eta"],
formalism="helicity",
)
problem_sets = stm.create_problem_sets()
qn_problem_sets = [
p.to_qn_problem_set()
for strength in sorted(problem_sets)
for p in problem_sets[strength]
redeboer marked this conversation as resolved.
Show resolved Hide resolved
]
return qn_problem_sets[0]
Loading