GetFEM  5.4.3
gmm_blas_interface.h
Go to the documentation of this file.
1 /* -*- c++ -*- (enables emacs c++ mode) */
2 /*===========================================================================
3 
4  Copyright (C) 2003-2020 Yves Renard
5 
6  This file is a part of GetFEM
7 
8  GetFEM is free software; you can redistribute it and/or modify it
9  under the terms of the GNU Lesser General Public License as published
10  by the Free Software Foundation; either version 3 of the License, or
11  (at your option) any later version along with the GCC Runtime Library
12  Exception either version 3.1 or (at your option) any later version.
13  This program is distributed in the hope that it will be useful, but
14  WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
15  or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
16  License and GCC Runtime Library Exception for more details.
17  You should have received a copy of the GNU Lesser General Public License
18  along with this program; if not, write to the Free Software Foundation,
19  Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA.
20 
21  As a special exception, you may use this file as it is a part of a free
22  software library without restriction. Specifically, if other files
23  instantiate templates or use macros or inline functions from this file,
24  or you compile this file and link it with other files to produce an
25  executable, this file does not by itself cause the resulting executable
26  to be covered by the GNU Lesser General Public License. This exception
27  does not however invalidate any other reasons why the executable file
28  might be covered by the GNU Lesser General Public License.
29 
30 ===========================================================================*/
31 
32 /**@file gmm_blas_interface.h
33  @author Yves Renard <[email protected]>
34  @date October 7, 2003.
35  @brief gmm interface for fortran BLAS.
36 */
37 
38 #if defined(GMM_USES_BLAS) || defined(GMM_USES_LAPACK)
39 
40 #ifndef GMM_BLAS_INTERFACE_H
41 #define GMM_BLAS_INTERFACE_H
42 
43 #include "gmm_blas.h"
44 #include "gmm_interface.h"
45 #include "gmm_matrix.h"
46 
47 namespace gmm {
48 
49  // Use ./configure --enable-blas-interface to activate this interface.
50 
51 #define GMMLAPACK_TRACE(f)
52 // #define GMMLAPACK_TRACE(f) cout << "function " << f << " called" << endl;
53 
54 #if defined(WeirdNEC) || defined(GMM_USE_BLAS64_INTERFACE)
55  #define BLAS_INT long
56 #else // By default BLAS_INT will just be int in C
57  #define BLAS_INT int
58 #endif
59 
60  /* ********************************************************************* */
61  /* Operations interfaced for T = float, double, std::complex<float> */
62  /* or std::complex<double> : */
63  /* */
64  /* vect_norm2(std::vector<T>) */
65  /* */
66  /* vect_sp(std::vector<T>, std::vector<T>) */
67  /* vect_sp(scaled(std::vector<T>), std::vector<T>) */
68  /* vect_sp(std::vector<T>, scaled(std::vector<T>)) */
69  /* vect_sp(scaled(std::vector<T>), scaled(std::vector<T>)) */
70  /* */
71  /* vect_hp(std::vector<T>, std::vector<T>) */
72  /* vect_hp(scaled(std::vector<T>), std::vector<T>) */
73  /* vect_hp(std::vector<T>, scaled(std::vector<T>)) */
74  /* vect_hp(scaled(std::vector<T>), scaled(std::vector<T>)) */
75  /* */
76  /* add(std::vector<T>, std::vector<T>) */
77  /* add(scaled(std::vector<T>, a), std::vector<T>) */
78  /* */
79  /* mult(dense_matrix<T>, dense_matrix<T>, dense_matrix<T>) */
80  /* mult(transposed(dense_matrix<T>), dense_matrix<T>, dense_matrix<T>) */
81  /* mult(dense_matrix<T>, transposed(dense_matrix<T>), dense_matrix<T>) */
82  /* mult(transposed(dense_matrix<T>), transposed(dense_matrix<T>), */
83  /* dense_matrix<T>) */
84  /* mult(conjugated(dense_matrix<T>), dense_matrix<T>, dense_matrix<T>) */
85  /* mult(dense_matrix<T>, conjugated(dense_matrix<T>), dense_matrix<T>) */
86  /* mult(conjugated(dense_matrix<T>), conjugated(dense_matrix<T>), */
87  /* dense_matrix<T>) */
88  /* */
89  /* mult(dense_matrix<T>, std::vector<T>, std::vector<T>) */
90  /* mult(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>) */
91  /* mult(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>) */
92  /* mult(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>) */
93  /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
94  /* std::vector<T>) */
95  /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
96  /* std::vector<T>) */
97  /* */
98  /* mult_add(dense_matrix<T>, std::vector<T>, std::vector<T>) */
99  /* mult_add(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>) */
100  /* mult_add(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>) */
101  /* mult_add(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>) */
102  /* mult_add(transposed(dense_matrix<T>), scaled(std::vector<T>), */
103  /* std::vector<T>) */
104  /* mult_add(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
105  /* std::vector<T>) */
106  /* */
107  /* mult(dense_matrix<T>, std::vector<T>, std::vector<T>, std::vector<T>) */
108  /* mult(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>, */
109  /* std::vector<T>) */
110  /* mult(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>, */
111  /* std::vector<T>) */
112  /* mult(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>, */
113  /* std::vector<T>) */
114  /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
115  /* std::vector<T>, std::vector<T>) */
116  /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
117  /* std::vector<T>, std::vector<T>) */
118  /* mult(dense_matrix<T>, std::vector<T>, scaled(std::vector<T>), */
119  /* std::vector<T>) */
120  /* mult(transposed(dense_matrix<T>), std::vector<T>, */
121  /* scaled(std::vector<T>), std::vector<T>) */
122  /* mult(conjugated(dense_matrix<T>), std::vector<T>, */
123  /* scaled(std::vector<T>), std::vector<T>) */
124  /* mult(dense_matrix<T>, scaled(std::vector<T>), scaled(std::vector<T>), */
125  /* std::vector<T>) */
126  /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
127  /* scaled(std::vector<T>), std::vector<T>) */
128  /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
129  /* scaled(std::vector<T>), std::vector<T>) */
130  /* */
131  /* lower_tri_solve(dense_matrix<T>, std::vector<T>, k, b) */
132  /* upper_tri_solve(dense_matrix<T>, std::vector<T>, k, b) */
133  /* lower_tri_solve(transposed(dense_matrix<T>), std::vector<T>, k, b) */
134  /* upper_tri_solve(transposed(dense_matrix<T>), std::vector<T>, k, b) */
135  /* lower_tri_solve(conjugated(dense_matrix<T>), std::vector<T>, k, b) */
136  /* upper_tri_solve(conjugated(dense_matrix<T>), std::vector<T>, k, b) */
137  /* */
138  /* rank_one_update(dense_matrix<T>, std::vector<T>, std::vector<T>) */
139  /* rank_one_update(dense_matrix<T>, scaled(std::vector<T>), */
140  /* std::vector<T>) */
141  /* rank_one_update(dense_matrix<T>, std::vector<T>, */
142  /* scaled(std::vector<T>)) */
143  /* */
144  /* ********************************************************************* */
145 
146  /* ********************************************************************* */
147  /* Basic defines. */
148  /* ********************************************************************* */
149 
150 # define BLAS_S float
151 # define BLAS_D double
152 # define BLAS_C std::complex<float>
153 # define BLAS_Z std::complex<double>
154 
155 // Hack due to BLAS ABI mess
156 #if defined(GMM_BLAS_RETURN_COMPLEX_AS_ARGUMENT)
157 # define BLAS_CPLX_FUNC_CALL(blasname, res, ...) blasname(&res, __VA_ARGS__)
158 #else
159 # define BLAS_CPLX_FUNC_CALL(blasname, res, ...) res = blasname(__VA_ARGS__)
160 #endif
161 
162  /* ********************************************************************* */
163  /* BLAS functions used. */
164  /* ********************************************************************* */
165  extern "C" {
166  void daxpy_(const BLAS_INT *n, const double *alpha, const double *x,
167  const BLAS_INT *incx, double *y, const BLAS_INT *incy);
168  void dgemm_(const char *tA, const char *tB, const BLAS_INT *m,
169  const BLAS_INT *n, const BLAS_INT *k, const double *alpha,
170  const double *A, const BLAS_INT *ldA, const double *B,
171  const BLAS_INT *ldB, const double *beta, double *C,
172  const BLAS_INT *ldC);
173  void sgemm_(...); void cgemm_(...); void zgemm_(...);
174  void sgemv_(...); void dgemv_(...); void cgemv_(...); void zgemv_(...);
175  void strsv_(...); void dtrsv_(...); void ctrsv_(...); void ztrsv_(...);
176  void saxpy_(...); /*void daxpy_(...); */void caxpy_(...); void zaxpy_(...);
177  BLAS_S sdot_ (...); BLAS_D ddot_ (...);
178  BLAS_C cdotu_(...); BLAS_Z zdotu_(...);
179  // Hermitian product in {c,z}dotc is defined in reverse order than usually
180  BLAS_C cdotc_(...); BLAS_Z zdotc_(...);
181  BLAS_S snrm2_(...); BLAS_D dnrm2_(...);
182  BLAS_S scnrm2_(...); BLAS_D dznrm2_(...);
183  void sger_(...); void dger_(...); void cgerc_(...); void zgerc_(...);
184  }
185 
186 
187  /* ********************************************************************* */
188  /* vect_norm2(x). */
189  /* ********************************************************************* */
190 
191 # define nrm2_interface(blas_name, base_type) \
192  inline number_traits<base_type>::magnitude_type \
193  vect_norm2(const std::vector<base_type> &x) { \
194  GMMLAPACK_TRACE("nrm2_interface"); \
195  BLAS_INT inc(1), n(BLAS_INT(vect_size(x))); \
196  return blas_name(&n, &x[0], &inc); \
197  }
198 
199  nrm2_interface(snrm2_, BLAS_S)
200  nrm2_interface(dnrm2_, BLAS_D)
201  nrm2_interface(scnrm2_, BLAS_C)
202  nrm2_interface(dznrm2_, BLAS_Z)
203 
204  /* ********************************************************************* */
205  /* vect_sp(x,y) = vect_hp(x,y) for real vectors */
206  /* ********************************************************************* */
207 
208 # define dot_interface(funcname, msg, blas_name, base_type) \
209  inline base_type funcname(const std::vector<base_type> &x, \
210  const std::vector<base_type> &y) { \
211  GMMLAPACK_TRACE(msg); \
212  BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
213  return blas_name(&n, &x[0], &inc, &y[0], &inc); \
214  } \
215  inline base_type funcname \
216  (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
217  const std::vector<base_type> &y) { \
218  GMMLAPACK_TRACE(msg); \
219  const std::vector<base_type> &x = *(linalg_origin(x_)); \
220  base_type a(x_.r); \
221  BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
222  return a * blas_name(&n, &x[0], &inc, &y[0], &inc); \
223  } \
224  inline base_type funcname \
225  (const std::vector<base_type> &x, \
226  const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
227  GMMLAPACK_TRACE(msg); \
228  const std::vector<base_type> &y = *(linalg_origin(y_)); \
229  base_type b(y_.r); \
230  BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
231  return b * blas_name(&n, &x[0], &inc, &y[0], &inc); \
232  } \
233  inline base_type funcname \
234  (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
235  const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
236  GMMLAPACK_TRACE(msg); \
237  const std::vector<base_type> &x = *(linalg_origin(x_)); \
238  const std::vector<base_type> &y = *(linalg_origin(y_)); \
239  base_type a(x_.r), b(y_.r); \
240  BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
241  return a*b * blas_name(&n, &x[0], &inc, &y[0], &inc); \
242  }
243 
244  dot_interface(vect_sp, "dot_interface", sdot_, BLAS_S)
245  dot_interface(vect_sp, "dot_interface", ddot_, BLAS_D)
246  dot_interface(vect_hp, "dotc_interface", sdot_, BLAS_S)
247  dot_interface(vect_hp, "dotc_interface", ddot_, BLAS_D)
248 
249  /* ********************************************************************* */
250  /* vect_sp(x,y) and vect_hp(x,y) for complex vectors */
251  /* vect_hp(x, y) = x.conj(y) (different order than in BLAS) */
252  /* switching x,y before passed to BLAS is important only for vect_hp */
253  /* ********************************************************************* */
254 
255 # define dot_interface_cplx(funcname, msg, blas_name, base_type, bdef) \
256  inline base_type funcname(const std::vector<base_type> &x, \
257  const std::vector<base_type> &y) { \
258  GMMLAPACK_TRACE(msg); \
259  base_type res; \
260  BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
261  BLAS_CPLX_FUNC_CALL(blas_name, res, &n, &y[0], &inc, &x[0], &inc); \
262  return res; \
263  } \
264  inline base_type funcname \
265  (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
266  const std::vector<base_type> &y) { \
267  GMMLAPACK_TRACE(msg); \
268  const std::vector<base_type> &x = *(linalg_origin(x_)); \
269  base_type res, a(x_.r); \
270  BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
271  BLAS_CPLX_FUNC_CALL(blas_name, res, &n, &y[0], &inc, &x[0], &inc); \
272  return a*res; \
273  } \
274  inline base_type funcname \
275  (const std::vector<base_type> &x, \
276  const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
277  GMMLAPACK_TRACE(msg); \
278  const std::vector<base_type> &y = *(linalg_origin(y_)); \
279  base_type res, b(bdef); \
280  BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
281  BLAS_CPLX_FUNC_CALL(blas_name, res, &n, &y[0], &inc, &x[0], &inc); \
282  return b*res; \
283  } \
284  inline base_type funcname \
285  (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
286  const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
287  GMMLAPACK_TRACE(msg); \
288  const std::vector<base_type> &x = *(linalg_origin(x_)); \
289  const std::vector<base_type> &y = *(linalg_origin(y_)); \
290  base_type res, a(x_.r), b(bdef); \
291  BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
292  BLAS_CPLX_FUNC_CALL(blas_name, res, &n, &y[0], &inc, &x[0], &inc); \
293  return a*b*res; \
294  }
295 
296  dot_interface_cplx(vect_sp, "dot_interface", cdotu_, BLAS_C, y_.r)
297  dot_interface_cplx(vect_sp, "dot_interface", zdotu_, BLAS_Z, y_.r)
298  dot_interface_cplx(vect_hp, "dotc_interface", cdotc_, BLAS_C, gmm::conj(y_.r))
299  dot_interface_cplx(vect_hp, "dotc_interface", zdotc_, BLAS_Z, gmm::conj(y_.r))
300 
301 
302  /* ********************************************************************* */
303  /* add(x, y). */
304  /* ********************************************************************* */
305  template<size_type N, class V1, class V2>
306  inline void add_fixed(const V1 &x, V2 &y)
307  {
308  for(size_type i = 0; i != N; ++i) y[i] += x[i];
309  }
310 
311  template<class V1, class V2>
312  inline void add_for_short_vectors(const V1 &x, V2 &y, size_type n)
313  {
314  switch(n)
315  {
316  case 1: add_fixed<1>(x, y); break;
317  case 2: add_fixed<2>(x, y); break;
318  case 3: add_fixed<3>(x, y); break;
319  case 4: add_fixed<4>(x, y); break;
320  case 5: add_fixed<5>(x, y); break;
321  case 6: add_fixed<6>(x, y); break;
322  case 7: add_fixed<7>(x, y); break;
323  case 8: add_fixed<8>(x, y); break;
324  case 9: add_fixed<9>(x, y); break;
325  case 10: add_fixed<10>(x, y); break;
326  case 11: add_fixed<11>(x, y); break;
327  case 12: add_fixed<12>(x, y); break;
328  case 13: add_fixed<13>(x, y); break;
329  case 14: add_fixed<14>(x, y); break;
330  case 15: add_fixed<15>(x, y); break;
331  case 16: add_fixed<16>(x, y); break;
332  case 17: add_fixed<17>(x, y); break;
333  case 18: add_fixed<18>(x, y); break;
334  case 19: add_fixed<19>(x, y); break;
335  case 20: add_fixed<20>(x, y); break;
336  case 21: add_fixed<21>(x, y); break;
337  case 22: add_fixed<22>(x, y); break;
338  case 23: add_fixed<23>(x, y); break;
339  case 24: add_fixed<24>(x, y); break;
340  default:
341  GMM_ASSERT2(false, "add_for_short_vectors used with unsupported size");
342  break;
343  }
344  }
345 
346  template<size_type N, class V1, class V2, class T>
347  inline void add_fixed(const V1 &x, V2 &y, const T &a)
348  {
349  for(size_type i = 0; i != N; ++i) y[i] += a*x[i];
350  }
351 
352  template<class V1, class V2, class T>
353  inline void add_for_short_vectors(const V1 &x, V2 &y, const T &a, size_type n)
354  {
355  switch(n)
356  {
357  case 1: add_fixed<1>(x, y, a); break;
358  case 2: add_fixed<2>(x, y, a); break;
359  case 3: add_fixed<3>(x, y, a); break;
360  case 4: add_fixed<4>(x, y, a); break;
361  case 5: add_fixed<5>(x, y, a); break;
362  case 6: add_fixed<6>(x, y, a); break;
363  case 7: add_fixed<7>(x, y, a); break;
364  case 8: add_fixed<8>(x, y, a); break;
365  case 9: add_fixed<9>(x, y, a); break;
366  case 10: add_fixed<10>(x, y, a); break;
367  case 11: add_fixed<11>(x, y, a); break;
368  case 12: add_fixed<12>(x, y, a); break;
369  case 13: add_fixed<13>(x, y, a); break;
370  case 14: add_fixed<14>(x, y, a); break;
371  case 15: add_fixed<15>(x, y, a); break;
372  case 16: add_fixed<16>(x, y, a); break;
373  case 17: add_fixed<17>(x, y, a); break;
374  case 18: add_fixed<18>(x, y, a); break;
375  case 19: add_fixed<19>(x, y, a); break;
376  case 20: add_fixed<20>(x, y, a); break;
377  case 21: add_fixed<21>(x, y, a); break;
378  case 22: add_fixed<22>(x, y, a); break;
379  case 23: add_fixed<23>(x, y, a); break;
380  case 24: add_fixed<24>(x, y, a); break;
381  default:
382  GMM_ASSERT2(false, "add_for_short_vectors used with unsupported size");
383  break;
384  }
385  }
386 
387 
388 # define axpy_interface(blas_name, base_type) \
389  inline void add(const std::vector<base_type> &x, \
390  std::vector<base_type> &y) { \
391  GMMLAPACK_TRACE("axpy_interface"); \
392  BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); base_type a(1); \
393  if(n == 0) return; \
394  else if(n < 25) add_for_short_vectors(x, y, n); \
395  else blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \
396  }
397 
398  axpy_interface(saxpy_, BLAS_S)
399  axpy_interface(daxpy_, BLAS_D)
400  axpy_interface(caxpy_, BLAS_C)
401  axpy_interface(zaxpy_, BLAS_Z)
402 
403 
404 # define axpy2_interface(blas_name, base_type) \
405  inline void add \
406  (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
407  std::vector<base_type> &y) { \
408  GMMLAPACK_TRACE("axpy_interface"); \
409  BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \
410  const std::vector<base_type>& x = *(linalg_origin(x_)); \
411  base_type a(x_.r); \
412  if(n == 0) return; \
413  else if(n < 25) add_for_short_vectors(x, y, a, n); \
414  else blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \
415  }
416 
417  axpy2_interface(saxpy_, BLAS_S)
418  axpy2_interface(daxpy_, BLAS_D)
419  axpy2_interface(caxpy_, BLAS_C)
420  axpy2_interface(zaxpy_, BLAS_Z)
421 
422 
423  /* ********************************************************************* */
424  /* mult_add(A, x, z). */
425  /* ********************************************************************* */
426 
427 # define gemv_interface(param1, trans1, param2, trans2, blas_name, \
428  base_type, orien) \
429  inline void mult_add_spec(param1(base_type), param2(base_type), \
430  std::vector<base_type> &z, orien) { \
431  GMMLAPACK_TRACE("gemv_interface"); \
432  trans1(base_type); trans2(base_type); base_type beta(1); \
433  BLAS_INT m(BLAS_INT(mat_nrows(A))), lda(m); \
434  BLAS_INT n(BLAS_INT(mat_ncols(A))), inc(1); \
435  if (m && n) blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, \
436  &beta, &z[0], &inc); \
437  else gmm::clear(z); \
438  }
439 
440  // First parameter
441 # define gem_p1_n(base_type) const dense_matrix<base_type> &A
442 # define gem_trans1_n(base_type) const char t = 'N'
443 # define gem_p1_t(base_type) \
444  const transposed_col_ref<dense_matrix<base_type> *> &A_
445 # define gem_trans1_t(base_type) const dense_matrix<base_type> &A = \
446  *(linalg_origin(A_)); \
447  const char t = 'T'
448 # define gem_p1_tc(base_type) \
449  const transposed_col_ref<const dense_matrix<base_type> *> &A_
450 # define gem_p1_c(base_type) \
451  const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_
452 # define gem_trans1_c(base_type) const dense_matrix<base_type> &A = \
453  *(linalg_origin(A_)); \
454  const char t = 'C'
455 
456  // second parameter
457 # define gemv_p2_n(base_type) const std::vector<base_type> &x
458 # define gemv_trans2_n(base_type) base_type alpha(1)
459 # define gemv_p2_s(base_type) \
460  const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_
461 # define gemv_trans2_s(base_type) const std::vector<base_type> &x = \
462  (*(linalg_origin(x_))); \
463  base_type alpha(x_.r)
464 
465  // Z <- AX + Z.
466  gemv_interface(gem_p1_n, gem_trans1_n,
467  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, col_major)
468  gemv_interface(gem_p1_n, gem_trans1_n,
469  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, col_major)
470  gemv_interface(gem_p1_n, gem_trans1_n,
471  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, col_major)
472  gemv_interface(gem_p1_n, gem_trans1_n,
473  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, col_major)
474 
475  // Z <- transposed(A)X + Z.
476  gemv_interface(gem_p1_t, gem_trans1_t,
477  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
478  gemv_interface(gem_p1_t, gem_trans1_t,
479  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
480  gemv_interface(gem_p1_t, gem_trans1_t,
481  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
482  gemv_interface(gem_p1_t, gem_trans1_t,
483  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
484 
485  // Z <- transposed(const A)X + Z.
486  gemv_interface(gem_p1_tc, gem_trans1_t,
487  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
488  gemv_interface(gem_p1_tc, gem_trans1_t,
489  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
490  gemv_interface(gem_p1_tc, gem_trans1_t,
491  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
492  gemv_interface(gem_p1_tc, gem_trans1_t,
493  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
494 
495  // Z <- conjugated(A)X + Z.
496  gemv_interface(gem_p1_c, gem_trans1_c,
497  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
498  gemv_interface(gem_p1_c, gem_trans1_c,
499  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
500  gemv_interface(gem_p1_c, gem_trans1_c,
501  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
502  gemv_interface(gem_p1_c, gem_trans1_c,
503  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
504 
505  // Z <- A scaled(X) + Z.
506  gemv_interface(gem_p1_n, gem_trans1_n,
507  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, col_major)
508  gemv_interface(gem_p1_n, gem_trans1_n,
509  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, col_major)
510  gemv_interface(gem_p1_n, gem_trans1_n,
511  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, col_major)
512  gemv_interface(gem_p1_n, gem_trans1_n,
513  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, col_major)
514 
515  // Z <- transposed(A) scaled(X) + Z.
516  gemv_interface(gem_p1_t, gem_trans1_t,
517  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
518  gemv_interface(gem_p1_t, gem_trans1_t,
519  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
520  gemv_interface(gem_p1_t, gem_trans1_t,
521  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
522  gemv_interface(gem_p1_t, gem_trans1_t,
523  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
524 
525  // Z <- transposed(const A) scaled(X) + Z.
526  gemv_interface(gem_p1_tc, gem_trans1_t,
527  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
528  gemv_interface(gem_p1_tc, gem_trans1_t,
529  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
530  gemv_interface(gem_p1_tc, gem_trans1_t,
531  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
532  gemv_interface(gem_p1_tc, gem_trans1_t,
533  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
534 
535  // Z <- conjugated(A) scaled(X) + Z.
536  gemv_interface(gem_p1_c, gem_trans1_c,
537  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
538  gemv_interface(gem_p1_c, gem_trans1_c,
539  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
540  gemv_interface(gem_p1_c, gem_trans1_c,
541  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
542  gemv_interface(gem_p1_c, gem_trans1_c,
543  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
544 
545 
546  /* ********************************************************************* */
547  /* mult(A, x, y). */
548  /* ********************************************************************* */
549 
550 # define gemv_interface2(param1, trans1, param2, trans2, blas_name, \
551  base_type, orien) \
552  inline void mult_spec(param1(base_type), param2(base_type), \
553  std::vector<base_type> &z, orien) { \
554  GMMLAPACK_TRACE("gemv_interface2"); \
555  trans1(base_type); trans2(base_type); base_type beta(0); \
556  BLAS_INT m(BLAS_INT(mat_nrows(A))), lda(m); \
557  BLAS_INT n(BLAS_INT(mat_ncols(A))), inc(1); \
558  if (m && n) \
559  blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, &beta, \
560  &z[0], &inc); \
561  else gmm::clear(z); \
562  }
563 
564  // Y <- AX.
565  gemv_interface2(gem_p1_n, gem_trans1_n,
566  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, col_major)
567  gemv_interface2(gem_p1_n, gem_trans1_n,
568  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, col_major)
569  gemv_interface2(gem_p1_n, gem_trans1_n,
570  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, col_major)
571  gemv_interface2(gem_p1_n, gem_trans1_n,
572  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, col_major)
573 
574  // Y <- transposed(A)X.
575  gemv_interface2(gem_p1_t, gem_trans1_t,
576  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
577  gemv_interface2(gem_p1_t, gem_trans1_t,
578  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
579  gemv_interface2(gem_p1_t, gem_trans1_t,
580  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
581  gemv_interface2(gem_p1_t, gem_trans1_t,
582  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
583 
584  // Y <- transposed(const A)X.
585  gemv_interface2(gem_p1_tc, gem_trans1_t,
586  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
587  gemv_interface2(gem_p1_tc, gem_trans1_t,
588  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
589  gemv_interface2(gem_p1_tc, gem_trans1_t,
590  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
591  gemv_interface2(gem_p1_tc, gem_trans1_t,
592  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
593 
594  // Y <- conjugated(A)X.
595  gemv_interface2(gem_p1_c, gem_trans1_c,
596  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
597  gemv_interface2(gem_p1_c, gem_trans1_c,
598  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
599  gemv_interface2(gem_p1_c, gem_trans1_c,
600  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
601  gemv_interface2(gem_p1_c, gem_trans1_c,
602  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
603 
604  // Y <- A scaled(X).
605  gemv_interface2(gem_p1_n, gem_trans1_n,
606  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, col_major)
607  gemv_interface2(gem_p1_n, gem_trans1_n,
608  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, col_major)
609  gemv_interface2(gem_p1_n, gem_trans1_n,
610  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, col_major)
611  gemv_interface2(gem_p1_n, gem_trans1_n,
612  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, col_major)
613 
614  // Y <- transposed(A) scaled(X).
615  gemv_interface2(gem_p1_t, gem_trans1_t,
616  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
617  gemv_interface2(gem_p1_t, gem_trans1_t,
618  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
619  gemv_interface2(gem_p1_t, gem_trans1_t,
620  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
621  gemv_interface2(gem_p1_t, gem_trans1_t,
622  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
623 
624  // Y <- transposed(const A) scaled(X).
625  gemv_interface2(gem_p1_tc, gem_trans1_t,
626  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
627  gemv_interface2(gem_p1_tc, gem_trans1_t,
628  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
629  gemv_interface2(gem_p1_tc, gem_trans1_t,
630  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
631  gemv_interface2(gem_p1_tc, gem_trans1_t,
632  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
633 
634  // Y <- conjugated(A) scaled(X).
635  gemv_interface2(gem_p1_c, gem_trans1_c,
636  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
637  gemv_interface2(gem_p1_c, gem_trans1_c,
638  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
639  gemv_interface2(gem_p1_c, gem_trans1_c,
640  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
641  gemv_interface2(gem_p1_c, gem_trans1_c,
642  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
643 
644 
645  /* ********************************************************************* */
646  /* Rank one update. */
647  /* ********************************************************************* */
648 
649 # define ger_interface(blas_name, base_type) \
650  inline void rank_one_update(const dense_matrix<base_type> &A, \
651  const std::vector<base_type> &V, \
652  const std::vector<base_type> &W) { \
653  GMMLAPACK_TRACE("ger_interface"); \
654  BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
655  BLAS_INT n(BLAS_INT(mat_ncols(A))); \
656  BLAS_INT incx = 1, incy = 1; \
657  base_type alpha(1); \
658  if (m && n) \
659  blas_name(&m, &n, &alpha, &V[0], &incx, &W[0], &incy, &A(0,0), &lda);\
660  }
661 
662  ger_interface(sger_, BLAS_S)
663  ger_interface(dger_, BLAS_D)
664  ger_interface(cgerc_, BLAS_C)
665  ger_interface(zgerc_, BLAS_Z)
666 
667 # define ger_interface_sn(blas_name, base_type) \
668  inline void rank_one_update(const dense_matrix<base_type> &A, \
669  gemv_p2_s(base_type), \
670  const std::vector<base_type> &W) { \
671  GMMLAPACK_TRACE("ger_interface"); \
672  gemv_trans2_s(base_type); \
673  BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
674  BLAS_INT n(BLAS_INT(mat_ncols(A))); \
675  BLAS_INT incx = 1, incy = 1; \
676  if (m && n) \
677  blas_name(&m, &n, &alpha, &x[0], &incx, &W[0], &incy, &A(0,0), &lda);\
678  }
679 
680  ger_interface_sn(sger_, BLAS_S)
681  ger_interface_sn(dger_, BLAS_D)
682  ger_interface_sn(cgerc_, BLAS_C)
683  ger_interface_sn(zgerc_, BLAS_Z)
684 
685 # define ger_interface_ns(blas_name, base_type) \
686  inline void rank_one_update(const dense_matrix<base_type> &A, \
687  const std::vector<base_type> &V, \
688  gemv_p2_s(base_type)) { \
689  GMMLAPACK_TRACE("ger_interface"); \
690  gemv_trans2_s(base_type); \
691  BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
692  BLAS_INT n(BLAS_INT(mat_ncols(A))); \
693  BLAS_INT incx = 1, incy = 1; \
694  base_type al2 = gmm::conj(alpha); \
695  if (m && n) \
696  blas_name(&m, &n, &al2, &V[0], &incx, &x[0], &incy, &A(0,0), &lda); \
697  }
698 
699  ger_interface_ns(sger_, BLAS_S)
700  ger_interface_ns(dger_, BLAS_D)
701  ger_interface_ns(cgerc_, BLAS_C)
702  ger_interface_ns(zgerc_, BLAS_Z)
703 
704  /* ********************************************************************* */
705  /* dense matrix x dense matrix multiplication. */
706  /* ********************************************************************* */
707 
708 # define gemm_interface_nn(blas_name, base_type) \
709  inline void mult_spec(const dense_matrix<base_type> &A, \
710  const dense_matrix<base_type> &B, \
711  dense_matrix<base_type> &C, c_mult) { \
712  GMMLAPACK_TRACE("gemm_interface_nn"); \
713  const char t = 'N'; \
714  BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
715  BLAS_INT k(BLAS_INT(mat_ncols(A))); \
716  BLAS_INT n(BLAS_INT(mat_ncols(B))); \
717  BLAS_INT ldb = k, ldc = m; \
718  base_type alpha(1), beta(0); \
719  if (m && k && n) \
720  blas_name(&t, &t, &m, &n, &k, &alpha, \
721  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
722  else gmm::clear(C); \
723  }
724 
725  gemm_interface_nn(sgemm_, BLAS_S)
726  gemm_interface_nn(dgemm_, BLAS_D)
727  gemm_interface_nn(cgemm_, BLAS_C)
728  gemm_interface_nn(zgemm_, BLAS_Z)
729 
730  /* ********************************************************************* */
731  /* transposed(dense matrix) x dense matrix multiplication. */
732  /* ********************************************************************* */
733 
734 # define gemm_interface_tn(blas_name, base_type, is_const) \
735  inline void mult_spec( \
736  const transposed_col_ref<is_const<base_type> *> &A_, \
737  const dense_matrix<base_type> &B, \
738  dense_matrix<base_type> &C, rcmult) { \
739  GMMLAPACK_TRACE("gemm_interface_tn"); \
740  const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
741  const char t = 'T', u = 'N'; \
742  BLAS_INT m(BLAS_INT(mat_ncols(A))), k(BLAS_INT(mat_nrows(A))); \
743  BLAS_INT n(BLAS_INT(mat_ncols(B))); \
744  BLAS_INT lda = k, ldb = k, ldc = m; \
745  base_type alpha(1), beta(0); \
746  if (m && k && n) \
747  blas_name(&t, &u, &m, &n, &k, &alpha, \
748  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
749  else gmm::clear(C); \
750  }
751 
752  gemm_interface_tn(sgemm_, BLAS_S, dense_matrix)
753  gemm_interface_tn(dgemm_, BLAS_D, dense_matrix)
754  gemm_interface_tn(cgemm_, BLAS_C, dense_matrix)
755  gemm_interface_tn(zgemm_, BLAS_Z, dense_matrix)
756  gemm_interface_tn(sgemm_, BLAS_S, const dense_matrix)
757  gemm_interface_tn(dgemm_, BLAS_D, const dense_matrix)
758  gemm_interface_tn(cgemm_, BLAS_C, const dense_matrix)
759  gemm_interface_tn(zgemm_, BLAS_Z, const dense_matrix)
760 
761  /* ********************************************************************* */
762  /* dense matrix x transposed(dense matrix) multiplication. */
763  /* ********************************************************************* */
764 
765 # define gemm_interface_nt(blas_name, base_type, is_const) \
766  inline void mult_spec(const dense_matrix<base_type> &A, \
767  const transposed_col_ref<is_const<base_type> *> &B_, \
768  dense_matrix<base_type> &C, r_mult) { \
769  GMMLAPACK_TRACE("gemm_interface_nt"); \
770  const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
771  const char t = 'N', u = 'T'; \
772  BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
773  BLAS_INT k(BLAS_INT(mat_ncols(A))); \
774  BLAS_INT n(BLAS_INT(mat_nrows(B))); \
775  BLAS_INT ldb = n, ldc = m; \
776  base_type alpha(1), beta(0); \
777  if (m && k && n) \
778  blas_name(&t, &u, &m, &n, &k, &alpha, \
779  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
780  else gmm::clear(C); \
781  }
782 
783  gemm_interface_nt(sgemm_, BLAS_S, dense_matrix)
784  gemm_interface_nt(dgemm_, BLAS_D, dense_matrix)
785  gemm_interface_nt(cgemm_, BLAS_C, dense_matrix)
786  gemm_interface_nt(zgemm_, BLAS_Z, dense_matrix)
787  gemm_interface_nt(sgemm_, BLAS_S, const dense_matrix)
788  gemm_interface_nt(dgemm_, BLAS_D, const dense_matrix)
789  gemm_interface_nt(cgemm_, BLAS_C, const dense_matrix)
790  gemm_interface_nt(zgemm_, BLAS_Z, const dense_matrix)
791 
792  /* ********************************************************************* */
793  /* transposed(dense matrix) x transposed(dense matrix) multiplication. */
794  /* ********************************************************************* */
795 
796 # define gemm_interface_tt(blas_name, base_type, isA_const, isB_const) \
797  inline void mult_spec( \
798  const transposed_col_ref<isA_const <base_type> *> &A_, \
799  const transposed_col_ref<isB_const <base_type> *> &B_, \
800  dense_matrix<base_type> &C, r_mult) { \
801  GMMLAPACK_TRACE("gemm_interface_tt"); \
802  const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
803  const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
804  const char t = 'T', u = 'T'; \
805  BLAS_INT m(BLAS_INT(mat_ncols(A))), k(BLAS_INT(mat_nrows(A))); \
806  BLAS_INT n(BLAS_INT(mat_nrows(B))); \
807  BLAS_INT lda = k, ldb = n, ldc = m; \
808  base_type alpha(1), beta(0); \
809  if (m && k && n) \
810  blas_name(&t, &u, &m, &n, &k, &alpha, \
811  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
812  else gmm::clear(C); \
813  }
814 
815  gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, dense_matrix)
816  gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, dense_matrix)
817  gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, dense_matrix)
818  gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, dense_matrix)
819  gemm_interface_tt(sgemm_, BLAS_S, const dense_matrix, dense_matrix)
820  gemm_interface_tt(dgemm_, BLAS_D, const dense_matrix, dense_matrix)
821  gemm_interface_tt(cgemm_, BLAS_C, const dense_matrix, dense_matrix)
822  gemm_interface_tt(zgemm_, BLAS_Z, const dense_matrix, dense_matrix)
823  gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, const dense_matrix)
824  gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, const dense_matrix)
825  gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, const dense_matrix)
826  gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, const dense_matrix)
827  gemm_interface_tt(sgemm_, BLAS_S, const dense_matrix, const dense_matrix)
828  gemm_interface_tt(dgemm_, BLAS_D, const dense_matrix, const dense_matrix)
829  gemm_interface_tt(cgemm_, BLAS_C, const dense_matrix, const dense_matrix)
830  gemm_interface_tt(zgemm_, BLAS_Z, const dense_matrix, const dense_matrix)
831 
832 
833  /* ********************************************************************* */
834  /* conjugated(dense matrix) x dense matrix multiplication. */
835  /* ********************************************************************* */
836 
837 # define gemm_interface_cn(blas_name, base_type) \
838  inline void mult_spec( \
839  const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_, \
840  const dense_matrix<base_type> &B, \
841  dense_matrix<base_type> &C, rcmult) { \
842  GMMLAPACK_TRACE("gemm_interface_cn"); \
843  const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
844  const char t = 'C', u = 'N'; \
845  BLAS_INT m(BLAS_INT(mat_ncols(A))), k(BLAS_INT(mat_nrows(A))); \
846  BLAS_INT n(BLAS_INT(mat_ncols(B))); \
847  BLAS_INT lda = k, ldb = k, ldc = m; \
848  base_type alpha(1), beta(0); \
849  if (m && k && n) \
850  blas_name(&t, &u, &m, &n, &k, &alpha, \
851  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
852  else gmm::clear(C); \
853  }
854 
855  gemm_interface_cn(sgemm_, BLAS_S)
856  gemm_interface_cn(dgemm_, BLAS_D)
857  gemm_interface_cn(cgemm_, BLAS_C)
858  gemm_interface_cn(zgemm_, BLAS_Z)
859 
860  /* ********************************************************************* */
861  /* dense matrix x conjugated(dense matrix) multiplication. */
862  /* ********************************************************************* */
863 
864 # define gemm_interface_nc(blas_name, base_type) \
865  inline void mult_spec(const dense_matrix<base_type> &A, \
866  const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &B_, \
867  dense_matrix<base_type> &C, c_mult, row_major) { \
868  GMMLAPACK_TRACE("gemm_interface_nc"); \
869  const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
870  const char t = 'N', u = 'C'; \
871  BLAS_INT m(BLAS_INT(mat_nrows(A))), lda = m; \
872  BLAS_INT k(BLAS_INT(mat_ncols(A))); \
873  BLAS_INT n(BLAS_INT(mat_nrows(B))), ldb = n, ldc = m; \
874  base_type alpha(1), beta(0); \
875  if (m && k && n) \
876  blas_name(&t, &u, &m, &n, &k, &alpha, \
877  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
878  else gmm::clear(C); \
879  }
880 
881  gemm_interface_nc(sgemm_, BLAS_S)
882  gemm_interface_nc(dgemm_, BLAS_D)
883  gemm_interface_nc(cgemm_, BLAS_C)
884  gemm_interface_nc(zgemm_, BLAS_Z)
885 
886  /* ********************************************************************* */
887  /* conjugated(dense matrix) x conjugated(dense matrix) multiplication. */
888  /* ********************************************************************* */
889 
890 # define gemm_interface_cc(blas_name, base_type) \
891  inline void mult_spec( \
892  const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_, \
893  const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &B_, \
894  dense_matrix<base_type> &C, r_mult) { \
895  GMMLAPACK_TRACE("gemm_interface_cc"); \
896  const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
897  const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
898  const char t = 'C', u = 'C'; \
899  BLAS_INT m(BLAS_INT(mat_ncols(A))), k(BLAS_INT(mat_nrows(A))); \
900  BLAS_INT lda = k, n(BLAS_INT(mat_nrows(B))), ldb = n, ldc = m; \
901  base_type alpha(1), beta(0); \
902  if (m && k && n) \
903  blas_name(&t, &u, &m, &n, &k, &alpha, \
904  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
905  else gmm::clear(C); \
906  }
907 
908  gemm_interface_cc(sgemm_, BLAS_S)
909  gemm_interface_cc(dgemm_, BLAS_D)
910  gemm_interface_cc(cgemm_, BLAS_C)
911  gemm_interface_cc(zgemm_, BLAS_Z)
912 
913  /* ********************************************************************* */
914  /* Tri solve. */
915  /* ********************************************************************* */
916 
917 # define trsv_interface(f_name, loru, param1, trans1, blas_name, base_type)\
918  inline void f_name(param1(base_type), std::vector<base_type> &x, \
919  size_type k, bool is_unit) { \
920  GMMLAPACK_TRACE("trsv_interface"); \
921  loru; trans1(base_type); char d = is_unit ? 'U' : 'N'; \
922  BLAS_INT lda(BLAS_INT(mat_nrows(A))), inc(1), n = BLAS_INT(k); \
923  if (lda) blas_name(&l, &t, &d, &n, &A(0,0), &lda, &x[0], &inc); \
924  }
925 
926 # define trsv_upper const char l = 'U'
927 # define trsv_lower const char l = 'L'
928 
929  // X <- LOWER(A)^{-1}X.
930  trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
931  strsv_, BLAS_S)
932  trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
933  dtrsv_, BLAS_D)
934  trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
935  ctrsv_, BLAS_C)
936  trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
937  ztrsv_, BLAS_Z)
938 
939  // X <- UPPER(A)^{-1}X.
940  trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
941  strsv_, BLAS_S)
942  trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
943  dtrsv_, BLAS_D)
944  trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
945  ctrsv_, BLAS_C)
946  trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
947  ztrsv_, BLAS_Z)
948 
949  // X <- LOWER(transposed(A))^{-1}X.
950  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
951  strsv_, BLAS_S)
952  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
953  dtrsv_, BLAS_D)
954  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
955  ctrsv_, BLAS_C)
956  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
957  ztrsv_, BLAS_Z)
958 
959  // X <- UPPER(transposed(A))^{-1}X.
960  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
961  strsv_, BLAS_S)
962  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
963  dtrsv_, BLAS_D)
964  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
965  ctrsv_, BLAS_C)
966  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
967  ztrsv_, BLAS_Z)
968 
969  // X <- LOWER(transposed(const A))^{-1}X.
970  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
971  strsv_, BLAS_S)
972  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
973  dtrsv_, BLAS_D)
974  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
975  ctrsv_, BLAS_C)
976  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
977  ztrsv_, BLAS_Z)
978 
979  // X <- UPPER(transposed(const A))^{-1}X.
980  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
981  strsv_, BLAS_S)
982  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
983  dtrsv_, BLAS_D)
984  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
985  ctrsv_, BLAS_C)
986  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
987  ztrsv_, BLAS_Z)
988 
989  // X <- LOWER(conjugated(A))^{-1}X.
990  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
991  strsv_, BLAS_S)
992  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
993  dtrsv_, BLAS_D)
994  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
995  ctrsv_, BLAS_C)
996  trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
997  ztrsv_, BLAS_Z)
998 
999  // X <- UPPER(conjugated(A))^{-1}X.
1000  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
1001  strsv_, BLAS_S)
1002  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
1003  dtrsv_, BLAS_D)
1004  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
1005  ctrsv_, BLAS_C)
1006  trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
1007  ztrsv_, BLAS_Z)
1008 
1009 }
1010 
1011 #endif // GMM_BLAS_INTERFACE_H
1012 
1013 #endif // GMM_USES_BLAS
Basic linear algebra functions.
gmm interface for STL vectors.
Declaration of some matrix types (gmm::dense_matrix, gmm::row_matrix, gmm::col_matrix,...
size_t size_type
used as the common size type in the library
Definition: bgeot_poly.h:49