38 #if defined(GMM_USES_BLAS) || defined(GMM_USES_LAPACK)
40 #ifndef GMM_BLAS_INTERFACE_H
41 #define GMM_BLAS_INTERFACE_H
51 #define GMMLAPACK_TRACE(f)
54 #if defined(WeirdNEC) || defined(GMM_USE_BLAS64_INTERFACE)
150 # define BLAS_S float
151 # define BLAS_D double
152 # define BLAS_C std::complex<float>
153 # define BLAS_Z std::complex<double>
156 #if defined(GMM_BLAS_RETURN_COMPLEX_AS_ARGUMENT)
157 # define BLAS_CPLX_FUNC_CALL(blasname, res, ...) blasname(&res, __VA_ARGS__)
159 # define BLAS_CPLX_FUNC_CALL(blasname, res, ...) res = blasname(__VA_ARGS__)
166 void daxpy_(
const BLAS_INT *n,
const double *alpha,
const double *x,
167 const BLAS_INT *incx,
double *y,
const BLAS_INT *incy);
168 void dgemm_(
const char *tA,
const char *tB,
const BLAS_INT *m,
169 const BLAS_INT *n,
const BLAS_INT *k,
const double *alpha,
170 const double *A,
const BLAS_INT *ldA,
const double *B,
171 const BLAS_INT *ldB,
const double *beta,
double *C,
172 const BLAS_INT *ldC);
173 void sgemm_(...);
void cgemm_(...);
void zgemm_(...);
174 void sgemv_(...);
void dgemv_(...);
void cgemv_(...);
void zgemv_(...);
175 void strsv_(...);
void dtrsv_(...);
void ctrsv_(...);
void ztrsv_(...);
176 void saxpy_(...);
void caxpy_(...);
void zaxpy_(...);
177 BLAS_S sdot_ (...); BLAS_D ddot_ (...);
178 BLAS_C cdotu_(...); BLAS_Z zdotu_(...);
180 BLAS_C cdotc_(...); BLAS_Z zdotc_(...);
181 BLAS_S snrm2_(...); BLAS_D dnrm2_(...);
182 BLAS_S scnrm2_(...); BLAS_D dznrm2_(...);
183 void sger_(...);
void dger_(...);
void cgerc_(...);
void zgerc_(...);
191 # define nrm2_interface(blas_name, base_type) \
192 inline number_traits<base_type>::magnitude_type \
193 vect_norm2(const std::vector<base_type> &x) { \
194 GMMLAPACK_TRACE("nrm2_interface"); \
195 BLAS_INT inc(1), n(BLAS_INT(vect_size(x))); \
196 return blas_name(&n, &x[0], &inc); \
199 nrm2_interface(snrm2_, BLAS_S)
200 nrm2_interface(dnrm2_, BLAS_D)
201 nrm2_interface(scnrm2_, BLAS_C)
202 nrm2_interface(dznrm2_, BLAS_Z)
208 # define dot_interface(funcname, msg, blas_name, base_type) \
209 inline base_type funcname(const std::vector<base_type> &x, \
210 const std::vector<base_type> &y) { \
211 GMMLAPACK_TRACE(msg); \
212 BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
213 return blas_name(&n, &x[0], &inc, &y[0], &inc); \
215 inline base_type funcname \
216 (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
217 const std::vector<base_type> &y) { \
218 GMMLAPACK_TRACE(msg); \
219 const std::vector<base_type> &x = *(linalg_origin(x_)); \
221 BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
222 return a * blas_name(&n, &x[0], &inc, &y[0], &inc); \
224 inline base_type funcname \
225 (const std::vector<base_type> &x, \
226 const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
227 GMMLAPACK_TRACE(msg); \
228 const std::vector<base_type> &y = *(linalg_origin(y_)); \
230 BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
231 return b * blas_name(&n, &x[0], &inc, &y[0], &inc); \
233 inline base_type funcname \
234 (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
235 const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
236 GMMLAPACK_TRACE(msg); \
237 const std::vector<base_type> &x = *(linalg_origin(x_)); \
238 const std::vector<base_type> &y = *(linalg_origin(y_)); \
239 base_type a(x_.r), b(y_.r); \
240 BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
241 return a*b * blas_name(&n, &x[0], &inc, &y[0], &inc); \
244 dot_interface(vect_sp,
"dot_interface", sdot_, BLAS_S)
245 dot_interface(vect_sp,
"dot_interface", ddot_, BLAS_D)
246 dot_interface(vect_hp,
"dotc_interface", sdot_, BLAS_S)
247 dot_interface(vect_hp,
"dotc_interface", ddot_, BLAS_D)
255 # define dot_interface_cplx(funcname, msg, blas_name, base_type, bdef) \
256 inline base_type funcname(const std::vector<base_type> &x, \
257 const std::vector<base_type> &y) { \
258 GMMLAPACK_TRACE(msg); \
260 BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
261 BLAS_CPLX_FUNC_CALL(blas_name, res, &n, &y[0], &inc, &x[0], &inc); \
264 inline base_type funcname \
265 (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
266 const std::vector<base_type> &y) { \
267 GMMLAPACK_TRACE(msg); \
268 const std::vector<base_type> &x = *(linalg_origin(x_)); \
269 base_type res, a(x_.r); \
270 BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
271 BLAS_CPLX_FUNC_CALL(blas_name, res, &n, &y[0], &inc, &x[0], &inc); \
274 inline base_type funcname \
275 (const std::vector<base_type> &x, \
276 const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
277 GMMLAPACK_TRACE(msg); \
278 const std::vector<base_type> &y = *(linalg_origin(y_)); \
279 base_type res, b(bdef); \
280 BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
281 BLAS_CPLX_FUNC_CALL(blas_name, res, &n, &y[0], &inc, &x[0], &inc); \
284 inline base_type funcname \
285 (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
286 const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
287 GMMLAPACK_TRACE(msg); \
288 const std::vector<base_type> &x = *(linalg_origin(x_)); \
289 const std::vector<base_type> &y = *(linalg_origin(y_)); \
290 base_type res, a(x_.r), b(bdef); \
291 BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
292 BLAS_CPLX_FUNC_CALL(blas_name, res, &n, &y[0], &inc, &x[0], &inc); \
296 dot_interface_cplx(vect_sp,
"dot_interface", cdotu_, BLAS_C, y_.r)
297 dot_interface_cplx(vect_sp, "dot_interface", zdotu_, BLAS_Z, y_.r)
298 dot_interface_cplx(vect_hp, "dotc_interface", cdotc_, BLAS_C, gmm::conj(y_.r))
299 dot_interface_cplx(vect_hp, "dotc_interface", zdotc_, BLAS_Z, gmm::conj(y_.r))
305 template<
size_type N, class V1, class V2>
306 inline
void add_fixed(const V1 &x, V2 &y)
308 for(
size_type i = 0; i != N; ++i) y[i] += x[i];
311 template<
class V1,
class V2>
312 inline void add_for_short_vectors(
const V1 &x, V2 &y,
size_type n)
316 case 1: add_fixed<1>(x, y);
break;
317 case 2: add_fixed<2>(x, y);
break;
318 case 3: add_fixed<3>(x, y);
break;
319 case 4: add_fixed<4>(x, y);
break;
320 case 5: add_fixed<5>(x, y);
break;
321 case 6: add_fixed<6>(x, y);
break;
322 case 7: add_fixed<7>(x, y);
break;
323 case 8: add_fixed<8>(x, y);
break;
324 case 9: add_fixed<9>(x, y);
break;
325 case 10: add_fixed<10>(x, y);
break;
326 case 11: add_fixed<11>(x, y);
break;
327 case 12: add_fixed<12>(x, y);
break;
328 case 13: add_fixed<13>(x, y);
break;
329 case 14: add_fixed<14>(x, y);
break;
330 case 15: add_fixed<15>(x, y);
break;
331 case 16: add_fixed<16>(x, y);
break;
332 case 17: add_fixed<17>(x, y);
break;
333 case 18: add_fixed<18>(x, y);
break;
334 case 19: add_fixed<19>(x, y);
break;
335 case 20: add_fixed<20>(x, y);
break;
336 case 21: add_fixed<21>(x, y);
break;
337 case 22: add_fixed<22>(x, y);
break;
338 case 23: add_fixed<23>(x, y);
break;
339 case 24: add_fixed<24>(x, y);
break;
341 GMM_ASSERT2(
false,
"add_for_short_vectors used with unsupported size");
346 template<
size_type N,
class V1,
class V2,
class T>
347 inline void add_fixed(
const V1 &x, V2 &y,
const T &a)
349 for(
size_type i = 0; i != N; ++i) y[i] += a*x[i];
352 template<
class V1,
class V2,
class T>
353 inline void add_for_short_vectors(
const V1 &x, V2 &y,
const T &a,
size_type n)
357 case 1: add_fixed<1>(x, y, a);
break;
358 case 2: add_fixed<2>(x, y, a);
break;
359 case 3: add_fixed<3>(x, y, a);
break;
360 case 4: add_fixed<4>(x, y, a);
break;
361 case 5: add_fixed<5>(x, y, a);
break;
362 case 6: add_fixed<6>(x, y, a);
break;
363 case 7: add_fixed<7>(x, y, a);
break;
364 case 8: add_fixed<8>(x, y, a);
break;
365 case 9: add_fixed<9>(x, y, a);
break;
366 case 10: add_fixed<10>(x, y, a);
break;
367 case 11: add_fixed<11>(x, y, a);
break;
368 case 12: add_fixed<12>(x, y, a);
break;
369 case 13: add_fixed<13>(x, y, a);
break;
370 case 14: add_fixed<14>(x, y, a);
break;
371 case 15: add_fixed<15>(x, y, a);
break;
372 case 16: add_fixed<16>(x, y, a);
break;
373 case 17: add_fixed<17>(x, y, a);
break;
374 case 18: add_fixed<18>(x, y, a);
break;
375 case 19: add_fixed<19>(x, y, a);
break;
376 case 20: add_fixed<20>(x, y, a);
break;
377 case 21: add_fixed<21>(x, y, a);
break;
378 case 22: add_fixed<22>(x, y, a);
break;
379 case 23: add_fixed<23>(x, y, a);
break;
380 case 24: add_fixed<24>(x, y, a);
break;
382 GMM_ASSERT2(
false,
"add_for_short_vectors used with unsupported size");
388 # define axpy_interface(blas_name, base_type) \
389 inline void add(const std::vector<base_type> &x, \
390 std::vector<base_type> &y) { \
391 GMMLAPACK_TRACE("axpy_interface"); \
392 BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); base_type a(1); \
394 else if(n < 25) add_for_short_vectors(x, y, n); \
395 else blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \
398 axpy_interface(saxpy_, BLAS_S)
399 axpy_interface(daxpy_, BLAS_D)
400 axpy_interface(caxpy_, BLAS_C)
401 axpy_interface(zaxpy_, BLAS_Z)
404 # define axpy2_interface(blas_name, base_type) \
406 (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
407 std::vector<base_type> &y) { \
408 GMMLAPACK_TRACE("axpy_interface"); \
409 BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
410 const std::vector<base_type>& x = *(linalg_origin(x_)); \
413 else if(n < 25) add_for_short_vectors(x, y, a, n); \
414 else blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \
417 axpy2_interface(saxpy_, BLAS_S)
418 axpy2_interface(daxpy_, BLAS_D)
419 axpy2_interface(caxpy_, BLAS_C)
420 axpy2_interface(zaxpy_, BLAS_Z)
427 # define gemv_interface(param1, trans1, param2, trans2, blas_name, \
429 inline void mult_add_spec(param1(base_type), param2(base_type), \
430 std::vector<base_type> &z, orien) { \
431 GMMLAPACK_TRACE("gemv_interface"); \
432 trans1(base_type); trans2(base_type); base_type beta(1); \
433 BLAS_INT m(BLAS_INT(mat_nrows(A))), lda(m); \
434 BLAS_INT n(BLAS_INT(mat_ncols(A))), inc(1); \
435 if (m && n) blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, \
436 &beta, &z[0], &inc); \
437 else gmm::clear(z); \
441 # define gem_p1_n(base_type) const dense_matrix<base_type> &A
442 # define gem_trans1_n(base_type) const char t = 'N'
443 # define gem_p1_t(base_type) \
444 const transposed_col_ref<dense_matrix<base_type> *> &A_
445 # define gem_trans1_t(base_type) const dense_matrix<base_type> &A = \
446 *(linalg_origin(A_)); \
448 # define gem_p1_tc(base_type) \
449 const transposed_col_ref<const dense_matrix<base_type> *> &A_
450 # define gem_p1_c(base_type) \
451 const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_
452 # define gem_trans1_c(base_type) const dense_matrix<base_type> &A = \
453 *(linalg_origin(A_)); \
457 # define gemv_p2_n(base_type) const std::vector<base_type> &x
458 # define gemv_trans2_n(base_type) base_type alpha(1)
459 # define gemv_p2_s(base_type) \
460 const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_
461 # define gemv_trans2_s(base_type) const std::vector<base_type> &x = \
462 (*(linalg_origin(x_))); \
463 base_type alpha(x_.r)
466 gemv_interface(gem_p1_n, gem_trans1_n,
467 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, col_major)
468 gemv_interface(gem_p1_n, gem_trans1_n,
469 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, col_major)
470 gemv_interface(gem_p1_n, gem_trans1_n,
471 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, col_major)
472 gemv_interface(gem_p1_n, gem_trans1_n,
473 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, col_major)
476 gemv_interface(gem_p1_t, gem_trans1_t,
477 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
478 gemv_interface(gem_p1_t, gem_trans1_t,
479 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
480 gemv_interface(gem_p1_t, gem_trans1_t,
481 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
482 gemv_interface(gem_p1_t, gem_trans1_t,
483 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
486 gemv_interface(gem_p1_tc, gem_trans1_t,
487 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
488 gemv_interface(gem_p1_tc, gem_trans1_t,
489 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
490 gemv_interface(gem_p1_tc, gem_trans1_t,
491 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
492 gemv_interface(gem_p1_tc, gem_trans1_t,
493 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
496 gemv_interface(gem_p1_c, gem_trans1_c,
497 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
498 gemv_interface(gem_p1_c, gem_trans1_c,
499 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
500 gemv_interface(gem_p1_c, gem_trans1_c,
501 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
502 gemv_interface(gem_p1_c, gem_trans1_c,
503 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
506 gemv_interface(gem_p1_n, gem_trans1_n,
507 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, col_major)
508 gemv_interface(gem_p1_n, gem_trans1_n,
509 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, col_major)
510 gemv_interface(gem_p1_n, gem_trans1_n,
511 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, col_major)
512 gemv_interface(gem_p1_n, gem_trans1_n,
513 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, col_major)
516 gemv_interface(gem_p1_t, gem_trans1_t,
517 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
518 gemv_interface(gem_p1_t, gem_trans1_t,
519 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
520 gemv_interface(gem_p1_t, gem_trans1_t,
521 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
522 gemv_interface(gem_p1_t, gem_trans1_t,
523 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
526 gemv_interface(gem_p1_tc, gem_trans1_t,
527 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
528 gemv_interface(gem_p1_tc, gem_trans1_t,
529 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
530 gemv_interface(gem_p1_tc, gem_trans1_t,
531 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
532 gemv_interface(gem_p1_tc, gem_trans1_t,
533 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
536 gemv_interface(gem_p1_c, gem_trans1_c,
537 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
538 gemv_interface(gem_p1_c, gem_trans1_c,
539 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
540 gemv_interface(gem_p1_c, gem_trans1_c,
541 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
542 gemv_interface(gem_p1_c, gem_trans1_c,
543 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
550 # define gemv_interface2(param1, trans1, param2, trans2, blas_name, \
552 inline void mult_spec(param1(base_type), param2(base_type), \
553 std::vector<base_type> &z, orien) { \
554 GMMLAPACK_TRACE("gemv_interface2"); \
555 trans1(base_type); trans2(base_type); base_type beta(0); \
556 BLAS_INT m(BLAS_INT(mat_nrows(A))), lda(m); \
557 BLAS_INT n(BLAS_INT(mat_ncols(A))), inc(1); \
559 blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, &beta, \
561 else gmm::clear(z); \
565 gemv_interface2(gem_p1_n, gem_trans1_n,
566 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, col_major)
567 gemv_interface2(gem_p1_n, gem_trans1_n,
568 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, col_major)
569 gemv_interface2(gem_p1_n, gem_trans1_n,
570 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, col_major)
571 gemv_interface2(gem_p1_n, gem_trans1_n,
572 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, col_major)
575 gemv_interface2(gem_p1_t, gem_trans1_t,
576 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
577 gemv_interface2(gem_p1_t, gem_trans1_t,
578 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
579 gemv_interface2(gem_p1_t, gem_trans1_t,
580 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
581 gemv_interface2(gem_p1_t, gem_trans1_t,
582 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
585 gemv_interface2(gem_p1_tc, gem_trans1_t,
586 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
587 gemv_interface2(gem_p1_tc, gem_trans1_t,
588 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
589 gemv_interface2(gem_p1_tc, gem_trans1_t,
590 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
591 gemv_interface2(gem_p1_tc, gem_trans1_t,
592 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
595 gemv_interface2(gem_p1_c, gem_trans1_c,
596 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
597 gemv_interface2(gem_p1_c, gem_trans1_c,
598 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
599 gemv_interface2(gem_p1_c, gem_trans1_c,
600 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
601 gemv_interface2(gem_p1_c, gem_trans1_c,
602 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
605 gemv_interface2(gem_p1_n, gem_trans1_n,
606 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, col_major)
607 gemv_interface2(gem_p1_n, gem_trans1_n,
608 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, col_major)
609 gemv_interface2(gem_p1_n, gem_trans1_n,
610 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, col_major)
611 gemv_interface2(gem_p1_n, gem_trans1_n,
612 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, col_major)
615 gemv_interface2(gem_p1_t, gem_trans1_t,
616 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
617 gemv_interface2(gem_p1_t, gem_trans1_t,
618 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
619 gemv_interface2(gem_p1_t, gem_trans1_t,
620 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
621 gemv_interface2(gem_p1_t, gem_trans1_t,
622 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
625 gemv_interface2(gem_p1_tc, gem_trans1_t,
626 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
627 gemv_interface2(gem_p1_tc, gem_trans1_t,
628 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
629 gemv_interface2(gem_p1_tc, gem_trans1_t,
630 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
631 gemv_interface2(gem_p1_tc, gem_trans1_t,
632 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
635 gemv_interface2(gem_p1_c, gem_trans1_c,
636 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
637 gemv_interface2(gem_p1_c, gem_trans1_c,
638 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
639 gemv_interface2(gem_p1_c, gem_trans1_c,
640 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
641 gemv_interface2(gem_p1_c, gem_trans1_c,
642 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
649 # define ger_interface(blas_name, base_type) \
650 inline void rank_one_update(const dense_matrix<base_type> &A, \
651 const std::vector<base_type> &V, \
652 const std::vector<base_type> &W) { \
653 GMMLAPACK_TRACE("ger_interface"); \
654 BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
655 BLAS_INT n(BLAS_INT(mat_ncols(A))); \
656 BLAS_INT incx = 1, incy = 1; \
657 base_type alpha(1); \
659 blas_name(&m, &n, &alpha, &V[0], &incx, &W[0], &incy, &A(0,0), &lda);\
662 ger_interface(sger_, BLAS_S)
663 ger_interface(dger_, BLAS_D)
664 ger_interface(cgerc_, BLAS_C)
665 ger_interface(zgerc_, BLAS_Z)
667 # define ger_interface_sn(blas_name, base_type) \
668 inline void rank_one_update(const dense_matrix<base_type> &A, \
669 gemv_p2_s(base_type), \
670 const std::vector<base_type> &W) { \
671 GMMLAPACK_TRACE("ger_interface"); \
672 gemv_trans2_s(base_type); \
673 BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
674 BLAS_INT n(BLAS_INT(mat_ncols(A))); \
675 BLAS_INT incx = 1, incy = 1; \
677 blas_name(&m, &n, &alpha, &x[0], &incx, &W[0], &incy, &A(0,0), &lda);\
680 ger_interface_sn(sger_, BLAS_S)
681 ger_interface_sn(dger_, BLAS_D)
682 ger_interface_sn(cgerc_, BLAS_C)
683 ger_interface_sn(zgerc_, BLAS_Z)
685 # define ger_interface_ns(blas_name, base_type) \
686 inline void rank_one_update(const dense_matrix<base_type> &A, \
687 const std::vector<base_type> &V, \
688 gemv_p2_s(base_type)) { \
689 GMMLAPACK_TRACE("ger_interface"); \
690 gemv_trans2_s(base_type); \
691 BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
692 BLAS_INT n(BLAS_INT(mat_ncols(A))); \
693 BLAS_INT incx = 1, incy = 1; \
694 base_type al2 = gmm::conj(alpha); \
696 blas_name(&m, &n, &al2, &V[0], &incx, &x[0], &incy, &A(0,0), &lda); \
699 ger_interface_ns(sger_, BLAS_S)
700 ger_interface_ns(dger_, BLAS_D)
701 ger_interface_ns(cgerc_, BLAS_C)
702 ger_interface_ns(zgerc_, BLAS_Z)
708 # define gemm_interface_nn(blas_name, base_type) \
709 inline void mult_spec(const dense_matrix<base_type> &A, \
710 const dense_matrix<base_type> &B, \
711 dense_matrix<base_type> &C, c_mult) { \
712 GMMLAPACK_TRACE("gemm_interface_nn"); \
713 const char t = 'N'; \
714 BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
715 BLAS_INT k(BLAS_INT(mat_ncols(A))); \
716 BLAS_INT n(BLAS_INT(mat_ncols(B))); \
717 BLAS_INT ldb = k, ldc = m; \
718 base_type alpha(1), beta(0); \
720 blas_name(&t, &t, &m, &n, &k, &alpha, \
721 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
722 else gmm::clear(C); \
725 gemm_interface_nn(sgemm_, BLAS_S)
726 gemm_interface_nn(dgemm_, BLAS_D)
727 gemm_interface_nn(cgemm_, BLAS_C)
728 gemm_interface_nn(zgemm_, BLAS_Z)
734 # define gemm_interface_tn(blas_name, base_type, is_const) \
735 inline void mult_spec( \
736 const transposed_col_ref<is_const<base_type> *> &A_, \
737 const dense_matrix<base_type> &B, \
738 dense_matrix<base_type> &C, rcmult) { \
739 GMMLAPACK_TRACE("gemm_interface_tn"); \
740 const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
741 const char t = 'T', u = 'N'; \
742 BLAS_INT m(BLAS_INT(mat_ncols(A))), k(BLAS_INT(mat_nrows(A))); \
743 BLAS_INT n(BLAS_INT(mat_ncols(B))); \
744 BLAS_INT lda = k, ldb = k, ldc = m; \
745 base_type alpha(1), beta(0); \
747 blas_name(&t, &u, &m, &n, &k, &alpha, \
748 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
749 else gmm::clear(C); \
752 gemm_interface_tn(sgemm_, BLAS_S, dense_matrix)
753 gemm_interface_tn(dgemm_, BLAS_D, dense_matrix)
754 gemm_interface_tn(cgemm_, BLAS_C, dense_matrix)
755 gemm_interface_tn(zgemm_, BLAS_Z, dense_matrix)
756 gemm_interface_tn(sgemm_, BLAS_S,
const dense_matrix)
757 gemm_interface_tn(dgemm_, BLAS_D,
const dense_matrix)
758 gemm_interface_tn(cgemm_, BLAS_C,
const dense_matrix)
759 gemm_interface_tn(zgemm_, BLAS_Z,
const dense_matrix)
765 # define gemm_interface_nt(blas_name, base_type, is_const) \
766 inline void mult_spec(const dense_matrix<base_type> &A, \
767 const transposed_col_ref<is_const<base_type> *> &B_, \
768 dense_matrix<base_type> &C, r_mult) { \
769 GMMLAPACK_TRACE("gemm_interface_nt"); \
770 const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
771 const char t = 'N', u = 'T'; \
772 BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
773 BLAS_INT k(BLAS_INT(mat_ncols(A))); \
774 BLAS_INT n(BLAS_INT(mat_nrows(B))); \
775 BLAS_INT ldb = n, ldc = m; \
776 base_type alpha(1), beta(0); \
778 blas_name(&t, &u, &m, &n, &k, &alpha, \
779 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
780 else gmm::clear(C); \
783 gemm_interface_nt(sgemm_, BLAS_S, dense_matrix)
784 gemm_interface_nt(dgemm_, BLAS_D, dense_matrix)
785 gemm_interface_nt(cgemm_, BLAS_C, dense_matrix)
786 gemm_interface_nt(zgemm_, BLAS_Z, dense_matrix)
787 gemm_interface_nt(sgemm_, BLAS_S,
const dense_matrix)
788 gemm_interface_nt(dgemm_, BLAS_D,
const dense_matrix)
789 gemm_interface_nt(cgemm_, BLAS_C,
const dense_matrix)
790 gemm_interface_nt(zgemm_, BLAS_Z,
const dense_matrix)
796 # define gemm_interface_tt(blas_name, base_type, isA_const, isB_const) \
797 inline void mult_spec( \
798 const transposed_col_ref<isA_const <base_type> *> &A_, \
799 const transposed_col_ref<isB_const <base_type> *> &B_, \
800 dense_matrix<base_type> &C, r_mult) { \
801 GMMLAPACK_TRACE("gemm_interface_tt"); \
802 const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
803 const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
804 const char t = 'T', u = 'T'; \
805 BLAS_INT m(BLAS_INT(mat_ncols(A))), k(BLAS_INT(mat_nrows(A))); \
806 BLAS_INT n(BLAS_INT(mat_nrows(B))); \
807 BLAS_INT lda = k, ldb = n, ldc = m; \
808 base_type alpha(1), beta(0); \
810 blas_name(&t, &u, &m, &n, &k, &alpha, \
811 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
812 else gmm::clear(C); \
815 gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, dense_matrix)
816 gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, dense_matrix)
817 gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, dense_matrix)
818 gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, dense_matrix)
819 gemm_interface_tt(sgemm_, BLAS_S,
const dense_matrix, dense_matrix)
820 gemm_interface_tt(dgemm_, BLAS_D,
const dense_matrix, dense_matrix)
821 gemm_interface_tt(cgemm_, BLAS_C,
const dense_matrix, dense_matrix)
822 gemm_interface_tt(zgemm_, BLAS_Z,
const dense_matrix, dense_matrix)
823 gemm_interface_tt(sgemm_, BLAS_S, dense_matrix,
const dense_matrix)
824 gemm_interface_tt(dgemm_, BLAS_D, dense_matrix,
const dense_matrix)
825 gemm_interface_tt(cgemm_, BLAS_C, dense_matrix,
const dense_matrix)
826 gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix,
const dense_matrix)
827 gemm_interface_tt(sgemm_, BLAS_S,
const dense_matrix,
const dense_matrix)
828 gemm_interface_tt(dgemm_, BLAS_D,
const dense_matrix,
const dense_matrix)
829 gemm_interface_tt(cgemm_, BLAS_C,
const dense_matrix,
const dense_matrix)
830 gemm_interface_tt(zgemm_, BLAS_Z,
const dense_matrix,
const dense_matrix)
837 # define gemm_interface_cn(blas_name, base_type) \
838 inline void mult_spec( \
839 const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_, \
840 const dense_matrix<base_type> &B, \
841 dense_matrix<base_type> &C, rcmult) { \
842 GMMLAPACK_TRACE("gemm_interface_cn"); \
843 const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
844 const char t = 'C', u = 'N'; \
845 BLAS_INT m(BLAS_INT(mat_ncols(A))), k(BLAS_INT(mat_nrows(A))); \
846 BLAS_INT n(BLAS_INT(mat_ncols(B))); \
847 BLAS_INT lda = k, ldb = k, ldc = m; \
848 base_type alpha(1), beta(0); \
850 blas_name(&t, &u, &m, &n, &k, &alpha, \
851 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
852 else gmm::clear(C); \
855 gemm_interface_cn(sgemm_, BLAS_S)
856 gemm_interface_cn(dgemm_, BLAS_D)
857 gemm_interface_cn(cgemm_, BLAS_C)
858 gemm_interface_cn(zgemm_, BLAS_Z)
864 # define gemm_interface_nc(blas_name, base_type) \
865 inline void mult_spec(const dense_matrix<base_type> &A, \
866 const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &B_, \
867 dense_matrix<base_type> &C, c_mult, row_major) { \
868 GMMLAPACK_TRACE("gemm_interface_nc"); \
869 const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
870 const char t = 'N', u = 'C'; \
871 BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
872 BLAS_INT k(BLAS_INT(mat_ncols(A))); \
873 BLAS_INT n(BLAS_INT(mat_nrows(B))), ldb = n, ldc = m; \
874 base_type alpha(1), beta(0); \
876 blas_name(&t, &u, &m, &n, &k, &alpha, \
877 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
878 else gmm::clear(C); \
881 gemm_interface_nc(sgemm_, BLAS_S)
882 gemm_interface_nc(dgemm_, BLAS_D)
883 gemm_interface_nc(cgemm_, BLAS_C)
884 gemm_interface_nc(zgemm_, BLAS_Z)
890 # define gemm_interface_cc(blas_name, base_type) \
891 inline void mult_spec( \
892 const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_, \
893 const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &B_, \
894 dense_matrix<base_type> &C, r_mult) { \
895 GMMLAPACK_TRACE("gemm_interface_cc"); \
896 const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
897 const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
898 const char t = 'C', u = 'C'; \
899 BLAS_INT m(BLAS_INT(mat_ncols(A))), k(BLAS_INT(mat_nrows(A))); \
900 BLAS_INT lda = k, n(BLAS_INT(mat_nrows(B))), ldb = n, ldc = m; \
901 base_type alpha(1), beta(0); \
903 blas_name(&t, &u, &m, &n, &k, &alpha, \
904 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
905 else gmm::clear(C); \
908 gemm_interface_cc(sgemm_, BLAS_S)
909 gemm_interface_cc(dgemm_, BLAS_D)
910 gemm_interface_cc(cgemm_, BLAS_C)
911 gemm_interface_cc(zgemm_, BLAS_Z)
917 # define trsv_interface(f_name, loru, param1, trans1, blas_name, base_type)\
918 inline void f_name(param1(base_type), std::vector<base_type> &x, \
919 size_type k, bool is_unit) { \
920 GMMLAPACK_TRACE("trsv_interface"); \
921 loru; trans1(base_type); char d = is_unit ? 'U' : 'N'; \
922 BLAS_INT lda(BLAS_INT(mat_nrows(A))), inc(1), n = BLAS_INT(k); \
923 if (lda) blas_name(&l, &t, &d, &n, &A(0,0), &lda, &x[0], &inc); \
926 # define trsv_upper const char l = 'U'
927 # define trsv_lower const char l = 'L'
930 trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
932 trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
934 trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
936 trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
940 trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
942 trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
944 trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
946 trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
950 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
952 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
954 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
956 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
960 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
962 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
964 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
966 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
970 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
972 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
974 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
976 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
980 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
982 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
984 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
986 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
990 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
992 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
994 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
996 trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
1000 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
1002 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
1004 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
1006 trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
Basic linear algebra functions.
gmm interface for STL vectors.
Declaration of some matrix types (gmm::dense_matrix, gmm::row_matrix, gmm::col_matrix,...
size_t size_type
used as the common size type in the library