From d7d43fdae5a3e97d2ec312affe01349c15210ca8 Mon Sep 17 00:00:00 2001 From: pragu Date: Thu, 24 Jul 2025 12:39:43 +1000 Subject: [PATCH 1/6] Add analysis for interprodeural copy-constant propagation using IDE framework, with some relevant tests --- src/main/scala/analysis/InterCopyConst.scala | 134 +++++ src/test/scala/InterCopyConstTests.scala | 503 +++++++++++++++++++ 2 files changed, 637 insertions(+) create mode 100644 src/main/scala/analysis/InterCopyConst.scala create mode 100644 src/test/scala/InterCopyConstTests.scala diff --git a/src/main/scala/analysis/InterCopyConst.scala b/src/main/scala/analysis/InterCopyConst.scala new file mode 100644 index 000000000..ced4dbb90 --- /dev/null +++ b/src/main/scala/analysis/InterCopyConst.scala @@ -0,0 +1,134 @@ +package analysis + +import analysis.solvers.ForwardIDESolver +import ir.* + +trait CopyConstAnalysisFunctions(parameterForm: Boolean) extends ForwardIDEAnalysis[Variable, FlatElement[BitVecLiteral], ConstantPropagationLattice] { + + val valuelattice: ConstantPropagationLattice = ConstantPropagationLattice() + val edgelattice: EdgeFunctionLattice[FlatElement[BitVecLiteral], ConstantPropagationLattice] = EdgeFunctionLattice(valuelattice) + import edgelattice.{IdEdge, ConstEdge} + + def edgesCallToEntry(call: DirectCall, entry: Procedure)(d: DL): Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]] = { + + //print("actual params: " + call.actualParams) + + + // below only for param form no?, otherwise just give everything + if !parameterForm then Map(d->IdEdge()) + else + d match { + case Left(a) => + call.actualParams.toList.foldLeft(Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]]()) { + case (m, (inVar, expression)) => expression match + case LocalVar(_, _, _) | Register(_, _) if expression == a => m ++ Map(Left(inVar) -> IdEdge(), d -> IdEdge()) // idk if this actually checks properly + case LocalVar(_, _, _) | Register(_, _) if expression != a && inVar != a => m ++ Map(d -> IdEdge()) + case _ => m ++ Map() + + } + case Right(a) => + val lambdaToLambda: Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]] = Map(d -> IdEdge()) + call.actualParams.toList.foldLeft(Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]]()) { + case (m, (inVar, expression)) => expression match { + case LocalVar(_,_,_) | Register(_,_) => m ++ lambdaToLambda //not add anything else from lambda + case BitVecLiteral(value, size) => m ++ lambdaToLambda ++ Map(Left(inVar)->ConstEdge(valuelattice.bv(BitVecLiteral(value, size)))) + case _ => m ++ lambdaToLambda ++ Map(Left(inVar) -> ConstEdge(valuelattice.top)) + // direct call? + + } + + } + } + + + + + } + + def edgesExitToAfterCall(exit: Return, aftercall: Command)(d: DL): Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]] = { + + //print(d) + + if !parameterForm then Map(d->IdEdge()) + else + val call: DirectCall = aftercall match { + case aftercall: Statement => aftercall.parent.statements.getPrev(aftercall).asInstanceOf[DirectCall] + case _: Jump => aftercall.parent.statements.last.asInstanceOf[DirectCall] + } + + d match { + case Left(a) => + exit.outParams.toList.foldLeft(Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]]()) { + case (m, (retVar, expression)) => expression match + case LocalVar(_, _, _) | Register(_, _) if expression == a => m ++ Map(Left(call.outParams(retVar)) -> IdEdge()) + //case LocalVar(_, _, _) | Register(_, _) if expression != a => m ++ Map(d -> IdEdge()) // lol need to fix up here in params shld always just be Map() so add case + case _ => m ++ Map() //ignore other kind of expr, including local vars / in params of the procedure + + } + case Right(a) => + val lambdaToLambda: Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]] = Map(d -> IdEdge()) + + exit.outParams.toList.foldLeft(Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]]()) { + case (m, (retVar, expression)) => expression match { + case LocalVar(_, _, _) | Register(_, _) => m ++ lambdaToLambda //not add anything else from lambda + case BitVecLiteral(value, size) => m ++ lambdaToLambda ++ Map(Left(call.outParams(retVar)) -> ConstEdge(valuelattice.bv(BitVecLiteral(value, size)))) + case _ => m ++ lambdaToLambda ++ Map(Left(call.outParams(retVar)) -> ConstEdge(valuelattice.top)) + + } + + } + + + } + } + + def edgesCallToAfterCall(call: DirectCall, aftercall: Command)(d: DL): Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]] = { + + if !parameterForm then Map() + else // unused locals in function and lambda need identity + d match { + case Left(v) if (call.outParams.exists(_._2 == v) || call.actualParams.exists(_._2 == v)) => Map() + case _ => Map(d->IdEdge()) + + } + // currently every global going into each procedure regardless if going to be modified or not - way to check?? + } + + def edgesOther(n: CFGPosition)(d: DL): Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]] = { + + + n match { + case LocalAssign(variable, expression, _) => + // shld make this function icl --> figure out if can just put all under one + d match { + case Right(_) => + val lambdaToLambda : Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]] = Map(d -> IdEdge()) + expression match { + case LocalVar(_,_,_) | Register(_,_) => lambdaToLambda //not add anything else from lambda + case BitVecLiteral(value, size) => lambdaToLambda ++ Map(Left(variable)->ConstEdge(valuelattice.bv(BitVecLiteral(value, size)))) + case _ => lambdaToLambda ++ Map(Left(variable) -> ConstEdge(valuelattice.top)) + } + + case Left(a) => + expression match { + case LocalVar(_, _, _) | Register(_,_) if expression == a => Map(Left(variable)->IdEdge(), d ->IdEdge()) // idk if this actually checks properly + //case LocalVar(_, _, _) | Register(_,_) if expression != a => Map(d->IdEdge()) + case BitVecLiteral(_, _) => Map() //remove old value + case _ => Map(d->IdEdge()) + //case _ => Map() //ignore other kind of expr, should this be top or nothing <-- or d is identity, var is top?? + + } + + } + case MemoryLoad(variable, _, _, _, _, _)=> //might have ti fix this up + d match { + case Left(_) => Map(d -> IdEdge()) + case Right(_) => Map(Left(variable) -> ConstEdge(valuelattice.top), d -> IdEdge()) + } + + case _ => Map(d->IdEdge()) + } + } +} + +class InterCopyConst(program:Program, parameterForm: Boolean) extends ForwardIDESolver[Variable, FlatElement[BitVecLiteral], ConstantPropagationLattice](program), CopyConstAnalysisFunctions(parameterForm) diff --git a/src/test/scala/InterCopyConstTests.scala b/src/test/scala/InterCopyConstTests.scala new file mode 100644 index 000000000..f2bba7648 --- /dev/null +++ b/src/test/scala/InterCopyConstTests.scala @@ -0,0 +1,503 @@ +import analysis.* +import ir.* +import org.scalatest.funsuite.AnyFunSuite +import test_util.CaptureOutput +import cilvisitor.* +import ir.dsl.* + +@test_util.tags.UnitTest +class InterCopyConstTests extends AnyFunSuite, CaptureOutput { + + + def getInterCopyConstResults(program: Program, paramaterForm: Boolean): Unit = { + print(InterCopyConst(program, paramaterForm).analyze()) + } + + test("intraCopyAssign") { + // all good + val program = prog( + proc("main", block("main", LocalAssign(R0, bv64(3)), LocalAssign(R1, R0), goto("mainRet")), block("mainRet", ret)), + ) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + + getInterCopyConstResults(program, false) + } + + test("intraOldTransform") { + val program = prog( + proc("main", + Seq(), + Seq("R0_out" -> BitVecType(64), "R1_out" -> BitVecType(64)), + block("main", + LocalAssign(LocalVar("R0", BitVecType(64), 0), bv64(3)), + LocalAssign(LocalVar("R1", BitVecType(64), 0), LocalVar("R0", BitVecType(64), 0)), + goto("mainRet") + ), + block("mainRet", + ret( + "R0_out" -> LocalVar("R0", BitVecType(64), 0), + "R1_out" -> LocalVar("R1", BitVecType(64), 0) + ) + ) + ) + ) + transforms.copyPropParamFixedPoint(program, Map()) + print(program) + } + + test("intraNewTransform") { + val program = prog( + proc("main", + Seq(), + Seq("R0_out" -> BitVecType(64), "R1_out" -> BitVecType(64)), + block("main", + LocalAssign(LocalVar("R0", BitVecType(64), 0), bv64(3)), + LocalAssign(LocalVar("R1", BitVecType(64), 0), LocalVar("R0", BitVecType(64), 0)), + goto("mainRet") + ), + block("mainRet", + ret( + "R0_out" -> LocalVar("R0", BitVecType(64), 0), + "R1_out" -> LocalVar("R1", BitVecType(64), 0) + ) + ) + ) + ) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + + + print(program) + visit_prog(transforms.ConstCopyPropTransform(program), program) + print(program) + } + + test("intraAssignOverride") { + // all good + val program = prog( + proc("main", block("main", LocalAssign(R0, bv64(3)), LocalAssign(R0, bv64(5)), goto("mainRet")), block("mainRet", ret)), + ) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + + println(program) + + getInterCopyConstResults(program, false) + } + + test("intraNonCopyAssignOverride") { + // all good + val program = prog( + proc("main", block("main", LocalAssign(R0, bv64(3)), LocalAssign(R0, BinaryExpr(BVADD, R0, bv64(1))), LocalAssign(R1, R0), goto("mainRet")), block("mainRet", ret)), + ) + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + print(program) + + getInterCopyConstResults(program, false) + + visit_prog(transforms.ConstCopyPropTransform(program), program) + print(program) + } + + + test("intraNonCopyAssign") { + val program = prog( + proc("main", block("main", LocalAssign(R1, bv64(3)), LocalAssign(R0, BinaryExpr(BVADD, R0, bv64(1))), goto("mainRet")), block("mainRet", ret)), + ) + println(program) + + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + + getInterCopyConstResults(program, false) + } + + + test("interAssign") { + + val program = prog( + proc("main", + block("main", directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("in" -> bv64(7))), goto("mainRet")), + block("mainRet", ret) + ), + proc("f", + Seq("in" -> BitVecType(64)), Seq("p_out" -> BitVecType(64)), + block("okay", goto("f_ret")), + block("f_ret", ret("p_out" -> bv64(5))), + ) + + ) + + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + + println(program) + + getInterCopyConstResults(program, true) + + //visit_prog(transforms.ConstCopyPropTransform(program), program) + + transforms.copyPropParamFixedPoint(program, Map()) + print(program) + } + + test("interMultipleOutParams") { + + val program = prog( + proc("main", + block("main", directCall(Seq("p_out_1" -> LocalVar("v", BitVecType(64)), "p_out_2" -> LocalVar("a", BitVecType(64))), "f", Seq("in" -> bv64(7))), goto("mainRet")), + block("mainRet", ret) + ), + proc("f", + Seq("in" -> BitVecType(64)), Seq("p_out_1" -> BitVecType(64), "p_out_2" -> BitVecType(64)), + block("okay", goto("f_ret")), + block("f_ret", ret("p_out_1" -> bv64(5), "p_out_2" -> bv64(1))), + ) + + ) + + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + + println(program) + + getInterCopyConstResults(program, true) + + + transforms.copyPropParamFixedPoint(program, Map()) + + //visit_prog(transforms.ConstCopyPropTransform(program), program) + print(program) + } + + + test("interCallProcNoOutParam") { + + val program = prog( + proc("main", + block("main", directCall("f"), goto("mainRet")), + block("mainRet", ret) + ), + proc("f", + block("okay", LocalAssign(LocalVar("R0", bv64), bv64(3)) ,goto("f_ret")), + block("f_ret", ret()), + ) + + ) + + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + + println(program) + + getInterCopyConstResults(program, true) + + visit_prog(transforms.ConstCopyPropTransform(program), program) + print(program) + } + + + test("interReturnNonConstant") { + + val program = prog( + proc("main", + block("main", directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("in" -> bv64(7))), goto("mainRet")), + block("mainRet", ret) + ), + proc("f", + Seq("in" -> BitVecType(64)), Seq("p_out" -> BitVecType(64)), + block("okay", goto("f_ret")), + block("f_ret", ret("p_out" -> BinaryExpr(BVADD, bv64(2), bv64(1)))), + ) + + ) + + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + + println(program) + + + getInterCopyConstResults(program, true) + + //cilvisitor.visit_prog(transforms.ConstCopyPropTransform(program), program) + transforms.copyPropParamFixedPoint(program, Map()) + + print(program) + } + + test("interReturnOneNonConstantOneConstant") { + + + + val program = prog( + proc("main", + block("main", directCall(Seq("p_out_1" -> LocalVar("a", BitVecType(64)), "p_out_2" -> LocalVar("b", BitVecType(64))), "f", Seq("in" -> bv64(7))), goto("mainRet")), + block("mainRet", ret) + ), + proc("f", + Seq("in" -> BitVecType(64)), Seq("p_out_1" -> BitVecType(64), "p_out_2" -> BitVecType(64)), + block("okay", goto("f_ret")), + block("f_ret", ret("p_out_1" -> BinaryExpr(BVADD, LocalVar("in", BitVecType(64)), bv64(1)) , "p_out_2" -> bv64(1))), + ) + + ) + + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + + println(program) + + + getInterCopyConstResults(program, true) + + //cilvisitor.visit_prog(transforms.ConstCopyPropTransform(program), program) + transforms.copyPropParamFixedPoint(program, Map()) + print(program) + } + + test("interAssignOverride") { + val program = prog( + proc("main", + block("main", LocalAssign(LocalVar("v", BitVecType(64)), bv64(2)), directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("in" -> bv64(7))), goto("mainRet")), + block("mainRet", ret) + ), + proc("f", + Seq("in" -> BitVecType(64)), Seq("p_out" -> BitVecType(64)), + block("okay", goto("f_ret")), + block("f_ret", ret("p_out" -> bv64(5))), + ) + + ) + + + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + + + println(program) + + + getInterCopyConstResults(program, true) + } + + + test("interReturnLocalVar") { + val program = prog( + proc("main", + block("main", directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("p_in" -> bv64(7))), goto("mainRet")), + block("mainRet", ret) + ), + proc("f", + Seq("p_in" -> BitVecType(64)), Seq("p_out" -> BitVecType(64)), + block("okay", LocalAssign(LocalVar("bruh", BitVecType(64)), bv64(6)), goto("f_ret")), + block("f_ret", ret("p_out" -> LocalVar("bruh", BitVecType(64)))) + ) + + ) + + + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + + + println(program) + + + getInterCopyConstResults(program, true) + } + + test("interReturnInParam") { + val program = prog( + proc("main", + block("main", directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("p_in" -> bv64(7))), goto("mainRet")), + block("mainRet", ret) + ), + proc("f", + Seq("p_in" -> BitVecType(64)), Seq("p_out" -> BitVecType(64)), + block("okay", goto("f_ret")), + block("f_ret", ret("p_out" -> LocalVar("p_in", BitVecType(64)))) + ) + + ) + + + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + + + println(program) + + + getInterCopyConstResults(program, true) + } + + + + + test("unusedLocalVar") { + val program = prog( + proc("main", + block("main", LocalAssign(LocalVar("unused", BitVecType(64)), bv64(2)), directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("in" -> bv64(7))), goto("mainRet")), + block("mainRet", ret) + ), + proc("f", + Seq("in" -> BitVecType(64)), Seq("p_out" -> BitVecType(64)), + block("okay", goto("f_ret")), + block("f_ret", ret("p_out" -> bv64(5))), + ) + + ) + + + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + + println(program) + + getInterCopyConstResults(program, true) + } + + test("callProcTwiceDiffContext") { + val program = prog( + proc("main", + block("main", directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("p_in" -> bv64(3))), goto("main2")), + block("main2", directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("p_in" -> bv64(7))), goto("mainRet")), + block("mainRet", ret) + ), + proc("f", + Seq("p_in" -> BitVecType(64)), Seq("p_out" -> BitVecType(64)), + block("okay", goto("f_ret")), + block("f_ret", ret("p_out" -> LocalVar("p_in", BitVecType(64)))) + ) + + ) + + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + + println(program) + + getInterCopyConstResults(program, true) + } + + test("callProcWithinProc") { + val program = prog( + proc("main", + block("main", directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("in" -> bv64(7))), goto("mainRet")), + block("mainRet", ret) + ), + proc("f", + Seq("in" -> BitVecType(64)), Seq("p_out" -> BitVecType(64)), + block("okay", goto("f_ret")), // add another procedure here !!! + block("f_ret", ret("p_out" -> bv64(5))), + ) + + ) + + + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + + + println(program) + + + getInterCopyConstResults(program, true) + } + + test ("waht") { + val program = prog( + proc("main", + Seq( + "R0_in" -> BitVecType(64), + "R10_in" -> BitVecType(64), + "R11_in" -> BitVecType(64), + "R12_in" -> BitVecType(64), + "R13_in" -> BitVecType(64), + "R14_in" -> BitVecType(64), + "R15_in" -> BitVecType(64), + "R16_in" -> BitVecType(64), + "R17_in" -> BitVecType(64), + "R18_in" -> BitVecType(64), + "R1_in" -> BitVecType(64), + "R29_in" -> BitVecType(64), + "R2_in" -> BitVecType(64), + "R30_in" -> BitVecType(64), + "R31_in" -> BitVecType(64), + "R3_in" -> BitVecType(64), + "R4_in" -> BitVecType(64), + "R5_in" -> BitVecType(64), + "R6_in" -> BitVecType(64), + "R7_in" -> BitVecType(64), + "R8_in" -> BitVecType(64), + "R9_in" -> BitVecType(64), + "_PC_in" -> BitVecType(64) + ), + Seq( + "R0_out" -> BitVecType(64), + "_PC_out" -> BitVecType(64) + ), + block("main_entry", + LocalAssign(LocalVar("R0", BitVecType(64), 0), LocalVar("R0_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R10", BitVecType(64), 0), LocalVar("R10_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R11", BitVecType(64), 0), LocalVar("R11_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R12", BitVecType(64), 0), LocalVar("R12_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R13", BitVecType(64), 0), LocalVar("R13_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R14", BitVecType(64), 0), LocalVar("R14_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R15", BitVecType(64), 0), LocalVar("R15_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R16", BitVecType(64), 0), LocalVar("R16_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R17", BitVecType(64), 0), LocalVar("R17_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R18", BitVecType(64), 0), LocalVar("R18_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R1", BitVecType(64), 0), LocalVar("R1_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R29", BitVecType(64), 0), LocalVar("R29_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R2", BitVecType(64), 0), LocalVar("R2_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R30", BitVecType(64), 0), LocalVar("R30_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R31", BitVecType(64), 0), LocalVar("R31_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R3", BitVecType(64), 0), LocalVar("R3_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R4", BitVecType(64), 0), LocalVar("R4_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R5", BitVecType(64), 0), LocalVar("R5_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R6", BitVecType(64), 0), LocalVar("R6_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R7", BitVecType(64), 0), LocalVar("R7_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R8", BitVecType(64), 0), LocalVar("R8_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R9", BitVecType(64), 0), LocalVar("R9_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("_PC", BitVecType(64), 0), LocalVar("_PC_in", BitVecType(64), 0), None), + LocalAssign(LocalVar("R0", BitVecType(64), 0), BitVecLiteral(BigInt("2"), 64), Some("4195968_0")), + goto("main_basil_return_1") + ), + block("main_basil_return_1", + ret( + "R0_out" -> LocalVar("R0", BitVecType(64), 0), + "_PC_out" -> LocalVar("_PC", BitVecType(64), 0) + ) + ) + ) + ) + + cilvisitor.visit_prog(transforms.ReplaceReturns(), program) + transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + + println(program) + getInterCopyConstResults(program, true) + } + +} \ No newline at end of file From b666a80e34ed17abdf72f3a6d2ab3c9eaa101c46 Mon Sep 17 00:00:00 2001 From: pragu Date: Thu, 24 Jul 2025 12:42:14 +1000 Subject: [PATCH 2/6] Add transform of IR/IL program using results of constant-copy propagation - does not handle keeping function call for non-constant out params yet --- .../transforms/ConstCopyPropTransform.scala | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 src/main/scala/ir/transforms/ConstCopyPropTransform.scala diff --git a/src/main/scala/ir/transforms/ConstCopyPropTransform.scala b/src/main/scala/ir/transforms/ConstCopyPropTransform.scala new file mode 100644 index 000000000..de2764664 --- /dev/null +++ b/src/main/scala/ir/transforms/ConstCopyPropTransform.scala @@ -0,0 +1,55 @@ +package ir.transforms + +import ir.* +import ir.cilvisitor.* +import analysis.* + +class ConstCopyPropTransform(p: Program) extends CILVisitor{ + val results: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]] = InterCopyConst(p, true).analyze() + + override def vstmt(e: Statement): VisitAction[List[Statement]] = { + + + e match { + case l: LocalAssign => + val absState: FlatElement[BitVecLiteral] = results.get(e.successor).get(l.lhs) + + l.rhs match { + case LocalVar(_,_,_) | Register(_,_) if absState != Top | absState != Bottom => + ChangeTo(List(LocalAssign(l.lhs, get_bv(absState)))) + + case _ => SkipChildren() + } + + case d: DirectCall if d.outParams.nonEmpty => // change so chekc if there are any acc outparams + print(d) + + val vars: List[Variable] = d.outParams.values.toList + val changed: List[Statement] = vars.foldLeft(List[Statement]()) { + case (l, lhs) => + val absState: FlatElement[BitVecLiteral] = results.get(e.successor).get(lhs) + if absState != Top | absState != Bottom then l ++ List(LocalAssign(lhs, get_bv(absState))) + else l + // need t ohandle when some top some not!! <-- unchanged and changed list!! + } + + //val untransformedParams = d.outParams.values.toList.foldLeft(List[]) + + //val transformed = changed ++ List(d) + + ChangeTo(changed) // need better name than cchangfed LMAO + + case _ => SkipChildren() + } + } +} + + + +def get_bv(a: FlatElement[BitVecLiteral]): BitVecLiteral = + a match + case FlatEl(x) => x + case _ => BitVecLiteral(0,0) // shldnt get here, idk what default to iuse + + + From 43a37e0033855cf6c80cdac66148cbbda029735b Mon Sep 17 00:00:00 2001 From: pragu Date: Fri, 25 Jul 2025 15:54:37 +1000 Subject: [PATCH 3/6] Modify transform to remove redundant outparams for procedure calls and procedures themselves if constant --- .../transforms/ConstCopyPropTransform.scala | 74 ++++++++++++++----- 1 file changed, 56 insertions(+), 18 deletions(-) diff --git a/src/main/scala/ir/transforms/ConstCopyPropTransform.scala b/src/main/scala/ir/transforms/ConstCopyPropTransform.scala index de2764664..cf3282960 100644 --- a/src/main/scala/ir/transforms/ConstCopyPropTransform.scala +++ b/src/main/scala/ir/transforms/ConstCopyPropTransform.scala @@ -4,52 +4,90 @@ import ir.* import ir.cilvisitor.* import analysis.* +import scala.collection.immutable.SortedMap + + +/** + * Transforms program by modifying assignments to local variables and procedure calls to constants if possible, as + * determined by copy-constant analysis (using the IDE framework). Procedure calls are modified to remove redundant + * out parameters if they always return a constant value. + */ class ConstCopyPropTransform(p: Program) extends CILVisitor{ val results: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]] = InterCopyConst(p, true).analyze() + private var removedFormalOutParams: Set[LocalVar] = Set() override def vstmt(e: Statement): VisitAction[List[Statement]] = { e match { case l: LocalAssign => - val absState: FlatElement[BitVecLiteral] = results.get(e.successor).get(l.lhs) + val absState: FlatElement[BitVecLiteral] = results(e.successor)(l.lhs) l.rhs match { case LocalVar(_,_,_) | Register(_,_) if absState != Top | absState != Bottom => - ChangeTo(List(LocalAssign(l.lhs, get_bv(absState)))) + ChangeTo(List(LocalAssign(l.lhs, get_bv(absState).get))) //replace rhs with constant case _ => SkipChildren() } - case d: DirectCall if d.outParams.nonEmpty => // change so chekc if there are any acc outparams - print(d) - val vars: List[Variable] = d.outParams.values.toList + case d: DirectCall if d.outParams.nonEmpty => + + val vars: List[LocalVar] = d.outParams.keys.toList val changed: List[Statement] = vars.foldLeft(List[Statement]()) { - case (l, lhs) => - val absState: FlatElement[BitVecLiteral] = results.get(e.successor).get(lhs) - if absState != Top | absState != Bottom then l ++ List(LocalAssign(lhs, get_bv(absState))) - else l - // need t ohandle when some top some not!! <-- unchanged and changed list!! + case (l, formalOutParam) => + val actualOutParam = d.outParams.getOrElse(formalOutParam, LocalVar("placeholder", BitVecType(64))) + val absState: FlatElement[BitVecLiteral] = results(d.successor)(actualOutParam) + + + + if results(d.target.returnBlock.get.jump)(formalOutParam) != Top then //outParam from procedure always constant + + d.outParams = d.outParams.removed(formalOutParam) + d.target.formalOutParam.remove(formalOutParam) //remove from called procedure + removedFormalOutParams = removedFormalOutParams + formalOutParam + l ++ List(LocalAssign(actualOutParam, get_bv(absState).get)) // add assignment + + else if absState != Top & absState != Bottom then //outParam from procedure constant for this call + + d.outParams = d.outParams.removed(formalOutParam) //remove assignment of x = f(y) --> f(y) alone + l ++ List(LocalAssign(actualOutParam, get_bv(absState).get)) // add assignment without changing function + + else l + } - - //val untransformedParams = d.outParams.values.toList.foldLeft(List[]) - //val transformed = changed ++ List(d) + val transformed = changed ++ List(d) - ChangeTo(changed) // need better name than cchangfed LMAO + ChangeTo(transformed) case _ => SkipChildren() } } -} + override def vjump(j: Jump): VisitAction[Jump] = { + j match { + case r: Return => + r.outParams = r.outParams.foldLeft(SortedMap[LocalVar, Expr]()) { + case (m, (l, e)) => + if removedFormalOutParams.contains(l) then m else m ++ Map(l->e) + // remove return params which are no longer needed + } + case _ => + } + SkipChildren() + } +} -def get_bv(a: FlatElement[BitVecLiteral]): BitVecLiteral = +/** + * Extract actual BitVecLiteral from given FlatElement of lattice. Do not use unless it is known that the FlatElement + * contains a BitVecLiteral and not Top/Bottom + */ +def get_bv(a: FlatElement[BitVecLiteral]): Option[BitVecLiteral] = a match - case FlatEl(x) => x - case _ => BitVecLiteral(0,0) // shldnt get here, idk what default to iuse + case FlatEl(x) => Some(x) + case _ => None // SHOULD BE UNREACHABLE From 6472372f183e3d2d03708cf3492babf6aadf88c3 Mon Sep 17 00:00:00 2001 From: pragu Date: Fri, 25 Jul 2025 15:55:51 +1000 Subject: [PATCH 4/6] Map formal out params of procedures to abstract states --- src/main/scala/analysis/InterCopyConst.scala | 84 ++++++++++++++------ 1 file changed, 61 insertions(+), 23 deletions(-) diff --git a/src/main/scala/analysis/InterCopyConst.scala b/src/main/scala/analysis/InterCopyConst.scala index ced4dbb90..e4277a397 100644 --- a/src/main/scala/analysis/InterCopyConst.scala +++ b/src/main/scala/analysis/InterCopyConst.scala @@ -11,29 +11,26 @@ trait CopyConstAnalysisFunctions(parameterForm: Boolean) extends ForwardIDEAnaly def edgesCallToEntry(call: DirectCall, entry: Procedure)(d: DL): Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]] = { - //print("actual params: " + call.actualParams) - - // below only for param form no?, otherwise just give everything if !parameterForm then Map(d->IdEdge()) else d match { - case Left(a) => + case Left(a) => // already existing variables + call.actualParams.toList.foldLeft(Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]]()) { case (m, (inVar, expression)) => expression match - case LocalVar(_, _, _) | Register(_, _) if expression == a => m ++ Map(Left(inVar) -> IdEdge(), d -> IdEdge()) // idk if this actually checks properly + case LocalVar(_, _, _) | Register(_, _) if expression == a => m ++ Map(Left(inVar) -> IdEdge(), d -> IdEdge()) case LocalVar(_, _, _) | Register(_, _) if expression != a && inVar != a => m ++ Map(d -> IdEdge()) case _ => m ++ Map() } case Right(a) => val lambdaToLambda: Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]] = Map(d -> IdEdge()) - call.actualParams.toList.foldLeft(Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]]()) { + call.actualParams.toList.foldLeft(Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]]()) { case (m, (inVar, expression)) => expression match { case LocalVar(_,_,_) | Register(_,_) => m ++ lambdaToLambda //not add anything else from lambda - case BitVecLiteral(value, size) => m ++ lambdaToLambda ++ Map(Left(inVar)->ConstEdge(valuelattice.bv(BitVecLiteral(value, size)))) + case BitVecLiteral(value, size) => m ++ lambdaToLambda ++ Map(Left(inVar)->ConstEdge(valuelattice.bv(BitVecLiteral(value, size)))) //assign val to in param case _ => m ++ lambdaToLambda ++ Map(Left(inVar) -> ConstEdge(valuelattice.top)) - // direct call? } @@ -47,7 +44,6 @@ trait CopyConstAnalysisFunctions(parameterForm: Boolean) extends ForwardIDEAnaly def edgesExitToAfterCall(exit: Return, aftercall: Command)(d: DL): Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]] = { - //print(d) if !parameterForm then Map(d->IdEdge()) else @@ -61,7 +57,6 @@ trait CopyConstAnalysisFunctions(parameterForm: Boolean) extends ForwardIDEAnaly exit.outParams.toList.foldLeft(Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]]()) { case (m, (retVar, expression)) => expression match case LocalVar(_, _, _) | Register(_, _) if expression == a => m ++ Map(Left(call.outParams(retVar)) -> IdEdge()) - //case LocalVar(_, _, _) | Register(_, _) if expression != a => m ++ Map(d -> IdEdge()) // lol need to fix up here in params shld always just be Map() so add case case _ => m ++ Map() //ignore other kind of expr, including local vars / in params of the procedure } @@ -74,6 +69,7 @@ trait CopyConstAnalysisFunctions(parameterForm: Boolean) extends ForwardIDEAnaly case BitVecLiteral(value, size) => m ++ lambdaToLambda ++ Map(Left(call.outParams(retVar)) -> ConstEdge(valuelattice.bv(BitVecLiteral(value, size)))) case _ => m ++ lambdaToLambda ++ Map(Left(call.outParams(retVar)) -> ConstEdge(valuelattice.top)) + } } @@ -85,13 +81,12 @@ trait CopyConstAnalysisFunctions(parameterForm: Boolean) extends ForwardIDEAnaly def edgesCallToAfterCall(call: DirectCall, aftercall: Command)(d: DL): Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]] = { if !parameterForm then Map() - else // unused locals in function and lambda need identity + else d match { - case Left(v) if (call.outParams.exists(_._2 == v) || call.actualParams.exists(_._2 == v)) => Map() - case _ => Map(d->IdEdge()) + case Left(v) if call.outParams.exists(_._1 == v) || call.outParams.exists(_._2 == v) || call.actualParams.exists(_._2 == v) => Map() + case _ => Map(d->IdEdge()) // unused locals in function ignore proc call } - // currently every global going into each procedure regardless if going to be modified or not - way to check?? } def edgesOther(n: CFGPosition)(d: DL): Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]] = { @@ -99,36 +94,79 @@ trait CopyConstAnalysisFunctions(parameterForm: Boolean) extends ForwardIDEAnaly n match { case LocalAssign(variable, expression, _) => - // shld make this function icl --> figure out if can just put all under one d match { case Right(_) => val lambdaToLambda : Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]] = Map(d -> IdEdge()) expression match { case LocalVar(_,_,_) | Register(_,_) => lambdaToLambda //not add anything else from lambda case BitVecLiteral(value, size) => lambdaToLambda ++ Map(Left(variable)->ConstEdge(valuelattice.bv(BitVecLiteral(value, size)))) + case BinaryExpr(op, arg1, arg2) => (arg1, arg2) match { + case (BitVecLiteral(value1,size1), BitVecLiteral(value2, size2)) => op match { + //started to add evaluations of simple expressions before assignment e.g. x = 1+2 will be evaluated as x->3 instead of x->Top + case BVADD => lambdaToLambda ++ Map(Left(variable)->ConstEdge(valuelattice.bvadd(valuelattice.bv(BitVecLiteral(value1, size1)), valuelattice.bv(BitVecLiteral(value2, size2))))) + case BVSUB => lambdaToLambda ++ Map(Left(variable)->ConstEdge(valuelattice.bvsub(valuelattice.bv(BitVecLiteral(value1, size1)), valuelattice.bv(BitVecLiteral(value2, size2))))) + case BVMUL => lambdaToLambda ++ Map(Left(variable)->ConstEdge(valuelattice.bvmul(valuelattice.bv(BitVecLiteral(value1, size1)), valuelattice.bv(BitVecLiteral(value2, size2))))) + case _ => lambdaToLambda ++ Map(Left(variable) -> ConstEdge(valuelattice.top)) + + } + case _ => lambdaToLambda ++ Map(Left(variable) -> ConstEdge(valuelattice.top)) + } case _ => lambdaToLambda ++ Map(Left(variable) -> ConstEdge(valuelattice.top)) - } + } + case Left(a) => expression match { - case LocalVar(_, _, _) | Register(_,_) if expression == a => Map(Left(variable)->IdEdge(), d ->IdEdge()) // idk if this actually checks properly - //case LocalVar(_, _, _) | Register(_,_) if expression != a => Map(d->IdEdge()) - case BitVecLiteral(_, _) => Map() //remove old value - case _ => Map(d->IdEdge()) - //case _ => Map() //ignore other kind of expr, should this be top or nothing <-- or d is identity, var is top?? + case LocalVar(_, _, _) | Register(_,_) if expression == a => Map(Left(variable)->IdEdge(), d ->IdEdge()) + case _ => Map() //ignore other kind of expr } } - case MemoryLoad(variable, _, _, _, _, _)=> //might have ti fix this up + + case MemoryLoad(variable, _, _, _, _, _)=> d match { case Left(_) => Map(d -> IdEdge()) case Right(_) => Map(Left(variable) -> ConstEdge(valuelattice.top), d -> IdEdge()) } + case exit: Return => + // needed to map abstract states of formal in and out parameters and whether they constant in all calls + + Map(d->IdEdge()) + d match { + case Left(a) => + exit.outParams.toList.foldLeft(Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]]()) { + case (m, (retVar, expression)) => expression match + case LocalVar(_, _, _) | Register(_, _) if expression == a => m ++ Map(Left(retVar) -> IdEdge()) + case _ => m ++ Map() + + } + case Right(a) => + val lambdaToLambda: Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]] = Map(d -> IdEdge()) + + exit.outParams.toList.foldLeft(Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]]()) { + case (m, (retVar, expression)) => expression match { + case LocalVar(_, _, _) | Register(_, _) => m ++ lambdaToLambda //not add anything else from lambda + case BitVecLiteral(value, size) => m ++ lambdaToLambda ++ Map(Left(retVar) -> ConstEdge(valuelattice.bv(BitVecLiteral(value, size)))) + case _ => m ++ lambdaToLambda ++ Map(Left(retVar) -> ConstEdge(valuelattice.top)) + + } + + } + } + case _ => Map(d->IdEdge()) } + } } -} + +/** + * Performs copy-constant propagation analysis on a program. Determines the variables with a constant value, thus + * providing information for relevant transforms to replace function calls and assignments to variables as assignments + * to constants. Note that only information for copy assignments is determined, to allow for distributivity and use of + * IDE solver. 'parameterForm' may be unnecessary as this analysis and corresponding transform may only occur after + * parameter simplifications have been done. + */ class InterCopyConst(program:Program, parameterForm: Boolean) extends ForwardIDESolver[Variable, FlatElement[BitVecLiteral], ConstantPropagationLattice](program), CopyConstAnalysisFunctions(parameterForm) From 329438d79740119e5837f98941cc01d23d16cfc9 Mon Sep 17 00:00:00 2001 From: pragu Date: Fri, 25 Jul 2025 17:35:25 +1000 Subject: [PATCH 5/6] Add assertations for unit tests --- src/test/scala/InterCopyConstTests.scala | 396 ++++------------------- 1 file changed, 60 insertions(+), 336 deletions(-) diff --git a/src/test/scala/InterCopyConstTests.scala b/src/test/scala/InterCopyConstTests.scala index f2bba7648..712b49066 100644 --- a/src/test/scala/InterCopyConstTests.scala +++ b/src/test/scala/InterCopyConstTests.scala @@ -4,17 +4,31 @@ import org.scalatest.funsuite.AnyFunSuite import test_util.CaptureOutput import cilvisitor.* import ir.dsl.* +import ir.transforms.ConstCopyPropTransform + +/** + * Unit tests for copy-constant propagation. Note: tests of the relevant transform have not been done (idk how to test + * it, but it seems to work) + * */ @test_util.tags.UnitTest class InterCopyConstTests extends AnyFunSuite, CaptureOutput { + /** + * Extract actual BitVecLiteral from given FlatElement of lattice. Do not use unless it is known that the FlatElement + * contains a BitVecLiteral and not Top/Bottom + */ + def get_bv(a: FlatElement[BitVecLiteral]): Option[BitVecLiteral] = + a match + case FlatEl(x) => Some(x) + case _ => None // SHOULD BE UNREACHABLE + - def getInterCopyConstResults(program: Program, paramaterForm: Boolean): Unit = { - print(InterCopyConst(program, paramaterForm).analyze()) + def getInterCopyConstResults(program: Program, paramaterForm: Boolean): Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]] = { + InterCopyConst(program, paramaterForm).analyze() } test("intraCopyAssign") { - // all good val program = prog( proc("main", block("main", LocalAssign(R0, bv64(3)), LocalAssign(R1, R0), goto("mainRet")), block("mainRet", ret)), ) @@ -22,59 +36,17 @@ class InterCopyConstTests extends AnyFunSuite, CaptureOutput { transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) - getInterCopyConstResults(program, false) - } + val f = program.nameToProcedure("main") + val results = getInterCopyConstResults(program, false) - test("intraOldTransform") { - val program = prog( - proc("main", - Seq(), - Seq("R0_out" -> BitVecType(64), "R1_out" -> BitVecType(64)), - block("main", - LocalAssign(LocalVar("R0", BitVecType(64), 0), bv64(3)), - LocalAssign(LocalVar("R1", BitVecType(64), 0), LocalVar("R0", BitVecType(64), 0)), - goto("mainRet") - ), - block("mainRet", - ret( - "R0_out" -> LocalVar("R0", BitVecType(64), 0), - "R1_out" -> LocalVar("R1", BitVecType(64), 0) - ) - ) - ) - ) - transforms.copyPropParamFixedPoint(program, Map()) - print(program) - } - - test("intraNewTransform") { - val program = prog( - proc("main", - Seq(), - Seq("R0_out" -> BitVecType(64), "R1_out" -> BitVecType(64)), - block("main", - LocalAssign(LocalVar("R0", BitVecType(64), 0), bv64(3)), - LocalAssign(LocalVar("R1", BitVecType(64), 0), LocalVar("R0", BitVecType(64), 0)), - goto("mainRet") - ), - block("mainRet", - ret( - "R0_out" -> LocalVar("R0", BitVecType(64), 0), - "R1_out" -> LocalVar("R1", BitVecType(64), 0) - ) - ) - ) - ) - cilvisitor.visit_prog(transforms.ReplaceReturns(), program) - transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it - cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + assert(get_bv(results(IRWalk.lastInProc(f).get)(R0)).get == bv64(3)) + assert(get_bv(results(IRWalk.lastInProc(f).get)(R1)).get == bv64(3)) - print(program) - visit_prog(transforms.ConstCopyPropTransform(program), program) - print(program) } + + test("intraAssignOverride") { // all good val program = prog( @@ -84,40 +56,28 @@ class InterCopyConstTests extends AnyFunSuite, CaptureOutput { transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) - println(program) + val f = program.nameToProcedure("main") + val results = getInterCopyConstResults(program, false) + assert(get_bv(results(IRWalk.lastInProc(f).get)(R0)).get == bv64(5)) + - getInterCopyConstResults(program, false) } test("intraNonCopyAssignOverride") { - // all good val program = prog( proc("main", block("main", LocalAssign(R0, bv64(3)), LocalAssign(R0, BinaryExpr(BVADD, R0, bv64(1))), LocalAssign(R1, R0), goto("mainRet")), block("mainRet", ret)), ) cilvisitor.visit_prog(transforms.ReplaceReturns(), program) transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) - print(program) - getInterCopyConstResults(program, false) + val f = program.nameToProcedure("main") + val results = getInterCopyConstResults(program, false) + assert((results(IRWalk.lastInProc(f).get)(R1)) == Top) - visit_prog(transforms.ConstCopyPropTransform(program), program) - print(program) } - test("intraNonCopyAssign") { - val program = prog( - proc("main", block("main", LocalAssign(R1, bv64(3)), LocalAssign(R0, BinaryExpr(BVADD, R0, bv64(1))), goto("mainRet")), block("mainRet", ret)), - ) - println(program) - - cilvisitor.visit_prog(transforms.ReplaceReturns(), program) - transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it - cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) - - getInterCopyConstResults(program, false) - } test("interAssign") { @@ -139,85 +99,27 @@ class InterCopyConstTests extends AnyFunSuite, CaptureOutput { transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) - println(program) - - getInterCopyConstResults(program, true) - - //visit_prog(transforms.ConstCopyPropTransform(program), program) - - transforms.copyPropParamFixedPoint(program, Map()) - print(program) - } - - test("interMultipleOutParams") { - - val program = prog( - proc("main", - block("main", directCall(Seq("p_out_1" -> LocalVar("v", BitVecType(64)), "p_out_2" -> LocalVar("a", BitVecType(64))), "f", Seq("in" -> bv64(7))), goto("mainRet")), - block("mainRet", ret) - ), - proc("f", - Seq("in" -> BitVecType(64)), Seq("p_out_1" -> BitVecType(64), "p_out_2" -> BitVecType(64)), - block("okay", goto("f_ret")), - block("f_ret", ret("p_out_1" -> bv64(5), "p_out_2" -> bv64(1))), - ) - - ) - - cilvisitor.visit_prog(transforms.ReplaceReturns(), program) - transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it - cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) - - println(program) - - getInterCopyConstResults(program, true) + val f = program.nameToProcedure("main") + val results = getInterCopyConstResults(program, true) - transforms.copyPropParamFixedPoint(program, Map()) - - //visit_prog(transforms.ConstCopyPropTransform(program), program) - print(program) - } + assert(get_bv(results(IRWalk.lastInProc(f).get)(LocalVar("v", BitVecType(64)))).get == bv64(5)) - test("interCallProcNoOutParam") { - val program = prog( - proc("main", - block("main", directCall("f"), goto("mainRet")), - block("mainRet", ret) - ), - proc("f", - block("okay", LocalAssign(LocalVar("R0", bv64), bv64(3)) ,goto("f_ret")), - block("f_ret", ret()), - ) - - ) - - cilvisitor.visit_prog(transforms.ReplaceReturns(), program) - transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it - cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) - - println(program) - - getInterCopyConstResults(program, true) - - visit_prog(transforms.ConstCopyPropTransform(program), program) - print(program) } - - test("interReturnNonConstant") { + test("interAssignVariableParameter") { val program = prog( proc("main", - block("main", directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("in" -> bv64(7))), goto("mainRet")), + block("main", LocalAssign(R0, bv64(3)), directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("in" -> R0)), goto("mainRet")), block("mainRet", ret) ), proc("f", Seq("in" -> BitVecType(64)), Seq("p_out" -> BitVecType(64)), block("okay", goto("f_ret")), - block("f_ret", ret("p_out" -> BinaryExpr(BVADD, bv64(2), bv64(1)))), + block("f_ret", ret("p_out" -> LocalVar("in", BitVecType(64)))), ) ) @@ -226,159 +128,77 @@ class InterCopyConstTests extends AnyFunSuite, CaptureOutput { transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) - println(program) - - getInterCopyConstResults(program, true) + val f = program.nameToProcedure("main") + val results = getInterCopyConstResults(program, true) + assert(get_bv(results(IRWalk.lastInProc(f).get)(LocalVar("v", BitVecType(64)))).get == bv64(3)) - //cilvisitor.visit_prog(transforms.ConstCopyPropTransform(program), program) - transforms.copyPropParamFixedPoint(program, Map()) - - print(program) } - test("interReturnOneNonConstantOneConstant") { - - + test("interMultipleOutParams") { val program = prog( proc("main", - block("main", directCall(Seq("p_out_1" -> LocalVar("a", BitVecType(64)), "p_out_2" -> LocalVar("b", BitVecType(64))), "f", Seq("in" -> bv64(7))), goto("mainRet")), + block("main", directCall(Seq("p_out_1" -> LocalVar("v", BitVecType(64)), "p_out_2" -> LocalVar("a", BitVecType(64))), "f", Seq("in" -> bv64(7))), goto("mainRet")), block("mainRet", ret) ), proc("f", Seq("in" -> BitVecType(64)), Seq("p_out_1" -> BitVecType(64), "p_out_2" -> BitVecType(64)), block("okay", goto("f_ret")), - block("f_ret", ret("p_out_1" -> BinaryExpr(BVADD, LocalVar("in", BitVecType(64)), bv64(1)) , "p_out_2" -> bv64(1))), - ) - - ) - - cilvisitor.visit_prog(transforms.ReplaceReturns(), program) - transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it - cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) - - println(program) - - - getInterCopyConstResults(program, true) - - //cilvisitor.visit_prog(transforms.ConstCopyPropTransform(program), program) - transforms.copyPropParamFixedPoint(program, Map()) - print(program) - } - - test("interAssignOverride") { - val program = prog( - proc("main", - block("main", LocalAssign(LocalVar("v", BitVecType(64)), bv64(2)), directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("in" -> bv64(7))), goto("mainRet")), - block("mainRet", ret) - ), - proc("f", - Seq("in" -> BitVecType(64)), Seq("p_out" -> BitVecType(64)), - block("okay", goto("f_ret")), - block("f_ret", ret("p_out" -> bv64(5))), + block("f_ret", ret("p_out_1" -> bv64(5), "p_out_2" -> bv64(1))), ) ) - cilvisitor.visit_prog(transforms.ReplaceReturns(), program) transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) + val f = program.nameToProcedure("main") + val results = getInterCopyConstResults(program, true) + assert(get_bv(results(IRWalk.lastInProc(f).get)(LocalVar("v", BitVecType(64)))).get == bv64(5)) + assert(get_bv(results(IRWalk.lastInProc(f).get)(LocalVar("a", BitVecType(64)))).get == bv64(1)) - println(program) - - - getInterCopyConstResults(program, true) } - test("interReturnLocalVar") { - val program = prog( - proc("main", - block("main", directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("p_in" -> bv64(7))), goto("mainRet")), - block("mainRet", ret) - ), - proc("f", - Seq("p_in" -> BitVecType(64)), Seq("p_out" -> BitVecType(64)), - block("okay", LocalAssign(LocalVar("bruh", BitVecType(64)), bv64(6)), goto("f_ret")), - block("f_ret", ret("p_out" -> LocalVar("bruh", BitVecType(64)))) - ) - ) - - - cilvisitor.visit_prog(transforms.ReplaceReturns(), program) - transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it - cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) - println(program) - + test("interReturnOneNonConstantOneConstant") { - getInterCopyConstResults(program, true) - } - test("interReturnInParam") { val program = prog( proc("main", - block("main", directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("p_in" -> bv64(7))), goto("mainRet")), + block("main", directCall(Seq("p_out_1" -> LocalVar("a", BitVecType(64)), "p_out_2" -> LocalVar("b", BitVecType(64))), "f", Seq("in" -> bv64(7))), goto("mainRet")), block("mainRet", ret) ), proc("f", - Seq("p_in" -> BitVecType(64)), Seq("p_out" -> BitVecType(64)), + Seq("in" -> BitVecType(64)), Seq("p_out_1" -> BitVecType(64), "p_out_2" -> BitVecType(64)), block("okay", goto("f_ret")), - block("f_ret", ret("p_out" -> LocalVar("p_in", BitVecType(64)))) + block("f_ret", ret("p_out_1" -> bv64(2), "p_out_2" -> BinaryExpr(BVAND, LocalVar("in", BitVecType(64)), bv64(1)))) ) ) - cilvisitor.visit_prog(transforms.ReplaceReturns(), program) transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) - - println(program) - - - getInterCopyConstResults(program, true) + val f = program.nameToProcedure("main") + val results = getInterCopyConstResults(program, true) + assert(get_bv(results(IRWalk.lastInProc(f).get)(LocalVar("a", BitVecType(64)))).get == bv64(2)) + assert(results(IRWalk.lastInProc(f).get)(LocalVar("b", BitVecType(64))) == Top) } - test("unusedLocalVar") { - val program = prog( - proc("main", - block("main", LocalAssign(LocalVar("unused", BitVecType(64)), bv64(2)), directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("in" -> bv64(7))), goto("mainRet")), - block("mainRet", ret) - ), - proc("f", - Seq("in" -> BitVecType(64)), Seq("p_out" -> BitVecType(64)), - block("okay", goto("f_ret")), - block("f_ret", ret("p_out" -> bv64(5))), - ) - - ) - - - cilvisitor.visit_prog(transforms.ReplaceReturns(), program) - transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it - cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) - - println(program) - - getInterCopyConstResults(program, true) - } - test("callProcTwiceDiffContext") { val program = prog( proc("main", block("main", directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("p_in" -> bv64(3))), goto("main2")), - block("main2", directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("p_in" -> bv64(7))), goto("mainRet")), + block("main2", directCall(Seq("p_out" -> LocalVar("a", BitVecType(64))), "f", Seq("p_in" -> BinaryExpr(BVAND, bv64(1), bv64(7)))), goto("mainRet")), block("mainRet", ret) ), proc("f", @@ -393,111 +213,15 @@ class InterCopyConstTests extends AnyFunSuite, CaptureOutput { transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) - println(program) - getInterCopyConstResults(program, true) - } + val f = program.nameToProcedure("main") + val results = getInterCopyConstResults(program, true) + assert(get_bv(results(IRWalk.lastInProc(f).get)(LocalVar("v", BitVecType(64)))).get == bv64(3)) + assert(results(IRWalk.lastInProc(f).get)(LocalVar("a", BitVecType(64))) == Top) - test("callProcWithinProc") { - val program = prog( - proc("main", - block("main", directCall(Seq("p_out" -> LocalVar("v", BitVecType(64))), "f", Seq("in" -> bv64(7))), goto("mainRet")), - block("mainRet", ret) - ), - proc("f", - Seq("in" -> BitVecType(64)), Seq("p_out" -> BitVecType(64)), - block("okay", goto("f_ret")), // add another procedure here !!! - block("f_ret", ret("p_out" -> bv64(5))), - ) - - ) - - - cilvisitor.visit_prog(transforms.ReplaceReturns(), program) - transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it - cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) - - - println(program) - - getInterCopyConstResults(program, true) } - test ("waht") { - val program = prog( - proc("main", - Seq( - "R0_in" -> BitVecType(64), - "R10_in" -> BitVecType(64), - "R11_in" -> BitVecType(64), - "R12_in" -> BitVecType(64), - "R13_in" -> BitVecType(64), - "R14_in" -> BitVecType(64), - "R15_in" -> BitVecType(64), - "R16_in" -> BitVecType(64), - "R17_in" -> BitVecType(64), - "R18_in" -> BitVecType(64), - "R1_in" -> BitVecType(64), - "R29_in" -> BitVecType(64), - "R2_in" -> BitVecType(64), - "R30_in" -> BitVecType(64), - "R31_in" -> BitVecType(64), - "R3_in" -> BitVecType(64), - "R4_in" -> BitVecType(64), - "R5_in" -> BitVecType(64), - "R6_in" -> BitVecType(64), - "R7_in" -> BitVecType(64), - "R8_in" -> BitVecType(64), - "R9_in" -> BitVecType(64), - "_PC_in" -> BitVecType(64) - ), - Seq( - "R0_out" -> BitVecType(64), - "_PC_out" -> BitVecType(64) - ), - block("main_entry", - LocalAssign(LocalVar("R0", BitVecType(64), 0), LocalVar("R0_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R10", BitVecType(64), 0), LocalVar("R10_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R11", BitVecType(64), 0), LocalVar("R11_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R12", BitVecType(64), 0), LocalVar("R12_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R13", BitVecType(64), 0), LocalVar("R13_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R14", BitVecType(64), 0), LocalVar("R14_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R15", BitVecType(64), 0), LocalVar("R15_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R16", BitVecType(64), 0), LocalVar("R16_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R17", BitVecType(64), 0), LocalVar("R17_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R18", BitVecType(64), 0), LocalVar("R18_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R1", BitVecType(64), 0), LocalVar("R1_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R29", BitVecType(64), 0), LocalVar("R29_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R2", BitVecType(64), 0), LocalVar("R2_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R30", BitVecType(64), 0), LocalVar("R30_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R31", BitVecType(64), 0), LocalVar("R31_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R3", BitVecType(64), 0), LocalVar("R3_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R4", BitVecType(64), 0), LocalVar("R4_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R5", BitVecType(64), 0), LocalVar("R5_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R6", BitVecType(64), 0), LocalVar("R6_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R7", BitVecType(64), 0), LocalVar("R7_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R8", BitVecType(64), 0), LocalVar("R8_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R9", BitVecType(64), 0), LocalVar("R9_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("_PC", BitVecType(64), 0), LocalVar("_PC_in", BitVecType(64), 0), None), - LocalAssign(LocalVar("R0", BitVecType(64), 0), BitVecLiteral(BigInt("2"), 64), Some("4195968_0")), - goto("main_basil_return_1") - ), - block("main_basil_return_1", - ret( - "R0_out" -> LocalVar("R0", BitVecType(64), 0), - "_PC_out" -> LocalVar("_PC", BitVecType(64), 0) - ) - ) - ) - ) - - cilvisitor.visit_prog(transforms.ReplaceReturns(), program) - transforms.addReturnBlocks(program, true) // add return to all blocks because IDE solver expects it - cilvisitor.visit_prog(transforms.ConvertSingleReturn(), program) - - println(program) - getInterCopyConstResults(program, true) - } + } \ No newline at end of file From c55563ca9612aa64617935d977d27df2d014fd1e Mon Sep 17 00:00:00 2001 From: pragu Date: Fri, 25 Jul 2025 17:36:15 +1000 Subject: [PATCH 6/6] Finalise formatting and comments --- src/main/scala/analysis/InterCopyConst.scala | 2 +- .../transforms/ConstCopyPropTransform.scala | 21 ++++++++++++------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/main/scala/analysis/InterCopyConst.scala b/src/main/scala/analysis/InterCopyConst.scala index e4277a397..9275fc5c1 100644 --- a/src/main/scala/analysis/InterCopyConst.scala +++ b/src/main/scala/analysis/InterCopyConst.scala @@ -57,7 +57,7 @@ trait CopyConstAnalysisFunctions(parameterForm: Boolean) extends ForwardIDEAnaly exit.outParams.toList.foldLeft(Map[DL, EdgeFunction[FlatElement[BitVecLiteral]]]()) { case (m, (retVar, expression)) => expression match case LocalVar(_, _, _) | Register(_, _) if expression == a => m ++ Map(Left(call.outParams(retVar)) -> IdEdge()) - case _ => m ++ Map() //ignore other kind of expr, including local vars / in params of the procedure + case _ => m ++ Map() //ignore other kind of expr, including local vars / in params of the procedure } case Right(a) => diff --git a/src/main/scala/ir/transforms/ConstCopyPropTransform.scala b/src/main/scala/ir/transforms/ConstCopyPropTransform.scala index cf3282960..56ef8ca44 100644 --- a/src/main/scala/ir/transforms/ConstCopyPropTransform.scala +++ b/src/main/scala/ir/transforms/ConstCopyPropTransform.scala @@ -10,7 +10,8 @@ import scala.collection.immutable.SortedMap /** * Transforms program by modifying assignments to local variables and procedure calls to constants if possible, as * determined by copy-constant analysis (using the IDE framework). Procedure calls are modified to remove redundant - * out parameters if they always return a constant value. + * out parameters if they always return a constant value. This could be extended to integrate removal of empty blocks + * and dead input variables. */ class ConstCopyPropTransform(p: Program) extends CILVisitor{ val results: Map[CFGPosition, Map[Variable, FlatElement[BitVecLiteral]]] = InterCopyConst(p, true).analyze() @@ -21,14 +22,17 @@ class ConstCopyPropTransform(p: Program) extends CILVisitor{ e match { case l: LocalAssign => - val absState: FlatElement[BitVecLiteral] = results(e.successor)(l.lhs) - - l.rhs match { - case LocalVar(_,_,_) | Register(_,_) if absState != Top | absState != Bottom => - ChangeTo(List(LocalAssign(l.lhs, get_bv(absState).get))) //replace rhs with constant - - case _ => SkipChildren() + if results.contains(e.successor) then { + val absState: FlatElement[BitVecLiteral] = results(e.successor)(l.lhs) + + l.rhs match { + case LocalVar(_, _, _) | Register(_, _) if absState != Top & absState != Bottom => + ChangeTo(List(LocalAssign(l.lhs, get_bv(absState).get))) //replace rhs with constant + + case _ => SkipChildren() + } } + else SkipChildren() case d: DirectCall if d.outParams.nonEmpty => @@ -80,6 +84,7 @@ class ConstCopyPropTransform(p: Program) extends CILVisitor{ } } + /** * Extract actual BitVecLiteral from given FlatElement of lattice. Do not use unless it is known that the FlatElement * contains a BitVecLiteral and not Top/Bottom