Skip to content

Commit

Permalink
Pass MPI communicator for flexible parallelization
Browse files Browse the repository at this point in the history
Co-authored-by: Mario Malcolms de Oliveira <[email protected]>
  • Loading branch information
hmenke and mmalcolms committed Aug 1, 2022
1 parent 3c5e366 commit 694a733
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 56 deletions.
24 changes: 12 additions & 12 deletions c++/triqs_tprf/lattice/fourier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace triqs_tprf {
}

template <typename Gf_type>
auto fourier_wr_to_tr_general_target(Gf_type g_wr, int n_tau = -1) {
auto fourier_wr_to_tr_general_target(Gf_type g_wr, int n_tau = -1, mpi::communicator const &c = {}) {

auto _ = all_t{};
// Get rid of structured binding declarations in this file due to issue #11
Expand All @@ -48,7 +48,7 @@ auto fourier_wr_to_tr_general_target(Gf_type g_wr, int n_tau = -1) {
auto r0 = *rmesh.begin();
auto p = _fourier_plan<0>(gf_const_view(g_wr[_, r0]), gf_view(g_tr[_, r0]));

auto r_arr = mpi_view(rmesh);
auto r_arr = mpi_view(rmesh, c);

#pragma omp parallel for
for (unsigned int idx = 0; idx < r_arr.size(); idx++) {
Expand All @@ -63,12 +63,12 @@ auto fourier_wr_to_tr_general_target(Gf_type g_wr, int n_tau = -1) {

g_tr[_, r] = g_t;
}
g_tr = mpi::all_reduce(g_tr);
g_tr = mpi::all_reduce(g_tr, c);
return g_tr;
}

template <typename Gf_type>
auto fourier_tr_to_wr_general_target(Gf_type g_tr, int n_w = -1) {
auto fourier_tr_to_wr_general_target(Gf_type g_tr, int n_w = -1, mpi::communicator const &c = {}) {

auto _ = all_t{};
//auto [tmesh, rmesh] = g_tr.mesh();
Expand All @@ -81,7 +81,7 @@ auto fourier_tr_to_wr_general_target(Gf_type g_tr, int n_w = -1) {
auto r0 = *rmesh.begin();
auto p = _fourier_plan<0>(gf_const_view(g_tr[_, r0]), gf_view(g_wr[_, r0]));

auto r_arr = mpi_view(rmesh);
auto r_arr = mpi_view(rmesh, c);

#pragma omp parallel for
for (unsigned int idx = 0; idx < r_arr.size(); idx++) {
Expand All @@ -96,12 +96,12 @@ auto fourier_tr_to_wr_general_target(Gf_type g_tr, int n_w = -1) {

g_wr[_, r] = g_w;
}
g_wr = mpi::all_reduce(g_wr);
g_wr = mpi::all_reduce(g_wr, c);
return g_wr;
}

template <typename Gf_type>
auto fourier_wk_to_wr_general_target(Gf_type g_wk) {
auto fourier_wk_to_wr_general_target(Gf_type g_wk, mpi::communicator const &c = {}) {

auto _ = all_t{};

Expand All @@ -116,7 +116,7 @@ auto fourier_wk_to_wr_general_target(Gf_type g_wk) {
auto w0 = *wmesh.begin();
auto p = _fourier_plan<0>(gf_const_view(g_wk[w0, _]), gf_view(g_wr[w0, _]));

auto w_arr = mpi_view(wmesh);
auto w_arr = mpi_view(wmesh, c);

#pragma omp parallel for
for (unsigned int idx = 0; idx < w_arr.size(); idx++) {
Expand All @@ -131,12 +131,12 @@ auto fourier_wk_to_wr_general_target(Gf_type g_wk) {

g_wr[w, _] = g_r;
}
g_wr = mpi::all_reduce(g_wr);
g_wr = mpi::all_reduce(g_wr, c);
return g_wr;
}

template <typename Gf_type>
auto fourier_wr_to_wk_general_target(Gf_type g_wr) {
auto fourier_wr_to_wk_general_target(Gf_type g_wr, mpi::communicator const &c = {}) {

auto _ = all_t{};

Expand All @@ -150,7 +150,7 @@ auto fourier_wr_to_wk_general_target(Gf_type g_wr) {
auto w0 = *wmesh.begin();
auto p = _fourier_plan<0>(gf_const_view(g_wr[w0, _]), gf_view(g_wk[w0, _]));

auto w_arr = mpi_view(wmesh);
auto w_arr = mpi_view(wmesh, c);

#pragma omp parallel for
for (unsigned int idx = 0; idx < w_arr.size(); idx++) {
Expand All @@ -165,7 +165,7 @@ auto fourier_wr_to_wk_general_target(Gf_type g_wr) {

g_wk[w, _] = g_k;
}
g_wk = mpi::all_reduce(g_wk);
g_wk = mpi::all_reduce(g_wk, c);
return g_wk;
}

Expand Down
44 changes: 22 additions & 22 deletions c++/triqs_tprf/lattice/gf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,28 +38,28 @@ namespace triqs_tprf {
// ----------------------------------------------------
// g

g_wk_t lattice_dyson_g0_wk(double mu, e_k_cvt e_k, gf_mesh<imfreq> mesh) {
g_wk_t lattice_dyson_g0_wk(double mu, e_k_cvt e_k, gf_mesh<imfreq> mesh, mpi::communicator const &c) {

auto I = nda::eye<ek_vt::scalar_t>(e_k.target_shape()[0]);
g_wk_t g0_wk({mesh, e_k.mesh()}, e_k.target_shape());
g0_wk() = 0.0;

auto arr = mpi_view(g0_wk.mesh());
auto arr = mpi_view(g0_wk.mesh(), c);

#pragma omp parallel for
for (unsigned int idx = 0; idx < arr.size(); idx++) {
auto &[w, k] = arr(idx);
g0_wk[w, k] = inverse((w + mu)*I - e_k(k));
}

g0_wk = mpi::all_reduce(g0_wk);
g0_wk = mpi::all_reduce(g0_wk, c);
return g0_wk;
}

// ----------------------------------------------------

template<typename sigma_t>
auto lattice_dyson_g_generic(double mu, e_k_cvt e_k, sigma_t sigma){
auto lattice_dyson_g_generic(double mu, e_k_cvt e_k, sigma_t sigma, mpi::communicator const &c){

auto &freqmesh = [&sigma]() -> auto & {
if constexpr (sigma_t::arity == 1) return sigma.mesh();
Expand All @@ -72,7 +72,7 @@ auto lattice_dyson_g_generic(double mu, e_k_cvt e_k, sigma_t sigma){
g_wk_t g_wk({freqmesh, e_k.mesh()}, e_k.target_shape());
g_wk() = 0.0;

auto arr = mpi_view(g_wk.mesh());
auto arr = mpi_view(g_wk.mesh(), c);
#pragma omp parallel for
for (unsigned int idx = 0; idx < arr.size(); idx++) {
auto &[w, k] = arr(idx);
Expand All @@ -84,60 +84,60 @@ auto lattice_dyson_g_generic(double mu, e_k_cvt e_k, sigma_t sigma){
g_wk[w, k] = inverse((w + mu)*I - e_k(k) - sigmaterm);
}

g_wk = mpi::all_reduce(g_wk);
g_wk = mpi::all_reduce(g_wk, c);
return g_wk;
}


g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_wk_cvt sigma_wk) {
return lattice_dyson_g_generic(mu, e_k, sigma_wk);
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_wk_cvt sigma_wk, mpi::communicator const &c) {
return lattice_dyson_g_generic(mu, e_k, sigma_wk, c);
}


g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_w_cvt sigma_w) {
return lattice_dyson_g_generic(mu, e_k, sigma_w);
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_w_cvt sigma_w, mpi::communicator const &c) {
return lattice_dyson_g_generic(mu, e_k, sigma_w, c);
}


g_w_t lattice_dyson_g_w(double mu, e_k_cvt e_k, g_w_cvt sigma_w) {
g_w_t lattice_dyson_g_w(double mu, e_k_cvt e_k, g_w_cvt sigma_w, mpi::communicator const &c) {

auto g_wk = lattice_dyson_g_generic(mu, e_k, sigma_w);
auto g_wk = lattice_dyson_g_generic(mu, e_k, sigma_w, c);
auto &[wmesh, kmesh] = g_wk.mesh();

g_w_t g_w(wmesh, e_k.target_shape());
g_w() = 0.0;

for (auto const &[w, k] : mpi_view(g_wk.mesh()))
for (auto const &[w, k] : mpi_view(g_wk.mesh(), c))
g_w[w] += g_wk[w, k];

g_w = mpi::all_reduce(g_w);
g_w = mpi::all_reduce(g_w, c);
g_w /= kmesh.size();
return g_w;
}

// ----------------------------------------------------
// Transformations: real space <-> reciprocal space

g_wr_t fourier_wk_to_wr(g_wk_cvt g_wk) {
auto g_wr = fourier_wk_to_wr_general_target(g_wk);
g_wr_t fourier_wk_to_wr(g_wk_cvt g_wk, mpi::communicator const &c) {
auto g_wr = fourier_wk_to_wr_general_target(g_wk, c);
return g_wr;
}

g_wk_t fourier_wr_to_wk(g_wr_cvt g_wr) {
auto g_wk = fourier_wr_to_wk_general_target(g_wr);
g_wk_t fourier_wr_to_wk(g_wr_cvt g_wr, mpi::communicator const &c) {
auto g_wk = fourier_wr_to_wk_general_target(g_wr, c);
return g_wk;
}

// ----------------------------------------------------
// Transformations: Matsubara frequency <-> imaginary time

g_wr_t fourier_tr_to_wr(g_tr_cvt g_tr, int nw) {
auto g_wr = fourier_tr_to_wr_general_target(g_tr, nw);
g_wr_t fourier_tr_to_wr(g_tr_cvt g_tr, int nw, mpi::communicator const &c) {
auto g_wr = fourier_tr_to_wr_general_target(g_tr, nw, c);
return g_wr;
}

g_tr_t fourier_wr_to_tr(g_wr_cvt g_wr, int nt) {
auto g_tr = fourier_wr_to_tr_general_target(g_wr, nt);
g_tr_t fourier_wr_to_tr(g_wr_cvt g_wr, int nt, mpi::communicator const &c) {
auto g_tr = fourier_wr_to_tr_general_target(g_wr, nt, c);
return g_tr;
}

Expand Down
16 changes: 8 additions & 8 deletions c++/triqs_tprf/lattice/gf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ namespace triqs_tprf {
@param mesh imaginary frequency mesh
@return Matsubara frequency lattice Green's function $G^{(0)}_{a\bar{b}}(i\omega_n, \mathbf{k})$
*/
g_wk_t lattice_dyson_g0_wk(double mu, e_k_cvt e_k, mesh::imfreq mesh);
g_wk_t lattice_dyson_g0_wk(double mu, e_k_cvt e_k, mesh::imfreq mesh, mpi::communicator const &c = {});

/** Construct an interacting Matsubara frequency lattice Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{k})`
Expand All @@ -61,7 +61,7 @@ g_wk_t lattice_dyson_g0_wk(double mu, e_k_cvt e_k, mesh::imfreq mesh);
@param sigma_w imaginary frequency self-energy :math:`\Sigma_{\bar{a}b}(i\omega_n)`
@return Matsubara frequency lattice Green's function $G_{a\bar{b}}(i\omega_n, \mathbf{k})$
*/
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_w_cvt sigma_w);
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_w_cvt sigma_w, mpi::communicator const &c = {});

/** Construct an interacting Matsubara frequency lattice Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{k})`
Expand All @@ -81,7 +81,7 @@ g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_w_cvt sigma_w);
@param sigma_wk imaginary frequency self-energy :math:`\Sigma_{\bar{a}b}(i\omega_n, \mathbf{k})`
@return Matsubara frequency lattice Green's function $G_{a\bar{b}}(i\omega_n, \mathbf{k})$
*/
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_wk_cvt sigma_wk);
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_wk_cvt sigma_wk, mpi::communicator const &c = {});

/** Construct an interacting Matsubara frequency local (:math:`\mathbf{r}=\mathbf{0}`) lattice Green's function :math:`G_{a\bar{b}}(i\omega_n)`
Expand All @@ -101,7 +101,7 @@ g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_wk_cvt sigma_wk);
@param sigma_w imaginary frequency self-energy :math:`\Sigma_{\bar{a}b}(i\omega_n)`
@return Matsubara frequency lattice Green's function $G_{a\bar{b}}(i\omega_n, \mathbf{k})$
*/
g_w_t lattice_dyson_g_w(double mu, e_k_cvt e_k, g_w_cvt sigma_w);
g_w_t lattice_dyson_g_w(double mu, e_k_cvt e_k, g_w_cvt sigma_w, mpi::communicator const &c = {});

/** Inverse fast fourier transform of imaginary frequency Green's function from k-space to real space
Expand All @@ -110,7 +110,7 @@ g_w_t lattice_dyson_g_w(double mu, e_k_cvt e_k, g_w_cvt sigma_w);
@param g_wk k-space imaginary frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{k})`
@return real-space imaginary frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{r})`
*/
g_wr_t fourier_wk_to_wr(g_wk_cvt g_wk);
g_wr_t fourier_wk_to_wr(g_wk_cvt g_wk, mpi::communicator const &c = {});

/** Fast fourier transform of imaginary frequency Green's function from real-space to k-space
Expand All @@ -119,7 +119,7 @@ g_wr_t fourier_wk_to_wr(g_wk_cvt g_wk);
@param g_wr real-space imaginary frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{r})`
@return k-space imaginary frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{k})`
*/
g_wk_t fourier_wr_to_wk(g_wr_cvt g_wr);
g_wk_t fourier_wr_to_wk(g_wr_cvt g_wr, mpi::communicator const &c = {});

/** Fast fourier transform of real-space Green's function from Matsubara frequency to imaginary time
Expand All @@ -128,7 +128,7 @@ g_wk_t fourier_wr_to_wk(g_wr_cvt g_wr);
@param g_wr real-space imaginary frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{r})`
@return real-space imaginary time Green's function :math:`G_{a\bar{b}}(\tau, \mathbf{r})`
*/
g_tr_t fourier_wr_to_tr(g_wr_cvt g_wr, int nt=-1);
g_tr_t fourier_wr_to_tr(g_wr_cvt g_wr, int nt=-1, mpi::communicator const &c = {});

/** Fast fourier transform of real-space Green's function from imaginary time to Matsubara frequency
Expand All @@ -137,6 +137,6 @@ g_tr_t fourier_wr_to_tr(g_wr_cvt g_wr, int nt=-1);
@param g_tr real-space imaginary time Green's function :math:`G_{a\bar{b}}(\tau, \mathbf{r})`
@return real-space Matsubara frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{r})`
*/
g_wr_t fourier_tr_to_wr(g_tr_cvt g_tr, int nw=-1);
g_wr_t fourier_tr_to_wr(g_tr_cvt g_tr, int nw=-1, mpi::communicator const &c = {});

} // namespace triqs_tprf
16 changes: 2 additions & 14 deletions c++/triqs_tprf/mpi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
namespace triqs_tprf {

template<class T>
auto mpi_view(const array<T, 1> &arr, mpi::communicator const & c) {
auto mpi_view(const array<T, 1> &arr, mpi::communicator const &c = {}) {

auto slice = itertools::chunk_range(0, arr.shape()[0], c.size(), c.rank());

Expand All @@ -42,13 +42,7 @@ auto mpi_view(const array<T, 1> &arr, mpi::communicator const & c) {
}

template<class T>
auto mpi_view(const array<T, 1> &arr) {
mpi::communicator c;
return mpi_view(arr, c);
}

template<class T>
auto mpi_view(const T &mesh, mpi::communicator const & c) {
auto mpi_view(const T &mesh, mpi::communicator const &c = {}) {

auto slice = itertools::chunk_range(0, mesh.size(), c.size(), c.rank());
int size = slice.second - slice.first;
Expand Down Expand Up @@ -78,11 +72,5 @@ auto mpi_view(const T &mesh, mpi::communicator const & c) {

return arr;
}

template<class T>
auto mpi_view(const T &mesh) {
mpi::communicator c;
return mpi_view(mesh, c);
}

} // namespace triqs_tprf

0 comments on commit 694a733

Please sign in to comment.