25 #ifndef EIGEN_GENERAL_BLOCK_PANEL_H
26 #define EIGEN_GENERAL_BLOCK_PANEL_H
32 template<
typename _LhsScalar,
typename _RhsScalar,
bool _ConjLhs=false,
bool _ConjRhs=false>
82 template<
typename LhsScalar,
typename RhsScalar,
int KcFactor>
93 std::ptrdiff_t l1, l2;
95 typedef gebp_traits<LhsScalar,RhsScalar> Traits;
97 kdiv = KcFactor * 2 * Traits::nr
98 * Traits::RhsProgress *
sizeof(RhsScalar),
99 mr = gebp_traits<LhsScalar,RhsScalar>::mr,
100 mr_mask = (0xffffffff/mr)*mr
104 k = std::min<std::ptrdiff_t>(k, l1/kdiv);
105 std::ptrdiff_t _m = k>0 ? l2/(4 *
sizeof(LhsScalar) * k) : 0;
106 if(_m<m) m = _m & mr_mask;
109 template<
typename LhsScalar,
typename RhsScalar>
112 computeProductBlockingSizes<LhsScalar,RhsScalar,1>(k, m, n);
115 #ifdef EIGEN_HAS_FUSE_CJMADD
116 #define MADD(CJ,A,B,C,T) C = CJ.pmadd(A,B,C);
121 template<
typename CJ,
typename A,
typename B,
typename C,
typename T>
struct gebp_madd_selector {
128 template<
typename CJ,
typename T>
struct gebp_madd_selector<CJ,T,T,T,T> {
131 t = b; t = cj.pmul(a,t); c =
padd(c,t);
135 template<
typename CJ,
typename A,
typename B,
typename C,
typename T>
138 gebp_madd_selector<CJ,A,B,C,T>::run(cj,a,b,c,t);
141 #define MADD(CJ,A,B,C,T) gebp_madd(CJ,A,B,C,T);
155 template<
typename _LhsScalar,
typename _RhsScalar,
bool _ConjLhs,
bool _ConjRhs>
159 typedef _LhsScalar LhsScalar;
160 typedef _RhsScalar RhsScalar;
161 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
166 Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable,
167 LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
168 RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
169 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
174 nr = NumberOfRegisters/4,
177 mr = 2 * LhsPacketSize,
179 WorkSpaceFactor = nr * RhsPacketSize,
181 LhsProgress = LhsPacketSize,
182 RhsProgress = RhsPacketSize
185 typedef typename packet_traits<LhsScalar>::type _LhsPacket;
186 typedef typename packet_traits<RhsScalar>::type _RhsPacket;
187 typedef typename packet_traits<ResScalar>::type _ResPacket;
189 typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
190 typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
191 typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
193 typedef ResPacket AccPacket;
197 p = pset1<ResPacket>(ResScalar(0));
203 pstore1<RhsPacket>(&b[k*RhsPacketSize], rhs[k]);
208 dest = pload<RhsPacket>(b);
213 dest = pload<LhsPacket>(a);
216 EIGEN_STRONG_INLINE void madd(
const LhsPacket& a,
const RhsPacket& b, AccPacket& c, AccPacket& tmp)
const
218 tmp = b; tmp =
pmul(a,tmp); c =
padd(c,tmp);
221 EIGEN_STRONG_INLINE void acc(
const AccPacket& c,
const ResPacket& alpha, ResPacket& r)
const
223 r =
pmadd(c,alpha,r);
231 template<
typename RealScalar,
bool _ConjLhs>
232 class gebp_traits<std::complex<RealScalar>, RealScalar, _ConjLhs, false>
235 typedef std::complex<RealScalar> LhsScalar;
236 typedef RealScalar RhsScalar;
237 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
242 Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable,
243 LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
244 RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
245 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
248 nr = NumberOfRegisters/4,
249 mr = 2 * LhsPacketSize,
250 WorkSpaceFactor = nr*RhsPacketSize,
252 LhsProgress = LhsPacketSize,
253 RhsProgress = RhsPacketSize
256 typedef typename packet_traits<LhsScalar>::type _LhsPacket;
257 typedef typename packet_traits<RhsScalar>::type _RhsPacket;
258 typedef typename packet_traits<ResScalar>::type _ResPacket;
260 typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
261 typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
262 typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
264 typedef ResPacket AccPacket;
268 p = pset1<ResPacket>(ResScalar(0));
274 pstore1<RhsPacket>(&b[k*RhsPacketSize], rhs[k]);
279 dest = pload<RhsPacket>(b);
284 dest = pload<LhsPacket>(a);
287 EIGEN_STRONG_INLINE void madd(
const LhsPacket& a,
const RhsPacket& b, AccPacket& c, RhsPacket& tmp)
const
289 madd_impl(a, b, c, tmp,
typename conditional<Vectorizable,true_type,false_type>::type());
292 EIGEN_STRONG_INLINE void madd_impl(
const LhsPacket& a,
const RhsPacket& b, AccPacket& c, RhsPacket& tmp,
const true_type&)
const
294 tmp = b; tmp =
pmul(a.v,tmp); c.v =
padd(c.v,tmp);
297 EIGEN_STRONG_INLINE void madd_impl(
const LhsScalar& a,
const RhsScalar& b, ResScalar& c, RhsScalar& ,
const false_type&)
const
302 EIGEN_STRONG_INLINE void acc(
const AccPacket& c,
const ResPacket& alpha, ResPacket& r)
const
304 r = cj.pmadd(c,alpha,r);
308 conj_helper<ResPacket,ResPacket,ConjLhs,false> cj;
311 template<
typename RealScalar,
bool _ConjLhs,
bool _ConjRhs>
312 class gebp_traits<std::complex<RealScalar>, std::complex<RealScalar>, _ConjLhs, _ConjRhs >
315 typedef std::complex<RealScalar> Scalar;
316 typedef std::complex<RealScalar> LhsScalar;
317 typedef std::complex<RealScalar> RhsScalar;
318 typedef std::complex<RealScalar> ResScalar;
323 Vectorizable = packet_traits<RealScalar>::Vectorizable
324 && packet_traits<Scalar>::Vectorizable,
325 RealPacketSize = Vectorizable ? packet_traits<RealScalar>::size : 1,
326 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
329 mr = 2 * ResPacketSize,
330 WorkSpaceFactor = Vectorizable ? 2*nr*RealPacketSize : nr,
332 LhsProgress = ResPacketSize,
333 RhsProgress = Vectorizable ? 2*ResPacketSize : 1
336 typedef typename packet_traits<RealScalar>::type RealPacket;
337 typedef typename packet_traits<Scalar>::type ScalarPacket;
344 typedef typename conditional<Vectorizable,RealPacket, Scalar>::type LhsPacket;
345 typedef typename conditional<Vectorizable,DoublePacket,Scalar>::type RhsPacket;
346 typedef typename conditional<Vectorizable,ScalarPacket,Scalar>::type ResPacket;
347 typedef typename conditional<Vectorizable,DoublePacket,Scalar>::type AccPacket;
353 p.first = pset1<RealPacket>(RealScalar(0));
354 p.second = pset1<RealPacket>(RealScalar(0));
367 pstore1<RealPacket>((RealScalar*)&b[k*ResPacketSize*2+0],
real(rhs[k]));
368 pstore1<RealPacket>((RealScalar*)&b[k*ResPacketSize*2+ResPacketSize],
imag(rhs[k]));
379 dest.first = pload<RealPacket>((
const RealScalar*)b);
380 dest.second = pload<RealPacket>((
const RealScalar*)(b+ResPacketSize));
386 dest = pload<LhsPacket>((
const typename unpacket_traits<LhsPacket>::type*)(a));
389 EIGEN_STRONG_INLINE void madd(
const LhsPacket& a,
const RhsPacket& b, DoublePacket& c, RhsPacket& )
const
391 c.first =
padd(
pmul(a,b.first), c.first);
392 c.second =
padd(
pmul(a,b.second),c.second);
395 EIGEN_STRONG_INLINE void madd(
const LhsPacket& a,
const RhsPacket& b, ResPacket& c, RhsPacket& )
const
400 EIGEN_STRONG_INLINE void acc(
const Scalar& c,
const Scalar& alpha, Scalar& r)
const { r += alpha * c; }
402 EIGEN_STRONG_INLINE void acc(
const DoublePacket& c,
const ResPacket& alpha, ResPacket& r)
const
406 if((!ConjLhs)&&(!ConjRhs))
409 tmp =
padd(ResPacket(c.first),tmp);
411 else if((!ConjLhs)&&(ConjRhs))
414 tmp =
padd(ResPacket(c.first),tmp);
416 else if((ConjLhs)&&(!ConjRhs))
419 tmp =
padd(
pconj(ResPacket(c.first)),tmp);
421 else if((ConjLhs)&&(ConjRhs))
424 tmp =
psub(
pconj(ResPacket(c.first)),tmp);
427 r =
pmadd(tmp,alpha,r);
431 conj_helper<LhsScalar,RhsScalar,ConjLhs,ConjRhs> cj;
434 template<
typename RealScalar,
bool _ConjRhs>
435 class gebp_traits<RealScalar, std::complex<RealScalar>, false, _ConjRhs >
438 typedef std::complex<RealScalar> Scalar;
439 typedef RealScalar LhsScalar;
440 typedef Scalar RhsScalar;
441 typedef Scalar ResScalar;
446 Vectorizable = packet_traits<RealScalar>::Vectorizable
447 && packet_traits<Scalar>::Vectorizable,
448 LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
449 RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
450 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
454 mr = 2*ResPacketSize,
455 WorkSpaceFactor = nr*RhsPacketSize,
457 LhsProgress = ResPacketSize,
458 RhsProgress = ResPacketSize
461 typedef typename packet_traits<LhsScalar>::type _LhsPacket;
462 typedef typename packet_traits<RhsScalar>::type _RhsPacket;
463 typedef typename packet_traits<ResScalar>::type _ResPacket;
465 typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
466 typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
467 typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
469 typedef ResPacket AccPacket;
473 p = pset1<ResPacket>(ResScalar(0));
479 pstore1<RhsPacket>(&b[k*RhsPacketSize], rhs[k]);
484 dest = pload<RhsPacket>(b);
489 dest = ploaddup<LhsPacket>(a);
492 EIGEN_STRONG_INLINE void madd(
const LhsPacket& a,
const RhsPacket& b, AccPacket& c, RhsPacket& tmp)
const
494 madd_impl(a, b, c, tmp,
typename conditional<Vectorizable,true_type,false_type>::type());
497 EIGEN_STRONG_INLINE void madd_impl(
const LhsPacket& a,
const RhsPacket& b, AccPacket& c, RhsPacket& tmp,
const true_type&)
const
499 tmp = b; tmp.v =
pmul(a,tmp.v); c =
padd(c,tmp);
502 EIGEN_STRONG_INLINE void madd_impl(
const LhsScalar& a,
const RhsScalar& b, ResScalar& c, RhsScalar& ,
const false_type&)
const
507 EIGEN_STRONG_INLINE void acc(
const AccPacket& c,
const ResPacket& alpha, ResPacket& r)
const
509 r = cj.pmadd(alpha,c,r);
513 conj_helper<ResPacket,ResPacket,false,ConjRhs> cj;
523 template<
typename LhsScalar,
typename RhsScalar,
typename Index,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
526 typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> Traits;
527 typedef typename Traits::ResScalar ResScalar;
528 typedef typename Traits::LhsPacket LhsPacket;
529 typedef typename Traits::RhsPacket RhsPacket;
530 typedef typename Traits::ResPacket ResPacket;
531 typedef typename Traits::AccPacket AccPacket;
534 Vectorizable = Traits::Vectorizable,
535 LhsProgress = Traits::LhsProgress,
536 RhsProgress = Traits::RhsProgress,
537 ResPacketSize = Traits::ResPacketSize
541 void operator()(ResScalar* res, Index resStride,
const LhsScalar* blockA,
const RhsScalar* blockB, Index rows, Index depth, Index cols, ResScalar alpha,
542 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0, RhsScalar* unpackedB = 0)
546 if(strideA==-1) strideA = depth;
547 if(strideB==-1) strideB = depth;
548 conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
550 Index packet_cols = (cols/nr) * nr;
551 const Index peeled_mc = (rows/mr)*mr;
553 const Index peeled_mc2 = peeled_mc + (rows-peeled_mc >= LhsProgress ? LhsProgress : 0);
554 const Index peeled_kc = (depth/4)*4;
557 unpackedB =
const_cast<RhsScalar*
>(blockB - strideB * nr * RhsProgress);
560 for(Index j2=0; j2<packet_cols; j2+=nr)
562 traits.unpackRhs(depth*nr,&blockB[j2*strideB+offsetB*nr],unpackedB);
567 for(Index i=0; i<peeled_mc; i+=mr)
569 const LhsScalar* blA = &blockA[i*strideA+offsetA*mr];
573 AccPacket C0, C1, C2, C3, C4, C5, C6, C7;
576 if(nr==4) traits.initAcc(C2);
577 if(nr==4) traits.initAcc(C3);
580 if(nr==4) traits.initAcc(C6);
581 if(nr==4) traits.initAcc(C7);
583 ResScalar* r0 = &res[(j2+0)*resStride + i];
584 ResScalar* r1 = r0 + resStride;
585 ResScalar* r2 = r1 + resStride;
586 ResScalar* r3 = r2 + resStride;
596 const RhsScalar* blB = unpackedB;
597 for(Index k=0; k<peeled_kc; k+=4)
606 traits.loadLhs(&blA[0*LhsProgress], A0);
607 traits.loadLhs(&blA[1*LhsProgress], A1);
608 traits.loadRhs(&blB[0*RhsProgress], B_0);
609 traits.madd(A0,B_0,C0,T0);
610 traits.madd(A1,B_0,C4,B_0);
611 traits.loadRhs(&blB[1*RhsProgress], B_0);
612 traits.madd(A0,B_0,C1,T0);
613 traits.madd(A1,B_0,C5,B_0);
615 traits.loadLhs(&blA[2*LhsProgress], A0);
616 traits.loadLhs(&blA[3*LhsProgress], A1);
617 traits.loadRhs(&blB[2*RhsProgress], B_0);
618 traits.madd(A0,B_0,C0,T0);
619 traits.madd(A1,B_0,C4,B_0);
620 traits.loadRhs(&blB[3*RhsProgress], B_0);
621 traits.madd(A0,B_0,C1,T0);
622 traits.madd(A1,B_0,C5,B_0);
624 traits.loadLhs(&blA[4*LhsProgress], A0);
625 traits.loadLhs(&blA[5*LhsProgress], A1);
626 traits.loadRhs(&blB[4*RhsProgress], B_0);
627 traits.madd(A0,B_0,C0,T0);
628 traits.madd(A1,B_0,C4,B_0);
629 traits.loadRhs(&blB[5*RhsProgress], B_0);
630 traits.madd(A0,B_0,C1,T0);
631 traits.madd(A1,B_0,C5,B_0);
633 traits.loadLhs(&blA[6*LhsProgress], A0);
634 traits.loadLhs(&blA[7*LhsProgress], A1);
635 traits.loadRhs(&blB[6*RhsProgress], B_0);
636 traits.madd(A0,B_0,C0,T0);
637 traits.madd(A1,B_0,C4,B_0);
638 traits.loadRhs(&blB[7*RhsProgress], B_0);
639 traits.madd(A0,B_0,C1,T0);
640 traits.madd(A1,B_0,C5,B_0);
647 RhsPacket B_0, B1, B2, B3;
650 traits.loadLhs(&blA[0*LhsProgress], A0);
651 traits.loadLhs(&blA[1*LhsProgress], A1);
652 traits.loadRhs(&blB[0*RhsProgress], B_0);
653 traits.loadRhs(&blB[1*RhsProgress], B1);
655 traits.madd(A0,B_0,C0,T0);
656 traits.loadRhs(&blB[2*RhsProgress], B2);
657 traits.madd(A1,B_0,C4,B_0);
658 traits.loadRhs(&blB[3*RhsProgress], B3);
659 traits.loadRhs(&blB[4*RhsProgress], B_0);
660 traits.madd(A0,B1,C1,T0);
661 traits.madd(A1,B1,C5,B1);
662 traits.loadRhs(&blB[5*RhsProgress], B1);
663 traits.madd(A0,B2,C2,T0);
664 traits.madd(A1,B2,C6,B2);
665 traits.loadRhs(&blB[6*RhsProgress], B2);
666 traits.madd(A0,B3,C3,T0);
667 traits.loadLhs(&blA[2*LhsProgress], A0);
668 traits.madd(A1,B3,C7,B3);
669 traits.loadLhs(&blA[3*LhsProgress], A1);
670 traits.loadRhs(&blB[7*RhsProgress], B3);
671 traits.madd(A0,B_0,C0,T0);
672 traits.madd(A1,B_0,C4,B_0);
673 traits.loadRhs(&blB[8*RhsProgress], B_0);
674 traits.madd(A0,B1,C1,T0);
675 traits.madd(A1,B1,C5,B1);
676 traits.loadRhs(&blB[9*RhsProgress], B1);
677 traits.madd(A0,B2,C2,T0);
678 traits.madd(A1,B2,C6,B2);
679 traits.loadRhs(&blB[10*RhsProgress], B2);
680 traits.madd(A0,B3,C3,T0);
681 traits.loadLhs(&blA[4*LhsProgress], A0);
682 traits.madd(A1,B3,C7,B3);
683 traits.loadLhs(&blA[5*LhsProgress], A1);
684 traits.loadRhs(&blB[11*RhsProgress], B3);
686 traits.madd(A0,B_0,C0,T0);
687 traits.madd(A1,B_0,C4,B_0);
688 traits.loadRhs(&blB[12*RhsProgress], B_0);
689 traits.madd(A0,B1,C1,T0);
690 traits.madd(A1,B1,C5,B1);
691 traits.loadRhs(&blB[13*RhsProgress], B1);
692 traits.madd(A0,B2,C2,T0);
693 traits.madd(A1,B2,C6,B2);
694 traits.loadRhs(&blB[14*RhsProgress], B2);
695 traits.madd(A0,B3,C3,T0);
696 traits.loadLhs(&blA[6*LhsProgress], A0);
697 traits.madd(A1,B3,C7,B3);
698 traits.loadLhs(&blA[7*LhsProgress], A1);
699 traits.loadRhs(&blB[15*RhsProgress], B3);
700 traits.madd(A0,B_0,C0,T0);
701 traits.madd(A1,B_0,C4,B_0);
702 traits.madd(A0,B1,C1,T0);
703 traits.madd(A1,B1,C5,B1);
704 traits.madd(A0,B2,C2,T0);
705 traits.madd(A1,B2,C6,B2);
706 traits.madd(A0,B3,C3,T0);
707 traits.madd(A1,B3,C7,B3);
710 blB += 4*nr*RhsProgress;
714 for(Index k=peeled_kc; k<depth; k++)
722 traits.loadLhs(&blA[0*LhsProgress], A0);
723 traits.loadLhs(&blA[1*LhsProgress], A1);
724 traits.loadRhs(&blB[0*RhsProgress], B_0);
725 traits.madd(A0,B_0,C0,T0);
726 traits.madd(A1,B_0,C4,B_0);
727 traits.loadRhs(&blB[1*RhsProgress], B_0);
728 traits.madd(A0,B_0,C1,T0);
729 traits.madd(A1,B_0,C5,B_0);
734 RhsPacket B_0, B1, B2, B3;
737 traits.loadLhs(&blA[0*LhsProgress], A0);
738 traits.loadLhs(&blA[1*LhsProgress], A1);
739 traits.loadRhs(&blB[0*RhsProgress], B_0);
740 traits.loadRhs(&blB[1*RhsProgress], B1);
742 traits.madd(A0,B_0,C0,T0);
743 traits.loadRhs(&blB[2*RhsProgress], B2);
744 traits.madd(A1,B_0,C4,B_0);
745 traits.loadRhs(&blB[3*RhsProgress], B3);
746 traits.madd(A0,B1,C1,T0);
747 traits.madd(A1,B1,C5,B1);
748 traits.madd(A0,B2,C2,T0);
749 traits.madd(A1,B2,C6,B2);
750 traits.madd(A0,B3,C3,T0);
751 traits.madd(A1,B3,C7,B3);
754 blB += nr*RhsProgress;
760 ResPacket R0, R1, R2, R3, R4, R5, R6;
761 ResPacket alphav = pset1<ResPacket>(alpha);
763 R0 = ploadu<ResPacket>(r0);
764 R1 = ploadu<ResPacket>(r1);
765 R2 = ploadu<ResPacket>(r2);
766 R3 = ploadu<ResPacket>(r3);
767 R4 = ploadu<ResPacket>(r0 + ResPacketSize);
768 R5 = ploadu<ResPacket>(r1 + ResPacketSize);
769 R6 = ploadu<ResPacket>(r2 + ResPacketSize);
770 traits.acc(C0, alphav, R0);
772 R0 = ploadu<ResPacket>(r3 + ResPacketSize);
774 traits.acc(C1, alphav, R1);
775 traits.acc(C2, alphav, R2);
776 traits.acc(C3, alphav, R3);
777 traits.acc(C4, alphav, R4);
778 traits.acc(C5, alphav, R5);
779 traits.acc(C6, alphav, R6);
780 traits.acc(C7, alphav, R0);
785 pstoreu(r0 + ResPacketSize, R4);
786 pstoreu(r1 + ResPacketSize, R5);
787 pstoreu(r2 + ResPacketSize, R6);
788 pstoreu(r3 + ResPacketSize, R0);
792 ResPacket R0, R1, R4;
793 ResPacket alphav = pset1<ResPacket>(alpha);
795 R0 = ploadu<ResPacket>(r0);
796 R1 = ploadu<ResPacket>(r1);
797 R4 = ploadu<ResPacket>(r0 + ResPacketSize);
798 traits.acc(C0, alphav, R0);
800 R0 = ploadu<ResPacket>(r1 + ResPacketSize);
801 traits.acc(C1, alphav, R1);
802 traits.acc(C4, alphav, R4);
803 traits.acc(C5, alphav, R0);
805 pstoreu(r0 + ResPacketSize, R4);
806 pstoreu(r1 + ResPacketSize, R0);
811 if(rows-peeled_mc>=LhsProgress)
814 const LhsScalar* blA = &blockA[i*strideA+offsetA*LhsProgress];
818 AccPacket C0, C1, C2, C3;
821 if(nr==4) traits.initAcc(C2);
822 if(nr==4) traits.initAcc(C3);
825 const RhsScalar* blB = unpackedB;
826 for(Index k=0; k<peeled_kc; k+=4)
833 traits.loadLhs(&blA[0*LhsProgress], A0);
834 traits.loadRhs(&blB[0*RhsProgress], B_0);
835 traits.loadRhs(&blB[1*RhsProgress], B1);
836 traits.madd(A0,B_0,C0,B_0);
837 traits.loadRhs(&blB[2*RhsProgress], B_0);
838 traits.madd(A0,B1,C1,B1);
839 traits.loadLhs(&blA[1*LhsProgress], A0);
840 traits.loadRhs(&blB[3*RhsProgress], B1);
841 traits.madd(A0,B_0,C0,B_0);
842 traits.loadRhs(&blB[4*RhsProgress], B_0);
843 traits.madd(A0,B1,C1,B1);
844 traits.loadLhs(&blA[2*LhsProgress], A0);
845 traits.loadRhs(&blB[5*RhsProgress], B1);
846 traits.madd(A0,B_0,C0,B_0);
847 traits.loadRhs(&blB[6*RhsProgress], B_0);
848 traits.madd(A0,B1,C1,B1);
849 traits.loadLhs(&blA[3*LhsProgress], A0);
850 traits.loadRhs(&blB[7*RhsProgress], B1);
851 traits.madd(A0,B_0,C0,B_0);
852 traits.madd(A0,B1,C1,B1);
857 RhsPacket B_0, B1, B2, B3;
859 traits.loadLhs(&blA[0*LhsProgress], A0);
860 traits.loadRhs(&blB[0*RhsProgress], B_0);
861 traits.loadRhs(&blB[1*RhsProgress], B1);
863 traits.madd(A0,B_0,C0,B_0);
864 traits.loadRhs(&blB[2*RhsProgress], B2);
865 traits.loadRhs(&blB[3*RhsProgress], B3);
866 traits.loadRhs(&blB[4*RhsProgress], B_0);
867 traits.madd(A0,B1,C1,B1);
868 traits.loadRhs(&blB[5*RhsProgress], B1);
869 traits.madd(A0,B2,C2,B2);
870 traits.loadRhs(&blB[6*RhsProgress], B2);
871 traits.madd(A0,B3,C3,B3);
872 traits.loadLhs(&blA[1*LhsProgress], A0);
873 traits.loadRhs(&blB[7*RhsProgress], B3);
874 traits.madd(A0,B_0,C0,B_0);
875 traits.loadRhs(&blB[8*RhsProgress], B_0);
876 traits.madd(A0,B1,C1,B1);
877 traits.loadRhs(&blB[9*RhsProgress], B1);
878 traits.madd(A0,B2,C2,B2);
879 traits.loadRhs(&blB[10*RhsProgress], B2);
880 traits.madd(A0,B3,C3,B3);
881 traits.loadLhs(&blA[2*LhsProgress], A0);
882 traits.loadRhs(&blB[11*RhsProgress], B3);
884 traits.madd(A0,B_0,C0,B_0);
885 traits.loadRhs(&blB[12*RhsProgress], B_0);
886 traits.madd(A0,B1,C1,B1);
887 traits.loadRhs(&blB[13*RhsProgress], B1);
888 traits.madd(A0,B2,C2,B2);
889 traits.loadRhs(&blB[14*RhsProgress], B2);
890 traits.madd(A0,B3,C3,B3);
892 traits.loadLhs(&blA[3*LhsProgress], A0);
893 traits.loadRhs(&blB[15*RhsProgress], B3);
894 traits.madd(A0,B_0,C0,B_0);
895 traits.madd(A0,B1,C1,B1);
896 traits.madd(A0,B2,C2,B2);
897 traits.madd(A0,B3,C3,B3);
900 blB += nr*4*RhsProgress;
901 blA += 4*LhsProgress;
904 for(Index k=peeled_kc; k<depth; k++)
911 traits.loadLhs(&blA[0*LhsProgress], A0);
912 traits.loadRhs(&blB[0*RhsProgress], B_0);
913 traits.loadRhs(&blB[1*RhsProgress], B1);
914 traits.madd(A0,B_0,C0,B_0);
915 traits.madd(A0,B1,C1,B1);
920 RhsPacket B_0, B1, B2, B3;
922 traits.loadLhs(&blA[0*LhsProgress], A0);
923 traits.loadRhs(&blB[0*RhsProgress], B_0);
924 traits.loadRhs(&blB[1*RhsProgress], B1);
925 traits.loadRhs(&blB[2*RhsProgress], B2);
926 traits.loadRhs(&blB[3*RhsProgress], B3);
928 traits.madd(A0,B_0,C0,B_0);
929 traits.madd(A0,B1,C1,B1);
930 traits.madd(A0,B2,C2,B2);
931 traits.madd(A0,B3,C3,B3);
934 blB += nr*RhsProgress;
938 ResPacket R0, R1, R2, R3;
939 ResPacket alphav = pset1<ResPacket>(alpha);
941 ResScalar* r0 = &res[(j2+0)*resStride + i];
942 ResScalar* r1 = r0 + resStride;
943 ResScalar* r2 = r1 + resStride;
944 ResScalar* r3 = r2 + resStride;
946 R0 = ploadu<ResPacket>(r0);
947 R1 = ploadu<ResPacket>(r1);
948 if(nr==4) R2 = ploadu<ResPacket>(r2);
949 if(nr==4) R3 = ploadu<ResPacket>(r3);
951 traits.acc(C0, alphav, R0);
952 traits.acc(C1, alphav, R1);
953 if(nr==4) traits.acc(C2, alphav, R2);
954 if(nr==4) traits.acc(C3, alphav, R3);
961 for(Index i=peeled_mc2; i<rows; i++)
963 const LhsScalar* blA = &blockA[i*strideA+offsetA];
967 ResScalar C0(0), C1(0), C2(0), C3(0);
969 const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
970 for(Index k=0; k<depth; k++)
980 MADD(cj,A0,B_0,C0,B_0);
981 MADD(cj,A0,B1,C1,B1);
986 RhsScalar B_0, B1, B2, B3;
994 MADD(cj,A0,B_0,C0,B_0);
995 MADD(cj,A0,B1,C1,B1);
996 MADD(cj,A0,B2,C2,B2);
997 MADD(cj,A0,B3,C3,B3);
1002 res[(j2+0)*resStride + i] += alpha*C0;
1003 res[(j2+1)*resStride + i] += alpha*C1;
1004 if(nr==4) res[(j2+2)*resStride + i] += alpha*C2;
1005 if(nr==4) res[(j2+3)*resStride + i] += alpha*C3;
1010 for(Index j2=packet_cols; j2<cols; j2++)
1013 traits.unpackRhs(depth, &blockB[j2*strideB+offsetB], unpackedB);
1015 for(Index i=0; i<peeled_mc; i+=mr)
1017 const LhsScalar* blA = &blockA[i*strideA+offsetA*mr];
1027 const RhsScalar* blB = unpackedB;
1028 for(Index k=0; k<depth; k++)
1034 traits.loadLhs(&blA[0*LhsProgress], A0);
1035 traits.loadLhs(&blA[1*LhsProgress], A1);
1036 traits.loadRhs(&blB[0*RhsProgress], B_0);
1037 traits.madd(A0,B_0,C0,T0);
1038 traits.madd(A1,B_0,C4,B_0);
1041 blA += 2*LhsProgress;
1044 ResPacket alphav = pset1<ResPacket>(alpha);
1046 ResScalar* r0 = &res[(j2+0)*resStride + i];
1048 R0 = ploadu<ResPacket>(r0);
1049 R4 = ploadu<ResPacket>(r0+ResPacketSize);
1051 traits.acc(C0, alphav, R0);
1052 traits.acc(C4, alphav, R4);
1055 pstoreu(r0+ResPacketSize, R4);
1057 if(rows-peeled_mc>=LhsProgress)
1059 Index i = peeled_mc;
1060 const LhsScalar* blA = &blockA[i*strideA+offsetA*LhsProgress];
1066 const RhsScalar* blB = unpackedB;
1067 for(Index k=0; k<depth; k++)
1071 traits.loadLhs(blA, A0);
1072 traits.loadRhs(blB, B_0);
1073 traits.madd(A0, B_0, C0, B_0);
1078 ResPacket alphav = pset1<ResPacket>(alpha);
1079 ResPacket R0 = ploadu<ResPacket>(&res[(j2+0)*resStride + i]);
1080 traits.acc(C0, alphav, R0);
1081 pstoreu(&res[(j2+0)*resStride + i], R0);
1083 for(Index i=peeled_mc2; i<rows; i++)
1085 const LhsScalar* blA = &blockA[i*strideA+offsetA];
1091 const RhsScalar* blB = &blockB[j2*strideB+offsetB];
1092 for(Index k=0; k<depth; k++)
1094 LhsScalar A0 = blA[k];
1095 RhsScalar B_0 = blB[k];
1096 MADD(cj, A0, B_0, C0, B_0);
1098 res[(j2+0)*resStride + i] += alpha*C0;
1120 template<
typename Scalar,
typename Index,
int Pack1,
int Pack2,
int StorageOrder,
bool Conjugate,
bool PanelMode>
1121 struct gemm_pack_lhs
1124 Index stride=0, Index offset=0)
1126 typedef typename packet_traits<Scalar>::type Packet;
1127 enum { PacketSize = packet_traits<Scalar>::size };
1130 eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
1133 const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs,lhsStride);
1135 Index peeled_mc = (rows/Pack1)*Pack1;
1136 for(Index i=0; i<peeled_mc; i+=Pack1)
1138 if(PanelMode) count += Pack1 * offset;
1142 for(Index k=0; k<depth; k++)
1145 if(Pack1>=1*PacketSize) A = ploadu<Packet>(&lhs(i+0*PacketSize, k));
1146 if(Pack1>=2*PacketSize) B = ploadu<Packet>(&lhs(i+1*PacketSize, k));
1147 if(Pack1>=3*PacketSize) C = ploadu<Packet>(&lhs(i+2*PacketSize, k));
1148 if(Pack1>=4*PacketSize) D = ploadu<Packet>(&lhs(i+3*PacketSize, k));
1149 if(Pack1>=1*PacketSize) {
pstore(blockA+count, cj.pconj(A)); count+=PacketSize; }
1150 if(Pack1>=2*PacketSize) {
pstore(blockA+count, cj.pconj(B)); count+=PacketSize; }
1151 if(Pack1>=3*PacketSize) {
pstore(blockA+count, cj.pconj(C)); count+=PacketSize; }
1152 if(Pack1>=4*PacketSize) {
pstore(blockA+count, cj.pconj(D)); count+=PacketSize; }
1157 for(Index k=0; k<depth; k++)
1161 for(; w<Pack1-3; w+=4)
1163 Scalar a(cj(lhs(i+w+0, k))),
1164 b(cj(lhs(i+w+1, k))),
1165 c(cj(lhs(i+w+2, k))),
1166 d(cj(lhs(i+w+3, k)));
1167 blockA[count++] = a;
1168 blockA[count++] = b;
1169 blockA[count++] = c;
1170 blockA[count++] = d;
1174 blockA[count++] = cj(lhs(i+w, k));
1177 if(PanelMode) count += Pack1 * (stride-offset-depth);
1179 if(rows-peeled_mc>=Pack2)
1181 if(PanelMode) count += Pack2*offset;
1182 for(Index k=0; k<depth; k++)
1183 for(Index w=0; w<Pack2; w++)
1184 blockA[count++] = cj(lhs(peeled_mc+w, k));
1185 if(PanelMode) count += Pack2 * (stride-offset-depth);
1188 for(Index i=peeled_mc; i<rows; i++)
1190 if(PanelMode) count += offset;
1191 for(Index k=0; k<depth; k++)
1192 blockA[count++] = cj(lhs(i, k));
1193 if(PanelMode) count += (stride-offset-depth);
1205 template<
typename Scalar,
typename Index,
int nr,
bool Conjugate,
bool PanelMode>
1206 struct gemm_pack_rhs<Scalar, Index, nr,
ColMajor, Conjugate, PanelMode>
1208 typedef typename packet_traits<Scalar>::type Packet;
1209 enum { PacketSize = packet_traits<Scalar>::size };
1210 EIGEN_DONT_INLINE void operator()(Scalar* blockB,
const Scalar* rhs, Index rhsStride, Index depth, Index cols,
1211 Index stride=0, Index offset=0)
1214 eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
1216 Index packet_cols = (cols/nr) * nr;
1218 for(Index j2=0; j2<packet_cols; j2+=nr)
1221 if(PanelMode) count += nr * offset;
1222 const Scalar* b0 = &rhs[(j2+0)*rhsStride];
1223 const Scalar* b1 = &rhs[(j2+1)*rhsStride];
1224 const Scalar* b2 = &rhs[(j2+2)*rhsStride];
1225 const Scalar* b3 = &rhs[(j2+3)*rhsStride];
1226 for(Index k=0; k<depth; k++)
1228 blockB[count+0] = cj(b0[k]);
1229 blockB[count+1] = cj(b1[k]);
1230 if(nr==4) blockB[count+2] = cj(b2[k]);
1231 if(nr==4) blockB[count+3] = cj(b3[k]);
1235 if(PanelMode) count += nr * (stride-offset-depth);
1239 for(Index j2=packet_cols; j2<cols; ++j2)
1241 if(PanelMode) count += offset;
1242 const Scalar* b0 = &rhs[(j2+0)*rhsStride];
1243 for(Index k=0; k<depth; k++)
1245 blockB[count] = cj(b0[k]);
1248 if(PanelMode) count += (stride-offset-depth);
1254 template<
typename Scalar,
typename Index,
int nr,
bool Conjugate,
bool PanelMode>
1255 struct gemm_pack_rhs<Scalar, Index, nr,
RowMajor, Conjugate, PanelMode>
1257 enum { PacketSize = packet_traits<Scalar>::size };
1258 EIGEN_DONT_INLINE void operator()(Scalar* blockB,
const Scalar* rhs, Index rhsStride, Index depth, Index cols,
1259 Index stride=0, Index offset=0)
1262 eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
1264 Index packet_cols = (cols/nr) * nr;
1266 for(Index j2=0; j2<packet_cols; j2+=nr)
1269 if(PanelMode) count += nr * offset;
1270 for(Index k=0; k<depth; k++)
1272 const Scalar* b0 = &rhs[k*rhsStride + j2];
1273 blockB[count+0] = cj(b0[0]);
1274 blockB[count+1] = cj(b0[1]);
1275 if(nr==4) blockB[count+2] = cj(b0[2]);
1276 if(nr==4) blockB[count+3] = cj(b0[3]);
1280 if(PanelMode) count += nr * (stride-offset-depth);
1283 for(Index j2=packet_cols; j2<cols; ++j2)
1285 if(PanelMode) count += offset;
1286 const Scalar* b0 = &rhs[j2];
1287 for(Index k=0; k<depth; k++)
1289 blockB[count] = cj(b0[k*rhsStride]);
1292 if(PanelMode) count += stride-offset-depth;
1303 std::ptrdiff_t l1, l2;
1312 std::ptrdiff_t l1, l2;
1329 #endif // EIGEN_GENERAL_BLOCK_PANEL_H