| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943 |
- /*******************************************************************************
- * Copyright 2014-2022 Intel Corporation.
- *
- * This software and the related documents are Intel copyrighted materials, and
- * your use of them is governed by the express license under which they were
- * provided to you (License). Unless the License provides otherwise, you may not
- * use, modify, copy, publish, distribute, disclose or transmit this software or
- * the related documents without Intel's prior written permission.
- *
- * This software and the related documents are provided as is, with no express
- * or implied warranties, other than those that are expressly stated in the
- * License.
- *******************************************************************************/
- /*
- ! Content:
- ! Intel(R) oneAPI Math Kernel Library (oneMKL) C functions that can be inlined
- !******************************************************************************/
- #include "mkl_types.h"
- #include "immintrin.h"
- #undef mkl_dc_gemm
- #undef mkl_dc_syrk
- #undef mkl_dc_trsm
- #if defined(MKL_DOUBLE)
- #define mkl_dc_gemm mkl_dc_dgemm
- #define mkl_dc_syrk mkl_dc_dsyrk
- #define mkl_dc_trsm mkl_dc_dtrsm
- #define mkl_dc_axpy mkl_dc_daxpy
- #define mkl_dc_dot mkl_dc_ddot
- #define MKL_DC_DOT_CONVERT mkl_dc_ddot_convert
- #elif defined(MKL_SINGLE)
- #define mkl_dc_gemm mkl_dc_sgemm
- #define mkl_dc_syrk mkl_dc_ssyrk
- #define mkl_dc_trsm mkl_dc_strsm
- #define mkl_dc_axpy mkl_dc_saxpy
- #define mkl_dc_dot mkl_dc_sdot
- #define MKL_DC_DOT_CONVERT mkl_dc_sdot_convert
- #elif defined(MKL_COMPLEX)
- #define mkl_dc_gemm mkl_dc_cgemm
- #define mkl_dc_syrk mkl_dc_csyrk
- #define mkl_dc_trsm mkl_dc_ctrsm
- #define mkl_dc_axpy mkl_dc_caxpy
- #elif defined(MKL_COMPLEX16)
- #define mkl_dc_gemm mkl_dc_zgemm
- #define mkl_dc_syrk mkl_dc_zsyrk
- #define mkl_dc_trsm mkl_dc_ztrsm
- #define mkl_dc_axpy mkl_dc_zaxpy
- #endif
- #define mkl_dc_gemm_xx_mnk_pst_beta(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, \
- a_op, b_op, c_op, a_access, b_access) \
- do { \
- MKL_INT i, j, l; \
- MKL_DC_PRAGMA_VECTOR \
- for (i = 0; i < m; i++) \
- for (j = 0; j < n; j++) { \
- mkl_dc_type c, temp; MKL_DC_SET_ZERO(temp); \
- for (l = 0; l < k; l++) { \
- mkl_dc_type a, b, temp1; \
- a = a_access(A, lda, i, l); a_op(a, a); \
- b = b_access(B, ldb, l, j); b_op(b, b); \
- MKL_DC_MUL(temp1, a, b); \
- MKL_DC_ADD(temp, temp, temp1); \
- } \
- MKL_DC_MUL(temp, alpha, temp); \
- c = C[i + j * ldc]; \
- c_op(c, beta); \
- MKL_DC_ADD(C[i + j * ldc], c, temp); \
- } \
- } while (0)
- #define mkl_dc_gemm_xx_mnk_pst(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, \
- a_op, b_op, a_access, b_access) \
- do { \
- if (MKL_DC_IS_ZERO(beta)) \
- mkl_dc_gemm_xx_mnk_pst_beta(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, \
- a_op, b_op, MKL_DC_ZERO_C, a_access, b_access); \
- else \
- mkl_dc_gemm_xx_mnk_pst_beta(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, \
- a_op, b_op, MKL_DC_MUL_C, a_access, b_access); \
- } while (0)
- #ifdef __AVX2__
- #ifdef MKL_DOUBLE
- #define MKL_DC_MUL_ADD_YMM(ymm_a, ymm_b, ymm_c, ymm_tmp) \
- ymm_c = _mm256_fmadd_pd(ymm_a, ymm_b, ymm_c);
- #define MKL_DC_MUL_ADD_XMM(xmm_a, xmm_b, xmm_c, xmm_tmp) \
- xmm_c = _mm_fmadd_pd(xmm_a, xmm_b, xmm_c);
- #define MKL_DC_MUL_ADD_XMM_S(xmm_a, xmm_b, xmm_c, xmm_tmp) \
- xmm_c = _mm_fmadd_sd(xmm_a, xmm_b, xmm_c);
- #define MKL_DC_XOR_YMM _mm256_xor_pd
- #define MKL_DC_SETZERO_YMM _mm256_setzero_pd
- #define MKL_DC_BCAST_YMM _mm256_broadcast_sd
- #define MKL_DC_LOAD_YMM _mm256_loadu_pd
- #define MKL_DC_STORE_YMM _mm256_storeu_pd
- #define MKL_DC_ADD_YMM _mm256_add_pd
- #define MKL_DC_MUL_YMM _mm256_mul_pd
- #define MKL_DC_MASKLOAD_YMM _mm256_maskload_pd
- #define MKL_DC_MASKSTORE_YMM _mm256_maskstore_pd
- #define MKL_DC_HADD_YMM _mm256_hadd_pd
- #define MKL_DC_PERM2F128_YMM _mm256_permute2f128_pd
- #define MKL_DC_UNPACKHI_YMM _mm256_unpackhi_pd
- #define MKL_DC_UNPACKLO_YMM _mm256_unpacklo_pd
- #define MKL_DC_CAST_YMM_TO_XMM _mm256_castpd256_pd128
- #define MKL_DC_CAST_XMM_TO_YMM _mm256_castpd128_pd256
- #define MKL_DC_XOR_XMM _mm_xor_pd
- #define MKL_DC_SETZERO_XMM _mm_setzero_pd
- #define MKL_DC_LOAD_XMM _mm_loadu_pd
- #define MKL_DC_STORE_XMM _mm_storeu_pd
- #define MKL_DC_ADD_XMM _mm_add_pd
- #define MKL_DC_MUL_XMM _mm_mul_pd
- #define MKL_DC_LOAD_XMM_S _mm_load_sd
- #define MKL_DC_STORE_XMM_S _mm_store_sd
- #define MKL_DC_LOADDUP_XMM _mm_loaddup_pd
- #define MKL_DC_ADD_XMM_S _mm_add_sd
- #define MKL_DC_MUL_XMM_S _mm_mul_sd
- #define MKL_DC_UNPACKHI_XMM _mm_unpackhi_pd
- #define MKL_DC_UNPACKLO_XMM _mm_unpacklo_pd
- #define MKL_DC_VEC_TRANSPOSE_YMM(x1, x2, x3, x4, tmp1, tmp2, tmp3, tmp4) \
- do { \
- tmp1 = _mm256_unpacklo_pd(x1, x2); \
- tmp2 = _mm256_unpackhi_pd(x1, x2); \
- tmp3 = _mm256_unpacklo_pd(x3, x4); \
- tmp4 = _mm256_unpackhi_pd(x3, x4); \
- x1 = _mm256_permute2f128_pd(tmp1, tmp3, 0x20); \
- x2 = _mm256_permute2f128_pd(tmp2, tmp4, 0x20); \
- x3 = _mm256_permute2f128_pd(tmp1, tmp3, 0x31); \
- x4 = _mm256_permute2f128_pd(tmp2, tmp4, 0x31); \
- } while (0)
- #define MKL_DC_VEC_TRANSPOSE_XMM(out1, out2, in1, in2) \
- do { \
- out1 = _mm_unpacklo_pd(in1, in2); \
- out2 = _mm_unpackhi_pd(in1, in2); \
- } while (0)
- typedef __m128d MKL_DC_XMMTYPE;
- typedef __m256d MKL_DC_YMMTYPE;
- #endif
- #endif
- #define MKL_DC_DGEMM_KERNELS(kernel_type, arch) \
- do { \
- if (AisN && BisN) { \
- mkl_dc_dgemm_nn_mnk_ ## kernel_type ## _ ## arch ## _pst(m, n, k, ALPHA, A, lda, B, ldb, BETA, C, ldc); \
- } else if (AisN && !BisN) { \
- mkl_dc_dgemm_nt_mnk_ ## kernel_type ## _ ## arch ## _pst(m, n, k, ALPHA, A, lda, B, ldb, BETA, C, ldc); \
- } else if (!AisN && BisN) { \
- mkl_dc_dgemm_tn_mnk_ ## kernel_type ## _ ## arch ## _pst(m, n, k, ALPHA, A, lda, B, ldb, BETA, C, ldc); \
- } else { \
- mkl_dc_dgemm_tt_mnk_ ## kernel_type ## _ ## arch ## _pst(m, n, k, ALPHA, A, lda, B, ldb, BETA, C, ldc); \
- } \
- } while (0)
- #include "mkl_direct_blas_kernels.h"
- #define MKL_DC_ALPHA_ONE
- #include "mkl_direct_blas_kernels.h"
- #undef MKL_DC_ALPHA_ONE
- #define MKL_DC_BETA_ZERO
- #include "mkl_direct_blas_kernels.h"
- #undef MKL_DC_BETA_ZERO
- #define MKL_DC_BETA_ONE
- #include "mkl_direct_blas_kernels.h"
- #undef MKL_DC_BETA_ONE
- #define MKL_DC_BETA_ONE
- #define MKL_DC_ALPHA_ONE
- #include "mkl_direct_blas_kernels.h"
- #undef MKL_DC_BETA_ONE
- #undef MKL_DC_ALPHA_ONE
- #define MKL_DC_BETA_ZERO
- #define MKL_DC_ALPHA_ONE
- #include "mkl_direct_blas_kernels.h"
- #undef MKL_DC_BETA_ZERO
- #undef MKL_DC_ALPHA_ONE
- static __inline void mkl_dc_gemm(const char * TRANSA, const char * TRANSB,
- const MKL_INT * M, const MKL_INT * N, const MKL_INT * K,
- const mkl_dc_type * ALPHA,
- const mkl_dc_type * A, const MKL_INT * LDA,
- const mkl_dc_type * B, const MKL_INT * LDB,
- const mkl_dc_type * BETA,
- mkl_dc_type * C, const MKL_INT * LDC)
- {
- int AisN, BisN;
- #ifndef MKL_REAL_DATA_TYPE
- int AisT, AisC;
- int BisT, BisC;
- #endif
- mkl_dc_type alpha = *ALPHA, beta = *BETA;
- MKL_INT m = *M, n = *N, k = *K;
- MKL_INT lda = *LDA, ldb = *LDB, ldc = *LDC;
- if (m <= 0 || n <= 0 || ((MKL_DC_IS_ZERO(alpha) || k <= 0) && MKL_DC_IS_ONE(beta)))
- return;
- AisN = MKL_DC_MisN(*TRANSA);
- BisN = MKL_DC_MisN(*TRANSB);
- #ifndef MKL_REAL_DATA_TYPE
- AisT = MKL_DC_MisT(*TRANSA);
- BisT = MKL_DC_MisT(*TRANSB);
- AisC = !(AisN || AisT);
- BisC = !(BisN || BisT);
- #endif
- if (MKL_DC_IS_ZERO(alpha)) {
- MKL_INT i, j;
- if (MKL_DC_IS_ZERO(beta))
- for (j = 0; j < n; j++)
- #pragma vector
- for (i = 0; i < m; i++)
- MKL_DC_SET_ZERO(C[i + ldc * j]);
- else
- for (j = 0; j < n; j++)
- #pragma vector
- for (i = 0; i < m; i++)
- MKL_DC_MUL(C[i + ldc * j], beta, C[i + ldc * j]);
- return;
- }
- #if defined(MKL_DOUBLE) && defined(__AVX2__)
- if (MKL_DC_IS_ONE(alpha) && MKL_DC_IS_ONE(beta)) {
- MKL_DC_DGEMM_KERNELS(a1b1, avx2);
- } else if (MKL_DC_IS_ONE(alpha) && MKL_DC_IS_ZERO(beta)) {
- MKL_DC_DGEMM_KERNELS(a1b0, avx2);
- } else if (MKL_DC_IS_ZERO(beta)) {
- MKL_DC_DGEMM_KERNELS(axb0, avx2);
- } else if (MKL_DC_IS_ONE(beta)) {
- MKL_DC_DGEMM_KERNELS(axb1, avx2);
- } else if (MKL_DC_IS_ONE(alpha)) {
- MKL_DC_DGEMM_KERNELS(a1bx, avx2);
- } else {
- MKL_DC_DGEMM_KERNELS(axbx, avx2);
- }
- #else
- if (AisN && BisN)
- mkl_dc_gemm_xx_mnk_pst(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, MKL_DC_MOV, MKL_DC_MOV, MKL_DC_MN, MKL_DC_MN);
- #ifndef MKL_REAL_DATA_TYPE
- else if (AisC && BisN)
- mkl_dc_gemm_xx_mnk_pst(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, MKL_DC_CONJ, MKL_DC_MOV, MKL_DC_MT, MKL_DC_MN);
- else if (AisC && BisT)
- mkl_dc_gemm_xx_mnk_pst(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, MKL_DC_CONJ, MKL_DC_MOV, MKL_DC_MT, MKL_DC_MT);
- else if (AisN && BisC)
- mkl_dc_gemm_xx_mnk_pst(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, MKL_DC_MOV, MKL_DC_CONJ, MKL_DC_MN, MKL_DC_MT);
- else if (AisT && BisC)
- mkl_dc_gemm_xx_mnk_pst(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, MKL_DC_MOV, MKL_DC_CONJ, MKL_DC_MT, MKL_DC_MT);
- else if (AisC && BisC)
- mkl_dc_gemm_xx_mnk_pst(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, MKL_DC_CONJ, MKL_DC_CONJ, MKL_DC_MT, MKL_DC_MT);
- #endif
- else if (AisN && !BisN)
- mkl_dc_gemm_xx_mnk_pst(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, MKL_DC_MOV, MKL_DC_MOV, MKL_DC_MN, MKL_DC_MT);
- else if (!AisN && BisN)
- mkl_dc_gemm_xx_mnk_pst(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, MKL_DC_MOV, MKL_DC_MOV, MKL_DC_MT, MKL_DC_MN);
- else
- mkl_dc_gemm_xx_mnk_pst(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, MKL_DC_MOV, MKL_DC_MOV, MKL_DC_MT, MKL_DC_MT);
- #endif
- }
- /* ?TRSM */
- #define mkl_dc_trsm_lxnx_mn_pst(uplo, m, n, alpha, A, lda, B, ldb, diag_op ) \
- do { \
- MKL_INT i, j, k; \
- if ( MKL_DC_MisU(uplo) ) { \
- for (j = 0; j < n; j++) { \
- if ( !(MKL_DC_IS_ONE(alpha)) ) { \
- for(i = 0; i < m; i++) { \
- MKL_DC_MUL_C( B[i+j*ldb], alpha ); \
- } \
- } \
- for (k = m-1; k >= 0; k--) { \
- diag_op( B[k + j * ldb], A[k + k * lda] ); \
- MKL_DC_PRAGMA_VECTOR \
- for ( i = 0; i <= k-1; i++ ) { \
- mkl_dc_type a, temp1; \
- a = A[i + k * lda]; \
- MKL_DC_MUL(temp1, B[k + j * ldb], a); \
- MKL_DC_SUB(B[i + j * ldb], B[i + j * ldb], temp1); \
- } \
- } \
- } \
- } else { \
- for (j = 0; j < n; j++) { \
- if ( !(MKL_DC_IS_ONE(alpha)) ) { \
- for(i = 0; i < m; i++) { \
- MKL_DC_MUL_C( B[i+j*ldb], alpha ); \
- } \
- } \
- for (k = 0; k < m; k++) { \
- diag_op( B[k + j * ldb], A[k + k * lda] ); \
- MKL_DC_PRAGMA_VECTOR \
- for ( i = k+1; i < m; i++ ) { \
- mkl_dc_type a, temp1; \
- a = A[i + k * lda]; \
- MKL_DC_MUL(temp1, B[k + j * ldb], a); \
- MKL_DC_SUB(B[i + j * ldb], B[i + j * ldb], temp1); \
- } \
- } \
- } \
- } \
- } while (0)
- #define mkl_dc_trsm_lxtx_mn_pst(uplo, m, n, alpha, A, lda, B, ldb, diag_op ) \
- do { \
- MKL_INT i, j, k; \
- if ( MKL_DC_MisU(uplo) ) { \
- for (j = 0; j < n; j++) { \
- mkl_dc_type temp; \
- for(i = 0; i < m; i++) { \
- MKL_DC_MUL(temp, alpha, B[i+j*ldb]); \
- MKL_DC_PRAGMA_VECTOR \
- for (k = 0; k <= i-1; k++) { \
- mkl_dc_type a, temp1; \
- a = A[k + i * lda]; \
- MKL_DC_MUL(temp1, B[k + j * ldb], a); \
- MKL_DC_SUB(temp, temp, temp1); \
- } \
- diag_op( temp, A[i + i * lda] ); \
- MKL_DC_MOV(B[i+j*ldb], temp); \
- } \
- } \
- } else { \
- for (j = 0; j < n; j++) { \
- mkl_dc_type temp; \
- for(i = m-1; i >= 0; i--) { \
- MKL_DC_MUL(temp, alpha, B[i+j*ldb]); \
- MKL_DC_PRAGMA_VECTOR \
- for (k = i+1; k < m; k++) { \
- mkl_dc_type a, temp1; \
- a = A[k + i * lda]; \
- MKL_DC_MUL(temp1, B[k + j * ldb], a); \
- MKL_DC_SUB(temp, temp, temp1); \
- } \
- diag_op( temp, A[i + i * lda] ); \
- MKL_DC_MOV(B[i+j*ldb], temp); \
- } \
- } \
- } \
- } while (0)
- #define mkl_dc_trsm_lxcx_mn_pst(uplo, m, n, alpha, A, lda, B, ldb, diag_op ) \
- do { \
- MKL_INT i, j, k; \
- if ( MKL_DC_MisU(uplo) ) { \
- for (j = 0; j < n; j++) { \
- mkl_dc_type temp; \
- for(i = 0; i < m; i++) { \
- mkl_dc_type a; \
- MKL_DC_MUL(temp, alpha, B[i+j*ldb]); \
- MKL_DC_PRAGMA_VECTOR \
- for (k = 0; k <= i-1; k++) { \
- mkl_dc_type a1, temp1; \
- a1 = A[k + i * lda]; \
- MKL_DC_CONJ(a1, a1); \
- MKL_DC_MUL(temp1, B[k + j * ldb], a1); \
- MKL_DC_SUB(temp, temp, temp1); \
- } \
- a = A[i + i * lda]; \
- MKL_DC_CONJ(a, a); \
- diag_op( temp, a ); \
- MKL_DC_MOV(B[i+j*ldb], temp); \
- } \
- } \
- } else { \
- for (j = 0; j < n; j++) { \
- mkl_dc_type temp; \
- for(i = m-1; i >= 0; i--) { \
- mkl_dc_type a; \
- MKL_DC_MUL(temp, alpha, B[i+j*ldb]); \
- MKL_DC_PRAGMA_VECTOR \
- for (k = i+1; k < m; k++) { \
- mkl_dc_type a1, temp1; \
- a1 = A[k + i * lda]; \
- MKL_DC_CONJ(a1, a1); \
- MKL_DC_MUL(temp1, B[k + j * ldb], a1); \
- MKL_DC_SUB(temp, temp, temp1); \
- } \
- a = A[i + i * lda]; \
- MKL_DC_CONJ(a, a); \
- diag_op( temp, a ); \
- MKL_DC_MOV(B[i+j*ldb], temp); \
- } \
- } \
- } \
- } while (0)
- #define mkl_dc_trsm_rxnn_mn_pst(uplo, m, n, alpha, A, lda, B, ldb ) \
- do { \
- MKL_INT i, j, k; \
- if ( MKL_DC_MisU(uplo) ) { \
- for (j = 0; j < n; j++) { \
- mkl_dc_type temp, one; \
- if ( !(MKL_DC_IS_ONE(alpha)) ) { \
- for(i = 0; i < m; i++) { \
- MKL_DC_MUL_C( B[i+j*ldb], alpha ); \
- } \
- } \
- for (k = 0; k <= j-1; k++) { \
- MKL_DC_PRAGMA_VECTOR \
- for ( i = 0; i < m; i++ ) { \
- mkl_dc_type a, temp1; \
- a = A[k + j * lda]; \
- MKL_DC_MUL(temp1, B[i + k * ldb], a); \
- MKL_DC_SUB(B[i + j * ldb], B[i + j * ldb], temp1); \
- } \
- } \
- MKL_DC_SET_ONE(one); \
- MKL_DC_DIV(temp, one, A[j + j * lda]); \
- for ( i = 0; i < m; i++ ) { \
- MKL_DC_MUL(B[i + j * ldb], temp, B[i + j * ldb]); \
- } \
- } \
- } else { \
- for (j = n-1; j >= 0; j--) { \
- mkl_dc_type temp, one; \
- if ( !(MKL_DC_IS_ONE(alpha)) ) { \
- for(i = 0; i < m; i++) { \
- MKL_DC_MUL_C( B[i+j*ldb], alpha ); \
- } \
- } \
- for (k = j+1; k < n; k++) { \
- MKL_DC_PRAGMA_VECTOR \
- for ( i = 0; i < m; i++ ) { \
- mkl_dc_type a, temp1; \
- a = A[k + j * lda]; \
- MKL_DC_MUL(temp1, B[i + k * ldb], a); \
- MKL_DC_SUB(B[i + j * ldb], B[i + j * ldb], temp1); \
- } \
- } \
- MKL_DC_SET_ONE(one); \
- MKL_DC_DIV(temp, one, A[j + j * lda]); \
- for ( i = 0; i < m; i++ ) { \
- MKL_DC_MUL(B[i + j * ldb], temp, B[i + j * ldb]); \
- } \
- } \
- } \
- } while (0)
- #define mkl_dc_trsm_rxnu_mn_pst(uplo, m, n, alpha, A, lda, B, ldb ) \
- do { \
- MKL_INT i, j, k; \
- if ( MKL_DC_MisU(uplo) ) { \
- for (j = 0; j < n; j++) { \
- if ( !(MKL_DC_IS_ONE(alpha)) ) { \
- for(i = 0; i < m; i++) { \
- MKL_DC_MUL_C( B[i+j*ldb], alpha ); \
- } \
- } \
- for (k = 0; k <= j-1; k++) { \
- MKL_DC_PRAGMA_VECTOR \
- for ( i = 0; i < m; i++ ) { \
- mkl_dc_type a, temp1; \
- a = A[k + j * lda]; \
- MKL_DC_MUL(temp1, B[i + k * ldb], a); \
- MKL_DC_SUB(B[i + j * ldb], B[i + j * ldb], temp1); \
- } \
- } \
- } \
- } else { \
- for (j = n-1; j >= 0; j--) { \
- if ( !(MKL_DC_IS_ONE(alpha)) ) { \
- for(i = 0; i < m; i++) { \
- MKL_DC_MUL_C( B[i+j*ldb], alpha ); \
- } \
- } \
- for (k = j+1; k < n; k++) { \
- MKL_DC_PRAGMA_VECTOR \
- for ( i = 0; i < m; i++ ) { \
- mkl_dc_type a, temp1; \
- a = A[k + j * lda]; \
- MKL_DC_MUL(temp1, B[i + k * ldb], a); \
- MKL_DC_SUB(B[i + j * ldb], B[i + j * ldb], temp1); \
- } \
- } \
- } \
- } \
- } while (0)
- #define mkl_dc_trsm_rxtn_mn_pst(uplo, m, n, alpha, A, lda, B, ldb) \
- do { \
- MKL_INT i, j, k; \
- if ( MKL_DC_MisU(uplo) ) { \
- for (k = n-1; k >= 0; k--) { \
- mkl_dc_type temp, one; \
- MKL_DC_SET_ONE(one); \
- MKL_DC_DIV(temp, one, A[k + k * lda]); \
- for (i = 0; i < m; i++) { \
- MKL_DC_MUL(B[i + k*ldb], temp, B[i + k*ldb]); \
- } \
- for(j = 0; j <= k-1; j++) { \
- MKL_DC_MOV(temp, A[j+k*lda]); \
- MKL_DC_PRAGMA_VECTOR \
- for (i = 0; i < m; i++) { \
- mkl_dc_type b, temp1; \
- b = B[i + k * ldb]; \
- MKL_DC_MUL(temp1, temp, b); \
- MKL_DC_SUB(B[i + j * ldb], B[i + j * ldb], temp1); \
- } \
- } \
- if ( !(MKL_DC_IS_ONE(alpha)) ) { \
- for ( i = 0; i < m; i++ ) { \
- MKL_DC_MUL(B[i + k * ldb], alpha, B[i + k * ldb]); \
- } \
- } \
- } \
- } else { \
- for (k = 0; k < n; k++) { \
- mkl_dc_type temp, one; \
- MKL_DC_SET_ONE(one); \
- MKL_DC_DIV(temp, one, A[k + k * lda]); \
- for (i = 0; i < m; i++) { \
- MKL_DC_MUL(B[i + k*ldb], temp, B[i + k*ldb]); \
- } \
- for(j = k+1; j < n; j++) { \
- MKL_DC_MOV(temp, A[j+k*lda]); \
- MKL_DC_PRAGMA_VECTOR \
- for (i = 0; i < m; i++) { \
- mkl_dc_type b, temp1; \
- b = B[i + k * ldb]; \
- MKL_DC_MUL(temp1, temp, b); \
- MKL_DC_SUB(B[i + j * ldb], B[i + j * ldb], temp1); \
- } \
- } \
- if ( !(MKL_DC_IS_ONE(alpha)) ) { \
- for ( i = 0; i < m; i++ ) { \
- MKL_DC_MUL(B[i + k * ldb], alpha, B[i + k * ldb]); \
- } \
- } \
- } \
- } \
- } while (0)
- #define mkl_dc_trsm_rxcn_mn_pst(uplo, m, n, alpha, A, lda, B, ldb) \
- do { \
- MKL_INT i, j, k; \
- if ( MKL_DC_MisU(uplo) ) { \
- for (k = n-1; k >= 0; k--) { \
- mkl_dc_type temp, one, a; \
- MKL_DC_SET_ONE(one); \
- a = A[k + k * lda]; \
- MKL_DC_CONJ(a, a); \
- MKL_DC_DIV(temp, one, a); \
- for (i = 0; i < m; i++) { \
- MKL_DC_MUL(B[i + k*ldb], temp, B[i + k*ldb]); \
- } \
- for(j = 0; j <= k-1; j++) { \
- MKL_DC_MOV(temp, A[j+k*lda]); \
- MKL_DC_CONJ(temp, temp); \
- MKL_DC_PRAGMA_VECTOR \
- for (i = 0; i < m; i++) { \
- mkl_dc_type b, temp1; \
- b = B[i + k * ldb]; \
- MKL_DC_MUL(temp1, temp, b); \
- MKL_DC_SUB(B[i + j * ldb], B[i + j * ldb], temp1); \
- } \
- } \
- if ( !(MKL_DC_IS_ONE(alpha)) ) { \
- for ( i = 0; i < m; i++ ) { \
- MKL_DC_MUL(B[i + k * ldb], alpha, B[i + k * ldb]); \
- } \
- } \
- } \
- } else { \
- for (k = 0; k < n; k++) { \
- mkl_dc_type temp, one, a; \
- MKL_DC_SET_ONE(one); \
- a = A[k + k * lda]; \
- MKL_DC_CONJ(a, a); \
- MKL_DC_DIV(temp, one, a); \
- for (i = 0; i < m; i++) { \
- MKL_DC_MUL(B[i + k*ldb], temp, B[i + k*ldb]); \
- } \
- for(j = k+1; j < n; j++) { \
- MKL_DC_MOV(temp, A[j+k*lda]); \
- MKL_DC_CONJ(temp, temp); \
- MKL_DC_PRAGMA_VECTOR \
- for (i = 0; i < m; i++) { \
- mkl_dc_type b, temp1; \
- b = B[i + k * ldb]; \
- MKL_DC_MUL(temp1, temp, b); \
- MKL_DC_SUB(B[i + j * ldb], B[i + j * ldb], temp1); \
- } \
- } \
- if ( !(MKL_DC_IS_ONE(alpha)) ) { \
- for ( i = 0; i < m; i++ ) { \
- MKL_DC_MUL(B[i + k * ldb], alpha, B[i + k * ldb]); \
- } \
- } \
- } \
- } \
- } while (0)
- #define mkl_dc_trsm_rxtu_mn_pst(uplo, m, n, alpha, A, lda, B, ldb) \
- do { \
- MKL_INT i, j, k; \
- if ( MKL_DC_MisU(uplo) ) { \
- for (k = n-1; k >= 0; k--) { \
- for(j = 0; j <= k-1; j++) { \
- mkl_dc_type temp; \
- MKL_DC_MOV(temp, A[j+k*lda]); \
- MKL_DC_PRAGMA_VECTOR \
- for (i = 0; i < m; i++) { \
- mkl_dc_type b, temp1; \
- b = B[i + k * ldb]; \
- MKL_DC_MUL(temp1, temp, b); \
- MKL_DC_SUB(B[i + j * ldb], B[i + j * ldb], temp1); \
- } \
- } \
- if ( !(MKL_DC_IS_ONE(alpha)) ) { \
- for ( i = 0; i < m; i++ ) { \
- MKL_DC_MUL(B[i + k * ldb], alpha, B[i + k * ldb]); \
- } \
- } \
- } \
- } else { \
- for (k = 0; k < n; k++) { \
- for(j = k+1; j < n; j++) { \
- mkl_dc_type temp; \
- MKL_DC_MOV(temp, A[j+k*lda]); \
- MKL_DC_PRAGMA_VECTOR \
- for (i = 0; i < m; i++) { \
- mkl_dc_type b, temp1; \
- b = B[i + k * ldb]; \
- MKL_DC_MUL(temp1, temp, b); \
- MKL_DC_SUB(B[i + j * ldb], B[i + j * ldb], temp1); \
- } \
- } \
- if ( !(MKL_DC_IS_ONE(alpha)) ) { \
- for ( i = 0; i < m; i++ ) { \
- MKL_DC_MUL(B[i + k * ldb], alpha, B[i + k * ldb]); \
- } \
- } \
- } \
- } \
- } while (0)
- #define mkl_dc_trsm_rxcu_mn_pst(uplo, m, n, alpha, A, lda, B, ldb) \
- do { \
- MKL_INT i, j, k; \
- if ( MKL_DC_MisU(uplo) ) { \
- for (k = n-1; k >= 0; k--) { \
- for(j = 0; j <= k-1; j++) { \
- mkl_dc_type temp; \
- temp = A[j+k*lda]; \
- MKL_DC_CONJ(temp, temp); \
- MKL_DC_PRAGMA_VECTOR \
- for (i = 0; i < m; i++) { \
- mkl_dc_type b, temp1; \
- b = B[i + k * ldb]; \
- MKL_DC_MUL(temp1, temp, b); \
- MKL_DC_SUB(B[i + j * ldb], B[i + j * ldb], temp1); \
- } \
- } \
- if ( !(MKL_DC_IS_ONE(alpha)) ) { \
- for ( i = 0; i < m; i++ ) { \
- MKL_DC_MUL(B[i + k * ldb], alpha, B[i + k * ldb]); \
- } \
- } \
- } \
- } else { \
- for (k = 0; k < n; k++) { \
- for(j = k+1; j < n; j++) { \
- mkl_dc_type temp; \
- temp = A[j+k*lda]; \
- MKL_DC_CONJ(temp, temp); \
- MKL_DC_PRAGMA_VECTOR \
- for (i = 0; i < m; i++) { \
- mkl_dc_type b, temp1; \
- b = B[i + k * ldb]; \
- MKL_DC_MUL(temp1, temp, b); \
- MKL_DC_SUB(B[i + j * ldb], B[i + j * ldb], temp1); \
- } \
- } \
- if ( !(MKL_DC_IS_ONE(alpha)) ) { \
- for ( i = 0; i < m; i++ ) { \
- MKL_DC_MUL(B[i + k * ldb], alpha, B[i + k * ldb]); \
- } \
- } \
- } \
- } \
- } while (0)
- static __inline void mkl_dc_trsm(const char * SIDE, const char * UPLO,
- const char * TRANSA, const char * DIAG,
- const MKL_INT * M, const MKL_INT * N,
- const mkl_dc_type * ALPHA,
- const mkl_dc_type * A, const MKL_INT * LDA,
- mkl_dc_type * B, const MKL_INT * LDB)
- {
- int AisN;
- int lside, unit;
- #ifndef MKL_REAL_DATA_TYPE
- int noconj;
- #endif
- mkl_dc_type alpha = *ALPHA;
- MKL_INT m = *M, n = *N;
- MKL_INT lda = *LDA, ldb = *LDB;
- char uplo = *UPLO;
- if (m <= 0 || n <= 0 )
- return;
- AisN = MKL_DC_MisN(*TRANSA);
- lside = MKL_DC_MisL(*SIDE);
- unit = MKL_DC_MisU(*DIAG);
- #ifndef MKL_REAL_DATA_TYPE
- noconj = MKL_DC_MisT(*TRANSA);
- #endif
- if (MKL_DC_IS_ZERO(alpha)) {
- MKL_INT i, j;
- for (j = 0; j < n; j++)
- #pragma vector
- for (i = 0; i < m; i++)
- MKL_DC_SET_ZERO(B[i + ldb * j]);
- return;
- }
- if (lside)
- if (AisN)
- if (unit)
- mkl_dc_trsm_lxnx_mn_pst(uplo, m, n, alpha, A, lda, B, ldb, MKL_DC_NOOP);
- else
- mkl_dc_trsm_lxnx_mn_pst(uplo, m, n, alpha, A, lda, B, ldb, MKL_DC_DIV_B);
- else
- #ifndef MKL_REAL_DATA_TYPE
- if (!noconj)
- if (unit)
- mkl_dc_trsm_lxcx_mn_pst(uplo, m, n, alpha, A, lda, B, ldb, MKL_DC_NOOP);
- else
- mkl_dc_trsm_lxcx_mn_pst(uplo, m, n, alpha, A, lda, B, ldb, MKL_DC_DIV_B);
- else
- #endif
- if (unit)
- mkl_dc_trsm_lxtx_mn_pst(uplo, m, n, alpha, A, lda, B, ldb, MKL_DC_NOOP);
- else
- mkl_dc_trsm_lxtx_mn_pst(uplo, m, n, alpha, A, lda, B, ldb, MKL_DC_DIV_B);
- else
- if (AisN)
- if (unit)
- mkl_dc_trsm_rxnu_mn_pst(uplo, m, n, alpha, A, lda, B, ldb);
- else
- mkl_dc_trsm_rxnn_mn_pst(uplo, m, n, alpha, A, lda, B, ldb);
- else
- #ifndef MKL_REAL_DATA_TYPE
- if (!noconj)
- if (unit)
- mkl_dc_trsm_rxcu_mn_pst(uplo, m, n, alpha, A, lda, B, ldb);
- else
- mkl_dc_trsm_rxcn_mn_pst(uplo, m, n, alpha, A, lda, B, ldb);
- else
- #endif
- if (unit)
- mkl_dc_trsm_rxtu_mn_pst(uplo, m, n, alpha, A, lda, B, ldb);
- else
- mkl_dc_trsm_rxtn_mn_pst(uplo, m, n, alpha, A, lda, B, ldb);
- }
- /* ?SYRK */
- #define mkl_dc_syrk_xx_nk_pst_beta(uplo, n, k, alpha, A, lda, beta, C, ldc, \
- c_op, a_access) \
- do { \
- MKL_INT i, j, l; \
- if (MKL_DC_MisU(uplo)) { \
- for (j = 0; j < n; j++) \
- MKL_DC_PRAGMA_VECTOR \
- for (i = 0; i <= j; i++) { \
- mkl_dc_type c, temp; MKL_DC_SET_ZERO(temp); \
- for (l = 0; l < k; l++) { \
- mkl_dc_type a, b, temp1; \
- a = a_access(A, lda, j, l); \
- b = a_access(A, lda, i, l); \
- MKL_DC_MUL(temp1, a, b); \
- MKL_DC_ADD(temp, temp, temp1); \
- } \
- MKL_DC_MUL(temp, alpha, temp); \
- c = C[i + j * ldc]; \
- c_op(c, beta); \
- MKL_DC_ADD(C[i + j * ldc], c, temp); \
- } \
- } else { \
- for (j = 0; j < n; j++) \
- MKL_DC_PRAGMA_VECTOR \
- for (i = j; i < n; i++) { \
- mkl_dc_type c, temp; MKL_DC_SET_ZERO(temp); \
- for (l = 0; l < k; l++) { \
- mkl_dc_type a, b, temp1; \
- a = a_access(A, lda, j, l); \
- b = a_access(A, lda, i, l); \
- MKL_DC_MUL(temp1, a, b); \
- MKL_DC_ADD(temp, temp, temp1); \
- } \
- MKL_DC_MUL(temp, alpha, temp); \
- c = C[i + j * ldc]; \
- c_op(c, beta); \
- MKL_DC_ADD(C[i + j * ldc], c, temp); \
- } \
- } \
- } while (0)
- #define mkl_dc_syrk_xx_nk_pst(uplo, n, k, alpha, A, lda, beta, C, ldc, \
- a_access) \
- do { \
- if (MKL_DC_IS_ZERO(beta)) \
- mkl_dc_syrk_xx_nk_pst_beta(uplo, n, k, alpha, A, lda, beta, C, ldc, \
- MKL_DC_ZERO_C, a_access); \
- else \
- mkl_dc_syrk_xx_nk_pst_beta(uplo, n, k, alpha, A, lda, beta, C, ldc, \
- MKL_DC_MUL_C, a_access); \
- } while (0)
- static __inline void mkl_dc_syrk(const char * UPLO, const char * TRANS,
- const MKL_INT * N, const MKL_INT * K,
- const mkl_dc_type * ALPHA,
- const mkl_dc_type * A, const MKL_INT * LDA,
- const mkl_dc_type * BETA,
- mkl_dc_type * C, const MKL_INT * LDC)
- {
- int AisN, CisU;
- mkl_dc_type alpha = *ALPHA, beta = *BETA;
- MKL_INT n = *N, k = *K;
- MKL_INT lda = *LDA, ldc = *LDC;
- char uplo = *UPLO;
- if (n <= 0 || ((MKL_DC_IS_ZERO(alpha) || k <= 0) && MKL_DC_IS_ONE(beta)))
- return;
- AisN = MKL_DC_MisN(*TRANS);
- CisU = MKL_DC_MisU(*UPLO);
- if (MKL_DC_IS_ZERO(alpha)) {
- MKL_INT i, j;
- if (MKL_DC_IS_ZERO(beta))
- if (CisU)
- for (j = 0; j < n; j++)
- #pragma vector
- for (i = 0; i <= j; i++)
- MKL_DC_SET_ZERO(C[i + ldc * j]);
- else
- for (j = 0; j < n; j++)
- #pragma vector
- for (i = j; i < n; i++)
- MKL_DC_SET_ZERO(C[i + ldc * j]);
- else
- if (CisU)
- for (j = 0; j < n; j++)
- #pragma vector
- for (i = 0; i <= j; i++)
- MKL_DC_MUL(C[i + ldc * j], beta, C[i + ldc * j]);
- else
- for (j = 0; j < n; j++)
- #pragma vector
- for (i = j; i < n; i++)
- MKL_DC_MUL(C[i + ldc * j], beta, C[i + ldc * j]);
- return;
- }
- if (AisN)
- mkl_dc_syrk_xx_nk_pst(uplo, n, k, alpha, A, lda, beta, C, ldc, MKL_DC_MN);
- else
- mkl_dc_syrk_xx_nk_pst(uplo, n, k, alpha, A, lda, beta, C, ldc, MKL_DC_MT);
- }
- static __inline void mkl_dc_axpy (const MKL_INT *N, const mkl_dc_type *ALPHA, const mkl_dc_type *x, const MKL_INT *INCX, mkl_dc_type *y, const MKL_INT *INCY)
- {
- MKL_INT i;
- MKL_INT n = *N, ix = 0, iy = 0;
- if (*INCX == 1 && *INCY == 1) {
- #pragma vector vecremainder
- for(i = 0; i < n; i++) MKL_DC_MUL_ADD(y[i], (*ALPHA), x[i], y[i]);
- } else {
- if (*INCX < 0) ix = (-(n) + 1) * *INCX;
- if (*INCY < 0) iy = (-(n) + 1) * *INCY;
- for(i = 0; i < n; i++, ix += *INCX, iy += *INCY) MKL_DC_MUL_ADD(y[iy], (*ALPHA), x[ix], y[iy]);
- }
- }
- #if defined(MKL_DOUBLE) || defined(MKL_SINGLE)
- static __inline mkl_dc_type mkl_dc_dot (const MKL_INT* N, const mkl_dc_type *x,
- const MKL_INT* INCX, const mkl_dc_type *y, const MKL_INT* INCY)
- {
- MKL_INT n = *N, ix = 0, iy = 0;
- mkl_dc_type ret = 0.0;
- MKL_INT i;
- if (*INCX == 1 && *INCY == 1)
- #pragma vector vecremainder
- for(i = 0; i < n; i++) ret += y[i] * x[i];
- else {
- if (*INCX < 0) ix = (-(n) + 1) * *INCX;
- if (*INCY < 0) iy = (-(n) + 1) * *INCY;
- for(i = 0; i < n; i++, ix += *INCX, iy += *INCY) ret += y[iy] * x[ix];
- }
- return ret;
- }
- #endif
- #undef MKL_DC_DOT_CONVERT
- #undef mkl_dc_gemm_xx_mnk_pst_beta
- #undef mkl_dc_gemm_xx_mnk_pst
- #undef mkl_dc_syrk_xx_nk_pst_beta
- #undef mkl_dc_syrk_xx_nk_pst
- #undef mkl_dc_trsm_lxnx_mn_pst
- #undef mkl_dc_trsm_lxtx_mn_pst
- #undef mkl_dc_trsm_lxcx_mn_pst
- #undef mkl_dc_trsm_rxnn_mn_pst
- #undef mkl_dc_trsm_rxnu_mn_pst
- #undef mkl_dc_trsm_rxtn_mn_pst
- #undef mkl_dc_trsm_rxtu_mn_pst
- #undef mkl_dc_trsm_rxcn_mn_pst
- #undef mkl_dc_trsm_rxcu_mn_pst
- #undef mkl_dc_axpy
- #undef mkl_dc_dot
|