Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Lens to allow for easy transition to Scala3 #3970

Merged
merged 5 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import com.wavesplatform.lang.v1.compiler.Types.*
import com.wavesplatform.lang.v1.evaluator.ctx.FunctionTypeSignature
import com.wavesplatform.lang.v1.parser.Expressions.Pos
import com.wavesplatform.lang.v1.parser.Expressions.Pos.AnyPos
import shapeless.*

case class CompilerContext(
predefTypes: Map[String, FINAL],
Expand Down Expand Up @@ -60,9 +59,5 @@ object CompilerContext {
y.provideRuntimeTypeOnCastError
)

val types: Lens[CompilerContext, Map[String, FINAL]] = lens[CompilerContext] >> Symbol("predefTypes")
val vars: Lens[CompilerContext, VariableTypes] = lens[CompilerContext] >> Symbol("varDefs")
val functions: Lens[CompilerContext, FunctionTypes] = lens[CompilerContext] >> Symbol("functionDefs")

val empty = CompilerContext(Map(), Map(), Map(), true)
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import com.wavesplatform.lang.contract.DApp.*
import com.wavesplatform.lang.contract.meta.{MetaMapper, V1 as MetaV1, V2 as MetaV2}
import com.wavesplatform.lang.directives.values.{StdLibVersion, V3, V6}
import com.wavesplatform.lang.v1.compiler.CompilationError.{AlreadyDefined, Generic, UnionNotAllowedForCallableArgs, WrongArgumentType}
import com.wavesplatform.lang.v1.compiler.CompilerContext.{VariableInfo, vars}
import com.wavesplatform.lang.v1.compiler.CompilerContext.VariableInfo
import com.wavesplatform.lang.v1.compiler.ContractCompiler.*
import com.wavesplatform.lang.v1.compiler.ScriptResultSource.FreeCall
import com.wavesplatform.lang.v1.compiler.Terms.EXPR
Expand Down Expand Up @@ -87,7 +87,7 @@ class ContractCompiler(version: StdLibVersion) extends ExpressionCompiler(versio
.getOrElse(List.empty)
unionInCallableErrs <- checkCallableUnions(af, annotationsWithErr._1.toList.flatten)
compiledBody <- local {
modify[Id, CompilerContext, CompilationError](vars.modify(_)(_ ++ annotationBindings)).flatMap(_ =>
modify[Id, CompilerContext, CompilationError](ctx => ctx.copy(varDefs = ctx.varDefs ++ annotationBindings)).flatMap(_ =>
compileFunc(af.f.position, af.f, saveExprContext, annotationBindings.map(_._1), allowIllFormedStrings)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ class ExpressionCompiler(val version: StdLibVersion) {
.handleError()
compiledFuncBody <- local {
val newArgs: VariableTypes = argTypesWithErr._1.getOrElse(List.empty).toMap
modify[Id, CompilerContext, CompilationError](vars.modify(_)(_ ++ newArgs))
modify[Id, CompilerContext, CompilationError](ctx1 => ctx1.copy(varDefs = ctx1.varDefs ++ newArgs))
.flatMap(_ => compileExprWithCtx(func.expr, saveExprContext, allowIllFormedStrings))
}

Expand All @@ -368,10 +368,10 @@ class ExpressionCompiler(val version: StdLibVersion) {
}

protected def updateCtx(letName: String, letType: Types.FINAL, p: Pos): CompileM[Unit] =
modify[Id, CompilerContext, CompilationError](vars.modify(_)(_ + (letName -> VariableInfo(p, letType))))
modify[Id, CompilerContext, CompilationError](ctx => ctx.copy(varDefs = ctx.varDefs + (letName -> VariableInfo(p, letType))))

protected def updateCtx(funcName: String, typeSig: FunctionTypeSignature, p: Pos): CompileM[Unit] =
modify[Id, CompilerContext, CompilationError](functions.modify(_)(_ + (funcName -> FunctionInfo(p, List(typeSig)))))
modify[Id, CompilerContext, CompilationError](ctx => ctx.copy(functionDefs = ctx.functionDefs + (funcName -> FunctionInfo(p, List(typeSig)))))

private def compileLetBlock(
p: Pos,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,18 @@
package com.wavesplatform.lang.v1.estimator.v2


import com.wavesplatform.lang.v1.FunctionHeader
import com.wavesplatform.lang.v1.compiler.Terms.FUNC
import com.wavesplatform.lang.v1.estimator.EstimationError
import com.wavesplatform.lang.v1.estimator.v2.EstimatorContext.EvalM
import com.wavesplatform.lang.v1.task.TaskM
import shapeless.{Lens, lens}

private[v2] case class EstimatorContext(
letDefs: Map[String, (Boolean, EvalM[Long])],
predefFuncs: Map[FunctionHeader, Long],
userFuncs: Map[FunctionHeader, FUNC] = Map.empty,
overlappedRefs: Map[String, (Boolean, EvalM[Long])] = Map.empty
letDefs: Map[String, (Boolean, EvalM[Long])],
predefFuncs: Map[FunctionHeader, Long],
userFuncs: Map[FunctionHeader, FUNC] = Map.empty,
overlappedRefs: Map[String, (Boolean, EvalM[Long])] = Map.empty
)

private[v2] object EstimatorContext {
type EvalM[A] = TaskM[EstimatorContext, EstimationError, A]

object Lenses {
val lets: Lens[EstimatorContext, Map[String, (Boolean, EvalM[Long])]] = lens[EstimatorContext] >> Symbol("letDefs")
val userFuncs: Lens[EstimatorContext, Map[FunctionHeader, FUNC]] = lens[EstimatorContext] >> Symbol("userFuncs")
val predefFuncs: Lens[EstimatorContext, Map[FunctionHeader, Long]] = lens[EstimatorContext] >> Symbol("predefFuncs")
val overlappedRefs: Lens[EstimatorContext, Map[String, (Boolean, EvalM[Long])]] = lens[EstimatorContext] >> Symbol("overlappedRefs")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import com.wavesplatform.lang.v1.FunctionHeader
import com.wavesplatform.lang.v1.compiler.Terms._
import com.wavesplatform.lang.v1.estimator.{EstimationError, ScriptEstimator}
import com.wavesplatform.lang.v1.estimator.v2.EstimatorContext.EvalM
import com.wavesplatform.lang.v1.estimator.v2.EstimatorContext.Lenses._
import com.wavesplatform.lang.v1.task.imports._
import monix.eval.Coeval

Expand Down Expand Up @@ -45,7 +44,7 @@ object ScriptEstimatorV2 extends ScriptEstimator {
local {
val letResult = (false, evalExpr(let.value))
for {
_ <- update(lets.modify(_)(_.updated(let.name, letResult)))
_ <- update(ctx => ctx.copy(letDefs = ctx.letDefs.updated(let.name, letResult)))
r <- evalExpr(inner)
} yield r + 5
}
Expand All @@ -61,31 +60,31 @@ object ScriptEstimatorV2 extends ScriptEstimator {
local {
for {
_ <- checkFuncCtx(func)
_ <- update(userFuncs.modify(_)(_ + (FunctionHeader.User(func.name) -> func)))
_ <- update(ctx => ctx.copy(userFuncs = ctx.userFuncs + (FunctionHeader.User(func.name) -> func)))
r <- evalExpr(inner)
} yield r + 5
}

private def checkFuncCtx(func: FUNC): EvalM[Unit] =
local {
for {
_ <- update(lets.modify(_)(_ ++ func.args.map((_, (true, const(0)))).toMap))
_ <- update(ctx => ctx.copy(letDefs = ctx.letDefs ++ func.args.map((_, (true, const(0)))).toMap))
_ <- evalExpr(func.body)
} yield ()
}

private def evalRef(key: String): EvalM[Long] =
for {
ctx <- get[Id, EstimatorContext, EstimationError]
r <- lets.get(ctx).get(key) match {
r <- ctx.letDefs.get(key) match {
case Some((false, lzy)) => setRefEvaluated(key, lzy)
case Some((true, _)) => const(0)
case None => raiseError[Id, EstimatorContext, EstimationError, Long](s"A definition of '$key' not found")
}
} yield r + 2

private def setRefEvaluated(key: String, lzy: EvalM[Long]): EvalM[Long] =
update(lets.modify(_)(_.updated(key, (true, lzy))))
update(ctx => ctx.copy(letDefs = ctx.letDefs.updated(key, (true, lzy))))
.flatMap(_ => lzy)

private def evalGetter(expr: EXPR): EvalM[Long] =
Expand All @@ -94,26 +93,36 @@ object ScriptEstimatorV2 extends ScriptEstimator {
private def evalFuncCall(header: FunctionHeader, args: List[EXPR]): EvalM[Long] =
for {
ctx <- get[Id, EstimatorContext, EstimationError]
bodyComplexity <- predefFuncs
.get(ctx)
bodyComplexity <- ctx.predefFuncs
.get(header)
.map(bodyComplexity => evalFuncArgs(args).map(_ + bodyComplexity))
.orElse(userFuncs.get(ctx).get(header).map(evalUserFuncCall(_, args)))
.orElse(ctx.userFuncs.get(header).map(evalUserFuncCall(_, args)))
.getOrElse(raiseError[Id, EstimatorContext, EstimationError, Long](s"function '$header' not found"))
} yield bodyComplexity

private def evalUserFuncCall(func: FUNC, args: List[EXPR]): EvalM[Long] =
for {
argsComplexity <- evalFuncArgs(args)
ctx <- get[Id, EstimatorContext, EstimationError]
_ <- update(lets.modify(_)(_ ++ ctx.overlappedRefs))
_ <- update(ctx1 => ctx1.copy(letDefs = ctx1.letDefs ++ ctx.overlappedRefs))
overlapped = func.args.flatMap(arg => ctx.letDefs.get(arg).map((arg, _))).toMap
ctxArgs = func.args.map((_, (false, const(1)))).toMap
_ <- update((lets ~ overlappedRefs).modify(_) { case (l, or) => (l ++ ctxArgs, or ++ overlapped) })
_ <- update(ctx1 =>
ctx1.copy(
letDefs = ctx1.letDefs ++ ctxArgs,
overlappedRefs = ctx1.overlappedRefs ++ overlapped
)
)

bodyComplexity <- evalExpr(func.body).map(_ + func.args.size * 5)
evaluatedCtx <- get[Id, EstimatorContext, EstimationError]
overlappedChanges = overlapped.map { case ref @ (name, _) => evaluatedCtx.letDefs.get(name).map((name, _)).getOrElse(ref) }
_ <- update((lets ~ overlappedRefs).modify(_) { case (l, or) => (l -- ctxArgs.keys ++ overlapped, or ++ overlappedChanges) })
_ <- update(ctx1 =>
ctx1.copy(
letDefs = ctx1.letDefs -- ctxArgs.keys ++ overlapped,
overlappedRefs = ctx1.overlappedRefs ++ overlappedChanges
)
)
} yield bodyComplexity + argsComplexity

private def evalFuncArgs(args: List[EXPR]): EvalM[Long] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import com.wavesplatform.lang.v1.estimator.EstimationError
import com.wavesplatform.lang.v1.estimator.v3.EstimatorContext.EvalM
import com.wavesplatform.lang.v1.task.TaskM
import monix.eval.Coeval
import shapeless.{Lens, lens}

private[v3] case class EstimatorContext(
funcs: Map[FunctionHeader, (Coeval[Long], Set[String])],
Expand All @@ -18,9 +17,4 @@ private[v3] case class EstimatorContext(

private[v3] object EstimatorContext {
type EvalM[A] = TaskM[EstimatorContext, EstimationError, A]

object Lenses {
val funcs: Lens[EstimatorContext, Map[FunctionHeader, (Coeval[Long], Set[String])]] = lens[EstimatorContext] >> Symbol("funcs")
val usedRefs: Lens[EstimatorContext, Set[String]] = lens[EstimatorContext] >> Symbol("usedRefs")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import com.wavesplatform.lang.v1.FunctionHeader
import com.wavesplatform.lang.v1.FunctionHeader.User
import com.wavesplatform.lang.v1.compiler.Terms.*
import com.wavesplatform.lang.v1.estimator.v3.EstimatorContext.EvalM
import com.wavesplatform.lang.v1.estimator.v3.EstimatorContext.Lenses.*
import com.wavesplatform.lang.v1.estimator.{EstimationError, ScriptEstimator}
import com.wavesplatform.lang.v1.task.imports.*
import monix.eval.Coeval
Expand Down Expand Up @@ -82,7 +81,7 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
letCosts <- usedRefs.toSeq.traverse { ref =>
local {
for {
_ <- update(funcs.set(_)(startCtx.funcs))
_ <- update(ctx1 => ctx1.copy(funcs = startCtx.funcs))
cost <- ctx.globalLetEvals.getOrElse(ref, zero)
} yield cost
}
Expand All @@ -100,22 +99,18 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
}

private def beforeNextExprEval(let: LET, eval: EvalM[Long]): EvalM[Unit] =
update(ctx =>
usedRefs
.modify(ctx)(_ - let.name)
.copy(refsCosts = ctx.refsCosts + (let.name -> local(eval)))
)
update(ctx => ctx.copy(usedRefs = ctx.usedRefs - let.name, refsCosts = ctx.refsCosts + (let.name -> local(eval))))

private def afterNextExprEval(let: LET, startCtx: EstimatorContext): EvalM[Unit] =
update(ctx =>
usedRefs
.modify(ctx)(r => if (startCtx.usedRefs.contains(let.name)) r + let.name else r - let.name)
.copy(refsCosts =
ctx.copy(
usedRefs = if (startCtx.usedRefs.contains(let.name)) ctx.usedRefs + let.name else ctx.usedRefs - let.name,
refsCosts =
if (startCtx.refsCosts.contains(let.name))
ctx.refsCosts + (let.name -> startCtx.refsCosts(let.name))
else
ctx.refsCosts - let.name
)
)
)

private def evalFuncBlock(func: FUNC, nextExpr: EXPR, activeFuncArgs: Set[String], globalDeclarationsMode: Boolean): EvalM[Long] =
Expand All @@ -142,14 +137,12 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
_ <- set[Id, EstimatorContext, EstimationError](ctx.copy(globalFunctionsCosts = ctx.globalFunctionsCosts + (name -> totalCost)))
} yield ()

private def handleUsedRefs(name: String, cost: Long, ctx: EstimatorContext, refsUsedInBody: Set[String]): EvalM[Unit] =
update(
(funcs ~ usedRefs).modify(_) { case (funcs, _) =>
(
funcs + (User(name) -> (Coeval.now(cost), refsUsedInBody)),
ctx.usedRefs
)
}
private def handleUsedRefs(name: String, cost: Long, startCtx: EstimatorContext, refsUsedInBody: Set[String]): EvalM[Unit] =
update(ctx =>
ctx.copy(
funcs = ctx.funcs + (User(name) -> (Coeval.now(cost), refsUsedInBody)),
usedRefs = startCtx.usedRefs
)
)

private def evalIF(cond: EXPR, ifTrue: EXPR, ifFalse: EXPR, activeFuncArgs: Set[String]): EvalM[Long] =
Expand All @@ -165,7 +158,7 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
if (activeFuncArgs.contains(key) && letFixes)
const(overheadCost)
else
update(usedRefs.modify(_)(_ + key)).map(_ => overheadCost)
update(ctx => ctx.copy(usedRefs = ctx.usedRefs + key)).map(_ => overheadCost)

private def evalGetter(expr: EXPR, activeFuncArgs: Set[String]): EvalM[Long] =
evalExpr(expr, activeFuncArgs).flatMap(sum(_, overheadCost))
Expand All @@ -187,18 +180,15 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
} yield result

private def setFuncToCtx(header: FunctionHeader, bodyCost: Coeval[Long], bodyUsedRefs: Set[EstimationError]): EvalM[Unit] =
update(
(funcs ~ usedRefs).modify(_) { case (funcs, usedRefs) =>
(
funcs + (header -> (bodyCost, Set())),
usedRefs ++ bodyUsedRefs
)
}
update(ctx =>
ctx.copy(
funcs = ctx.funcs + (header -> (bodyCost, Set())),
usedRefs = ctx.usedRefs ++ bodyUsedRefs
)
)

private def getFuncCost(header: FunctionHeader, ctx: EstimatorContext): EvalM[(Coeval[Long], Set[EstimationError])] =
funcs
.get(ctx)
ctx.funcs
.get(header)
.map(const)
.getOrElse(
Expand All @@ -217,9 +207,9 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
): EvalM[Long] =
for {
startCtx <- get[Id, EstimatorContext, EstimationError]
_ <- ctxFuncsOpt.fold(doNothing.void)(ctxFuncs => update(funcs.set(_)(ctxFuncs)))
_ <- ctxFuncsOpt.fold(doNothing.void)(ctxFuncs => update(ctx => ctx.copy(funcs = ctxFuncs)))
cost <- evalExpr(expr, activeFuncArgs)
_ <- update(funcs.set(_)(startCtx.funcs))
_ <- update(ctx => ctx.copy(funcs = startCtx.funcs))
} yield cost

private def withUsedRefs[A](eval: EvalM[A]): EvalM[(A, Set[String])] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import com.wavesplatform.lang.v1.compiler.Terms.*
import com.wavesplatform.lang.v1.compiler.Types.{CASETYPEREF, NOTHING}
import com.wavesplatform.lang.v1.evaluator.ContextfulNativeFunction.{Extended, Simple}
import com.wavesplatform.lang.v1.evaluator.ctx.*
import com.wavesplatform.lang.v1.evaluator.ctx.EnabledLogEvaluationContext.Lenses
import com.wavesplatform.lang.v1.task.imports.*
import com.wavesplatform.lang.v1.traits.Environment
import com.wavesplatform.lang.*
Expand All @@ -30,16 +29,16 @@ object EvaluatorV1 {
}

class EvaluatorV1[F[_]: Monad, C[_[_]]](implicit ev: Monad[EvalF[F, *]], ev2: Monad[CoevalF[F, *]]) {
private val lenses = new Lenses[F, C]
import lenses.*

private def evalLetBlock(let: LET, inner: EXPR): EvalM[F, C, (EvaluationContext[C, F], EVALUATED)] =
for {
ctx <- get[F, EnabledLogEvaluationContext[C, F], ExecutionError]
blockEvaluation = evalExpr(let.value)
lazyBlock = LazyVal(blockEvaluation.ter(ctx), ctx.l(let.name))
result <- local {
modify[F, EnabledLogEvaluationContext[C, F], ExecutionError](lets.modify(_)(_.updated(let.name, lazyBlock)))
modify[F, EnabledLogEvaluationContext[C, F], ExecutionError](ctx1 =>
ctx1.copy(ec = ctx1.ec.copy(letDefs = ctx1.ec.letDefs.updated(let.name, lazyBlock)))
)
.flatMap(_ => evalExprWithCtx(inner))
}
} yield result
Expand All @@ -49,15 +48,17 @@ class EvaluatorV1[F[_]: Monad, C[_[_]]](implicit ev: Monad[EvalF[F, *]], ev2: Mo
val function = UserFunction(func.name, 0, NOTHING, func.args.map(n => (n, NOTHING))*)(func.body)
.asInstanceOf[UserFunction[C]]
local {
modify[F, EnabledLogEvaluationContext[C, F], ExecutionError](funcs.modify(_)(_.updated(funcHeader, function)))
modify[F, EnabledLogEvaluationContext[C, F], ExecutionError](ctx =>
ctx.copy(ec = ctx.ec.copy(functions = ctx.ec.functions.updated(funcHeader, function)))
)
.flatMap(_ => evalExprWithCtx(inner))
}
}

private def evalRef(key: String): EvalM[F, C, (EvaluationContext[C, F], EVALUATED)] =
for {
ctx <- get[F, EnabledLogEvaluationContext[C, F], ExecutionError]
r <- lets.get(ctx).get(key) match {
r <- ctx.ec.letDefs.get(key) match {
case Some(lzy) => liftTER[F, C, EVALUATED](lzy.value)
case None => raiseError[F, EnabledLogEvaluationContext[C, F], ExecutionError, EVALUATED](s"A definition of '$key' not found")
}
Expand All @@ -83,8 +84,7 @@ class EvaluatorV1[F[_]: Monad, C[_[_]]](implicit ev: Monad[EvalF[F, *]], ev2: Mo
private def evalFunctionCall(header: FunctionHeader, args: List[EXPR]): EvalM[F, C, (EvaluationContext[C, F], EVALUATED)] =
for {
ctx <- get[F, EnabledLogEvaluationContext[C, F], ExecutionError]
result <- funcs
.get(ctx)
result <- ctx.ec.functions
.get(header)
.map {
case func: UserFunction[C] =>
Expand All @@ -94,7 +94,7 @@ class EvaluatorV1[F[_]: Monad, C[_[_]]](implicit ev: Monad[EvalF[F, *]], ev2: Mo
}
local {
val newState: EvalM[F, C, Unit] =
set[F, EnabledLogEvaluationContext[C, F], ExecutionError](lets.set(ctx)(letDefsWithArgs)).map(_.pure[F])
set[F, EnabledLogEvaluationContext[C, F], ExecutionError](ctx.copy(ec = ctx.ec.copy(letDefs = letDefsWithArgs))).map(_.pure[F])
Monad[EvalM[F, C, *]].flatMap(newState)(_ => evalExpr(func.ev(ctx.ec.environment, args)))
}
}: EvalM[F, C, EVALUATED]
Expand All @@ -118,7 +118,7 @@ class EvaluatorV1[F[_]: Monad, C[_[_]]](implicit ev: Monad[EvalF[F, *]], ev2: Mo
// no such function, try data constructor
header match {
case FunctionHeader.User(typeName, _) =>
types.get(ctx).get(typeName).collect { case t @ CASETYPEREF(_, fields, _) =>
ctx.ec.typeDefs.get(typeName).collect { case t @ CASETYPEREF(_, fields, _) =>
args
.traverse[EvalM[F, C, *], EVALUATED](evalExpr)
.map(values => CaseObj(t, fields.map(_._1).zip(values).toMap): EVALUATED)
Expand Down
Loading