Note: This website is archived. For up-to-date information about D projects and development, please visit wiki.dlang.org.

root/trunk/multiarray/dflat/Blas.d

Revision 127, 16.7 kB (checked in by baxissimo, 17 years ago)

.dup handling with Views is kind of weird. Needs more looking into. Probably most views right now don't dup the data when they get .duped. Fixed DenseVector/Array?, but others probably need this too.
This was revealed by adding an opNeg to GeMatrix? and DenseVector?.
Finally the mm() templates in Blas were not distinguishable by good old brain-dead DMD.

Line 
1 /*==========================================================================
2  * Blas.d
3  *    Written in the D Programming Language (http://www.digitalmars.com/d)
4  */
5 /***************************************************************************
6  * Wrappers of BLAS operations taking dFlat types.
7  *
8  * <TODO: Description>
9  *
10  * Authors:  William V. Baxter III, OLM Digital, Inc.
11  * Date: 18 Feb 2008
12  * Copyright: (C) 2008  William Baxter, OLM Digital, Inc.
13  *            Based in part on FLENS, Copyright (c) 2007, Michael Lehn
14  *            All rights reserved.
15  * License:
16  *
17  *   Redistribution and use in source and binary forms, with or without
18  *   modification, are permitted provided that the following conditions
19  *   are met:
20  *
21  *   1) Redistributions of source code must retain the above copyright
22  *      notice, this list of conditions and the following disclaimer.
23  *   2) Redistributions in binary form must reproduce the above copyright
24  *      notice, this list of conditions and the following disclaimer in
25  *      the documentation and/or other materials provided with the
26  *      distribution.
27  *   3) Neither the name of the DFLAT development group nor the names of
28  *      its contributors may be used to endorse or promote products derived
29  *      from this software without specific prior written permission.
30  *
31  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
32  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
33  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
34  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
35  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
36  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
37  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
38  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
39  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
40  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
41  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
42  */
43 //===========================================================================
44
45 module dflat.Blas;
46
47 import dflat.Storage;
48 import dflat.DenseVector;
49 import dflat.GeneralMatrix;
50 import dflat.Range;
51 import dflat.IdMatrix;
52
53 import dfb = dflat.dflat_blas;
54
55 alias dflat.Storage.Transpose Transpose;
56
57 private template ReT(T) {
58     alias typeof(T.init.re) ReT;
59 }
60 /** Some private templates to unify access to Matrices and DenseVectors
61  *  as RHSs
62  */
63 private int _NumCols(MT)(ref MT mat)
64 {
65     static if(is(typeof(mat.numCols))) {
66         return mat.numCols();
67     }
68     else static if (is(MT MTc : DenseVector!(MTc))) {
69         return 1;
70     }
71     else static if (is(MT MTc: MTc[])) {
72         return 1;
73     }
74     else {
75         static assert(false, "No NumCols for type: "~typeof(mat).stringof);
76     }
77 }
78 /// Ditto
79 private int _NumRows(MT)(ref MT mat)
80 {
81     static if(is(typeof(mat.numRows))) {
82         return mat.numRows();
83     }
84     else static if (is(MT MTc : DenseVector!(MTc))) {
85         return mat.length;
86     }
87     else static if (is(MT : MT[])) {
88         return mat.length;
89     }
90     else {
91         static assert(false, "No NumRows for type: "~typeof(mat).stringof);
92     }
93 }
94 /// Ditto
95 private int _RowStride(MT)(ref MT mat)
96 {
97     static if(is(typeof(mat.strideRow))) {
98         return mat.strideRow;
99     }
100     else static if (is(MT MTc : DenseVector!(MTc))) {
101         return mat.stride();
102     }
103     else static if (is(MT MTc: MTc[])) {
104         return 1;
105     }
106     else {
107         static assert(false, "No RowStride for type: "~typeof(mat).stringof);
108     }
109 }
110 /// Ditto
111 private int _LeadingDim(MT)(ref MT mat)
112 {
113     static if(is(typeof(mat.leadingDimension))) {
114         return mat.leadingDimension;
115     }
116     else static if (is(MT MTc : DenseVector!(MTc))) {
117         return mat.length;
118     }
119     else static if (is(MT MTc : MTc[])) {
120         return mat.length;
121     }
122     else {
123         static assert(false, "No leading dimension for type: "~typeof(mat).stringof);
124     }
125 }
126
127
128 //----------------------------------------------------------------------------
129 // Level 1 routines for DenseVector
130
131 X.ElementType
132 dot_dv(X,Y)(/*const*/ ref DenseVector!(X) x, /*const*/ ref DenseVector!(Y) y)
133 {
134     alias X.ElementType T;
135     static if(is(T==float)||is(T==double)) {
136         // x.T * y
137         return dfb.dot(x.length, x.ptr, x.stride, y.ptr, y.stride);
138     }
139     else {
140         // x.H * y
141         return dfb.dotc(x.length, x.ptr, x.stride, y.ptr, y.stride);
142     }
143 }
144
145 void copy_dv(X,Y)(/*const*/ ref DenseVector!(X) x, ref DenseVector!(Y) y)
146 {
147 //    assert(&x!=&y);
148     if (y.length!=x.length) {
149         y.engine().resize(x.length, x.beginIndex());
150     }
151     dfb.copyT(x.length, x.ptr, x.stride(), y.ptr, y.stride());
152 }
153
154 void scal_dv(X)(DenseVector!(X).ElementType alpha, ref DenseVector!(X) x)
155 {
156     dfb.scalT(x.length, alpha, x.ptr, x.stride());
157 }
158
159 DenseVector!(I).ElementType  asum_dv(I)(/*const*/ ref DenseVector!(I) x)
160 {
161     return dfb.asum(x.length, x.ptr, x.stride());
162 }
163
164 int amax_dv(I)(/*const*/ ref DenseVector!(I) x)
165 {
166     return dfb.amax(x.length, x.ptr, x.stride()) + x.beginIndex();
167 }
168
169 int amin_dv(T)(/*const*/ ref DenseVector!(I) x)
170 {
171     return dfb.amin(x.length, x.ptr, x.stride()) + x.beginIndex();
172 }
173
174
175 void axpy_dv(X,Y)(DenseVector!(X).ElementType alpha,
176                   /*const*/ref DenseVector!(X) x,
177                   ref DenseVector!(Y) y)
178 {
179     assert(y.length==x.length);
180
181     dfb.axpy(x.length, alpha, x.ptr, x.stride(), y.ptr, y.stride());
182 }
183
184 ReT!(X.ElementType)
185 nrm2_dv(X)(/*const*/ref DenseVector!(X) x)
186 {
187     return dfb.nrm2(x.length, x.ptr, x.stride());
188 }
189
190 //----------------------------------------
191 // Level 1 for ge matrices
192 void scal_ge(X)(GeMatrix!(X).ElementType alpha, ref GeMatrix!(X) x)
193 {
194     for (int i=x.beginRow(); i<x.endRow(); ++i) {
195         scal_dv!(X.VectorView)( alpha, x[i,_] );
196     }
197 }
198
199 void copy_ge(X,Y)(/*const*/ ref GeMatrix!(X) x, ref GeMatrix!(Y) y)
200 {
201     if ((y.numRows()!=x.numRows()) || (y.numCols()!=x.numCols())) {
202         y.resize(x.numRows(), x.numCols(),
203                  x.beginRow(), x.beginCol());
204     }
205
206     for (int i=x.beginRow(), I=y.beginRow(); i<x.endRow(); ++i, ++I) {
207         y[I,_] = x[i,_];
208     }
209 }
210
211 void copyTrans_ge(X,Y)(/*const*/ ref GeMatrix!(X) x, ref GeMatrix!(Y) y)
212 {
213     if ((y.numRows()!=x.numCols()) || (y.numCols()!=x.numRows())) {
214         y.resize(x.numCols(), x.numRows(),
215                  x.beginCol(), x.beginRow());
216     }
217
218     for (int i=x.beginRow(), I=y.beginCol(); i<x.endRow(); ++i, ++I) {
219         y[_,I] = x[i,_];
220     }
221 }
222
223 void copyConjugateTrans_ge(X,Y)(/*const*/ ref GeMatrix!(X) x, ref GeMatrix!(Y) y)
224 {
225     if ((y.numRows()!=x.numCols()) || (y.numCols()!=x.numRows())) {
226         y.resize(x.numCols(), x.numRows(),
227                  x.beginCol(), x.beginRow());
228     }
229     // there is no blas function to calculate the conjugate transpose ...
230     for (int i=x.beginRow(), I=y.beginCol(); i<x.endRow(); ++i, ++I) {
231         for(int jj = x.beginCol(), JJ = y.beginRow(); jj < x.endCol(); jj++, JJ++) {
232             y[JJ,I] = conjugate(x[i,jj]);
233         }
234     }
235 }
236
237 void copy_ge(Y)(/*const*/ ref IdMatrix A, ref GeMatrix!(Y) B)
238 {
239     if (A.dim()!=B.numRows()) {
240         B.resize(A.dim(),A.dim(),1,1);
241     } else {
242         B = 0.;
243     }
244     B.diag(0) = 1.;
245 }
246
247
248 void axpy_ge(X,Y)(GeMatrix!(X).ElementType alpha,
249                /*const*/ ref GeMatrix!(X) x,
250                ref GeMatrix!(Y) y)
251 {
252     assert(y.numRows()==x.numRows());
253     assert(y.numCols()==x.numCols());
254
255     alias GeMatrix!(X).VectorView.Storage AX;
256     alias GeMatrix!(Y).VectorView.Storage AY;
257
258
259     int I = y.beginRow();
260     for (int i=x.beginRow(); i<x.endRow(); ++i, ++I) {
261         axpy(alpha, x[i,_], y[I,_]);
262         //y[I,_] += alpha*x[i,_];
263     }
264 }
265
266 //----------------------------------------
267 // Level 1 for sy matrices
268
269
270 void scal_sy(X)(SyMatrix!(X).ElementType alpha, ref SyMatrix!(X) x)
271 {
272     if (x.upLo()==StorageUpLo.Upper) {
273         int offset = x.beginCol() - x.beginRow();
274         for (int i=x.beginRow(); i<x.endRow(); ++i) {
275             for (int j=i+offset; j<x.endCol(); ++j) {
276                 x[i,j] *= alpha;
277             }
278         }
279     } else {
280         int offset = x.endCol() - x.endRow();
281         for (int i=x.beginRow(); i<x.endRow(); ++i) {
282             for (int j=x.beginCol; j<i+offset; ++j) {
283                 x[i,j] *= alpha;
284             }
285         }
286     }
287 }
288
289
290 //---------------------------------------------------------------------------
291 // Generic versions
292 //   D is incapable of picking the best match from multiple templates.
293 //   You have to have one big template that inspects all the
294 //   cases one by one in a big static if - else if - else block.
295
296 X.ElementType dot(X,Y)(/*const*/ ref X x, /*const*/ ref Y y)
297 {
298     alias X.ElementType XT;
299     alias Y.ElementType YT;
300     static if (is(X Xc: DenseVector!(Xc)) && is(Y Yc: DenseVector!(Yc))) {
301         return dot_dv!(Xc,Yc)(x,y);
302     } else {
303         static assert("Bad arg types for template dot()");
304     }
305 }
306
307 ReT!(X.ElementType) nrm2(X)(/*const*/ref X x)
308 {
309     static if (is(X Xc: DenseVector!(Xc))) {
310         return nrm2_dv!(Xc)(x);
311     } else {
312         static assert("Bad arg types for template dot()");
313     }
314 }
315
316 void copy(X,Y)(/*const*/ ref X x, ref Y y)
317 {
318     static if(is(X Xc : GeMatrix!(Xc)) && is(Y Yc : GeMatrix!(Yc))) {
319         copy_ge!(Xc,Yc)(x,y);
320     } else static if (is(X Xc: DenseVector!(Xc)) && is(Y Yc: DenseVector!(Yc))) {
321         copy_dv!(Xc,Yc)(x,y);
322     } else {
323         static assert("Bad arg types for template axpy()");
324     }
325 }
326
327 void copyTrans(X,Y)(/*const*/ ref X x, ref Y y)
328 {
329     static if(is(X Xc : GeMatrix!(Xc)) && is(Y Yc : GeMatrix!(Yc))) {
330         copyTrans_ge!(Xc,Yc)(x,y);
331     } else {
332         static assert("Bad arg types for template copyTrans()");
333     }
334 }
335
336 void copyConjugateTrans(X,Y)(/*const*/ ref X x, ref Y y)
337 {
338     static if(is(X Xc : GeMatrix!(Xc)) && is(Y Yc : GeMatrix!(Yc))) {
339         copyConjugateTrans_ge!(Xc,Yc)(x,y);
340     } else {
341         static assert("Bad arg types for template copyConjugateTrans()");
342     }
343 }
344
345
346 void axpy(ALPHA,X,Y)(ALPHA alpha, ref X x, ref Y y)
347 {
348     static if(is(X Xc : GeMatrix!(Xc)) && is(Y Yc : GeMatrix!(Yc))) {
349         axpy_ge!(Xc,Yc)(alpha,x,y);
350     } else static if (is(X Xc: DenseVector!(Xc)) && is(Y Yc: DenseVector!(Yc))) {
351         axpy_dv!(Xc,Yc)(alpha,x,y);
352     } else {
353         static assert("Bad arg types for template axpy()");
354     }
355 }
356
357 void scal(ALPHA,X)(ALPHA alpha, ref X x)
358 {
359     static if(is(X Xc : GeMatrix!(Xc))) {
360         scal_ge!(Xc)(alpha,x);
361     }
362     else static if(is(X Xc : SyMatrix!(Xc))) {
363         scal_sy!(Xc)(alpha,x);
364     }
365     else static if (is(X Xc: DenseVector!(Xc))) {
366         scal_dv!(Xc)(alpha,x);
367     } else {
368         static assert("Bad arg types for template scal()");
369     }
370 }
371
372
373 //- Level 2 --------------------------------------------------------------------
374
375 // gemv
376 /** matrix vector multiply
377         y = alpha*A*x + beta*y
378    OR   y = alpha*A.T*x + beta*y
379    OR   y = alpha*A.H*x + beta*y,  with A an mxn matrix
380 */
381 void mv(ALPHA,MA,VX,BETA,VY)(
382     Transpose trans,
383     ALPHA alpha, /*const*/ ref GeMatrix!(MA) A, /*const*/ ref VX x,
384     BETA beta, ref VY y)
385 {
386     assert(y.ptr !is x.ptr);
387     assert(x.length==((trans==Transpose.NoTrans)
388                       ? A.numCols()
389                       : A.numRows()));
390
391     int yLength = (trans==Transpose.NoTrans)
392         ? A.numRows()
393         : A.numCols();
394
395     assert((beta==0) || (y.length==yLength));
396
397     if (y.length!=yLength) {
398         y.length = yLength;
399     }
400
401     dfb.gemv(MA.order,
402              trans, A.numRows(), A.numCols(),
403              alpha,
404              A.ptr, _LeadingDim(A),
405              x.ptr, _RowStride(x),
406              beta,
407              y.ptr, _RowStride(y));
408 }
409
410
411 // symv
412 void mv(ALPHA, MA, VX, BETA, VY)(
413     ALPHA alpha, /*const*/ ref SyMatrix!(MA) A, /*const*/ ref VX x,
414     BETA beta, ref VY y)
415 {
416     assert(x.length==A.dim());
417     assert((beta==0) || (y.length==A.dim));
418
419     if (y.length!=A.dim) {
420         y.resize(A.dim);
421     }
422
423     dfb.symv(MA.order,
424              A.upLo, A.dim,
425              alpha,
426              A.ptr, _LeadingDim(A),
427              x.ptr, _RowStride(x),
428              beta,
429              y.ptr, _RowStride(y));
430 }
431
432
433 //- Level 3 --------------------------------------------------------------------
434
435 // gemm
436 /// matrix matrix multiply
437 ///     C := alpha * transa(A) * transb(B) + beta * C
438 void mm(ALPHA,MA,MB,BETA,MC)(
439     Transpose transA, Transpose transB,
440     ALPHA alpha, /*const*/ ref GeMatrix!(MA) A, /*const*/ ref GeMatrix!(MB) B,
441     BETA beta, ref GeMatrix!(MC) C)
442 {
443     assert(MA.order==MB.order);
444     assert(MA.order==MC.order);
445
446     // M: op(A) - M x K
447     // N: op(B) - K x N
448     version(NDEBUG) {
449     } else {
450         int K_A = (transA==Transpose.NoTrans) ? A.numCols() : A.numRows();
451         int K_B = (transB==Transpose.NoTrans) ? B.numRows() : B.numCols();
452         assert(K_A==K_B);
453     }
454
455     int m = (transA==Transpose.NoTrans) ? A.numRows() : A.numCols();
456     int n = (transB==Transpose.NoTrans) ? B.numCols() : B.numRows();
457
458     assert((beta==0) || (C.numRows()==m));
459     assert((beta==0) || (C.numCols()==n));
460
461     if ((C.numRows()!=m) || (C.numCols()!=n)) {
462         C.resize(m,n);
463     }
464
465     dfb.gemm(MA.order,
466              transA, transB,
467              C.numRows(),
468              C.numCols(),
469              (transA==Transpose.NoTrans) ? A.numCols() : A.numRows(),   // K
470              alpha,
471              A.ptr, _LeadingDim(A),
472              B.ptr, _LeadingDim(B),
473              beta,
474              C.ptr, _LeadingDim(C));
475 }
476
477
478 // symm
479 void mm(ALPHA, MA, MB, BETA, MC, dum=void)(
480     BlasSide side,
481     ALPHA alpha, /*const*/ ref SyMatrix!(MA) A, /*const*/ ref GeMatrix!(MB) B,
482     BETA beta, ref GeMatrix!(MC) C)
483 {
484     assert(MA.order==MB.order);
485     assert(MA.order==MC.order);
486     debug{
487         if (side==BlasSide.Left) {
488             assert(A.dim==B.numRows);
489         } else {
490             assert(B.numCols==A.dim);
491         }
492     }
493
494     int m = (side==BlasSide.Left) ? A.dim() : B.numRows();
495     int n = (side==BlasSide.Left) ? B.numCols() : A.dim();
496
497     assert((beta==0) || (C.numRows()==m));
498     assert((beta==0) || (C.numCols()==n));
499
500     if ((C.numRows()!=m) || (C.numCols()!=n)) {
501         C.resize(m,n);
502     }
503
504     symm(MA.order,
505          side, A.upLo,
506          C.numRows, C.numCols,
507          alpha,
508          A.ptr, A.leadingDimension,
509          B.ptr, B.leadingDimension,
510          beta, C.ptr, C.leadingDimension);
511 }
512
513 // trmm
514 void mm(ALPHA, MA, MB)(
515     BlasSide side,
516     Transpose transA, ALPHA alpha, /*const*/ ref TrMatrix!(MA) A,
517     ref GeMatrix!(MB) B)
518 {
519     assert(MA.order==MB.order);
520     debug{
521         if (side==BlasSide.Left) {
522             assert(A.dim==B.numRows);
523         } else {
524             assert(B.numCols==A.dim);
525         }
526     }
527
528     trmm(MA.order,
529          side, A.upLo, transA,
530          A.unitDiag,
531          B.numRows, B.numCols,
532          alpha,
533          A.ptr, A.leadingDimension,
534          B.ptr, B.leadingDimension);
535    
536 }
537
538 // trsv
539 /** solving triangular matrix problems
540         x := A.inv * x
541     OR  x := A.inv.T * x
542     OR  x := A.inv.H * x
543  */
544 void trsv(E, S)(
545     Transpose transpose, /*const*/ ref TrMatrix!(S) A, ref DenseVector!(E) xb)
546 {
547     assert(xb.length==A.dim());
548
549     trsv(S.order,                    // ORDER
550          A.upLo,                     // UPLO
551          transA,                     // TRANSA
552          diag,                       // DIAG
553          A.dim,                      // N
554          A.ptr,                      // A
555          A.leadingDimension,         // LDA
556          xb.ptr,                     // X
557          xb.stride);                 // INCX
558 }
559
560 // trsm
561 /// solving triangular matrix with multiple right hand sides
562 ///     B := alpha * transa(A.inv) * B
563 /// OR  B := alpha * B * transa(A.inv)
564 void trsm(ALPHA, MA, MB)(
565     BlasSide side,
566     Transpose transA, ALPHA alpha, /*const*/ ref TrMatrix!(MA) A,
567     ref GeMatrix!(MB) B)
568 {
569     assert(MA.order==MB.order);
570     debug {
571         if (side==BlasSide.Left) {
572             assert(A.dim==B.numRows);
573         } else {
574             assert(B.numCols==A.dim);
575         }
576     }
577
578     trsm(MA.order,
579          side, A.upLo, transA,
580          A.unitDiag,
581          B.numRows, B.numCols,
582          alpha,
583          A.ptr, A.leadingDimension,
584          B.ptr, B.leadingDimension);
585
586 }
587
588
589
590 //--- Emacs setup ---
591 // Local Variables:
592 // c-basic-offset: 4
593 // indent-tabs-mode: nil
594 // End:
Note: See TracBrowser for help on using the browser.