diff --git a/blas/blas.h b/blas/blas.h index 2f218a8b5ac857f3bfa61817a51b8b012336e075..220e956dbc9e31495383ccd9b55fb0a09155c7de 100644 --- a/blas/blas.h +++ b/blas/blas.h @@ -59,6 +59,19 @@ EIGEN_BLAS_API void BLASFUNC(caxpyc)(const int *, const float *, const float *, EIGEN_BLAS_API void BLASFUNC(zaxpyc)(const int *, const double *, const double *, const int *, double *, const int *); EIGEN_BLAS_API void BLASFUNC(xaxpyc)(const int *, const double *, const double *, const int *, double *, const int *); +EIGEN_BLAS_API void BLASFUNC(saxpby)(const int *, const float *, const float *, const int *, const float *, float *, + const int *); +EIGEN_BLAS_API void BLASFUNC(daxpby)(const int *, const double *, const double *, const int *, const double *, double *, + const int *); +EIGEN_BLAS_API void BLASFUNC(qaxpby)(const int *, const double *, const double *, const int *, const double *, double *, + const int *); +EIGEN_BLAS_API void BLASFUNC(caxpby)(const int *, const float *, const float *, const int *, const float *, float *, + const int *); +EIGEN_BLAS_API void BLASFUNC(zaxpby)(const int *, const double *, const double *, const int *, const double *, double *, + const int *); +EIGEN_BLAS_API void BLASFUNC(xaxpby)(const int *, const double *, const double *, const int *, const double *, double *, + const int *); + EIGEN_BLAS_API void BLASFUNC(scopy)(int *, float *, int *, float *, int *); EIGEN_BLAS_API void BLASFUNC(dcopy)(int *, double *, int *, double *, int *); EIGEN_BLAS_API void BLASFUNC(qcopy)(int *, double *, int *, double *, int *); diff --git a/blas/eigen_blas.def b/blas/eigen_blas.def index fb18d44a046e52e7a0da06a626d9a8ab96908a14..20f26dc764bbbf9c9ed502796ed69287dba7377a 100644 --- a/blas/eigen_blas.def +++ b/blas/eigen_blas.def @@ -13,6 +13,10 @@ EXPORTS zaxpy_ ; caxpyc_ ; zaxpyc_ + saxpby_, + daxpby_, + caxpby_, + zaxpby_, scopy_ dcopy_ ccopy_ @@ -91,8 +95,8 @@ EXPORTS ; dmin_ ; cmin_ ; zmin_ - - + + ; Level 2 sgemv_ dgemv_ diff --git a/blas/level1_impl.h b/blas/level1_impl.h index a65af92d5e2c2603d0538204093c9a1e9df38ca7..085b35650cc36da7729921372e3e3c88bb106264 100644 --- a/blas/level1_impl.h +++ b/blas/level1_impl.h @@ -22,11 +22,42 @@ EIGEN_BLAS_FUNC(axpy) else if (*incx > 0 && *incy > 0) make_vector(y, *n, *incy) += alpha * make_vector(x, *n, *incx); else if (*incx > 0 && *incy < 0) - make_vector(y, *n, -*incy).reverse() += alpha * make_vector(x, *n, *incx); + make_vector(y, *n, -*incy) += alpha * make_vector(x, *n, *incx).reverse(); else if (*incx < 0 && *incy > 0) make_vector(y, *n, *incy) += alpha * make_vector(x, *n, -*incx).reverse(); else if (*incx < 0 && *incy < 0) - make_vector(y, *n, -*incy).reverse() += alpha * make_vector(x, *n, -*incx).reverse(); + make_vector(y, *n, -*incy) += alpha * make_vector(x, *n, -*incx); +} + +EIGEN_BLAS_FUNC(axpby) +(const int *pn, const RealScalar *palpha, const RealScalar *px, const int *pincx, const RealScalar *pbeta, + RealScalar *py, const int *pincy) { + const Scalar *x = reinterpret_cast(px); + Scalar *y = reinterpret_cast(py); + const Scalar alpha = *reinterpret_cast(palpha); + const Scalar beta = *reinterpret_cast(pbeta); + const int n = *pn; + + if (n <= 0) return; + + if (Eigen::numext::equal_strict(beta, Scalar(1))) { + EIGEN_BLAS_FUNC_NAME(axpy)(pn, palpha, px, pincx, py, pincy); + return; + } + + const int incx = *pincx; + const int incy = *pincy; + + if (incx == 1 && incy == 1) + make_vector(y, n) = alpha * make_vector(x, n) + beta * make_vector(y, n); + else if (incx > 0 && incy > 0) + make_vector(y, n, incy) = alpha * make_vector(x, n, incx) + beta * make_vector(y, n, incy); + else if (incx > 0 && incy < 0) + make_vector(y, n, -incy) = alpha * make_vector(x, n, incx).reverse() + beta * make_vector(y, n, -incy); + else if (incx < 0 && incy > 0) + make_vector(y, n, incy) = alpha * make_vector(x, n, -incx).reverse() + beta * make_vector(y, n, incy); + else if (incx < 0 && incy < 0) + make_vector(y, n, -incy) = alpha * make_vector(x, n, -incx) + beta * make_vector(y, n, -incy); } EIGEN_BLAS_FUNC(copy)(int *n, RealScalar *px, int *incx, RealScalar *py, int *incy) {