diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/ThreadCreation.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/ThreadCreation.java index add6a1db5e..bbe7fbe787 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/ThreadCreation.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/ThreadCreation.java @@ -4,8 +4,8 @@ import com.dat3m.dartagnan.exception.MalformedProgramException; import com.dat3m.dartagnan.expression.Expression; import com.dat3m.dartagnan.expression.ExpressionFactory; -import com.dat3m.dartagnan.expression.IExprUn; -import com.dat3m.dartagnan.expression.op.IOpUn; +import com.dat3m.dartagnan.expression.IConst; +import com.dat3m.dartagnan.expression.IValue; import com.dat3m.dartagnan.expression.processing.ExprTransformer; import com.dat3m.dartagnan.expression.processing.ExpressionVisitor; import com.dat3m.dartagnan.expression.type.IntegerType; @@ -19,7 +19,6 @@ import com.dat3m.dartagnan.program.event.Tag; import com.dat3m.dartagnan.program.event.core.Event; import com.dat3m.dartagnan.program.event.core.Label; -import com.dat3m.dartagnan.program.event.core.Load; import com.dat3m.dartagnan.program.event.core.Local; import com.dat3m.dartagnan.program.event.core.threading.ThreadCreate; import com.dat3m.dartagnan.program.event.core.threading.ThreadStart; @@ -44,7 +43,6 @@ import java.math.BigInteger; import java.util.*; -import java.util.stream.Stream; import static com.dat3m.dartagnan.configuration.OptionNames.THREAD_CREATE_ALWAYS_SUCCEEDS; import static com.dat3m.dartagnan.program.event.EventFactory.*; @@ -99,10 +97,10 @@ public void run(Program program) { throw new MalformedProgramException("Program contains no main function"); } - final int maxId = Stream.concat(program.getThreads().stream(), program.getFunctions().stream()) - .mapToInt(Function::getId) - .max().orElse(0); - int nextTid = maxId + 1; + // NOTE: We start from id = 0 which overlaps with existing function ids. + // However, we reassign ids after thread creation so that functions get higher ids. + // TODO: Do we even need ids for functions? + int nextTid = 0; final Queue workingQueue = new ArrayDeque<>(); workingQueue.add(createThreadFromFunction(main.get(), nextTid++, null, null)); @@ -111,7 +109,9 @@ public void run(Program program) { final Thread thread = workingQueue.remove(); program.addThread(thread); - final Map tid2ComAddrMap = new HashMap<>(); + // We collect the communication addresses we use for each thread id. + // These are used later to lower pthread_join. + final Map tid2ComAddrMap = new LinkedHashMap<>(); for (DirectFunctionCall call : thread.getEvents(DirectFunctionCall.class)) { final List arguments = call.getArguments(); switch (call.getCallTarget().getName()) { @@ -126,7 +126,7 @@ public void run(Program program) { assert resultRegister.getType() instanceof IntegerType; final ThreadCreate createEvent = newThreadCreate(List.of(argument)); - final Expression tidExpr = expressions.makeValue(BigInteger.valueOf(nextTid), archType); + final IValue tidExpr = expressions.makeValue(BigInteger.valueOf(nextTid), archType); final MemoryObject comAddress = program.getMemory().allocate(1, true); comAddress.setCVar("__com" + nextTid + "__" + targetFunction.getName()); @@ -143,36 +143,10 @@ public void run(Program program) { final Thread spawnedThread = createThreadFromFunction(targetFunction, nextTid, createEvent, comAddress); createEvent.setSpawnedThread(spawnedThread); workingQueue.add(spawnedThread); - tid2ComAddrMap.put(tidExpr, comAddress); - propagateThreadIds(tidExpr, pidResultAddress, createEvent); nextTid++; } - case "pthread_join", "__pthread_join" -> { - assert arguments.size() == 2; - final Expression tidExpr = arguments.get(0); - // TODO: support return values for threads - // final Expression returnAddr = arguments.get(1); - - final Register resultRegister = getResultRegister(call); - assert resultRegister.getType() instanceof IntegerType; - final Expression comAddrOfThreadToJoinWith = tid2ComAddrMap.get(tidExpr); - if (comAddrOfThreadToJoinWith == null) { - throw new UnsupportedOperationException( - "Cannot handle pthread_join with dynamic thread parameter."); - } - final int tid = tidExpr.reduce().getValueAsInt(); - final Register joinDummyReg = thread.getOrNewRegister("__joinT" + tid, types.getBooleanType()); - final List replacement = eventSequence( - newAcquireLoad(joinDummyReg, comAddrOfThreadToJoinWith), - newJump(joinDummyReg, (Label)thread.getExit()), - // Note: In our modelling, pthread_join always succeeds if it returns - newLocal(resultRegister, expressions.makeZero((IntegerType) resultRegister.getType())) - ); - replacement.forEach(e -> e.copyAllMetadataFrom(call)); - call.replaceBy(replacement); - } case "get_my_tid" -> { final Register resultRegister = getResultRegister(call); assert resultRegister.getType() instanceof IntegerType; @@ -185,58 +159,93 @@ public void run(Program program) { } } } + + // FIXME: This only allows joining with child threads that were created by this thread. + handlePthreadJoins(thread, tid2ComAddrMap); } IdReassignment.newInstance().run(program); logger.info("Number of threads (including main): " + program.getThreads().size()); } + /* + This method replaces in all pthread_join calls by a switch over all possible tids. + Each candidate thread gets a switch-case which tries to synchronize with that thread. + */ + private void handlePthreadJoins(Thread thread, Map tid2ComAddrMap) { + final TypeFactory types = TypeFactory.getInstance(); + final ExpressionFactory expressions = ExpressionFactory.getInstance(); + int joinCounter = 0; - // Helper code to do constant propagation of generated tid's to pthread_join calls - //TODO: Ideally, this kind of propagation shouldn't be done here. - // Also, it is currently unsound if the code is not in SSA, for example, after unrolling. - private static void propagateThreadIds(Expression tidExpr, Expression tidResultAddress, ThreadCreate createEvent) { - Set tidValues = new HashSet<>(); - Set tidPtrs = new HashSet<>(); - tidPtrs.add(tidResultAddress); - - // Backpropagation of pointers: - // "p1 <- pExpr; pthread_create(pExpr, ...)" => p1 also points to an address holding a tid - for (Event pred : createEvent.getPredecessors()) { - if (pred instanceof Local local && tidPtrs.contains(local.getResultRegister())) { - tidPtrs.add(local.getExpr()); + for (DirectFunctionCall call : thread.getEvents(DirectFunctionCall.class)) { + final String targetName = call.getCallTarget().getName(); + if (!(targetName.equals("pthread_join") || targetName.equals("__pthread_join"))) { + continue; } - } - // Forward propagation - // (1) p1 <- pExpr => p1 points to tid if pExpr does - // (2) r1 <- expr => r1 holds tid if expr does - // (3) r <- load(pExpr) => r holds tid if pExpr points to tid - for (Event succ : createEvent.getSuccessors()) { - if (succ instanceof Load load && tidPtrs.contains(load.getAddress())) { - // Do tid propagation over loads - tidValues.add(load.getResultRegister()); - } - if (succ instanceof Local local) { - Expression rhs = local.getExpr(); - while (rhs instanceof IExprUn unExpr && unExpr.getOp() == IOpUn.CAST_UNSIGNED) { - // Try to skip cast expressions - rhs = unExpr.getInner(); - } - if (tidPtrs.contains(rhs)) { - tidPtrs.add(local.getResultRegister()); - } else if (tidValues.contains(rhs)) { - tidValues.add(local.getResultRegister()); + + final List arguments = call.getArguments(); + assert arguments.size() == 2; + final Expression tidExpr = arguments.get(0); + // TODO: support return values for threads + // final Expression returnAddr = arguments.get(1); + + final Register resultRegister = getResultRegister(call); + assert resultRegister.getType() instanceof IntegerType; + + // This register will hold the value "false" IFF the join succeeds. + final Register joinDummyReg = thread.getOrNewRegister("__joinFail#" + joinCounter, types.getBooleanType()); + final Label joinEnd = EventFactory.newLabel("__joinEnd#" + joinCounter); + + // ----- Construct a switch case for each possible tid ----- + final Map> tid2joinCases = new LinkedHashMap<>(); + for (IValue tidCandidate : tid2ComAddrMap.keySet()) { + final int tid = tidCandidate.getValueAsInt(); + final Expression comAddrOfThreadToJoinWith = tid2ComAddrMap.get(tidCandidate); + + if (tidExpr instanceof IConst iConst && iConst.getValueAsInt() != tid) { + // Little optimization if we join with a constant address + continue; } + + final Label joinCase = EventFactory.newLabel("__joinWithT" + tid + "#" + joinCounter); + final List caseBody = eventSequence( + joinCase, + newAcquireLoad(joinDummyReg, comAddrOfThreadToJoinWith), + EventFactory.newGoto(joinEnd) + ); + tid2joinCases.put(tidCandidate, caseBody); } - // Here we actually change pthread_join's first argument to directly hold the target tid - if (succ instanceof DirectFunctionCall call && call.getCallTarget().getName().contains("pthread_join")) { - if (tidValues.contains(call.getArguments().get(0))) { - // TODO: Direct access to call's argument list is fishy. - // Calls should have a setArgument function instead. - call.getArguments().set(0, tidExpr); - } + // ----- Construct the actual switch (a simple jump table) ----- + final List switchJumpTable = new ArrayList<>(); + for (Expression tid : tid2joinCases.keySet()) { + switchJumpTable.add(EventFactory.newJump( + expressions.makeEQ(tidExpr, tid), (Label)tid2joinCases.get(tid).get(0)) + ); } + // Add default case for when no tid matches. We make the join just fail here as if it + // was waiting for a never-terminating thread. + // FIXME: This does not align with the correct pthread_join semantics. + switchJumpTable.add(EventFactory.newLocal(joinDummyReg, expressions.makeTrue())); + switchJumpTable.add(EventFactory.newGoto(joinEnd)); + + // ----- Generate actual replacement for the pthread_join call ----- + final List replacement = new ArrayList<>(); + replacement.add(EventFactory.newFunctionCallMarker(call.getCallTarget().getName())); + replacement.addAll(switchJumpTable); + tid2joinCases.values().forEach(replacement::addAll); + replacement.addAll(Arrays.asList( + joinEnd, + newJump(joinDummyReg, (Label)thread.getExit()), + // Note: In our modelling, pthread_join always succeeds if it returns + newLocal(resultRegister, expressions.makeZero((IntegerType) resultRegister.getType())), + EventFactory.newFunctionReturnMarker(call.getCallTarget().getName()) + )); + + replacement.forEach(e -> e.copyAllMetadataFrom(call)); + call.replaceBy(replacement); + + joinCounter++; } } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/wmm/analysis/RelationAnalysis.java b/dartagnan/src/main/java/com/dat3m/dartagnan/wmm/analysis/RelationAnalysis.java index 902facfcf2..b817dd729b 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/wmm/analysis/RelationAnalysis.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/wmm/analysis/RelationAnalysis.java @@ -820,22 +820,6 @@ public Knowledge visitReadFrom(Relation rel) { if (eq.isImplied(startLoad, startStore)) { may.removeIf(t -> t.getSecond() == startLoad && t.getFirst() != startStore); } - - // Must-rf edge for thread joining - cur = thread.getExit(); - while (!(cur instanceof Store endStore)) { cur = cur.getPredecessor(); } - cur = start.getCreator(); - while (cur != null && !(cur instanceof Load joinLoad && joinLoad.getAddress().equals(endStore.getAddress()))) { - cur = cur.getSuccessor(); - } - - if (cur instanceof Load joinLoad) { - must.add(new Tuple(endStore, joinLoad)); - if (eq.isImplied(joinLoad, endStore)) { - // NOTE: The above condition is likely never satisfied in practice - may.removeIf(t -> t.getSecond() == joinLoad && t.getFirst() != endStore); - } - } } if (wmmAnalysis.isLocallyConsistent()) {