diff -Naur a/include/El/blas_like/level3.hpp b/include/El/blas_like/level3.hpp --- a/include/El/blas_like/level3.hpp 2017-06-08 07:30:43.180249917 -0700 +++ b/include/El/blas_like/level3.hpp 2017-06-08 07:35:27.325434602 -0700 @@ -31,6 +31,10 @@ } using namespace GemmAlgorithmNS; +void GemmUseGPU(int min_M, int min_N, int min_K); + +void GemmUseCPU(); + template void Gemm ( Orientation orientA, Orientation orientB, diff -Naur a/include/El/core/imports/blas.hpp b/include/El/core/imports/blas.hpp --- a/include/El/core/imports/blas.hpp 2017-06-08 07:30:43.522016908 -0700 +++ b/include/El/core/imports/blas.hpp 2017-06-08 07:35:06.834030908 -0700 @@ -916,4 +916,63 @@ } // namespace blas } // namespace El + +#if defined(EL_USE_CUBLAS) + +namespace El { + +#ifdef EL_USE_64BIT_BLAS_INTS +typedef long long int BlasInt; +#else +typedef int BlasInt; +#endif + +namespace cublas { + +// NOTE: templated routines are custom and not wrappers + +// Level 3 BLAS +// ============ +template +void Gemm +( char transA, char transB, BlasInt m, BlasInt n, BlasInt k, + const T& alpha, + const T* A, BlasInt ALDim, + const T* B, BlasInt BLDim, + const T& beta, + T* C, BlasInt CLDim ); + +void Gemm +( char transA, char transB, BlasInt m, BlasInt n, BlasInt k, + const float& alpha, + const float* A, BlasInt ALDim, + const float* B, BlasInt BLDim, + const float& beta, + float* C, BlasInt CLDim ); +void Gemm +( char transA, char transB, BlasInt m, BlasInt n, BlasInt k, + const double& alpha, + const double* A, BlasInt ALDim, + const double* B, BlasInt BLDim, + const double& beta, + double* C, BlasInt CLDim ); +void Gemm +( char transA, char transB, BlasInt m, BlasInt n, BlasInt k, + const scomplex& alpha, + const scomplex* A, BlasInt ALDim, + const scomplex* B, BlasInt BLDim, + const scomplex& beta, + scomplex* C, BlasInt CLDim ); +void Gemm +( char transA, char transB, BlasInt m, BlasInt n, BlasInt k, + const dcomplex& alpha, + const dcomplex* A, BlasInt ALDim, + const dcomplex* B, BlasInt BLDim, + const dcomplex& beta, + dcomplex* C, BlasInt CLDim ); + +} // namespace cublas +} // namespace El +#endif + #endif // ifndef EL_IMPORTS_BLAS_DECL_HPP diff -Naur a/src/blas_like/level3/Gemm.cpp b/src/blas_like/level3/Gemm.cpp --- a/src/blas_like/level3/Gemm.cpp 2017-06-08 07:30:44.307096427 -0700 +++ b/src/blas_like/level3/Gemm.cpp 2017-06-08 07:34:23.062863489 -0700 @@ -16,6 +16,20 @@ namespace El { +char gemm_cpu_gpu_switch = 'c'; +int min_M = 0, min_N = 0, min_K = 0; + +void GemmUseGPU(int _min_M, int _min_N, int _min_K) { + gemm_cpu_gpu_switch = 'g'; + min_M = _min_M; + min_N = _min_N; + min_K = _min_K; +} + +void GemmUseCPU() { + gemm_cpu_gpu_switch = 'c'; +} + template void Gemm ( Orientation orientA, Orientation orientB, @@ -59,11 +73,30 @@ const Int k = ( orientA == NORMAL ? A.Width() : A.Height() ); if( k != 0 ) { +#if defined(EL_USE_CUBLAS) + if (gemm_cpu_gpu_switch == 'g' && + m >= min_M && + n >= min_N && + k >= min_K) { + cublas::Gemm + ( transA, transB, m, n, k, + alpha, A.LockedBuffer(), A.LDim(), + B.LockedBuffer(), B.LDim(), + beta, C.Buffer(), C.LDim() ); + } else { + blas::Gemm + ( transA, transB, m, n, k, + alpha, A.LockedBuffer(), A.LDim(), + B.LockedBuffer(), B.LDim(), + beta, C.Buffer(), C.LDim() ); + } +#else blas::Gemm ( transA, transB, m, n, k, alpha, A.LockedBuffer(), A.LDim(), B.LockedBuffer(), B.LDim(), beta, C.Buffer(), C.LDim() ); +#endif } else { diff -Naur a/src/core/imports/blas/Gemm.hpp b/src/core/imports/blas/Gemm.hpp --- a/src/core/imports/blas/Gemm.hpp 2017-06-08 07:30:45.090529967 -0700 +++ b/src/core/imports/blas/Gemm.hpp 2017-06-08 07:34:46.503009958 -0700 @@ -41,6 +41,12 @@ } // extern "C" + +#if defined(EL_USE_CUBLAS) +#include +#include +#endif + namespace El { namespace blas { @@ -515,3 +521,515 @@ } // namespace blas } // namespace El + + +#if EL_USE_CUBLAS + +#define USE_CUB 1 + +namespace El { +namespace cublas { + +#if USE_CUB +cub::CachingDeviceAllocator g_allocator(true); // Caching allocator for device memory +#endif + +template +void Gemm +( char transA, char transB, + BlasInt m, BlasInt n, BlasInt k, + const T& alpha, + const T* A, BlasInt ALDim, + const T* B, BlasInt BLDim, + const T& beta, + T* C, BlasInt CLDim ) +{ + // put something here + printf("integer version \n"); +} +template void Gemm +( char transA, char transB, + BlasInt m, BlasInt n, BlasInt k, + const Int& alpha, + const Int* A, BlasInt ALDim, + const Int* B, BlasInt BLDim, + const Int& beta, + Int* C, BlasInt CLDim ); +#ifdef EL_HAVE_QD +template void Gemm +( char transA, char transB, + BlasInt m, BlasInt n, BlasInt k, + const DoubleDouble& alpha, + const DoubleDouble* A, BlasInt ALDim, + const DoubleDouble* B, BlasInt BLDim, + const DoubleDouble& beta, + DoubleDouble* C, BlasInt CLDim ); +template void Gemm +( char transA, char transB, + BlasInt m, BlasInt n, BlasInt k, + const QuadDouble& alpha, + const QuadDouble* A, BlasInt ALDim, + const QuadDouble* B, BlasInt BLDim, + const QuadDouble& beta, + QuadDouble* C, BlasInt CLDim ); +template void Gemm +( char transA, char transB, + BlasInt m, BlasInt n, BlasInt k, + const Complex& alpha, + const Complex* A, BlasInt ALDim, + const Complex* B, BlasInt BLDim, + const Complex& beta, + Complex* C, BlasInt CLDim ); +template void Gemm +( char transA, char transB, + BlasInt m, BlasInt n, BlasInt k, + const Complex& alpha, + const Complex* A, BlasInt ALDim, + const Complex* B, BlasInt BLDim, + const Complex& beta, + Complex* C, BlasInt CLDim ); +#endif +#ifdef EL_HAVE_QUAD +template void Gemm +( char transA, char transB, + BlasInt m, BlasInt n, BlasInt k, + const Quad& alpha, + const Quad* A, BlasInt ALDim, + const Quad* B, BlasInt BLDim, + const Quad& beta, + Quad* C, BlasInt CLDim ); +template void Gemm +( char transA, char transB, + BlasInt m, BlasInt n, BlasInt k, + const Complex& alpha, + const Complex* A, BlasInt ALDim, + const Complex* B, BlasInt BLDim, + const Complex& beta, + Complex* C, BlasInt CLDim ); +#endif +#ifdef EL_HAVE_MPC +template void Gemm +( char transA, char transB, + BlasInt m, BlasInt n, BlasInt k, + const BigInt& alpha, + const BigInt* A, BlasInt ALDim, + const BigInt* B, BlasInt BLDim, + const BigInt& beta, + BigInt* C, BlasInt CLDim ); +template void Gemm +( char transA, char transB, + BlasInt m, BlasInt n, BlasInt k, + const BigFloat& alpha, + const BigFloat* A, BlasInt ALDim, + const BigFloat* B, BlasInt BLDim, + const BigFloat& beta, + BigFloat* C, BlasInt CLDim ); +template void Gemm +( char transA, char transB, + BlasInt m, BlasInt n, BlasInt k, + const Complex& alpha, + const Complex* A, BlasInt ALDim, + const Complex* B, BlasInt BLDim, + const Complex& beta, + Complex* C, BlasInt CLDim ); +#endif + +void Gemm +( char transA, char transB, + BlasInt m, BlasInt n, BlasInt k, + const float& alpha, + const float* A, BlasInt ALDim, + const float* B, BlasInt BLDim, + const float& beta, + float* C, BlasInt CLDim ) +{ + EL_DEBUG_CSE + EL_DEBUG_ONLY( + if( std::toupper(transA) == 'N' ) + { + if( ALDim < Max(m,1) ) + LogicError("ALDim was too small: ALDim=",ALDim,",m=",m); + } + else + { + if( ALDim < Max(k,1) ) + LogicError("ALDim was too small: ALDim=",ALDim,",k=",k); + } + + if( std::toupper(transB) == 'N' ) + { + if( BLDim < Max(k,1) ) + LogicError("BLDim was too small: BLDim=",BLDim,",k=",k); + } + else + { + if( BLDim < Max(n,1) ) + LogicError("BLDim was too small: BLDim=",BLDim,",n=",n); + } + + if( CLDim < Max(m,1) ) + LogicError("CLDim was too small: CLDim=",CLDim,",m=",m); + ) + const char fixedTransA = ( std::toupper(transA) == 'C' ? 'T' : transA ); + const char fixedTransB = ( std::toupper(transB) == 'C' ? 'T' : transB ); + + const mpi::Comm comm; + const Int commRank = mpi::Rank( comm ); + if (commRank == 0) { + //printf("calling cublas Sgemm: m %d n %d k %d\n", m, n, k); + } + + BlasInt rowA, colA, rowB, colB, rowC, colC; + // device memory size for A, B and C + BlasInt sizeA, sizeB, sizeC; + float *devA=NULL, *devB=NULL, *devC=NULL; + + rowA = fixedTransA == 'T' ? k : m; + colA = fixedTransA == 'T' ? m : k; + rowB = fixedTransB == 'T' ? n : k; + colB = fixedTransB == 'T' ? k : n; + rowC = m; + colC = n; + sizeA = rowA * colA; + sizeB = rowB * colB; + sizeC = rowC * colC; + + cublasStatus stat; + +#if USE_CUB + CubDebugExit(g_allocator.DeviceAllocate((void**)&devA, + sizeof(float) * (sizeA+sizeB+sizeC) )); +#else + stat = cublasAlloc(sizeA+sizeB+sizeC, sizeof(float), (void **) &devA); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("Alloc A,B,C error\n"); } +#endif + + devB = devA + sizeA; + devC = devB + sizeB; + + // copy matrix A, B and C to device + stat = cublasSetMatrix(rowA, colA, sizeof(float), A, ALDim, devA, rowA); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("SetMatrix A error\n"); } + + stat = cublasSetMatrix(rowB, colB, sizeof(float), B, BLDim, devB, rowB); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("SetMatrix B error\n"); } + + if (beta != 0.0) + { + stat = cublasSetMatrix(rowC, colC, sizeof(float), C, CLDim, devC, rowC); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("SetMatrix C error\n"); } + } + + // cublasgemm + cublasSgemm + ( fixedTransA, fixedTransB, m, n, k, + alpha, devA, rowA, devB, rowB, beta, devC, rowC ); + + // copy matrix C to host + stat = cublasGetMatrix(rowC, colC, sizeof(float), devC, rowC, C, CLDim); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("GetMatrix C error\n"); } + + // free +#if USE_CUB + CubDebugExit(g_allocator.DeviceFree(devA)); +#else + cublasFree(devA); +#endif + //printf("CUBLAS float done ...\n"); +} + +void Gemm +( char transA, char transB, + BlasInt m, BlasInt n, BlasInt k, + const double& alpha, + const double* A, BlasInt ALDim, + const double* B, BlasInt BLDim, + const double& beta, + double* C, BlasInt CLDim ) +{ + EL_DEBUG_CSE + EL_DEBUG_ONLY( + if( std::toupper(transA) == 'N' ) + { + if( ALDim < Max(m,1) ) + LogicError("ALDim was too small: ALDim=",ALDim,",m=",m); + } + else + { + if( ALDim < Max(k,1) ) + LogicError("ALDim was too small: ALDim=",ALDim,",k=",k); + } + + if( std::toupper(transB) == 'N' ) + { + if( BLDim < Max(k,1) ) + LogicError("BLDim was too small: BLDim=",BLDim,",k=",k); + } + else + { + if( BLDim < Max(n,1) ) + LogicError("BLDim was too small: BLDim=",BLDim,",n=",n); + } + + if( CLDim < Max(m,1) ) + LogicError("CLDim was too small: CLDim=",CLDim,",m=",m); + ) + const char fixedTransA = ( std::toupper(transA) == 'C' ? 'T' : transA ); + const char fixedTransB = ( std::toupper(transB) == 'C' ? 'T' : transB ); + + const mpi::Comm comm; + const Int commRank = mpi::Rank( comm ); + if (commRank == 0) { + //printf("calling cublas Dgemm: m %d n %d k %d\n", m, n, k); + } + + BlasInt rowA, colA, rowB, colB, rowC, colC; + // device memory size for A, B and C + BlasInt sizeA, sizeB, sizeC; + double *devA=NULL, *devB=NULL, *devC=NULL; + + rowA = fixedTransA == 'T' ? k : m; + colA = fixedTransA == 'T' ? m : k; + rowB = fixedTransB == 'T' ? n : k; + colB = fixedTransB == 'T' ? k : n; + rowC = m; + colC = n; + sizeA = rowA * colA; + sizeB = rowB * colB; + sizeC = rowC * colC; + + cublasStatus stat; + +#if USE_CUB + CubDebugExit(g_allocator.DeviceAllocate((void**)&devA, + sizeof(double) * (sizeA+sizeB+sizeC) )); +#else + stat = cublasAlloc(sizeA+sizeB+sizeC, sizeof(double), (void **) &devA); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("Alloc A,B,C error\n"); } +#endif + + devB = devA + sizeA; + devC = devB + sizeB; + + // copy matrix A, B and C to device + stat = cublasSetMatrix(rowA, colA, sizeof(double), A, ALDim, devA, rowA); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("SetMatrix A error\n"); } + + stat = cublasSetMatrix(rowB, colB, sizeof(double), B, BLDim, devB, rowB); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("SetMatrix B error\n"); } + + if (beta != 0.0) + { + stat = cublasSetMatrix(rowC, colC, sizeof(double), C, CLDim, devC, rowC); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("SetMatrix C error\n"); } + } + + // cublasgemm + cublasDgemm + ( fixedTransA, fixedTransB, m, n, k, + alpha, devA, rowA, devB, rowB, beta, devC, rowC ); + + // copy matrix C to host + stat = cublasGetMatrix(rowC, colC, sizeof(double), devC, rowC, C, CLDim); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("GetMatrix C error\n"); } + + // free +#if USE_CUB + CubDebugExit(g_allocator.DeviceFree(devA)); +#else + cublasFree(devA); +#endif + //printf("CUBLAS double done ...\n"); +} + +void Gemm +( char transA, char transB, BlasInt m, BlasInt n, BlasInt k, + const scomplex& alpha, + const scomplex* A, BlasInt ALDim, + const scomplex* B, BlasInt BLDim, + const scomplex& beta, + scomplex* C, BlasInt CLDim ) +{ + EL_DEBUG_CSE + EL_DEBUG_ONLY( + if( std::toupper(transA) == 'N' ) + { + if( ALDim < Max(m,1) ) + LogicError("ALDim was too small: ALDim=",ALDim,",m=",m); + } + else + { + if( ALDim < Max(k,1) ) + LogicError("ALDim was too small: ALDim=",ALDim,",k=",k); + } + + if( std::toupper(transB) == 'N' ) + { + if( BLDim < Max(k,1) ) + LogicError("BLDim was too small: BLDim=",BLDim,",k=",k); + } + else + { + if( BLDim < Max(n,1) ) + LogicError("BLDim was too small: BLDim=",BLDim,",n=",n); + } + + if( CLDim < Max(m,1) ) + LogicError("CLDim was too small: CLDim=",CLDim,",m=",m); + ) + + const char fixedTransA = transA; + const char fixedTransB = transB; + + const mpi::Comm comm; + const Int commRank = mpi::Rank( comm ); + if (commRank == 0) { + //printf("calling cublas Cgemm: m %d n %d k %d\n", m, n, k); + } + + BlasInt rowA, colA, rowB, colB, rowC, colC; + // device memory size for A, B and C + BlasInt sizeA, sizeB, sizeC; + cuComplex *devA=NULL, *devB=NULL, *devC=NULL; + + rowA = fixedTransA == 'T' ? k : m; + colA = fixedTransA == 'T' ? m : k; + rowB = fixedTransB == 'T' ? n : k; + colB = fixedTransB == 'T' ? k : n; + rowC = m; + colC = n; + sizeA = rowA * colA; + sizeB = rowB * colB; + sizeC = rowC * colC; + + cublasStatus stat; + stat = cublasAlloc(sizeA+sizeB+sizeC, sizeof(cuComplex), (void **) &devA); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("Alloc A,B,C error\n"); } + + devB = devA + sizeA; + devC = devB + sizeB; + + // copy matrix A, B and C to device + stat = cublasSetMatrix(rowA, colA, sizeof(cuComplex), A, ALDim, devA, rowA); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("SetMatrix A error\n"); } + + stat = cublasSetMatrix(rowB, colB, sizeof(cuComplex), B, BLDim, devB, rowB); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("SetMatrix B error\n"); } + + if (beta.real() != 0.0 || beta.imag() != 0.0) + { + stat = cublasSetMatrix(rowC, colC, sizeof(cuComplex), C, CLDim, devC, rowC); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("SetMatrix C error\n"); } + } + + // cublasgemm + cublasCgemm + ( fixedTransA, fixedTransB, m, n, k, + *((cuComplex*) &alpha), devA, rowA, devB, rowB, *((cuComplex*) &beta), devC, rowC ); + + // copy matrix C to host + stat = cublasGetMatrix(rowC, colC, sizeof(cuComplex), devC, rowC, C, CLDim); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("GetMatrix C error\n"); } + + // free + cublasFree(devA); +} + +void Gemm +( char transA, char transB, BlasInt m, BlasInt n, BlasInt k, + const dcomplex& alpha, + const dcomplex* A, BlasInt ALDim, + const dcomplex* B, BlasInt BLDim, + const dcomplex& beta, + dcomplex* C, BlasInt CLDim ) +{ + EL_DEBUG_CSE + EL_DEBUG_ONLY( + if( std::toupper(transA) == 'N' ) + { + if( ALDim < Max(m,1) ) + LogicError("ALDim was too small: ALDim=",ALDim,",m=",m); + } + else + { + if( ALDim < Max(k,1) ) + LogicError("ALDim was too small: ALDim=",ALDim,",k=",k); + } + + if( std::toupper(transB) == 'N' ) + { + if( BLDim < Max(k,1) ) + LogicError("BLDim was too small: BLDim=",BLDim,",k=",k); + } + else + { + if( BLDim < Max(n,1) ) + LogicError("BLDim was too small: BLDim=",BLDim,",n=",n); + } + + if( CLDim < Max(m,1) ) + LogicError("CLDim was too small: CLDim=",CLDim,",m=",m); + ) + + const char fixedTransA = transA; + const char fixedTransB = transB; + + const mpi::Comm comm; + const Int commRank = mpi::Rank( comm ); + if (commRank == 0) { + //printf("calling cublas Zgemm: m %d n %d k %d\n", m, n, k); + } + + BlasInt rowA, colA, rowB, colB, rowC, colC; + // device memory size for A, B and C + BlasInt sizeA, sizeB, sizeC; + cuDoubleComplex *devA=NULL, *devB=NULL, *devC=NULL; + + rowA = fixedTransA == 'T' ? k : m; + colA = fixedTransA == 'T' ? m : k; + rowB = fixedTransB == 'T' ? n : k; + colB = fixedTransB == 'T' ? k : n; + rowC = m; + colC = n; + sizeA = rowA * colA; + sizeB = rowB * colB; + sizeC = rowC * colC; + + cublasStatus stat; + stat = cublasAlloc(sizeA+sizeB+sizeC, sizeof(cuDoubleComplex), (void **) &devA); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("Alloc A,B,C error\n"); } + + devB = devA + sizeA; + devC = devB + sizeB; + + // copy matrix A, B and C to device + stat = cublasSetMatrix(rowA, colA, sizeof(cuDoubleComplex), A, ALDim, devA, rowA); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("SetMatrix A error\n"); } + + stat = cublasSetMatrix(rowB, colB, sizeof(cuDoubleComplex), B, BLDim, devB, rowB); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("SetMatrix B error\n"); } + + if (beta.real() != 0.0 || beta.imag() != 0.0) + { + stat = cublasSetMatrix(rowC, colC, sizeof(cuDoubleComplex), C, CLDim, devC, rowC); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("SetMatrix C error\n"); } + } + + cublasZgemm + ( fixedTransA, fixedTransB, m, n, k, + *((cuDoubleComplex*) &alpha), devA, rowA, devB, rowB, *((cuDoubleComplex*) &beta), + devC, rowC ); + + // copy matrix C to host + stat = cublasGetMatrix(rowC, colC, sizeof(cuDoubleComplex), devC, rowC, C, CLDim); + if (stat != CUBLAS_STATUS_SUCCESS) { RuntimeError("GetMatrix C error\n"); } + + // free + cublasFree(devA); +} + +} // namespace cublas +} // namespace El + +#endif +