Skip to content

Commit

Permalink
add thread scope setter
Browse files Browse the repository at this point in the history
  • Loading branch information
tonghaining committed Oct 2, 2024
1 parent a02dbd2 commit 7821c1d
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import com.dat3m.dartagnan.expression.Type;
import com.dat3m.dartagnan.expression.integers.IntLiteral;
import com.dat3m.dartagnan.expression.type.FunctionType;
import com.dat3m.dartagnan.expression.type.IntegerType;
import com.dat3m.dartagnan.expression.type.TypeFactory;
import com.dat3m.dartagnan.program.*;
import com.dat3m.dartagnan.program.Thread;
Expand Down Expand Up @@ -46,10 +45,6 @@ public class ProgramBuilder {
private final Map<Integer, Map<String, Label>> fid2LabelsMap = new HashMap<>();
private final Map<String, MemoryObject> locations = new HashMap<>();
private final Map<Register, MemoryObject> reg2LocMap = new HashMap<>();
private final Map<Integer, Map<String, IntegerType>> id2RegTypeMap = new HashMap<>();
private final Map<Integer, Map<String, Expression>> id2RegConstMap = new HashMap<>();
private final Map<Integer, Map<String, String>> id2RegLocPtrMap = new HashMap<>();
private final Map<Integer, Map<String, String>> id2RegLocValMap = new HashMap<>();

private final Program program;

Expand Down Expand Up @@ -119,19 +114,13 @@ public void setAssertFilter(Expression ass) {

// This method creates a "default" thread that has no parameters, no return value, and runs unconditionally.
// It is only useful for creating threads of Litmus code.
public Thread newThread(int tid, Thread thread) {
public Thread newThread(String name, int tid) {
if(id2FunctionsMap.containsKey(tid)) {
throw new MalformedProgramException("Function or thread with id " + tid + " already exists.");
}
final Thread thread = new Thread(name, DEFAULT_THREAD_TYPE, List.of(), tid, EventFactory.newThreadStart(null));
id2FunctionsMap.put(tid, thread);
program.addThread(thread);
if (id2RegConstMap.containsKey(tid)) {
id2RegConstMap.get(tid).forEach((regName, value) ->
initRegEqConst(tid, regName, value));
} else if (id2RegLocPtrMap.containsKey(tid)) {
id2RegLocPtrMap.get(tid).forEach((regName, value) ->
initRegEqLocPtr(tid, regName, value, getRegType(tid, regName)));
} else if (id2RegLocValMap.containsKey(tid)) {
id2RegLocValMap.get(tid).forEach((regName, value) ->
initRegEqLocVal(tid, regName, value, getRegType(tid, regName)));
}
return thread;
}

Expand All @@ -146,12 +135,8 @@ public Function newFunction(String name, int fid, FunctionType type, List<String
}

public Thread newThread(int tid) {
if(id2FunctionsMap.containsKey(tid)) {
throw new MalformedProgramException("Function or thread with id " + tid + " already exists.");
}
final String threadName = (program.getFormat() == LITMUS ? "P" : "__thread_") + tid;
final Thread thread = new Thread(threadName, DEFAULT_THREAD_TYPE, List.of(), tid, EventFactory.newThreadStart(null));
return newThread(tid, thread);
return newThread(threadName, tid);
}

public Thread getOrNewThread(int tid) {
Expand Down Expand Up @@ -250,29 +235,6 @@ public void initRegEqConst(int regThread, String regName, Expression value){
addChild(regThread, EventFactory.newLocal(getOrNewRegister(regThread, regName, value.getType()), value));
}

public void addRegType(int tid, String regName, IntegerType type) {
id2RegTypeMap.computeIfAbsent(tid, k -> new HashMap<>()).put(regName, type);
}

public IntegerType getRegType(int tid, String regName) {
if (id2RegTypeMap.containsKey(tid) && id2RegTypeMap.get(tid).containsKey(regName)) {
return id2RegTypeMap.get(tid).get(regName);
}
throw new IllegalStateException("Register " + tid + ":" + regName + " is not initialised");
}

public void addRegToConstMap(int tid, String regName, Expression value) {
id2RegConstMap.computeIfAbsent(tid, k -> new HashMap<>()).put(regName, value);
}

public void addRegToLocPtrMap(int tid, String regName, String locName) {
id2RegLocPtrMap.computeIfAbsent(tid, k -> new HashMap<>()).put(regName, locName);
}

public void addRegToLocValMap(int tid, String regName, String locName) {
id2RegLocValMap.computeIfAbsent(tid, k -> new HashMap<>()).put(regName, locName);
}

private Expression getInitialValue(String name) {
return getOrNewMemoryObject(name).getInitialValue(0);
}
Expand Down Expand Up @@ -314,26 +276,28 @@ public Label getEndOfThreadLabel(int tid) {

// ----------------------------------------------------------------------------------------------------------------
// GPU
public void newScopedThread(Arch arch, String name, int id, int ...scopeIds) {
if(id2FunctionsMap.containsKey(id)) {
throw new MalformedProgramException("Function or thread with id " + id + " already exists.");
}
// Litmus threads run unconditionally (have no creator) and have no parameters/return types.
ThreadStart threadEntry = EventFactory.newThreadStart(null);
Thread scopedThread = switch (arch) {
case PTX -> new Thread(name, DEFAULT_THREAD_TYPE, List.of(), id, threadEntry,
ScopeHierarchy.ScopeHierarchyForPTX(scopeIds[0], scopeIds[1]), new HashSet<>());
case VULKAN -> new Thread(name, DEFAULT_THREAD_TYPE, List.of(), id, threadEntry,
ScopeHierarchy.ScopeHierarchyForVulkan(scopeIds[0], scopeIds[1], scopeIds[2]), new HashSet<>());
case OPENCL -> new Thread(name, DEFAULT_THREAD_TYPE, List.of(), id, threadEntry,
ScopeHierarchy.ScopeHierarchyForOpenCL(scopeIds[0], scopeIds[1]), new HashSet<>());
public void setOrCreateScopedThread(Arch arch, String name, int id, int ...scopeIds) {
ScopeHierarchy scopeHierarchy = switch (arch) {
case PTX -> ScopeHierarchy.ScopeHierarchyForPTX(scopeIds[0], scopeIds[1]);
case VULKAN -> ScopeHierarchy.ScopeHierarchyForVulkan(scopeIds[0], scopeIds[1], scopeIds[2]);
case OPENCL -> ScopeHierarchy.ScopeHierarchyForOpenCL(scopeIds[0], scopeIds[1]);
default -> throw new UnsupportedOperationException("Unsupported architecture: " + arch);
};
newThread(id, scopedThread);

if(id2FunctionsMap.containsKey(id)) {
Thread thread = (Thread) id2FunctionsMap.get(id);
thread.setScopeHierarchy(scopeHierarchy);
} else {
// Litmus threads run unconditionally (have no creator) and have no parameters/return types.
ThreadStart threadEntry = EventFactory.newThreadStart(null);
Thread scopedThread = new Thread(name, DEFAULT_THREAD_TYPE, List.of(), id, threadEntry, scopeHierarchy, new HashSet<>());
id2FunctionsMap.put(id, scopedThread);
program.addThread(scopedThread);
}
}

public void newScopedThread(Arch arch, int id, int ...ids) {
newScopedThread(arch, String.valueOf(id), id, ids);
public void setOrCreateScopedThread(Arch arch, int id, int ...ids) {
setOrCreateScopedThread(arch, String.valueOf(id), id, ids);
}

// ----------------------------------------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import com.dat3m.dartagnan.configuration.Arch;
import com.dat3m.dartagnan.exception.ParsingException;
import com.dat3m.dartagnan.expression.BinaryExpression;
import com.dat3m.dartagnan.expression.Expression;
import com.dat3m.dartagnan.expression.ExpressionFactory;
import com.dat3m.dartagnan.expression.integers.IntLiteral;
Expand Down Expand Up @@ -46,6 +45,8 @@ public VisitorLitmusC(){

@Override
public Program visitMain(LitmusCParser.MainContext ctx) {
//FIXME: We should visit thread declarations before variable declarations
// because variable declaration refer to threads.
visitVariableDeclaratorList(ctx.variableDeclaratorList());
visitProgram(ctx.program());
VisitorLitmusAssertions.parseAssertions(programBuilder, ctx.assertionList(), ctx.assertionFilter());
Expand All @@ -68,9 +69,10 @@ public Object visitGlobalDeclaratorLocation(LitmusCParser.GlobalDeclaratorLocati
@Override
public Object visitGlobalDeclaratorRegister(LitmusCParser.GlobalDeclaratorRegisterContext ctx) {
if (ctx.initConstantValue() != null) {
// FIXME: We visit declarators before threads, so we need to create threads early
programBuilder.getOrNewThread(ctx.threadId().id);
IntLiteral value = expressions.parseValue(ctx.initConstantValue().constant().getText(), archType);
programBuilder.addRegType(ctx.threadId().id, ctx.varName().getText(), archType);
programBuilder.addRegToConstMap(ctx.threadId().id, ctx.varName().getText(), value);
programBuilder.initRegEqConst(ctx.threadId().id,ctx.varName().getText(), value);
}
return null;
}
Expand All @@ -93,19 +95,17 @@ public Object visitGlobalDeclaratorLocationLocation(LitmusCParser.GlobalDeclarat

@Override
public Object visitGlobalDeclaratorRegisterLocation(LitmusCParser.GlobalDeclaratorRegisterLocationContext ctx) {
int threadId = ctx.threadId().id;
String regName = ctx.varName(0).getText();
String locName = ctx.varName(1).getText();
programBuilder.addRegType(threadId, regName, archType);
// FIXME: We visit declarators before threads, so we need to create threads early
programBuilder.getOrNewThread(ctx.threadId().id);
if(ctx.Ast() == null){
programBuilder.addRegToLocPtrMap(threadId, regName, locName);
programBuilder.initRegEqLocPtr(ctx.threadId().id, ctx.varName(0).getText(), ctx.varName(1).getText(), archType);
} else {
String rightName = ctx.varName(1).getText();
MemoryObject object = programBuilder.getMemoryObject(rightName);
if(object != null){
programBuilder.addRegToConstMap(threadId, regName, object);
programBuilder.initRegEqConst(ctx.threadId().id, ctx.varName(0).getText(), object);
} else {
programBuilder.addRegToLocValMap(threadId, regName, locName);
programBuilder.initRegEqLocVal(ctx.threadId().id, ctx.varName(0).getText(), ctx.varName(1).getText(), archType);
}
}
return null;
Expand Down Expand Up @@ -158,7 +158,7 @@ public Object visitThread(LitmusCParser.ThreadContext ctx) {
// Declarations in the preamble may have created the thread already
if (ctx.threadScope() == null) {
// Set dummy scope for C11 threads
programBuilder.newScopedThread(Arch.OPENCL, currentThread, 0, 0);
programBuilder.setOrCreateScopedThread(Arch.OPENCL, currentThread, 0, 0);
} else {
ctx.threadScope().accept(this);
this.isOpenCL = true;
Expand All @@ -176,7 +176,7 @@ public Object visitThread(LitmusCParser.ThreadContext ctx) {
public Object visitOpenCLThreadScope(LitmusCParser.OpenCLThreadScopeContext ctx) {
int wgID = ctx.scopeID(0).id;
int devID = ctx.scopeID(1).id;
programBuilder.newScopedThread(Arch.OPENCL, currentThread, devID, wgID);
programBuilder.setOrCreateScopedThread(Arch.OPENCL, currentThread, devID, wgID);
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public Object visitThreadDeclaratorList(LitmusPTXParser.ThreadDeclaratorListCont
int ctaID = threadScopeContext.scopeID().ctaID().id;
int gpuID = threadScopeContext.scopeID().gpuID().id;
// NB: the order of scopeIDs is important
programBuilder.newScopedThread(Arch.PTX, threadScopeContext.threadId().id, gpuID, ctaID);
programBuilder.setOrCreateScopedThread(Arch.PTX, threadScopeContext.threadId().id, gpuID, ctaID);
threadCount++;
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public Object visitThreadDeclaratorList(LitmusVulkanParser.ThreadDeclaratorListC
int workgroupID = threadScopeContext.workgroupScope().scopeID().id;
int queuefamilyID = threadScopeContext.queuefamilyScope().scopeID().id;
// NB: the order of scopeIDs is important
programBuilder.newScopedThread(Arch.VULKAN, threadScopeContext.threadId().id,
programBuilder.setOrCreateScopedThread(Arch.VULKAN, threadScopeContext.threadId().id,
queuefamilyID, workgroupID, subgroupID);
threadCount++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
public class Thread extends Function {

// Scope hierarchy of the thread
private final Optional<ScopeHierarchy> scopeHierarchy;
private Optional<ScopeHierarchy> scopeHierarchy;

// Threads that are system-synchronized-with this thread
private final Optional<Set<Thread>> syncSet;
Expand Down Expand Up @@ -56,6 +56,10 @@ public Set<Thread> getSyncSet() {
return syncSet.get();
}

public void setScopeHierarchy(ScopeHierarchy scopeHierarchy) {
this.scopeHierarchy = Optional.of(scopeHierarchy);
}

@Override
public ThreadStart getEntry() {
return (ThreadStart) entry;
Expand Down

0 comments on commit 7821c1d

Please sign in to comment.