Skip to content

Commit

Permalink
Merge remote-tracking branch 'kubagalecki/arbitrary-grid' into unstab…
Browse files Browse the repository at this point in the history
…le-merge
  • Loading branch information
llaniewski committed Jan 3, 2024
2 parents b38b8dd + e02f071 commit 6016957
Show file tree
Hide file tree
Showing 14 changed files with 518 additions and 326 deletions.
18 changes: 9 additions & 9 deletions src/ArbConnectivity.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ struct ArbLatticeConnectivity {
zones.reserve(getLocalSize());
}

void dump(std::string filename) {
void dump(std::string filename) const {
FILE* f;
f = fopen(filename.c_str(),"w");
fprintf(f,"idx_og,idx");
for (size_t q=0;q<Q;q++) fprintf(f,",nbr%ld",q);
fprintf(f,"\n");
f = fopen(filename.c_str(), "w");
fprintf(f, "idx_og,idx");
for (size_t q = 0; q < Q; q++) fprintf(f, ",nbr%ld", q);
fprintf(f, "\n");
size_t n = chunk_end - chunk_begin;
for (size_t lid=0; lid<n; lid++) {
fprintf(f,"%ld,%ld",(size_t) og_index[lid],(size_t) lid + chunk_begin);
for (size_t q=0;q<Q;q++) fprintf(f,",%ld",(signed long int) neighbor(q, lid));
fprintf(f,"\n");
for (size_t lid = 0; lid < n; lid++) {
fprintf(f, "%ld,%ld", (size_t)og_index[lid], (size_t)lid + chunk_begin);
for (size_t q = 0; q < Q; q++) fprintf(f, ",%ld", (signed long int)neighbor(q, lid));
fprintf(f, "\n");
}
fclose(f);
}
Expand Down
352 changes: 229 additions & 123 deletions src/ArbLattice.cpp

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions src/ArbLattice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class ArbLattice : public LatticeBase {
void getQuantity(int quant, real_t* host_tab, real_t scale); /// Write GPU data to \p host_tab
const ArbVTUGeom& getVTUGeom() const { return vtu_geom; }
Span<const flag_t> getNodeTypes() const { return {node_types_host.data(), node_types_host.size()}; } /// Get host view of node types (permuted)
const ArbLatticeConnectivity& getConnectivity() const { return connect; }
const std::vector<unsigned>& getLocalPermutation() const { return local_permutation; }

protected:
ArbLatticeLauncher launcher; /// Launcher responsible for running CUDA kernels on the lattice
Expand All @@ -105,6 +107,7 @@ class ArbLattice : public LatticeBase {
};

storage_t* getSnapPtr(int snap_ind); /// Get device pointer to the specified snap (somewhere within the total snap allocation)
const storage_t* getSnapPtr(int snap_ind) const { return const_cast<ArbLattice*>(this)->getSnapPtr(snap_ind); }
#ifdef ADJOINT
storage_t* getAdjointSnapPtr(int snap_ind); /// Get device pointer to the specified adjoint snap, snap_ind must be 0 or 1
#endif
Expand All @@ -122,11 +125,12 @@ class ArbLattice : public LatticeBase {
void initialize(size_t num_snaps_, const std::map<std::string, int>& setting_zones, pugi::xml_node arb_node); /// Init based on args
void readFromCxn(const std::string& cxn_path); /// Read the lattice info from a .cxn file
void partition(); /// Repartition the lattice, if ParMETIS is not present this is a noop
void computeLocalPermutation(); /// Compute the local permutation, see comment at the top
std::function<bool(int, int)> makePermCompare(pugi::xml_node arb_node, const std::map<std::string, int>& setting_zones); /// Make type-erased comparison operator for computing the local permutation, according to the strategy specified in the xml file
void computeLocalPermutation(pugi::xml_node arb_node, const std::map<std::string, int>& setting_zones); /// Compute the local permutation, see comment at the top
void computeGhostNodes(); /// Retrieve GIDs of ghost nodes from the connectivity info structure
void allocDeviceMemory(); /// Allocate required device memory
std::vector<NodeTypeBrush> parseBrushFromXml(pugi::xml_node arb_node, const std::map<std::string, int>& setting_zones) const; /// Parse the arbitrary lattice XML to determine the brush sequence to be applied to each node
void computeNodeTypesOnHost(pugi::xml_node arb_node, const std::map<std::string, int>& setting_zones); /// Compute the node types to be stored on the device
void computeNodeTypesOnHost(pugi::xml_node arb_node, const std::map<std::string, int>& setting_zones, bool permute); /// Compute the node types to be stored on the device, `permute` enables better code reuse
std::pmr::vector<real_t> computeCoords() const; /// Compute the coordinates 2D array to be stored on the device
std::pmr::vector<unsigned> computeNeighbors() const; /// Compute the neighbors 2D array to be stored on the device
void initDeviceData(pugi::xml_node arb_node, const std::map<std::string, int>& setting_zones); /// Initialize data residing in device memory
Expand All @@ -137,6 +141,8 @@ class ArbLattice : public LatticeBase {
ArbVTUGeom makeVTUGeom() const; /// Compute VTU geometry
void communicateBorder(); /// Send and receive border values in snap (overlapped with interior computation)
unsigned lookupLocalGhostIndex(ArbLatticeConnectivity::Index gid) const; /// For a given ghost gid, look up its local id
void debugDumpConnect(const std::string& name) const; /// Dump connectivity info for debug purposes
void debugDumpVTU() const; /// Dump VTU info info for debug purposes
};

#endif // ARBLATTICE_HPP
3 changes: 1 addition & 2 deletions src/Handlers/acLoadMemoryDump.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ int acLoadMemoryDump::Init () {
if (attr2) {
error("Depreceted API call. Use LoadBinary with comp parameter");
}
const auto lattice = solver->getCartLattice();
lattice->loadSolution(attr.value());
solver->lattice->loadSolution(attr.value());
return 0;
}

Expand Down
12 changes: 5 additions & 7 deletions src/Handlers/cbBIN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@ int cbBIN::Init () {
return 0;
}


int cbBIN::DoIt () {
Callback::DoIt();
const auto filename = solver->outIterFile(nm, "");
return binWriteLattice(filename, *solver->getCartLattice(), solver->units);
};

int cbBIN::DoIt() {
Callback::DoIt();
const auto filename = solver->outIterFile(nm, "");
return std::visit([&](const auto lattice_ptr) { return binWriteLattice(filename, *lattice_ptr, solver->units); }, solver->getLatticeVariant());
};

// Register the handler (basing on xmlname) in the Handler Factory
template class HandlerFactory::Register< GenericAsk< cbBIN > >;
168 changes: 73 additions & 95 deletions src/Handlers/cbFailcheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,102 +2,80 @@

std::string cbFailcheck::xmlname = "Failcheck";

int cbFailcheck::Init () {
Callback::Init();
currentlyactive = false;
const auto lattice = solver->getCartLattice();
reg.dx = lattice->getLocalRegion().dx;
reg.dy = lattice->getLocalRegion().dy;
reg.dz = lattice->getLocalRegion().dz;



pugi::xml_attribute attr = node.attribute("dx");
if (attr) {
reg.dx = solver->units.alt(attr.value());
}
attr = node.attribute("dy");
if (attr) {
reg.dy = solver->units.alt(attr.value());
}
attr = node.attribute("dz");
if (attr) {
reg.dz = solver->units.alt(attr.value());
}


attr = node.attribute("nx");
if (attr) {
reg.nx = solver->units.alt(attr.value());
}
attr = node.attribute("ny");
if (attr) {
reg.ny = solver->units.alt(attr.value());
}
attr = node.attribute("nz");
if (attr) {
reg.nz = solver->units.alt(attr.value());
}

return 0;
}


int cbFailcheck::DoIt () {
Callback::DoIt();
if (currentlyactive) return 0;
currentlyactive = true;
int ret = 0;
int fin;
fin = false;

pugi::xml_attribute comp = node.attribute("what");

name_set components;
if(comp){
components.add_from_string(comp.value(),',');
} else {
components.add_from_string("all",',');
}

for (const Model::Quantity& it : solver->lattice->model->quantities) {
if (it.isAdjoint) continue;
if (components.in(it.name)) {
int comp = 1;
if (it.isVector) comp = 3;
real_t* tmp = new real_t[reg.size()*comp];
solver->getCartLattice()->GetQuantity(it.id, reg, tmp, 1);
bool cond = false;
for (int k = 0; k < reg.size()*comp; k++){
cond = cond || (std::isnan(tmp[k]));
}
delete[] tmp;
MPI_Allreduce(&cond,&fin,1,MPI_INT,MPI_LOR,MPMD.local);

if(fin ){
notice("Checking %s discovered NaN", it.name.c_str());
break;
}
}

int cbFailcheck::Init() {
Callback::Init();
currentlyactive = false;
const auto init_cart = [&](const Lattice<CartLattice>* lattice) {
reg.dx = lattice->getLocalRegion().dx;
reg.dy = lattice->getLocalRegion().dy;
reg.dz = lattice->getLocalRegion().dz;
const auto set_if_present = [&](const char* name, auto& value) {
const auto attribute = node.attribute(name);
if (attribute) value = solver->units.alt(attribute.value());
};
set_if_present("dx", reg.dx);
set_if_present("dy", reg.dy);
set_if_present("dz", reg.dz);
set_if_present("nx", reg.nx);
set_if_present("ny", reg.ny);
set_if_present("nz", reg.nz);
return EXIT_SUCCESS;
};
const auto init_arb = [&](const Lattice<ArbLattice>*) { return EXIT_SUCCESS; };
return std::visit(OverloadSet{init_cart, init_arb}, solver->getLatticeVariant());
}

int cbFailcheck::DoIt() {
Callback::DoIt();
if (currentlyactive) return EXIT_SUCCESS;
currentlyactive = true;

name_set components;
const auto comp = node.attribute("what");
components.add_from_string(comp ? comp.value() : "all", ',');

const auto check_for_nans = [&](const Model::Quantity& quantity) -> bool {
if (!components.in(quantity.name) || quantity.isAdjoint) return false;

const auto get_quantity_vec = [&](int quant_id, int n_comps) -> std::vector<real_t> {
const auto get_from_cart = [&](Lattice<CartLattice>* lattice) {
std::vector<real_t> retval(reg.size() * n_comps);
lattice->GetQuantity(quant_id, reg, retval.data(), 1.);
return retval;
};
const auto get_from_arb = [&](Lattice<ArbLattice>* lattice) {
std::vector<real_t> retval(lattice->getLocalSize() * n_comps);
lattice->getQuantity(quant_id, retval.data(), 1.);
return retval;
};
return std::visit(OverloadSet{get_from_cart, get_from_arb}, solver->getLatticeVariant());
};

const int n_comps = quantity.isVector ? 3 : 1;
const auto values = get_quantity_vec(quantity.id, n_comps);
int has_nans = std::any_of(values.begin(), values.end(), [](auto v) { return std::isnan(v); });
MPI_Allreduce(MPI_IN_PLACE, &has_nans, 1, MPI_INT, MPI_LOR, MPMD.local);
if (has_nans) notice("Discovered NaN values in %s", quantity.name.c_str());
return has_nans;
};

// Note: std::any_of would break early, we want to print all quantities which have NaN values, hence std::transform_reduce
const auto& quants = solver->lattice->model->quantities;
if (std::transform_reduce(quants.begin(), quants.end(), false, std::logical_or{}, check_for_nans)) {
notice("NaN value discovered. Executing final actions from the Failcheck element before full stop...\n");
for (pugi::xml_node par = node.first_child(); par; par = par.next_sibling()) {
Handler hand(par, solver);
if (hand) hand.DoIt();
}
if (fin) {
notice("NaN value discovered. Executing final actions from the Failcheck element before full stop...\n");
for (pugi::xml_node par = node.first_child(); par; par = par.next_sibling()) {
Handler hand(par, solver);
if (hand) hand.DoIt();
}
notice("Stopping due to Nan value\n");
ret = ITERATION_STOP;
}
return ret;
}


int cbFailcheck::Finish () {
return Callback::Finish();
}
notice("Stopping due to NaN value\n");
return ITERATION_STOP;
}
return EXIT_SUCCESS;
}

int cbFailcheck::Finish() {
return Callback::Finish();
}

// Register the handler (basing on xmlname) in the Handler Factory
template class HandlerFactory::Register< GenericAsk< cbFailcheck > >;
template class HandlerFactory::Register<GenericAsk<cbFailcheck> >;
5 changes: 2 additions & 3 deletions src/Handlers/cbSaveBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ int cbSaveBinary::Init () {
int cbSaveBinary::DoIt () {
Callback::DoIt();
pugi::xml_attribute attr = node.attribute("comp");
const auto lattice = solver->getCartLattice();
if (attr) {
lattice->saveComp(fn, attr.value());
solver->lattice->saveComp(fn, attr.value());
} else {
lattice->saveSolution(fn);
solver->lattice->saveSolution(fn);
//error("Missing comp attribute in SaveBinary");
}
return 0;
Expand Down
3 changes: 1 addition & 2 deletions src/Handlers/cbSaveCheckpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,9 @@ int cbSaveCheckpoint::DoIt () {
delete the first set into the queue
*/
output("writing checkpoint");
const auto lattice = solver->getCartLattice();
const auto filename = solver->outIterCollectiveFile("checkpoint", "");
const auto restartFile = solver->outIterCollectiveFile("restart", ".xml");
auto fileStr = lattice->saveSolution(filename);
auto fileStr = solver->lattice->saveSolution(filename);
std::string restStr;
if (D_MPI_RANK == 0 ) {
writeRestartFile(filename.c_str(), restartFile.c_str());
Expand Down
2 changes: 1 addition & 1 deletion src/Handlers/cbSaveMemoryDump.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ int cbSaveMemoryDump::DoIt () {
if (attr) {
error("Depreceted API call. Use SaveBinary with comp parameter");
}
solver->getCartLattice()->saveSolution(fn);
solver->lattice->saveSolution(fn);
return 0;
};

Expand Down
13 changes: 5 additions & 8 deletions src/Handlers/cbTXT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,11 @@ int cbTXT::Init () {
return 0;
}


int cbTXT::DoIt () {
Callback::DoIt();
const auto filename = solver->outIterFile(nm, "");
auto& lattice = *solver->getCartLattice();
return txtWriteLattice(filename.c_str(), lattice, solver->units, s, txt_type);
};

int cbTXT::DoIt() {
Callback::DoIt();
const auto filename = solver->outIterFile(nm, "");
return std::visit([&](const auto lattice_ptr) { return txtWriteLattice(filename, *lattice_ptr, solver->units, s, txt_type); }, solver->getLatticeVariant());
};

// Register the handler (basing on xmlname) in the Handler Factory
template class HandlerFactory::Register< GenericAsk< cbTXT > >;
4 changes: 1 addition & 3 deletions src/configure.ac
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,9 @@ AC_ARG_WITH([cpp-flags],
[privide additionals flags for the compiler]),
[CONF_CPPFLAGS="$withval"])

NVFLAGS=""

if test "x${COMPILER_BINDIR}" != "x"
then
NVFLAGS="-ccbin=${COMPILER_BINDIR} "
NVFLAGS="${NVFLAGS} -ccbin=${COMPILER_BINDIR} "
fi

if test -z "$CXX"
Expand Down
Loading

0 comments on commit 6016957

Please sign in to comment.