Actual source code: fnsqrt.c

slepc-3.18.2 2023-01-26
Report Typos and Errors
  1: /*
  2:    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  3:    SLEPc - Scalable Library for Eigenvalue Problem Computations
  4:    Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain

  6:    This file is part of SLEPc.
  7:    SLEPc is distributed under a 2-clause BSD license (see LICENSE).
  8:    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  9: */
 10: /*
 11:    Square root function  sqrt(x)
 12: */

 14: #include <slepc/private/fnimpl.h>
 15: #include <slepcblaslapack.h>

 17: PetscErrorCode FNEvaluateFunction_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
 18: {
 19: #if !defined(PETSC_USE_COMPLEX)
 21: #endif
 22:   *y = PetscSqrtScalar(x);
 23:   return 0;
 24: }

 26: PetscErrorCode FNEvaluateDerivative_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
 27: {
 29: #if !defined(PETSC_USE_COMPLEX)
 31: #endif
 32:   *y = 1.0/(2.0*PetscSqrtScalar(x));
 33:   return 0;
 34: }

 36: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Schur(FN fn,Mat A,Mat B)
 37: {
 38:   PetscBLASInt   n=0;
 39:   PetscScalar    *T;
 40:   PetscInt       m;

 42:   if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
 43:   MatDenseGetArray(B,&T);
 44:   MatGetSize(A,&m,NULL);
 45:   PetscBLASIntCast(m,&n);
 46:   FNSqrtmSchur(fn,n,T,n,PETSC_FALSE);
 47:   MatDenseRestoreArray(B,&T);
 48:   return 0;
 49: }

 51: PetscErrorCode FNEvaluateFunctionMatVec_Sqrt_Schur(FN fn,Mat A,Vec v)
 52: {
 53:   PetscBLASInt   n=0;
 54:   PetscScalar    *T;
 55:   PetscInt       m;
 56:   Mat            B;

 58:   FN_AllocateWorkMat(fn,A,&B);
 59:   MatDenseGetArray(B,&T);
 60:   MatGetSize(A,&m,NULL);
 61:   PetscBLASIntCast(m,&n);
 62:   FNSqrtmSchur(fn,n,T,n,PETSC_TRUE);
 63:   MatDenseRestoreArray(B,&T);
 64:   MatGetColumnVector(B,v,0);
 65:   FN_FreeWorkMat(fn,&B);
 66:   return 0;
 67: }

 69: PetscErrorCode FNEvaluateFunctionMat_Sqrt_DBP(FN fn,Mat A,Mat B)
 70: {
 71:   PetscBLASInt   n=0;
 72:   PetscScalar    *T;
 73:   PetscInt       m;

 75:   if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
 76:   MatDenseGetArray(B,&T);
 77:   MatGetSize(A,&m,NULL);
 78:   PetscBLASIntCast(m,&n);
 79:   FNSqrtmDenmanBeavers(fn,n,T,n,PETSC_FALSE);
 80:   MatDenseRestoreArray(B,&T);
 81:   return 0;
 82: }

 84: PetscErrorCode FNEvaluateFunctionMat_Sqrt_NS(FN fn,Mat A,Mat B)
 85: {
 86:   PetscBLASInt   n=0;
 87:   PetscScalar    *Ba;
 88:   PetscInt       m;

 90:   if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
 91:   MatDenseGetArray(B,&Ba);
 92:   MatGetSize(A,&m,NULL);
 93:   PetscBLASIntCast(m,&n);
 94:   FNSqrtmNewtonSchulz(fn,n,Ba,n,PETSC_FALSE);
 95:   MatDenseRestoreArray(B,&Ba);
 96:   return 0;
 97: }

 99: #define MAXIT 50

101: /*
102:    Computes the principal square root of the matrix A using the
103:    Sadeghi iteration. A is overwritten with sqrtm(A).
104:  */
105: PetscErrorCode FNSqrtmSadeghi(FN fn,PetscBLASInt n,PetscScalar *A,PetscBLASInt ld)
106: {
107:   PetscScalar    *M,*M2,*G,*X=A,*work,work1,sqrtnrm;
108:   PetscScalar    szero=0.0,sone=1.0,smfive=-5.0,s1d16=1.0/16.0;
109:   PetscReal      tol,Mres=0.0,nrm,rwork[1],done=1.0;
110:   PetscInt       i,it;
111:   PetscBLASInt   N,*piv=NULL,info,lwork=0,query=-1,one=1,zero=0;
112:   PetscBool      converged=PETSC_FALSE;
113:   unsigned int   ftz;

115:   N = n*n;
116:   tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
117:   SlepcSetFlushToZero(&ftz);

119:   /* query work size */
120:   PetscCallBLAS("LAPACKgetri",LAPACKgetri_(&n,A,&ld,piv,&work1,&query,&info));
121:   PetscBLASIntCast((PetscInt)PetscRealPart(work1),&lwork);

123:   PetscMalloc5(N,&M,N,&M2,N,&G,lwork,&work,n,&piv);
124:   PetscArraycpy(M,A,N);

126:   /* scale M */
127:   nrm = LAPACKlange_("fro",&n,&n,M,&n,rwork);
128:   if (nrm>1.0) {
129:     sqrtnrm = PetscSqrtReal(nrm);
130:     PetscCallBLAS("LAPACKlascl",LAPACKlascl_("G",&zero,&zero,&nrm,&done,&N,&one,M,&N,&info));
131:     SlepcCheckLapackInfo("lascl",info);
132:     tol *= nrm;
133:   }
134:   PetscInfo(fn,"||A||_F = %g, new tol: %g\n",(double)nrm,(double)tol);

136:   /* X = I */
137:   PetscArrayzero(X,N);
138:   for (i=0;i<n;i++) X[i+i*ld] = 1.0;

140:   for (it=0;it<MAXIT && !converged;it++) {

142:     /* G = (5/16)*I + (1/16)*M*(15*I-5*M+M*M) */
143:     PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M,&ld,M,&ld,&szero,M2,&ld));
144:     PetscCallBLAS("BLASaxpy",BLASaxpy_(&N,&smfive,M,&one,M2,&one));
145:     for (i=0;i<n;i++) M2[i+i*ld] += 15.0;
146:     PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&s1d16,M,&ld,M2,&ld,&szero,G,&ld));
147:     for (i=0;i<n;i++) G[i+i*ld] += 5.0/16.0;

149:     /* X = X*G */
150:     PetscArraycpy(M2,X,N);
151:     PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M2,&ld,G,&ld,&szero,X,&ld));

153:     /* M = M*inv(G*G) */
154:     PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,G,&ld,&szero,M2,&ld));
155:     PetscCallBLAS("LAPACKgetrf",LAPACKgetrf_(&n,&n,M2,&ld,piv,&info));
156:     SlepcCheckLapackInfo("getrf",info);
157:     PetscCallBLAS("LAPACKgetri",LAPACKgetri_(&n,M2,&ld,piv,work,&lwork,&info));
158:     SlepcCheckLapackInfo("getri",info);

160:     PetscArraycpy(G,M,N);
161:     PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,M2,&ld,&szero,M,&ld));

163:     /* check ||I-M|| */
164:     PetscArraycpy(M2,M,N);
165:     for (i=0;i<n;i++) M2[i+i*ld] -= 1.0;
166:     Mres = LAPACKlange_("fro",&n,&n,M2,&n,rwork);
168:     if (Mres<=tol) converged = PETSC_TRUE;
169:     PetscInfo(fn,"it: %" PetscInt_FMT " res: %g\n",it,(double)Mres);
170:     PetscLogFlops(8.0*n*n*n+2.0*n*n+2.0*n*n*n/3.0+4.0*n*n*n/3.0+2.0*n*n*n+2.0*n*n);
171:   }


175:   /* undo scaling */
176:   if (nrm>1.0) PetscCallBLAS("BLASscal",BLASscal_(&N,&sqrtnrm,A,&one));

178:   PetscFree5(M,M2,G,work,piv);
179:   SlepcResetFlushToZero(&ftz);
180:   return 0;
181: }

183: #if defined(PETSC_HAVE_CUDA)
184: #include "../src/sys/classes/fn/impls/cuda/fnutilcuda.h"
185: #include <slepccublas.h>

187: #if defined(PETSC_HAVE_MAGMA)
188: #include <slepcmagma.h>

190: /*
191:  * Matrix square root by Sadeghi iteration. CUDA version.
192:  * Computes the principal square root of the matrix A using the
193:  * Sadeghi iteration. A is overwritten with sqrtm(A).
194:  */
195: PetscErrorCode FNSqrtmSadeghi_CUDAm(FN fn,PetscBLASInt n,PetscScalar *d_A,PetscBLASInt ld)
196: {
197:   PetscScalar        *d_M,*d_M2,*d_G,*d_work,alpha;
198:   const PetscScalar  szero=0.0,sone=1.0,smfive=-5.0,s15=15.0,s1d16=1.0/16.0;
199:   PetscReal          tol,Mres=0.0,nrm,sqrtnrm=1.0;
200:   PetscInt           it,nb,lwork;
201:   PetscBLASInt       *piv,N;
202:   const PetscBLASInt one=1;
203:   PetscBool          converged=PETSC_FALSE;
204:   cublasHandle_t     cublasv2handle;

206:   PetscDeviceInitialize(PETSC_DEVICE_CUDA); /* For CUDA event timers */
207:   PetscCUBLASGetHandle(&cublasv2handle);
208:   SlepcMagmaInit();
209:   N = n*n;
210:   tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;

212:   PetscMalloc1(n,&piv);
213:   cudaMalloc((void **)&d_M,sizeof(PetscScalar)*N);
214:   cudaMalloc((void **)&d_M2,sizeof(PetscScalar)*N);
215:   cudaMalloc((void **)&d_G,sizeof(PetscScalar)*N);

217:   nb = magma_get_xgetri_nb(n);
218:   lwork = nb*n;
219:   cudaMalloc((void **)&d_work,sizeof(PetscScalar)*lwork);
220:   PetscLogGpuTimeBegin();

222:   /* M = A */
223:   cudaMemcpy(d_M,d_A,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice);

225:   /* scale M */
226:   cublasXnrm2(cublasv2handle,N,d_M,one,&nrm);
227:   if (nrm>1.0) {
228:     sqrtnrm = PetscSqrtReal(nrm);
229:     alpha = 1.0/nrm;
230:     cublasXscal(cublasv2handle,N,&alpha,d_M,one);
231:     tol *= nrm;
232:   }
233:   PetscInfo(fn,"||A||_F = %g, new tol: %g\n",(double)nrm,(double)tol);

235:   /* X = I */
236:   cudaMemset(d_A,0,sizeof(PetscScalar)*N);
237:   set_diagonal(n,d_A,ld,sone);

239:   for (it=0;it<MAXIT && !converged;it++) {

241:     /* G = (5/16)*I + (1/16)*M*(15*I-5*M+M*M) */
242:     cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_M,ld,d_M,ld,&szero,d_M2,ld);
243:     cublasXaxpy(cublasv2handle,N,&smfive,d_M,one,d_M2,one);
244:     shift_diagonal(n,d_M2,ld,s15);
245:     cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&s1d16,d_M,ld,d_M2,ld,&szero,d_G,ld);
246:     shift_diagonal(n,d_G,ld,5.0/16.0);

248:     /* X = X*G */
249:     cudaMemcpy(d_M2,d_A,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice);
250:     cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_M2,ld,d_G,ld,&szero,d_A,ld);

252:     /* M = M*inv(G*G) */
253:     cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_G,ld,d_G,ld,&szero,d_M2,ld);
254:     /* magma */
255:     PetscCallMAGMA(magma_xgetrf_gpu,n,n,d_M2,ld,piv);
256:     PetscCallMAGMA(magma_xgetri_gpu,n,d_M2,ld,piv,d_work,lwork);
257:     /* magma */
258:     cudaMemcpy(d_G,d_M,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice);
259:     cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_G,ld,d_M2,ld,&szero,d_M,ld);

261:     /* check ||I-M|| */
262:     cudaMemcpy(d_M2,d_M,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice);
263:     shift_diagonal(n,d_M2,ld,-1.0);
264:     cublasXnrm2(cublasv2handle,N,d_M2,one,&Mres);
266:     if (Mres<=tol) converged = PETSC_TRUE;
267:     PetscInfo(fn,"it: %" PetscInt_FMT " res: %g\n",it,(double)Mres);
268:     PetscLogGpuFlops(8.0*n*n*n+2.0*n*n+2.0*n*n*n/3.0+4.0*n*n*n/3.0+2.0*n*n*n+2.0*n*n);
269:   }


273:   if (nrm>1.0) {
274:     alpha = sqrtnrm;
275:     cublasXscal(cublasv2handle,N,&alpha,d_A,one);
276:   }
277:   PetscLogGpuTimeEnd();

279:   cudaFree(d_M);
280:   cudaFree(d_M2);
281:   cudaFree(d_G);
282:   cudaFree(d_work);
283:   PetscFree(piv);
284:   return 0;
285: }
286: #endif /* PETSC_HAVE_MAGMA */
287: #endif /* PETSC_HAVE_CUDA */

289: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Sadeghi(FN fn,Mat A,Mat B)
290: {
291:   PetscBLASInt   n=0;
292:   PetscScalar    *Ba;
293:   PetscInt       m;

295:   if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
296:   MatDenseGetArray(B,&Ba);
297:   MatGetSize(A,&m,NULL);
298:   PetscBLASIntCast(m,&n);
299:   FNSqrtmSadeghi(fn,n,Ba,n);
300:   MatDenseRestoreArray(B,&Ba);
301:   return 0;
302: }

304: #if defined(PETSC_HAVE_CUDA)
305: PetscErrorCode FNEvaluateFunctionMat_Sqrt_NS_CUDA(FN fn,Mat A,Mat B)
306: {
307:   PetscBLASInt   n=0;
308:   PetscScalar    *Ba;
309:   PetscInt       m;

311:   if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
312:   MatDenseCUDAGetArray(B,&Ba);
313:   MatGetSize(A,&m,NULL);
314:   PetscBLASIntCast(m,&n);
315:   FNSqrtmNewtonSchulz_CUDA(fn,n,Ba,n,PETSC_FALSE);
316:   MatDenseCUDARestoreArray(B,&Ba);
317:   return 0;
318: }

320: #if defined(PETSC_HAVE_MAGMA)
321: PetscErrorCode FNEvaluateFunctionMat_Sqrt_DBP_CUDAm(FN fn,Mat A,Mat B)
322: {
323:   PetscBLASInt   n=0;
324:   PetscScalar    *T;
325:   PetscInt       m;

327:   if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
328:   MatDenseCUDAGetArray(B,&T);
329:   MatGetSize(A,&m,NULL);
330:   PetscBLASIntCast(m,&n);
331:   FNSqrtmDenmanBeavers_CUDAm(fn,n,T,n,PETSC_FALSE);
332:   MatDenseCUDARestoreArray(B,&T);
333:   return 0;
334: }

336: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Sadeghi_CUDAm(FN fn,Mat A,Mat B)
337: {
338:   PetscBLASInt   n=0;
339:   PetscScalar    *Ba;
340:   PetscInt       m;

342:   if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
343:   MatDenseCUDAGetArray(B,&Ba);
344:   MatGetSize(A,&m,NULL);
345:   PetscBLASIntCast(m,&n);
346:   FNSqrtmSadeghi_CUDAm(fn,n,Ba,n);
347:   MatDenseCUDARestoreArray(B,&Ba);
348:   return 0;
349: }
350: #endif /* PETSC_HAVE_MAGMA */
351: #endif /* PETSC_HAVE_CUDA */

353: PetscErrorCode FNView_Sqrt(FN fn,PetscViewer viewer)
354: {
355:   PetscBool      isascii;
356:   char           str[50];
357:   const char     *methodname[] = {
358:                   "Schur method for the square root",
359:                   "Denman-Beavers (product form)",
360:                   "Newton-Schulz iteration",
361:                   "Sadeghi iteration"
362:   };
363:   const int      nmeth=PETSC_STATIC_ARRAY_LENGTH(methodname);

365:   PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&isascii);
366:   if (isascii) {
367:     if (fn->beta==(PetscScalar)1.0) {
368:       if (fn->alpha==(PetscScalar)1.0) PetscViewerASCIIPrintf(viewer,"  square root: sqrt(x)\n");
369:       else {
370:         SlepcSNPrintfScalar(str,sizeof(str),fn->alpha,PETSC_TRUE);
371:         PetscViewerASCIIPrintf(viewer,"  square root: sqrt(%s*x)\n",str);
372:       }
373:     } else {
374:       SlepcSNPrintfScalar(str,sizeof(str),fn->beta,PETSC_TRUE);
375:       if (fn->alpha==(PetscScalar)1.0) PetscViewerASCIIPrintf(viewer,"  square root: %s*sqrt(x)\n",str);
376:       else {
377:         PetscViewerASCIIPrintf(viewer,"  square root: %s",str);
378:         PetscViewerASCIIUseTabs(viewer,PETSC_FALSE);
379:         SlepcSNPrintfScalar(str,sizeof(str),fn->alpha,PETSC_TRUE);
380:         PetscViewerASCIIPrintf(viewer,"*sqrt(%s*x)\n",str);
381:         PetscViewerASCIIUseTabs(viewer,PETSC_TRUE);
382:       }
383:     }
384:     if (fn->method<nmeth) PetscViewerASCIIPrintf(viewer,"  computing matrix functions with: %s\n",methodname[fn->method]);
385:   }
386:   return 0;
387: }

389: SLEPC_EXTERN PetscErrorCode FNCreate_Sqrt(FN fn)
390: {
391:   fn->ops->evaluatefunction          = FNEvaluateFunction_Sqrt;
392:   fn->ops->evaluatederivative        = FNEvaluateDerivative_Sqrt;
393:   fn->ops->evaluatefunctionmat[0]    = FNEvaluateFunctionMat_Sqrt_Schur;
394:   fn->ops->evaluatefunctionmat[1]    = FNEvaluateFunctionMat_Sqrt_DBP;
395:   fn->ops->evaluatefunctionmat[2]    = FNEvaluateFunctionMat_Sqrt_NS;
396:   fn->ops->evaluatefunctionmat[3]    = FNEvaluateFunctionMat_Sqrt_Sadeghi;
397: #if defined(PETSC_HAVE_CUDA)
398:   fn->ops->evaluatefunctionmatcuda[2] = FNEvaluateFunctionMat_Sqrt_NS_CUDA;
399: #if defined(PETSC_HAVE_MAGMA)
400:   fn->ops->evaluatefunctionmatcuda[1] = FNEvaluateFunctionMat_Sqrt_DBP_CUDAm;
401:   fn->ops->evaluatefunctionmatcuda[3] = FNEvaluateFunctionMat_Sqrt_Sadeghi_CUDAm;
402: #endif /* PETSC_HAVE_MAGMA */
403: #endif /* PETSC_HAVE_CUDA */
404:   fn->ops->evaluatefunctionmatvec[0] = FNEvaluateFunctionMatVec_Sqrt_Schur;
405:   fn->ops->view                      = FNView_Sqrt;
406:   return 0;
407: }