Skip to content

Commit

Permalink
Added facilities for saving/loading the solution to/from binary files…
Browse files Browse the repository at this point in the history
…, with corresponding handlers
  • Loading branch information
kubagalecki committed Dec 31, 2023
1 parent 5d7635b commit e02f071
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 160 deletions.
73 changes: 62 additions & 11 deletions src/ArbLattice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,22 +678,73 @@ void ArbLattice::MPIStream_B() {
CudaStreamSynchronize(inStream);
}

/// TODO section
int ArbLattice::loadComp(const std::string& filename, const std::string& comp) {
throw std::runtime_error{"UNIMPLEMENTED"};
return -1;
static int saveImpl(const std::string& filename, const storage_t* device_ptr, size_t size) {
std::pmr::vector<storage_t> tab(size);
CudaMemcpy(tab.data(), device_ptr, size * sizeof(storage_t), CudaMemcpyDeviceToHost);
auto file = fopen(filename.c_str(), "wb");
if (!file) {
const auto err_msg = std::string("Failed to open ") + filename + " for writing";
ERROR(err_msg.c_str());
return EXIT_FAILURE;
}
const auto n_written = fwrite(tab.data(), sizeof(storage_t), size, file);
fclose(file);
if (n_written != size) {
const auto err_msg = std::string("Error writing to ") + filename;
ERROR(err_msg.c_str());
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}
int ArbLattice::saveComp(const std::string& filename, const std::string& comp) const {
throw std::runtime_error{"UNIMPLEMENTED"};
return -1;

static int loadImpl(const std::string& filename, storage_t* device_ptr, size_t size) {
auto file = fopen(filename.c_str(), "rb");
if (!file) {
const auto err_msg = std::string("Failed to open ") + filename + " for reading";
ERROR(err_msg.c_str());
return EXIT_FAILURE;
}
std::pmr::vector<storage_t> tab(size);
const auto n_read = fread(tab.data(), sizeof(storage_t), size, file);
fclose(file);
if (n_read != size) {
const auto err_msg = std::string("Error reading from ") + filename;
ERROR(err_msg.c_str());
return EXIT_FAILURE;
}
CudaMemcpy(device_ptr, tab.data(), size * sizeof(storage_t), CudaMemcpyHostToDevice);
return EXIT_SUCCESS;
}

void ArbLattice::savePrimal(const std::string& filename, int snap_ind) const {
if (saveImpl(filename, getSnapPtr(snap_ind), sizes.snaps_pitch * NF)) throw std::runtime_error{"savePrimal failed"};
}

int ArbLattice::loadPrimal(const std::string& filename, int snap_ind) {
throw std::runtime_error{"UNIMPLEMENTED"};
return -1;
return loadImpl(filename, getSnapPtr(snap_ind), sizes.snaps_pitch * NF);
}
void ArbLattice::savePrimal(const std::string& filename, int snap_ind) const {
throw std::runtime_error{"UNIMPLEMENTED"};

int ArbLattice::loadComp(const std::string& filename, const std::string& comp) {
const int comp_ind = Model_m::lookupFieldIndexByName(comp);
if (comp_ind == -1) {
const auto err_msg = std::string("ArbLattice::loadComp called with unknown component: ") + comp;
error(err_msg.c_str());
return EXIT_FAILURE;
}
return loadImpl(filename, std::next(getSnapPtr(Snap), comp_ind * sizes.snaps_pitch), sizes.snaps_pitch);
}

int ArbLattice::saveComp(const std::string& filename, const std::string& comp) const {
const int comp_ind = Model_m::lookupFieldIndexByName(comp);
if (comp_ind == -1) {
const auto err_msg = std::string("ArbLattice::saveComp called with unknown component: ") + comp;
error(err_msg.c_str());
return EXIT_FAILURE;
}
return saveImpl(filename, std::next(getSnapPtr(Snap), comp_ind * sizes.snaps_pitch), sizes.snaps_pitch);
}

/// TODO section
#ifdef ADJOINT
int ArbLattice::loadAdj(const std::string& filename, int asnap_ind) {
throw std::runtime_error{"UNIMPLEMENTED"};
Expand Down
1 change: 1 addition & 0 deletions src/ArbLattice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class ArbLattice : public LatticeBase {
};

storage_t* getSnapPtr(int snap_ind); /// Get device pointer to the specified snap (somewhere within the total snap allocation)
const storage_t* getSnapPtr(int snap_ind) const { return const_cast<ArbLattice*>(this)->getSnapPtr(snap_ind); }
#ifdef ADJOINT
storage_t* getAdjointSnapPtr(int snap_ind); /// Get device pointer to the specified adjoint snap, snap_ind must be 0 or 1
#endif
Expand Down
31 changes: 13 additions & 18 deletions src/Handlers/acLoadMemoryDump.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,20 @@
std::string acLoadMemoryDump::xmlname = "LoadMemoryDump";
#include "../HandlerFactory.h"

int acLoadMemoryDump::Init () {
Action::Init();
pugi::xml_attribute attr = node.attribute("file");
if (!attr) {
attr = node.attribute("filename");
if (!attr) {
error("No file specified in LoadMemoryDump\n");
return -1;
}
}
pugi::xml_attribute attr2= node.attribute("comp");
if (attr2) {
error("Depreceted API call. Use LoadBinary with comp parameter");
int acLoadMemoryDump::Init() {
Action::Init();
pugi::xml_attribute attr = node.attribute("file");
if (!attr) {
attr = node.attribute("filename");
if (!attr) {
error("No file specified in LoadMemoryDump\n");
return EXIT_FAILURE;
}
const auto lattice = solver->getCartLattice();
lattice->loadSolution(attr.value());
return 0;
}
if (node.attribute("comp")) error("Deprecated API call. Use LoadBinary with comp parameter");
solver->lattice->loadSolution(attr.value());
return EXIT_SUCCESS;
}


// Register the handler (basing on xmlname) in the Handler Factory
template class HandlerFactory::Register< GenericAsk< acLoadMemoryDump > >;
template class HandlerFactory::Register<GenericAsk<acLoadMemoryDump> >;
52 changes: 23 additions & 29 deletions src/Handlers/cbSaveBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,30 @@
std::string cbSaveBinary::xmlname = "SaveBinary";
#include "../HandlerFactory.h"

int cbSaveBinary::Init () {
Callback::Init();
pugi::xml_attribute attr = node.attribute("file");
if (!attr) {
attr = node.attribute("filename");
if (!attr) {
fn = solver->outIterFile("Save", "");
} else {
fn = attr.value();
}
} else {
fn = ((std::string) solver->outpath) + "_" + attr.value();
int cbSaveBinary::Init() {
Callback::Init();
auto attr = node.attribute("file");
if (!attr) {
attr = node.attribute("filename");
if (!attr) {
fn = solver->outIterFile("Save", "");
} else {
fn = attr.value();
}
return 0;
}


int cbSaveBinary::DoIt () {
Callback::DoIt();
pugi::xml_attribute attr = node.attribute("comp");
const auto lattice = solver->getCartLattice();
if (attr) {
lattice->saveComp(fn, attr.value());
} else {
lattice->saveSolution(fn);
//error("Missing comp attribute in SaveBinary");
}
return 0;
};
} else {
fn = solver->outpath + "_" + attr.value();
}
return 0;
}

int cbSaveBinary::DoIt() {
Callback::DoIt();
const auto attr = node.attribute("comp");
if (attr) solver->lattice->saveComp(fn, attr.value());
else
solver->lattice->saveSolution(fn);
return EXIT_SUCCESS;
};

// Register the handler (basing on xmlname) in the Handler Factory
template class HandlerFactory::Register< GenericAsk< cbSaveBinary > >;
template class HandlerFactory::Register<GenericAsk<cbSaveBinary> >;
146 changes: 70 additions & 76 deletions src/Handlers/cbSaveCheckpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,102 +2,96 @@
std::string cbSaveCheckpoint::xmlname = "SaveCheckpoint";
#include "../HandlerFactory.h"

int cbSaveCheckpoint::Init () {
Callback::Init();
/*
int cbSaveCheckpoint::Init() {
Callback::Init();
/*
Initialisation of handler checks for keep attribute
inside of SaveCheckpoint. Keep attribute can be used
to save the last "x" number of checkpoints, or specify
"all" to keep all points.
DEFAULT behaviour is to return only the most recent
checkpoint.
*/
pugi::xml_attribute attr = node.attribute("keep");
if (attr) {
if (std::string(attr.value()) == "all"){
// Look for the keyword all and use keep = 0 as flag to store all
keep = 0;
} else {
// If the attr value is not the string all, assume it is an int
keep = attr.as_int();
if ( keep < 0) {
// Check the user hasn't set keep to a negative value
error("Keeping a negative no. of chckpnts not allowed, returning default behaviour.");
keep = 1;
}
}
} else{
keep = 1;
}

return 0;
}
pugi::xml_attribute attr = node.attribute("keep");
if (attr) {
if (std::string(attr.value()) == "all") {
// Look for the keyword all and use keep = 0 as flag to store all
keep = 0;
} else {
// If the attr value is not the string all, assume it is an int
keep = attr.as_int();
if (keep < 0) {
// Check the user hasn't set keep to a negative value
error("Keeping a negative no. of chckpnts not allowed, returning default behaviour.");
keep = 1;
}
}
} else {
keep = 1;
}

return 0;
}

int cbSaveCheckpoint::DoIt () {
Callback::DoIt();
/*
int cbSaveCheckpoint::DoIt() {
Callback::DoIt();
/*
Here we saveSolution to a _x.pri file where x is the MPI rank.
If keep == 0, then we save all solutions. Otherwise, we check
the size of the queue; less than keep then save file, else
delete the first set into the queue
*/
output("writing checkpoint");
const auto lattice = solver->getCartLattice();
const auto filename = solver->outIterCollectiveFile("checkpoint", "");
const auto restartFile = solver->outIterCollectiveFile("restart", ".xml");
auto fileStr = lattice->saveSolution(filename);
std::string restStr;
if (D_MPI_RANK == 0 ) {
writeRestartFile(filename.c_str(), restartFile.c_str());
restStr = restartFile;
}
if (keep != 0){
myqueue.push( fileStr );
myqueue_rst.push( restStr );
if (myqueue.size() > (size_t) keep) {
// myqueue should only ever reach the size of keep
fileStr = myqueue.front();
int rm_result = remove( fileStr.c_str() ); //Takes char
if (rm_result != 0) error("Checkpoint file was not deleted: %s",fileStr.c_str());
myqueue.pop();

if (D_MPI_RANK == 0 ) {
restStr = myqueue_rst.front();
rm_result = remove( restStr.c_str() );
if (rm_result != 0) error("Restart file was not deleted: %s",restStr.c_str());
myqueue_rst.pop();
}
}
}
output("writing checkpoint");
const auto filename = solver->outIterCollectiveFile("checkpoint", "");
const auto restartFile = solver->outIterCollectiveFile("restart", ".xml");
auto fileStr = solver->lattice->saveSolution(filename);
std::string restStr;
if (D_MPI_RANK == 0) {
writeRestartFile(filename.c_str(), restartFile.c_str());
restStr = restartFile;
}
if (keep != 0) {
myqueue.push(fileStr);
myqueue_rst.push(restStr);
if (myqueue.size() > (size_t)keep) {
// myqueue should only ever reach the size of keep
fileStr = myqueue.front();
int rm_result = remove(fileStr.c_str()); //Takes char
if (rm_result != 0) error("Checkpoint file was not deleted: %s", fileStr.c_str());
myqueue.pop();

return 0;
};
if (D_MPI_RANK == 0) {
restStr = myqueue_rst.front();
rm_result = remove(restStr.c_str());
if (rm_result != 0) error("Restart file was not deleted: %s", restStr.c_str());
myqueue_rst.pop();
}
}
}

int cbSaveCheckpoint::writeRestartFile( const char * fn, const char * rf ) {
return 0;
};

pugi::xml_document restartfile;
for (pugi::xml_node n = solver->configfile.first_child(); n; n = n.next_sibling()){
restartfile.append_copy(n);
}
int cbSaveCheckpoint::writeRestartFile(const char* fn, const char* rf) {
pugi::xml_document restartfile;
for (pugi::xml_node n = solver->configfile.first_child(); n; n = n.next_sibling()) { restartfile.append_copy(n); }

pugi::xml_node n1 = restartfile.child("CLBConfig").child("LoadBinary");
if (!n1){
// If it doesn't exist, create it before solve
n1 = restartfile.child("CLBConfig").child("Solve");
pugi::xml_node n2 = restartfile.child("CLBConfig").insert_child_before("LoadBinary", n1);
n2.append_attribute("file").set_value(fn);
} else {
// If it does exist, remove it and replace it with up to date file string
n1.remove_attribute(n1.attribute("file"));
n1.append_attribute("file").set_value(fn);
}
pugi::xml_node n1 = restartfile.child("CLBConfig").child("LoadBinary");
if (!n1) {
// If it doesn't exist, create it before solve
n1 = restartfile.child("CLBConfig").child("Solve");
pugi::xml_node n2 = restartfile.child("CLBConfig").insert_child_before("LoadBinary", n1);
n2.append_attribute("file").set_value(fn);
} else {
// If it does exist, remove it and replace it with up to date file string
n1.remove_attribute(n1.attribute("file"));
n1.append_attribute("file").set_value(fn);
}

restartfile.save_file( rf );
restartfile.save_file(rf);


return 0;
return 0;
}

// Register the handler (basing on xmlname) in the Handler Factory
template class HandlerFactory::Register< GenericAsk< cbSaveCheckpoint > >;
template class HandlerFactory::Register<GenericAsk<cbSaveCheckpoint> >;
Loading

0 comments on commit e02f071

Please sign in to comment.