Skip to content

Commit

Permalink
lib/monads: improve style of nondet and trace
Browse files Browse the repository at this point in the history
Signed-off-by: Corey Lewis <[email protected]>
  • Loading branch information
corlewis committed Oct 5, 2023
1 parent 293b97c commit ca50a02
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 80 deletions.
49 changes: 20 additions & 29 deletions lib/Monads/nondet/Nondet_Monad.thy
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ definition bind ::
"bind f g \<equiv> \<lambda>s. (\<Union>(fst ` case_prod g ` fst (f s)),
True \<in> snd ` case_prod g ` fst (f s) \<or> snd (f s))"

text \<open>
Sometimes it is convenient to write @{text bind} in reverse order.\<close>
text \<open>Sometimes it is convenient to write @{text bind} in reverse order.\<close>
abbreviation (input) bind_rev ::
"('c \<Rightarrow> ('a, 'b) nondet_monad) \<Rightarrow> ('a, 'c) nondet_monad \<Rightarrow> ('a, 'b) nondet_monad"
(infixl "=<<" 60) where
Expand Down Expand Up @@ -125,20 +124,21 @@ text \<open>
definition state_select :: "('s \<times> 's) set \<Rightarrow> ('s, unit) nondet_monad" where
"state_select r \<equiv> \<lambda>s. ((\<lambda>x. ((), x)) ` {s'. (s, s') \<in> r}, \<not> (\<exists>s'. (s, s') \<in> r))"


subsection "Failure"

text \<open>
The monad function that always fails. Returns an empty set of results and sets the failure flag.\<close>
definition fail :: "('s, 'a) nondet_monad" where
"fail \<equiv> \<lambda>s. ({}, True)"
"fail \<equiv> \<lambda>s. ({}, True)"

text \<open>Assertions: fail if the property @{text P} is not true\<close>
definition assert :: "bool \<Rightarrow> ('a, unit) nondet_monad" where
"assert P \<equiv> if P then return () else fail"
"assert P \<equiv> if P then return () else fail"

text \<open>Fail if the value is @{const None}, return result @{text v} for @{term "Some v"}\<close>
definition assert_opt :: "'a option \<Rightarrow> ('b, 'a) nondet_monad" where
"assert_opt v \<equiv> case v of None \<Rightarrow> fail | Some v \<Rightarrow> return v"
"assert_opt v \<equiv> case v of None \<Rightarrow> fail | Some v \<Rightarrow> return v"

text \<open>An assertion that also can introspect the current state.\<close>
definition state_assert :: "('s \<Rightarrow> bool) \<Rightarrow> ('s, unit) nondet_monad" where
Expand All @@ -148,11 +148,11 @@ subsection "Generic functions on top of the state monad"

text \<open>Apply a function to the current state and return the result without changing the state.\<close>
definition gets :: "('s \<Rightarrow> 'a) \<Rightarrow> ('s, 'a) nondet_monad" where
"gets f \<equiv> get >>= (\<lambda>s. return (f s))"
"gets f \<equiv> get >>= (\<lambda>s. return (f s))"

text \<open>Modify the current state using the function passed in.\<close>
definition modify :: "('s \<Rightarrow> 's) \<Rightarrow> ('s, unit) nondet_monad" where
"modify f \<equiv> get >>= (\<lambda>s. put (f s))"
"modify f \<equiv> get >>= (\<lambda>s. put (f s))"

lemma simpler_gets_def:
"gets f = (\<lambda>s. ({(f s, s)}, False))"
Expand Down Expand Up @@ -196,7 +196,7 @@ definition

subsection \<open>The Monad Laws\<close>

text \<open>A more expanded definition of @{text bind}\<close>
text \<open>An alternative definition of @{term bind}, sometimes more convenient.\<close>
lemma bind_def':
"(f >>= g) \<equiv>
\<lambda>s. ({(r'', s''). \<exists>(r', s') \<in> fst (f s). (r'', s'') \<in> fst (g r' s') },
Expand All @@ -212,7 +212,8 @@ lemma return_bind[simp]:
by (simp add: return_def bind_def)

text \<open>@{term return} is absorbed on the right of a @{term bind}\<close>
lemma bind_return[simp]: "(m >>= return) = m"
lemma bind_return[simp]:
"(m >>= return) = m"
by (simp add: bind_def return_def split_def)

text \<open>@{term bind} is associative\<close>
Expand Down Expand Up @@ -264,15 +265,13 @@ definition bindE ::
(infixl ">>=E" 60) where
"f >>=E g \<equiv> f >>= lift g"


text \<open>
Lifting a normal nondeterministic monad into the
exception monad is achieved by always returning its
result as normal result and never throwing an exception.\<close>
definition liftE :: "('s,'a) nondet_monad \<Rightarrow> ('s, 'e+'a) nondet_monad" where
"liftE f \<equiv> f >>= (\<lambda>r. return (Inr r))"


text \<open>
Since the underlying type and @{text return} function changed,
we need new definitions for when and unless:\<close>
Expand All @@ -282,13 +281,11 @@ definition whenE :: "bool \<Rightarrow> ('s, 'e + unit) nondet_monad \<Rightarro
definition unlessE :: "bool \<Rightarrow> ('s, 'e + unit) nondet_monad \<Rightarrow> ('s, 'e + unit) nondet_monad" where
"unlessE P f \<equiv> if P then returnOk () else f"


text \<open>
Throwing an exception when the parameter is @{term None}, otherwise
returning @{term "v"} for @{term "Some v"}.\<close>
definition throw_opt :: "'e \<Rightarrow> 'a option \<Rightarrow> ('s, 'e + 'a) nondet_monad" where
"throw_opt ex x \<equiv> case x of None \<Rightarrow> throwError ex | Some v \<Rightarrow> returnOk v"

"throw_opt ex x \<equiv> case x of None \<Rightarrow> throwError ex | Some v \<Rightarrow> returnOk v"

text \<open>
Failure in the exception monad is redefined in the same way
Expand All @@ -297,6 +294,7 @@ text \<open>
definition assertE :: "bool \<Rightarrow> ('a, 'e + unit) nondet_monad" where
"assertE P \<equiv> if P then returnOk () else fail"


subsection "Monad Laws for the Exception Monad"

text \<open>More direct definition of @{const liftE}:\<close>
Expand Down Expand Up @@ -415,9 +413,7 @@ lemma "doE x \<leftarrow> returnOk 1;
by simp



section "Library of Monadic Functions and Combinators"

section "Library of additional Monadic Functions and Combinators"

text \<open>Lifting a normal function into the monad type:\<close>
definition liftM :: "('a \<Rightarrow> 'b) \<Rightarrow> ('s,'a) nondet_monad \<Rightarrow> ('s, 'b) nondet_monad" where
Expand All @@ -427,12 +423,11 @@ text \<open>The same for the exception monad:\<close>
definition liftME :: "('a \<Rightarrow> 'b) \<Rightarrow> ('s,'e+'a) nondet_monad \<Rightarrow> ('s,'e+'b) nondet_monad" where
"liftME f m \<equiv> doE x \<leftarrow> m; returnOk (f x) odE"

text \<open> Execute @{term f} for @{term "Some x"}, otherwise do nothing. \<close>
text \<open>Execute @{term f} for @{term "Some x"}, otherwise do nothing.\<close>
definition maybeM :: "('a \<Rightarrow> ('s, unit) nondet_monad) \<Rightarrow> 'a option \<Rightarrow> ('s, unit) nondet_monad" where
"maybeM f y \<equiv> case y of Some x \<Rightarrow> f x | None \<Rightarrow> return ()"

text \<open>
Run a sequence of monads from left to right, ignoring return values.\<close>
text \<open>Run a sequence of monads from left to right, ignoring return values.\<close>
definition sequence_x :: "('s, 'a) nondet_monad list \<Rightarrow> ('s, unit) nondet_monad" where
"sequence_x xs \<equiv> foldr (\<lambda>x y. x >>= (\<lambda>_. y)) xs (return ())"

Expand All @@ -450,7 +445,6 @@ definition zipWithM_x ::
"('a \<Rightarrow> 'b \<Rightarrow> ('s,'c) nondet_monad) \<Rightarrow> 'a list \<Rightarrow> 'b list \<Rightarrow> ('s, unit) nondet_monad" where
"zipWithM_x f xs ys \<equiv> sequence_x (zipWith f xs ys)"


text \<open>
The same three functions as above, but returning a list of
return values instead of @{text unit}\<close>
Expand All @@ -465,8 +459,8 @@ definition zipWithM ::
"('a \<Rightarrow> 'b \<Rightarrow> ('s,'c) nondet_monad) \<Rightarrow> 'a list \<Rightarrow> 'b list \<Rightarrow> ('s, 'c list) nondet_monad" where
"zipWithM f xs ys \<equiv> sequence (zipWith f xs ys)"

definition foldM :: "('b \<Rightarrow> 'a \<Rightarrow> ('s, 'a) nondet_monad) \<Rightarrow> 'b list \<Rightarrow> 'a \<Rightarrow> ('s, 'a) nondet_monad"
where
definition foldM ::
"('b \<Rightarrow> 'a \<Rightarrow> ('s, 'a) nondet_monad) \<Rightarrow> 'b list \<Rightarrow> 'a \<Rightarrow> ('s, 'a) nondet_monad" where
"foldM m xs a \<equiv> foldr (\<lambda>p q. q >>= m p) xs (return a) "

definition foldME ::
Expand All @@ -486,11 +480,10 @@ definition sequenceE :: "('s, 'e+'a) nondet_monad list \<Rightarrow> ('s, 'e+'a
"sequenceE xs \<equiv> let mcons = (\<lambda>p q. p >>=E (\<lambda>x. q >>=E (\<lambda>y. returnOk (x#y))))
in foldr mcons xs (returnOk [])"

definition mapME :: "('a \<Rightarrow> ('s,'e+'b) nondet_monad) \<Rightarrow> 'a list \<Rightarrow> ('s,'e+'b list) nondet_monad"
where
definition mapME ::
"('a \<Rightarrow> ('s,'e+'b) nondet_monad) \<Rightarrow> 'a list \<Rightarrow> ('s,'e+'b list) nondet_monad" where
"mapME f xs \<equiv> sequenceE (map f xs)"


text \<open>Filtering a list using a monadic function as predicate:\<close>
primrec filterM :: "('a \<Rightarrow> ('s, bool) nondet_monad) \<Rightarrow> 'a list \<Rightarrow> ('s, 'a list) nondet_monad" where
"filterM P [] = return []"
Expand Down Expand Up @@ -556,12 +549,10 @@ definition handleE ::
(infix "<handle>" 10) where
"handleE \<equiv> handleE'"


text \<open>
Handling exceptions, and additionally providing a continuation
if the left-hand side throws no exception:\<close>
definition
handle_elseE ::
definition handle_elseE ::
"('s, 'e + 'a) nondet_monad \<Rightarrow> ('e \<Rightarrow> ('s, 'ee + 'b) nondet_monad) \<Rightarrow>
('a \<Rightarrow> ('s, 'ee + 'b) nondet_monad) \<Rightarrow> ('s, 'ee + 'b) nondet_monad"
("_ <handle> _ <else> _" 10) where
Expand Down
4 changes: 2 additions & 2 deletions lib/Monads/nondet/Nondet_No_Fail.thy
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,15 @@ lemma no_fail_spec:

lemma no_fail_assertE[wp]:
"no_fail (\<lambda>_. P) (assertE P)"
by (simp add: assertE_def split: if_split)
by (simp add: assertE_def)

lemma no_fail_spec_pre:
"\<lbrakk> no_fail (((=) s) and P') f; \<And>s. P s \<Longrightarrow> P' s \<rbrakk> \<Longrightarrow> no_fail (((=) s) and P) f"
by (erule no_fail_pre, simp)

lemma no_fail_whenE[wp]:
"\<lbrakk> G \<Longrightarrow> no_fail P f \<rbrakk> \<Longrightarrow> no_fail (\<lambda>s. G \<longrightarrow> P s) (whenE G f)"
by (simp add: whenE_def split: if_split)
by (simp add: whenE_def)

lemma no_fail_unlessE[wp]:
"\<lbrakk> \<not> G \<Longrightarrow> no_fail P f \<rbrakk> \<Longrightarrow> no_fail (\<lambda>s. \<not> G \<longrightarrow> P s) (unlessE G f)"
Expand Down
2 changes: 1 addition & 1 deletion lib/Monads/nondet/Nondet_While_Loop_Rules.thy
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ lemma fst_whileLoop_cond_false:
lemma snd_whileLoop:
assumes init_I: "I r s"
and cond_I: "C r s"
and non_term: "\<And>r. \<lbrace> \<lambda>s. I r s \<and> C r s \<and> \<not> snd (B r s) \<rbrace> B r \<exists>\<lbrace> \<lambda>r' s'. C r' s' \<and> I r' s' \<rbrace>"
and non_term: "\<And>r. \<lbrace> \<lambda>s. I r s \<and> C r s \<and> \<not> snd (B r s) \<rbrace> B r \<exists>\<lbrace> \<lambda>r' s'. C r' s' \<and> I r' s' \<rbrace>"
shows "snd (whileLoop C B r s)"
apply (clarsimp simp: whileLoop_def)
apply (rotate_tac)
Expand Down
33 changes: 16 additions & 17 deletions lib/Monads/trace/Trace_Monad.thy
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,13 @@ datatype ('s, 'a) tmres = Failed | Incomplete | Result "('a \<times> 's)"
abbreviation map_tmres_rv :: "('a \<Rightarrow> 'b) \<Rightarrow> ('s, 'a) tmres \<Rightarrow> ('s, 'b) tmres" where
"map_tmres_rv f \<equiv> map_tmres id f"

section "The Monad"

text \<open>
tmonad returns a set of non-deterministic computations, including
a trace as a list of "thread identifier" \<times> state, and an optional
pair of result and state when the computation did not fail.\<close>
type_synonym ('s, 'a) tmonad = "'s \<Rightarrow> ((tmid \<times> 's) list \<times> ('s, 'a) tmres) set"


text \<open>
Print the type @{typ "('s,'a) tmonad"} instead of its unwieldy expansion.
Needs an AST translation in code, because it needs to check that the state variable
Expand All @@ -64,6 +63,7 @@ print_ast_translation \<open>
else raise Match
in [(@{type_syntax "fun"}, tmonad_tr)] end\<close>


text \<open>Returns monad results, ignoring failures and traces.\<close>
definition mres :: "((tmid \<times> 's) list \<times> ('s, 'a) tmres) set \<Rightarrow> ('a \<times> 's) set" where
"mres r = Result -` (snd ` r)"
Expand Down Expand Up @@ -98,9 +98,8 @@ definition bind ::
| Result (rv, s) \<Rightarrow> fst_upd (\<lambda>ys. ys @ xs) ` g rv s"

text \<open>Sometimes it is convenient to write @{text bind} in reverse order.\<close>
abbreviation(input) bind_rev ::
"('c \<Rightarrow> ('a, 'b) tmonad) \<Rightarrow> ('a, 'c) tmonad \<Rightarrow> ('a, 'b) tmonad" (infixl "=<<" 60)
where
abbreviation (input) bind_rev ::
"('c \<Rightarrow> ('a, 'b) tmonad) \<Rightarrow> ('a, 'c) tmonad \<Rightarrow> ('a, 'b) tmonad" (infixl "=<<" 60) where
"g =<< f \<equiv> f >>= g"

text \<open>
Expand All @@ -123,6 +122,7 @@ primrec put_trace :: "(tmid \<times> 's) list \<Rightarrow> ('s, unit) tmonad" w
"put_trace [] = return ()"
| "put_trace (x # xs) = (put_trace xs >>= (\<lambda>_. put_trace_elem x))"


subsection "Nondeterminism"

text \<open>
Expand Down Expand Up @@ -153,6 +153,7 @@ definition state_select :: "('s \<times> 's) set \<Rightarrow> ('s, unit) tmonad
"state_select r \<equiv>
\<lambda>s. (Pair [] ` default_elem Failed (Result ` (\<lambda>x. ((), x)) ` {s'. (s, s') \<in> r}))"


subsection "Failure"

text \<open>
Expand Down Expand Up @@ -224,8 +225,8 @@ definition

subsection \<open>The Monad Laws\<close>

text \<open>An alternative definition of bind, sometimes more convenient.\<close>
lemma bind_def2:
text \<open>An alternative definition of @{term bind}, sometimes more convenient.\<close>
lemma bind_def':
"bind f g \<equiv>
\<lambda>s. ((\<lambda>xs. (xs, Failed)) ` {xs. (xs, Failed) \<in> f s})
\<union> ((\<lambda>xs. (xs, Incomplete)) ` {xs. (xs, Incomplete) \<in> f s})
Expand All @@ -242,7 +243,7 @@ lemma elem_bindE:
\<lbrakk>res = Incomplete \<or> res = Failed; (tr, map_tmres undefined undefined res) \<in> f s\<rbrakk> \<Longrightarrow> P;
\<And>tr' tr'' x s'. \<lbrakk>(tr', Result (x, s')) \<in> f s; (tr'', res) \<in> g x s'; tr = tr'' @ tr'\<rbrakk> \<Longrightarrow> P\<rbrakk>
\<Longrightarrow> P"
by (auto simp: bind_def2)
by (auto simp: bind_def')

text \<open>Each monad satisfies at least the following three laws.\<close>

Expand Down Expand Up @@ -277,6 +278,7 @@ lemma bind_assoc:
apply (simp add: image_image)
done


section \<open>Adding Exceptions\<close>

text \<open>
Expand Down Expand Up @@ -345,6 +347,7 @@ text \<open>
definition assertE :: "bool \<Rightarrow> ('a, 'e + unit) tmonad" where
"assertE P \<equiv> if P then returnOk () else fail"


subsection "Monad Laws for the Exception Monad"

text \<open>More direct definition of @{const liftE}:\<close>
Expand Down Expand Up @@ -585,7 +588,6 @@ definition sequenceE :: "('s, 'e+'a) tmonad list \<Rightarrow> ('s, 'e+'a list)
definition mapME :: "('a \<Rightarrow> ('s,'e+'b) tmonad) \<Rightarrow> 'a list \<Rightarrow> ('s,'e+'b list) tmonad" where
"mapME f xs \<equiv> sequenceE (map f xs)"


text \<open>Filtering a list using a monadic function as predicate:\<close>
primrec filterM :: "('a \<Rightarrow> ('s, bool) tmonad) \<Rightarrow> 'a list \<Rightarrow> ('s, 'a list) tmonad" where
"filterM P [] = return []"
Expand Down Expand Up @@ -751,20 +753,17 @@ definition ifME ::
if c then t else f
odE"

definition whenM ::
"('s, bool) tmonad \<Rightarrow> ('s, unit) tmonad \<Rightarrow> ('s, unit) tmonad" where
definition whenM :: "('s, bool) tmonad \<Rightarrow> ('s, unit) tmonad \<Rightarrow> ('s, unit) tmonad" where
"whenM t m = ifM t m (return ())"

definition orM ::
"('s, bool) tmonad \<Rightarrow> ('s, bool) tmonad \<Rightarrow> ('s, bool) tmonad" where
definition orM :: "('s, bool) tmonad \<Rightarrow> ('s, bool) tmonad \<Rightarrow> ('s, bool) tmonad" where
"orM a b = ifM a (return True) b"

definition
andM :: "('s, bool) tmonad \<Rightarrow> ('s, bool) tmonad \<Rightarrow> ('s, bool) tmonad" where
definition andM :: "('s, bool) tmonad \<Rightarrow> ('s, bool) tmonad \<Rightarrow> ('s, bool) tmonad" where
"andM a b = ifM a b (return False)"


subsection "Await command"
section "Await command"

text \<open>@{term "Await c f"} blocks the execution until @{term "c"} is true,
and then atomically executes @{term "f"}.\<close>
Expand All @@ -782,7 +781,7 @@ definition Await :: "('s \<Rightarrow> bool) \<Rightarrow> ('s,unit) tmonad" whe
od"


section "Trace monad Parallel"
section "Parallel combinator"

definition parallel :: "('s,'a) tmonad \<Rightarrow> ('s,'a) tmonad \<Rightarrow> ('s,'a) tmonad" where
"parallel f g = (\<lambda>s. {(xs, rv). \<exists>f_steps. length f_steps = length xs
Expand Down
8 changes: 8 additions & 0 deletions lib/Monads/trace/Trace_Monad_Equations.thy
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,14 @@ lemma gets_fold_into_modify:
by (simp_all add: fun_eq_iff modify_def bind_assoc exec_gets
exec_get exec_put)

lemma gets_return_gets_eq:
"gets f >>= (\<lambda>g. return (h g)) = gets (\<lambda>s. h (f s))"
by (simp add: simpler_gets_def bind_def return_def)

lemma gets_prod_comp:
"gets (case x of (a, b) \<Rightarrow> f a b) = (case x of (a, b) \<Rightarrow> gets (f a b))"
by (auto simp: split_def)

lemma bind_assoc2:
"(do x \<leftarrow> a; _ \<leftarrow> b; c x od) = (do x \<leftarrow> (do x' \<leftarrow> a; _ \<leftarrow> b; return x' od); c x od)"
by (simp add: bind_assoc)
Expand Down
Loading

0 comments on commit ca50a02

Please sign in to comment.