Skip to content

Feature: cusolver support for LCAO kpar #6139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
13 changes: 13 additions & 0 deletions source/module_base/parallel_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,44 +21,55 @@ void Parallel_Common::bcast_string(std::string& object) // Peize Lin fix bug 201
}

MPI_Bcast(&object[0], size, MPI_CHAR, 0, MPI_COMM_WORLD);
MPI_Barrier(MPI_COMM_WORLD);
return;
}

void Parallel_Common::bcast_string(std::string* object, const int n) // Peize Lin fix bug 2019-03-18
{
for (int i = 0; i < n; i++)
{
bcast_string(object[i]);
}
MPI_Barrier(MPI_COMM_WORLD);
return;
}

void Parallel_Common::bcast_complex_double(std::complex<double>& object)
{
MPI_Bcast(&object, 1, MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD);
MPI_Barrier(MPI_COMM_WORLD);

}

void Parallel_Common::bcast_complex_double(std::complex<double>* object, const int n)
{
MPI_Bcast(object, n, MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD);
MPI_Barrier(MPI_COMM_WORLD);
}

void Parallel_Common::bcast_double(double& object)
{
MPI_Bcast(&object, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD);
MPI_Barrier(MPI_COMM_WORLD);
}

void Parallel_Common::bcast_double(double* object, const int n)
{
MPI_Bcast(object, n, MPI_DOUBLE, 0, MPI_COMM_WORLD);
MPI_Barrier(MPI_COMM_WORLD);
}

void Parallel_Common::bcast_int(int& object)
{
MPI_Bcast(&object, 1, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Barrier(MPI_COMM_WORLD);
}

void Parallel_Common::bcast_int(int* object, const int n)
{
MPI_Bcast(object, n, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Barrier(MPI_COMM_WORLD);
}

void Parallel_Common::bcast_bool(bool& object)
Expand All @@ -69,13 +80,15 @@ void Parallel_Common::bcast_bool(bool& object)
if (my_rank == 0)
swap = object;
MPI_Bcast(&swap, 1, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Barrier(MPI_COMM_WORLD);
if (my_rank != 0)
object = static_cast<bool>(swap);
}

void Parallel_Common::bcast_char(char* object, const int n)
{
MPI_Bcast(object, n, MPI_CHAR, 0, MPI_COMM_WORLD);
MPI_Barrier(MPI_COMM_WORLD);
}

#endif
18 changes: 18 additions & 0 deletions source/module_hsolver/diago_cusolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,24 @@ static void distributePsi(const int* desc_psi, T* psi, T* psi_g)
Cpxgemr2d(nrows, ncols, psi_g, 1, 1, descg, psi, 1, 1, descl, ctxt);
}

template <typename T>
void DiagoCusolver<T>::diag_pool(hamilt::MatrixBlock<T>& h_mat,
hamilt::MatrixBlock<T>& s_mat,
psi::Psi<T>& psi,
Real* eigenvalue_in,
MPI_Comm& comm)
{
ModuleBase::TITLE("DiagoCusolver", "diag_pool");
ModuleBase::timer::tick("DiagoCusolver", "diag_pool");
std::vector<double> eigen(PARAM.globalv.nlocal, 0.0);
std::vector<T> eigenvectors(h_mat.row * h_mat.col);
this->dc.Dngvd(h_mat.row, h_mat.col, h_mat.p, s_mat.p, eigen.data(), eigenvectors.data());
const int size = psi.get_nbands() * psi.get_nbasis();
BlasConnector::copy(size, eigenvectors.data(), 1, psi.get_pointer(), 1);
BlasConnector::copy(PARAM.inp.nbands, eigen.data(), 1, eigenvalue_in, 1);
ModuleBase::timer::tick("DiagoCusolver", "diag_pool");
}

// Diagonalization function
template <typename T>
void DiagoCusolver<T>::diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in)
Expand Down
7 changes: 7 additions & 0 deletions source/module_hsolver/diago_cusolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ class DiagoCusolver
DiagoCusolver(const Parallel_Orbitals* ParaV = nullptr);
~DiagoCusolver();

void diag_pool(
hamilt::MatrixBlock<T>& h_mat,
hamilt::MatrixBlock<T>& s_mat,
psi::Psi<T>& psi,
Real* eigenvalue_in,
MPI_Comm& comm);

// Override the diag function for CUSOLVER diagonalization
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in);

Expand Down
45 changes: 38 additions & 7 deletions source/module_hsolver/hsolver_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ void HSolverLCAO<T, Device>::solve(hamilt::Hamilt<T>* pHamilt,
if (this->method != "pexsi")
{
if (PARAM.globalv.kpar_lcao > 1
&& (this->method == "genelpa" || this->method == "elpa" || this->method == "scalapack_gvx"))
&& (this->method == "genelpa" || this->method == "elpa"
|| this->method == "scalapack_gvx" || this->method == "cusolver"))
{
#ifdef __MPI
this->parakSolve(pHamilt, psi, pes, PARAM.globalv.kpar_lcao);
Expand Down Expand Up @@ -191,11 +192,20 @@ void HSolverLCAO<T, Device>::parakSolve(hamilt::Hamilt<T>* pHamilt,
int nks = psi.get_nk();
int nrow = this->ParaV->get_global_row_size();
int nb2d = this->ParaV->get_block_size();
k2d.set_para_env(psi.get_nk(), nrow, nb2d, GlobalV::NPROC, GlobalV::MY_RANK, PARAM.inp.nspin);
if(this->method == "cusolver")
{
k2d.set_para_env_cusolver(psi.get_nk(), nrow, nb2d, GlobalV::NPROC, GlobalV::MY_RANK, PARAM.inp.nspin);
} else
{
k2d.set_para_env(psi.get_nk(), nrow, nb2d, GlobalV::NPROC, GlobalV::MY_RANK, PARAM.inp.nspin);
}
/// set psi_pool
const int zero = 0;
int ncol_bands_pool
= numroc_(&(nbands), &(nb2d), &(k2d.get_p2D_pool()->coord[1]), &zero, &(k2d.get_p2D_pool()->dim1));
int ncol_bands_pool = 0;
if(k2d.is_in_pool())
{
ncol_bands_pool = numroc_(&(nbands), &(nb2d), &(k2d.get_p2D_pool()->coord[1]), &zero, &(k2d.get_p2D_pool()->dim1));
}
/// Loop over k points for solve Hamiltonian to charge density
for (int ik = 0; ik < k2d.get_pKpoints()->get_max_nks_pool(); ++ik)
{
Expand All @@ -222,9 +232,12 @@ void HSolverLCAO<T, Device>::parakSolve(hamilt::Hamilt<T>* pHamilt,
}
}
k2d.distribute_hsk(pHamilt, ik_kpar, nrow);
auto psi_pool = psi::Psi<T>();
if(k2d.is_in_pool())
{
/// global index of k point
int ik_global = ik + k2d.get_pKpoints()->startk_pool[k2d.get_my_pool()];
auto psi_pool = psi::Psi<T>(1, ncol_bands_pool, k2d.get_p2D_pool()->nrow, k2d.get_p2D_pool()->nrow, true);
psi_pool = psi::Psi<T>(1, ncol_bands_pool, k2d.get_p2D_pool()->nrow, k2d.get_p2D_pool()->nrow, true);
ModuleBase::Memory::record("HSolverLCAO::psi_pool", nrow * ncol_bands_pool * sizeof(T));
if (ik_global < psi.get_nk() && ik < k2d.get_pKpoints()->nks_pool[k2d.get_my_pool()])
{
Expand Down Expand Up @@ -255,21 +268,39 @@ void HSolverLCAO<T, Device>::parakSolve(hamilt::Hamilt<T>* pHamilt,
DiagoElpaNative<T> el;
el.diag_pool(hk_pool, sk_pool, psi_pool, &(pes->ekb(ik_global, 0)), k2d.POOL_WORLD_K2D);
}
#endif
#ifdef __CUDA
else if (this->method == "cusolver")
{
DiagoCusolver<T> cs(nullptr);
cs.diag_pool(hk_pool, sk_pool, psi_pool, &(pes->ekb(ik_global, 0)), k2d.POOL_WORLD_K2D);
}
#endif
else
{
ModuleBase::WARNING_QUIT("HSolverLCAO::solve",
"This type of eigensolver for k-parallelism diagnolization is not supported!");
}
}
}
MPI_Barrier(MPI_COMM_WORLD);
ModuleBase::timer::tick("HSolverLCAO", "collect_psi");
for (int ipool = 0; ipool < ik_kpar.size(); ++ipool)
{
int source = k2d.get_pKpoints()->get_startpro_pool(ipool);
int source = 0;
if(this->method != "cusolver")
{
source = k2d.get_pKpoints()->get_startpro_pool(ipool);
} else
{
source = ipool;
}
MPI_Bcast(&(pes->ekb(ik_kpar[ipool], 0)), nbands, MPI_DOUBLE, source, MPI_COMM_WORLD);
int desc_pool[9];
std::copy(k2d.get_p2D_pool()->desc, k2d.get_p2D_pool()->desc + 9, desc_pool);
if(k2d.is_in_pool())
{
std::copy(k2d.get_p2D_pool()->desc, k2d.get_p2D_pool()->desc + 9, desc_pool);
}
if (k2d.get_my_pool() != ipool)
{
desc_pool[1] = -1;
Expand Down
4 changes: 4 additions & 0 deletions source/module_hsolver/kernels/cuda/diag_cusolver.cu
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#include <assert.h>
#include "diag_cusolver.cuh"
#include "helper_cuda.h"
#include "module_base/module_device/device.h"

Diag_Cusolver_gvd::Diag_Cusolver_gvd(){
// step 1: create cusolver/cublas handle
#if defined(__MPI) && defined(__CUDA)
base_device::information::set_device_by_rank();
#endif
cusolverH = NULL;
checkCudaErrors( cusolverDnCreate(&cusolverH) );

Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/kernels/cuda/diag_cusolver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Diag_Cusolver_gvd{
//-------------------

cusolverDnHandle_t cusolverH = nullptr;

cusolverEigType_t itype = CUSOLVER_EIG_TYPE_1; //problem type: A*x = (lambda)*B*x
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR; // compute eigenvalues and eigenvectors.
cublasFillMode_t uplo = CUBLAS_FILL_MODE_LOWER;
Expand Down
56 changes: 51 additions & 5 deletions source/module_hsolver/parallel_k2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,47 @@ void Parallel_K2D<TK>::set_para_env(int nks,
this->P2D_pool->init(nw, nw, nb2d, this->POOL_WORLD_K2D);
}

template <typename TK>
void Parallel_K2D<TK>::set_para_env_cusolver(int nks,
const int& nw,
const int& nb2d,
const int& nproc,
const int& my_rank,
const int& nspin) {
const int kpar = this->get_kpar();
if(kpar <= 0 || kpar > nproc)
{
ModuleBase::WARNING_QUIT("Parallel_K2D::set_para_env_cusolver",
"kpar must be greater than 0 and less than nproc.");
}
const int pool_id = (my_rank < kpar) ? my_rank : MPI_UNDEFINED;
MPI_Comm_split(MPI_COMM_WORLD,
pool_id,
0,
&this->POOL_WORLD_K2D);
this->P2D_global = new Parallel_2D;
this->P2D_global->init(nw, nw, nb2d, MPI_COMM_WORLD);
if(this->POOL_WORLD_K2D != MPI_COMM_NULL)
{
this->MY_POOL = pool_id;
this->RANK_IN_POOL = 0;
this->NPROC_IN_POOL = 1;
this->in_pool = true;
this->P2D_pool = new Parallel_2D;
this->P2D_pool->init(nw, nw, nb2d, this->POOL_WORLD_K2D);
}
else
{
this->in_pool = false;
this->MY_POOL = -1;
this->RANK_IN_POOL = -1;
this->NPROC_IN_POOL = 0;
}
this->Pkpoints = new Parallel_Kpoints;
this->Pkpoints
->kinfo(nks, kpar, this->MY_POOL, this->RANK_IN_POOL, nproc, nspin);
}

template <typename TK>
void Parallel_K2D<TK>::distribute_hsk(hamilt::Hamilt<TK>* pHamilt,
const std::vector<int>& ik_kpar,
Expand All @@ -45,13 +86,16 @@ void Parallel_K2D<TK>::distribute_hsk(hamilt::Hamilt<TK>* pHamilt,
pHamilt->updateHk(ik_kpar[ipool]);
hamilt::MatrixBlock<TK> HK_global, SK_global;
pHamilt->matrix(HK_global, SK_global);
if (this->MY_POOL == this->Pkpoints->whichpool[ik_kpar[ipool]]) {
if (this->in_pool && this->MY_POOL == this->Pkpoints->whichpool[ik_kpar[ipool]]) {
this->hk_pool.resize(this->P2D_pool->get_local_size(), 0.0);
this->sk_pool.resize(this->P2D_pool->get_local_size(), 0.0);
}
int desc_pool[9];
std::copy(this->P2D_pool->desc, this->P2D_pool->desc + 9, desc_pool);
if (this->MY_POOL != this->Pkpoints->whichpool[ik_kpar[ipool]]) {
if(this->in_pool)
{
std::copy(this->P2D_pool->desc, this->P2D_pool->desc + 9, desc_pool);
}
if ( !this->in_pool || this->MY_POOL != this->Pkpoints->whichpool[ik_kpar[ipool]]) {
desc_pool[1] = -1;
}
Cpxgemr2d(nw,
Expand All @@ -77,7 +121,6 @@ void Parallel_K2D<TK>::distribute_hsk(hamilt::Hamilt<TK>* pHamilt,
desc_pool,
this->P2D_global->blacs_ctxt);
}
ModuleBase::Memory::record("Parallel_K2D::hsk_pool", this->P2D_pool->get_local_size() * 2 * sizeof(TK));
ModuleBase::timer::tick("Parallel_K2D", "distribute_hsk");
MPI_Barrier(MPI_COMM_WORLD);
#endif
Expand All @@ -97,7 +140,10 @@ void Parallel_K2D<TK>::unset_para_env() {
delete this->P2D_pool;
this->P2D_pool = nullptr;
}
MPI_Comm_free(&this->POOL_WORLD_K2D);
if(this->POOL_WORLD_K2D != MPI_COMM_NULL)
{
MPI_Comm_free(&this->POOL_WORLD_K2D);
}
}

template <typename TK>
Expand Down
34 changes: 22 additions & 12 deletions source/module_hsolver/parallel_k2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,27 @@ template <typename TK>
class Parallel_K2D {
public:
/// private constructor
Parallel_K2D() {}
Parallel_K2D() {};
/// private destructor
~Parallel_K2D() {}
~Parallel_K2D() {};
/**
* Public member functions
*/
/// this function sets the parallel environment for k-points parallelism
/// including the glabal and pool 2D parallel distribution
void set_para_env(int nks,
const int& nw,
const int& nb2d,
const int& nproc,
const int& my_rank,
const int& nspin);
const int& nw,
const int& nb2d,
const int& nproc,
const int& my_rank,
const int& nspin);

void set_para_env_cusolver(int nks,
const int& nw,
const int& nb2d,
const int& nproc,
const int& my_rank,
const int& nspin);

/// this function distributes the Hk and Sk matrices to hk_pool and sk_pool
void distribute_hsk(hamilt::Hamilt<TK>* pHamilt,
Expand All @@ -44,15 +51,17 @@ class Parallel_K2D {
/// set the number of k-points
void set_kpar(int kpar);
/// get the number of k-points
int get_kpar() { return this->kpar_; }
int get_kpar() { return this->kpar_; };
/// get my pool
int get_my_pool() { return this->MY_POOL; }
int get_my_pool() { return this->MY_POOL; };
/// check if this proc is in any pool
bool is_in_pool() { return this->in_pool; };
/// get pKpoints
Parallel_Kpoints* get_pKpoints() { return this->Pkpoints; }
Parallel_Kpoints* get_pKpoints() { return this->Pkpoints; };
/// get p2D_global
Parallel_2D* get_p2D_global() { return this->P2D_global; }
Parallel_2D* get_p2D_global() { return this->P2D_global; };
/// get p2D_pool
Parallel_2D* get_p2D_pool() { return this->P2D_pool; }
Parallel_2D* get_p2D_pool() { return this->P2D_pool; };

/**
* the local Hk, Sk matrices in POOL_WORLD_K2D
Expand All @@ -69,6 +78,7 @@ class Parallel_K2D {
* Private member variables
*/
int kpar_ = 0;
bool in_pool = true;

/**
* mpi info
Expand Down
Loading