Skip to content

Commit

Permalink
Dynamic Pthread join (#496)
Browse files Browse the repository at this point in the history
* Generalized ThreadCreation to allow dynmaic joining.

* Cleaned up ThreadCreation's handling of pthread_join.
Removed unsound propagation code.
Removed must-rf edges related to pthread_join (hard to establish and exploit with dynamic thread joins).
  • Loading branch information
ThomasHaas authored Aug 8, 2023
1 parent 16d4e0e commit a89dd7e
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import com.dat3m.dartagnan.exception.MalformedProgramException;
import com.dat3m.dartagnan.expression.Expression;
import com.dat3m.dartagnan.expression.ExpressionFactory;
import com.dat3m.dartagnan.expression.IExprUn;
import com.dat3m.dartagnan.expression.op.IOpUn;
import com.dat3m.dartagnan.expression.IConst;
import com.dat3m.dartagnan.expression.IValue;
import com.dat3m.dartagnan.expression.processing.ExprTransformer;
import com.dat3m.dartagnan.expression.processing.ExpressionVisitor;
import com.dat3m.dartagnan.expression.type.IntegerType;
Expand All @@ -19,7 +19,6 @@
import com.dat3m.dartagnan.program.event.Tag;
import com.dat3m.dartagnan.program.event.core.Event;
import com.dat3m.dartagnan.program.event.core.Label;
import com.dat3m.dartagnan.program.event.core.Load;
import com.dat3m.dartagnan.program.event.core.Local;
import com.dat3m.dartagnan.program.event.core.threading.ThreadCreate;
import com.dat3m.dartagnan.program.event.core.threading.ThreadStart;
Expand All @@ -44,7 +43,6 @@

import java.math.BigInteger;
import java.util.*;
import java.util.stream.Stream;

import static com.dat3m.dartagnan.configuration.OptionNames.THREAD_CREATE_ALWAYS_SUCCEEDS;
import static com.dat3m.dartagnan.program.event.EventFactory.*;
Expand Down Expand Up @@ -99,10 +97,10 @@ public void run(Program program) {
throw new MalformedProgramException("Program contains no main function");
}

final int maxId = Stream.concat(program.getThreads().stream(), program.getFunctions().stream())
.mapToInt(Function::getId)
.max().orElse(0);
int nextTid = maxId + 1;
// NOTE: We start from id = 0 which overlaps with existing function ids.
// However, we reassign ids after thread creation so that functions get higher ids.
// TODO: Do we even need ids for functions?
int nextTid = 0;

final Queue<Thread> workingQueue = new ArrayDeque<>();
workingQueue.add(createThreadFromFunction(main.get(), nextTid++, null, null));
Expand All @@ -111,7 +109,9 @@ public void run(Program program) {
final Thread thread = workingQueue.remove();
program.addThread(thread);

final Map<Expression, Expression> tid2ComAddrMap = new HashMap<>();
// We collect the communication addresses we use for each thread id.
// These are used later to lower pthread_join.
final Map<IValue, Expression> tid2ComAddrMap = new LinkedHashMap<>();
for (DirectFunctionCall call : thread.getEvents(DirectFunctionCall.class)) {
final List<Expression> arguments = call.getArguments();
switch (call.getCallTarget().getName()) {
Expand All @@ -126,7 +126,7 @@ public void run(Program program) {
assert resultRegister.getType() instanceof IntegerType;

final ThreadCreate createEvent = newThreadCreate(List.of(argument));
final Expression tidExpr = expressions.makeValue(BigInteger.valueOf(nextTid), archType);
final IValue tidExpr = expressions.makeValue(BigInteger.valueOf(nextTid), archType);
final MemoryObject comAddress = program.getMemory().allocate(1, true);
comAddress.setCVar("__com" + nextTid + "__" + targetFunction.getName());

Expand All @@ -143,36 +143,10 @@ public void run(Program program) {
final Thread spawnedThread = createThreadFromFunction(targetFunction, nextTid, createEvent, comAddress);
createEvent.setSpawnedThread(spawnedThread);
workingQueue.add(spawnedThread);

tid2ComAddrMap.put(tidExpr, comAddress);
propagateThreadIds(tidExpr, pidResultAddress, createEvent);

nextTid++;
}
case "pthread_join", "__pthread_join" -> {
assert arguments.size() == 2;
final Expression tidExpr = arguments.get(0);
// TODO: support return values for threads
// final Expression returnAddr = arguments.get(1);

final Register resultRegister = getResultRegister(call);
assert resultRegister.getType() instanceof IntegerType;
final Expression comAddrOfThreadToJoinWith = tid2ComAddrMap.get(tidExpr);
if (comAddrOfThreadToJoinWith == null) {
throw new UnsupportedOperationException(
"Cannot handle pthread_join with dynamic thread parameter.");
}
final int tid = tidExpr.reduce().getValueAsInt();
final Register joinDummyReg = thread.getOrNewRegister("__joinT" + tid, types.getBooleanType());
final List<Event> replacement = eventSequence(
newAcquireLoad(joinDummyReg, comAddrOfThreadToJoinWith),
newJump(joinDummyReg, (Label)thread.getExit()),
// Note: In our modelling, pthread_join always succeeds if it returns
newLocal(resultRegister, expressions.makeZero((IntegerType) resultRegister.getType()))
);
replacement.forEach(e -> e.copyAllMetadataFrom(call));
call.replaceBy(replacement);
}
case "get_my_tid" -> {
final Register resultRegister = getResultRegister(call);
assert resultRegister.getType() instanceof IntegerType;
Expand All @@ -185,58 +159,93 @@ public void run(Program program) {
}
}
}

// FIXME: This only allows joining with child threads that were created by this thread.
handlePthreadJoins(thread, tid2ComAddrMap);
}

IdReassignment.newInstance().run(program);
logger.info("Number of threads (including main): " + program.getThreads().size());
}

/*
This method replaces in <thread> all pthread_join calls by a switch over all possible tids.
Each candidate thread gets a switch-case which tries to synchronize with that thread.
*/
private void handlePthreadJoins(Thread thread, Map<IValue, Expression> tid2ComAddrMap) {
final TypeFactory types = TypeFactory.getInstance();
final ExpressionFactory expressions = ExpressionFactory.getInstance();
int joinCounter = 0;

// Helper code to do constant propagation of generated tid's to pthread_join calls
//TODO: Ideally, this kind of propagation shouldn't be done here.
// Also, it is currently unsound if the code is not in SSA, for example, after unrolling.
private static void propagateThreadIds(Expression tidExpr, Expression tidResultAddress, ThreadCreate createEvent) {
Set<Expression> tidValues = new HashSet<>();
Set<Expression> tidPtrs = new HashSet<>();
tidPtrs.add(tidResultAddress);

// Backpropagation of pointers:
// "p1 <- pExpr; pthread_create(pExpr, ...)" => p1 also points to an address holding a tid
for (Event pred : createEvent.getPredecessors()) {
if (pred instanceof Local local && tidPtrs.contains(local.getResultRegister())) {
tidPtrs.add(local.getExpr());
for (DirectFunctionCall call : thread.getEvents(DirectFunctionCall.class)) {
final String targetName = call.getCallTarget().getName();
if (!(targetName.equals("pthread_join") || targetName.equals("__pthread_join"))) {
continue;
}
}
// Forward propagation
// (1) p1 <- pExpr => p1 points to tid if pExpr does
// (2) r1 <- expr => r1 holds tid if expr does
// (3) r <- load(pExpr) => r holds tid if pExpr points to tid
for (Event succ : createEvent.getSuccessors()) {
if (succ instanceof Load load && tidPtrs.contains(load.getAddress())) {
// Do tid propagation over loads
tidValues.add(load.getResultRegister());
}
if (succ instanceof Local local) {
Expression rhs = local.getExpr();
while (rhs instanceof IExprUn unExpr && unExpr.getOp() == IOpUn.CAST_UNSIGNED) {
// Try to skip cast expressions
rhs = unExpr.getInner();
}
if (tidPtrs.contains(rhs)) {
tidPtrs.add(local.getResultRegister());
} else if (tidValues.contains(rhs)) {
tidValues.add(local.getResultRegister());

final List<Expression> arguments = call.getArguments();
assert arguments.size() == 2;
final Expression tidExpr = arguments.get(0);
// TODO: support return values for threads
// final Expression returnAddr = arguments.get(1);

final Register resultRegister = getResultRegister(call);
assert resultRegister.getType() instanceof IntegerType;

// This register will hold the value "false" IFF the join succeeds.
final Register joinDummyReg = thread.getOrNewRegister("__joinFail#" + joinCounter, types.getBooleanType());
final Label joinEnd = EventFactory.newLabel("__joinEnd#" + joinCounter);

// ----- Construct a switch case for each possible tid -----
final Map<Expression, List<Event>> tid2joinCases = new LinkedHashMap<>();
for (IValue tidCandidate : tid2ComAddrMap.keySet()) {
final int tid = tidCandidate.getValueAsInt();
final Expression comAddrOfThreadToJoinWith = tid2ComAddrMap.get(tidCandidate);

if (tidExpr instanceof IConst iConst && iConst.getValueAsInt() != tid) {
// Little optimization if we join with a constant address
continue;
}

final Label joinCase = EventFactory.newLabel("__joinWithT" + tid + "#" + joinCounter);
final List<Event> caseBody = eventSequence(
joinCase,
newAcquireLoad(joinDummyReg, comAddrOfThreadToJoinWith),
EventFactory.newGoto(joinEnd)
);
tid2joinCases.put(tidCandidate, caseBody);
}

// Here we actually change pthread_join's first argument to directly hold the target tid
if (succ instanceof DirectFunctionCall call && call.getCallTarget().getName().contains("pthread_join")) {
if (tidValues.contains(call.getArguments().get(0))) {
// TODO: Direct access to call's argument list is fishy.
// Calls should have a setArgument function instead.
call.getArguments().set(0, tidExpr);
}
// ----- Construct the actual switch (a simple jump table) -----
final List<Event> switchJumpTable = new ArrayList<>();
for (Expression tid : tid2joinCases.keySet()) {
switchJumpTable.add(EventFactory.newJump(
expressions.makeEQ(tidExpr, tid), (Label)tid2joinCases.get(tid).get(0))
);
}
// Add default case for when no tid matches. We make the join just fail here as if it
// was waiting for a never-terminating thread.
// FIXME: This does not align with the correct pthread_join semantics.
switchJumpTable.add(EventFactory.newLocal(joinDummyReg, expressions.makeTrue()));
switchJumpTable.add(EventFactory.newGoto(joinEnd));

// ----- Generate actual replacement for the pthread_join call -----
final List<Event> replacement = new ArrayList<>();
replacement.add(EventFactory.newFunctionCallMarker(call.getCallTarget().getName()));
replacement.addAll(switchJumpTable);
tid2joinCases.values().forEach(replacement::addAll);
replacement.addAll(Arrays.asList(
joinEnd,
newJump(joinDummyReg, (Label)thread.getExit()),
// Note: In our modelling, pthread_join always succeeds if it returns
newLocal(resultRegister, expressions.makeZero((IntegerType) resultRegister.getType())),
EventFactory.newFunctionReturnMarker(call.getCallTarget().getName())
));

replacement.forEach(e -> e.copyAllMetadataFrom(call));
call.replaceBy(replacement);

joinCounter++;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -820,22 +820,6 @@ public Knowledge visitReadFrom(Relation rel) {
if (eq.isImplied(startLoad, startStore)) {
may.removeIf(t -> t.getSecond() == startLoad && t.getFirst() != startStore);
}

// Must-rf edge for thread joining
cur = thread.getExit();
while (!(cur instanceof Store endStore)) { cur = cur.getPredecessor(); }
cur = start.getCreator();
while (cur != null && !(cur instanceof Load joinLoad && joinLoad.getAddress().equals(endStore.getAddress()))) {
cur = cur.getSuccessor();
}

if (cur instanceof Load joinLoad) {
must.add(new Tuple(endStore, joinLoad));
if (eq.isImplied(joinLoad, endStore)) {
// NOTE: The above condition is likely never satisfied in practice
may.removeIf(t -> t.getSecond() == joinLoad && t.getFirst() != endStore);
}
}
}

if (wmmAnalysis.isLocallyConsistent()) {
Expand Down

0 comments on commit a89dd7e

Please sign in to comment.