root/trunk/blade/Blade.d

Revision 187, 27.2 kB (checked in by Don Clugston, 3 months ago)

Added prod(). Use .ptr to get raw data, so it works with Bill Baxter's ArrayView?.

  • Property svn:eol-style set to native
Line 
1 //  Written in the D programming language 1.0
2 /**
3 * BLADE 0.6Alpha -- Basic Linear Algebra D Expressions
4 *
5 * Generate near-optimal x87/SSE2 asm code for BLAS1 basic vector operations at compile time.
6 * 32, 64 and 80 bit vectors are all supported.
7 * Uses techniques described in Agner Fog's superb Pentium optimisation manual (www.agner.org).
8 *
9 * Author:
10 *   Don Clugston.
11 * License:
12 *   Public domain.
13 *
14 * FEATURES:
15 *  - Supports any mix of vector addition, subtraction, unary minus,
16 *    multiplication by a scalar,
17 *    cumulation via dot product, sum() and prod(), and multidimensional slicing.
18 *  - Generates either x87 asm code, SSE or SSE2 asm code or pure D, depending on
19 *    the complexity of the expression, and the availability of inline asm.
20 *  - When static arrays are used, mismatches in array length are detected
21 *    at compile time.
22 *  - Error messages refer to the line of user code which generated the error.
23 *    The library never produces a torrent of undecipherable template error messages.
24 *  - A 'const folding' (actually vector/scalar folding) step is performed.
25 *
26 * SPEED/ACCURACY TRADEOFF:
27 * Tradeoff arises because IEEE floating point multiplication and addition are not associative.
28 *  Assuming that overflow and underflow do not occur:
29 *  (a*b)*c may differ from a*(b*c) in the last bit.
30 *  (a+b)+c may differ from a+(b+c) by a factor a million or more.
31 *
32 * - Multiplication is assumed to be associative.
33 * - Addition and subtraction are not treated as associative.
34 *    Except: Addition inside a dot product or vector sum is treated as associative.
35 *
36 * FUTURE DIRECTIONS (in order of expected implementation):
37 * - nested D expressions
38 * - cumulative operations min, max
39 * - Loop unrolling for cumulative operations dot, sum, prod.
40 * - Dense matrix support.
41 * - Triangular, banded, symmetric, and sparse matrix support
42 *
43 * THEORY:
44 * An expression string of the form "A+(B*C)*D" is used, together with a tuple, the
45 * entries of which correspond to A, B, C, ...
46 * This string is converted to postfix. The postfix string is converted to
47 * a string containing asm instructions, which is then mixed into a function
48 * which accepts the tuple.
49 *
50 * COMPILER BUGS/LIMITATIONS AFFECTING THIS LIBRARY
51 * - Local arrays are not aligned to a 128-bit boundary, so use of aligned SSE is not
52 *   always possible.
53 * - Bugzilla #1125 -- structs in a tuple can't be used in asm.
54 * - Bugzilla #1382 -- CTFE strings never get deleted --> SLOOOOOW compilation. KILLER BUG.
55 * - Bugzilla #1768 -- in CTFE, arrays of arrays aren't initialized properly
56 *
57 * HISTORY:
58 * 0.1 - Used classes to make expression templates.
59 * 0.2 - Support for a wider variety of expressions. Dot product, imaginary numbers, etc.
60 * 0.3 - Based on string mixins. Most of the new features of 0.2 are gone, but SSE2 is added.
61 * 0.4 - Added D code generator. Nice error messages. Optimal parameter passing.
62 *       (passes pointers, not arrays).
63 * 0.5 - Expression simplification step. Slicing support.
64 * 0.6 - Dot product, nested expressions (asm only), intrinsics: abs, sqrt, sum.
65 * 0.7 - Intrinsics: prod
66 */
67
68 module blade.Blade;
69
70 public import blade.SyntaxTree : AbstractSyntaxTree, mixin_tupleAndSyntaxtreeof, AST, Symbol;
71 private import blade.BladeUtil : enquote, itoa;
72 private import blade.BladeRank : isStrided, exprRank;
73 private import blade.BladeSimplify : simplifySyntaxTree, RevisedExpression, remapCompounds;
74 private import blade.CodegenX86 : generateCodeForAsmX87, MAX_X87_VECTORS,
75                                  MAX_87_REALSCALARSPLUSTEMPORARIES,
76                                  generateCodeForSSE,  MAX_SSE_VECTORS;
77 private import blade.BladeVisitor: expressionContainsAssignment;
78
79 private import blade.PostfixX86 : makePostfixForX87, makePostfixForSSE;
80
81 public:
82
83 // FOR MIXIN: Generate code to evaluate the given vector expression.
84 char [] vectorize(char [] expr)
85 {
86     debug (BladeFrontEnd) {
87     return `pragma(msg, \n ~ "// " __FILE__ ~ "(" ~__LINE__.stringof[0..$-1] ~ ") ` ~ enquote(expr) ~ `" ~ \n ~ ` ~ mixin_tupleAndSyntaxtreeof("makeVectorCode", expr) ~ "~\\n);"
88     "mixin(" ~ mixin_tupleAndSyntaxtreeof("makeVectorCode", expr)~ ");";
89     } else {
90        return "mixin(" ~ mixin_tupleAndSyntaxtreeof("makeVectorCode", expr)~ ");";
91     }
92 }
93
94 // Simplify the expression, categorise it,
95 // and dispatch to the appropriate code generator.
96 char [] makeVectorCode(Types...)(AbstractSyntaxTree tree)
97 {
98     RevisedExpression revised = simplifySyntaxTree(tree);
99     if (revised.errorMessage.length>0)  return `static assert(0, "BLADE: ` ~ enquote(revised.errorMessage) ~ `");`;
100     VecExpressionType exprType = categorizeExpression(revised);
101     InvocationCode q;
102     if (exprType == VecExpressionType.SSE2Expression || exprType == VecExpressionType.SSE1Expression) {
103         q = invokeSSE((exprType == VecExpressionType.SSE2Expression), revised);
104     } else if (exprType == VecExpressionType.X87Expression) {
105         q = invokeX87(revised);
106     } else {
107         q = DCodeGenerator(revised);
108     }
109     return q.assertions ~ q.invoker ~ ";";
110 }
111
112 // For a compound of a different dimensionality (eg a dot product), we may need
113 // to calculate the result seperately.
114 char [] makeVectorCodeForDimensionalCompound(char [] expression, AbstractSyntaxTree tree)
115 {
116     // TODO:
117     return expression;
118 }
119
120 template X87RetType(char [] expr) {
121     static if (expr[0]!='0' && expr[0]!='1') alias void X87RetType;
122     else alias real X87RetType;
123 }
124
125 // These functions have the complete expression encoded in the template type.
126 // One of these functions is instantiated for each expression.
127 // A difficulty is, that the only way to transfer information from the CTFE code
128 // into the function, is via the template parameters. So from inside the function,
129 // we must re-assemble the type information, and use this to generate the asm code.
130
131 /** Function to implement BLAS1 operations using SSE/SSE2 assembler.
132  * Every member of the Values tuple must only be double or double *.
133  */
134 RetType SSEVECGEN(RetType, char [] expr, Values...)(int veclength, Values values) {
135     debug(BladeBackEnd) {
136        pragma(msg, generateCodeForSSE!(Values)(expr));
137     }
138     mixin(generateCodeForSSE!(Values)(expr));
139 }
140
141 /** Function to implement BLAS1 operations using X87 assembler.
142  * Every member of the Values tuple must only be real,
143  * float[], double [], or real[], or BladeStrided!(float), !(double), !(real)
144  */
145 X87RetType!(expr) X87VECGEN(char [] expr, int numStrides, Values...)(int veclength, Values values) {
146     debug(BladeBackEnd) {
147         pragma(msg, generateCodeForAsmX87!(numStrides, Values)(expr));
148     }
149     mixin(generateCodeForAsmX87!(numStrides, Values)(expr));
150 }
151
152 private:
153 // Masks for setting or clearing the signbit in SSE registers.
154 static ulong[2] SSE_SIGNMASKpd = [0x7FFF_FFFF_FFFF_FFFFL, 0x7FFF_FFFF_FFFF_FFFFL];
155 static uint[4] SSE_SIGNMASKps = [0x7FFF_FFFF, 0x7FFF_FFFF, 0x7FFF_FFFF, 0x7FFF_FFFF];
156 static ulong[2] SSE_SIGNBITpd = [0x8000_0000_0000_0000L, 0x8000_0000_0000_0000L];
157 static uint[4] SSE_SIGNBITps = [0x8000_0000,0x8000_0000,0x8000_0000, 0x8000_0000];
158 // The value 1.0 for a parallel SSE register
159 static ulong[2] SSE_ONEpd = [0x3FF0_0000_0000_0000L, 0x3FF0_0000_0000_0000L];
160 static uint[4] SSE_ONEps = [0x3F0_000, 0x3F0_000, 0x3F0_000, 0x3F0_000];
161
162 private:
163
164 // ------------------------------------
165
166 // ------------------------------------
167 // Categorize the expression type
168 // SSE2 is possible only if all vectors are doubles.
169 // SSE1 is possible only if all vectors are floats.
170 // X87 is possible for any mix of real, double, and float vectors.
171 // BUG: for X87, should also check number of temporaries (don't overflow the FPU stack)
172 enum VecExpressionType { SSE1Expression, SSE2Expression, X87Expression, DExpression };
173
174 VecExpressionType categorizeExpression(RevisedExpression tree)
175 {
176     bool SSE2 = true;
177     bool SSE1 = true;
178     bool X87 = true;
179     bool strided = false; // true if any strided vector or matrix operations exist
180 version (D_InlineAsm_X86) {} else {
181     // Without an assembler, there's no chance!
182     SSE2 = false;
183     SSE1 = false;
184     X87 = false;
185 }
186 int wholerank = exprRank(tree.expression, tree.rank);
187
188 if (wholerank==0 && expressionContainsAssignment(tree.expression)) {
189     return VecExpressionType.DExpression; // Scalar assignments always use inline D.
190 }
191     int numvectors = 0;
192     int numscalars = 0;
193     int numRealScalars = 0; // scalars other than float or double.
194     for (int i=0; i<tree.mapping.length;++i) {
195         char r = tree.rank[i];
196         int x = tree.mapping[i]-'A';
197         if (r=='0') {
198             ++numscalars;
199             if (x<tree.symbolTable.length) {
200                 char [] t = tree.symbolTable[x].element;
201                 if (t!="double" && t!="float" && t!="idouble" && t!="ifloat") {
202                     ++numRealScalars;
203                 }
204             }
205             // TODO: disallow asm if any non-int/not FP types are used.
206             continue;
207         }
208         if (r>'1') return VecExpressionType.DExpression; // can only do scalars and vectors right now.
209         // At this point, all compounds are an original symbol + indexing/slicing.
210         int y = x; // for compounds, get the original type
211         if (x>=tree.symbolTable.length) {
212             y = tree.compounds[x-tree.symbolTable.length][0]-'A';
213             // Check for a stride..
214             if (tree.compounds[x-tree.symbolTable.length][$-1]==']') {
215                 strided |= isStrided(tree.compounds[x-tree.symbolTable.length]);
216             }
217         }
218
219         char [] t = tree.symbolTable[y].element;
220         if (t == "double") {
221             ++numvectors;
222             SSE1 = false;
223         } else if (t == "float") {
224             ++numvectors;
225             SSE2 = false;
226         } else {
227             SSE1 = false;
228             SSE2 = false;
229             if (t == "real") { ++numvectors; }
230             else X87 = false;
231         }
232     }
233     // It's not worth doing strided operations with SSE.
234     if (strided) { SSE1=false; SSE2=false; }
235     if (numRealScalars > MAX_87_REALSCALARSPLUSTEMPORARIES) X87 = false;
236     if (numvectors > MAX_X87_VECTORS) X87 = false;
237     if (numvectors > MAX_SSE_VECTORS) { SSE1=false; SSE2=false; }
238     if (SSE1) return VecExpressionType.SSE1Expression;
239     if (SSE2) return VecExpressionType.SSE2Expression;
240     return X87 ? VecExpressionType.X87Expression : VecExpressionType.DExpression;
241 }
242
243
244 // This is mainly a workaround for compiler bug #1125. Ideally both
245 // pointer and stride would be stored together.
246 struct BladeStrided(T)
247 {
248     T * data;     // Pointer to the first element
249 }
250
251 //-------------------------------------------------------
252 //                Invoker functions
253 //-------------------------------------------------------
254 // These are CTFE functions which, when mixed in, will call the BLAS template
255 // function. They ensure that all types are converted into standard
256 // simple forms, ensure that the vector lengths are equal, and pass in all
257 // of the parameters.
258
259 struct InvocationCode {
260     char [] invoker; // For mixin: code to invoke the functions.
261     char [] assertions; // For mixin: code to assert that everything is correct
262 }
263
264 /// Generate code which will call the X87 function.
265 InvocationCode invokeX87(RevisedExpression tree)
266 {
267     char [] assertions = assertAllVectorLengthsEqual(tree);
268     char [] result = "";
269     char [] stridelist="";
270     char [] alltypes="";
271
272     char [][] typelist;
273
274     char [] vals;
275     int numstrides=0;
276     for (int i=0; i<tree.mapping.length;++i) {
277         char rnk = tree.rank[i];
278         vals ~= ",";
279         InvocationCode q = getValueForSymbol(tree.mapping[i], tree);
280         char [] v = q.invoker;
281         assertions ~= q.assertions;
282         int x = tree.mapping[i]-'A';
283         char [] t;
284         bool strided = false;
285         if (x<tree.symbolTable.length) {
286             // it's an original symbol
287             if (rnk=='0') {
288                 t = tree.symbolTable[x].type;
289             } else t = tree.symbolTable[x].element;
290         } else {
291             // it's a compound
292             if (rnk=='0') {
293                 t = "real"; // convert all compounds to real.
294                 // TODO (tricky): if the number is exactly representable as a double
295                 // or float, it could use less FPU stack space.
296             } else { // for arrays, the type is the type of the original array
297                 t = tree.symbolTable[tree.compounds[x-tree.symbolTable.length][0]-'A'].element;
298                         // Check for a stride..
299                 if (tree.compounds[x-tree.symbolTable.length][$-1]==']') {
300                     strided = isStrided(tree.compounds[x-tree.symbolTable.length]);
301                     if (strided) ++numstrides;
302                 }
303
304             }
305         }
306         alltypes ~= ",";
307         if (rnk=='0') {
308             // Convert scalars into standard form.
309             // long, ulong, and real must become real.
310             // We convert everything else to double, since that uses less
311             // FPU stack space.
312             if (t == "real" || t == "double" || t == "float") alltypes ~= t;
313             else if (t == "long" || t == "ulong") result ~= "real";
314             else alltypes ~= "double"; // Convert all other scalars into doubles.
315             vals ~= v;
316         } else {
317             assert (t == "real" || t=="double" || t=="float", "BLADE BUG");
318             if (strided) {
319                 alltypes ~= "BladeStrided!(" ~ t ~ ")";
320                 vals ~= "BladeStrided!("~t~")(&" ~ v ~ "[0])";
321                 stridelist ~= "," ~ getStrideForSymbol(tree.mapping[i], tree);
322             } else {
323                 alltypes ~= t ~ "*";
324                 // for vectors, we only need the pointer, not the length
325                 //vals ~= "&" ~  v ~ "[0]";
326                 vals ~= v ~ ".ptr";
327             }
328         }
329         typelist ~= t;
330     }
331     char [] postfixops = makePostfixForX87(tree.expression, typelist, tree.rank);
332
333     result ~= `X87VECGEN!("` ~ enquote(postfixops) ~ `"`;
334     result ~= "," ~ itoa(numstrides);
335     result ~= alltypes;
336     for (int i=0; i<numstrides; ++i) result ~= ",int";
337     result ~= ")(";
338     int firstVector = findVectorForLength(tree);
339     return InvocationCode(result ~ getValueForSymbol(tree.mapping[firstVector], tree).invoker ~ ".length"
340         ~ vals ~ stridelist  ~ ")", assertions);
341 }
342
343 char [] generateAsserts(RevisedExpression tree, bool checkAlignment)
344 {
345     char [] result = assertAllVectorLengthsEqual(tree);
346     if (checkAlignment) result ~= assertAllVectorsAlign128(tree);
347     return result;
348 }
349
350 /// Generate code which will call the SSE/SSE2 code generation function
351 InvocationCode invokeSSE(bool SSE2, RevisedExpression tree)
352 {
353     char [] assertions = assertAllVectorLengthsEqual(tree)
354     ~ assertAllVectorsAlign128(tree);
355
356     char [] postfix = makePostfixForSSE(tree.expression, tree.rank);
357     char [] retType = "void";
358     if (postfix[0]=='0' || postfix[0]=='1') retType = (SSE2? "double" : "float");
359
360     char [] result = "SSEVECGEN!(" ~ retType ~ `,"` ~ enquote(postfix) ~ `"`;
361     // For SSE2, everything must be implicitly convertible to double.
362     char [] vals;
363     for (int i=0; i<tree.mapping.length;++i) {
364         char rnk = tree.rank[i];
365         if (rnk=='0') result ~= SSE2? ",double" : ",float";
366         else result ~= SSE2? ",double*" : ",float*";
367         vals ~= ",";
368 //        if (rnk=='1') vals ~= "&";
369         InvocationCode q = getValueForSymbol(tree.mapping[i], tree);
370         vals ~= q.invoker;
371         assertions ~= q.assertions;
372         // for vectors, we only need the pointer, not the length
373         if (rnk=='1') vals ~= ".ptr";
374     }
375
376     result ~= ")(";
377     int firstVector = findVectorForLength(tree);
378     result ~= getValueForSymbol(tree.mapping[firstVector], tree).invoker ~ ".length";
379     result ~= vals;
380
381     return InvocationCode(result ~ ")", assertions);
382 }
383
384 /** Generates an assert which ensures that all vectors are of equal length.
385  * If possible, the error will be detected at compile time.
386  */
387 char [] assertAllVectorLengthsEqual(RevisedExpression tree)
388 {
389     char [] result ="";
390     int firstVector = findVectorForLength(tree);
391 //    bool known = arrayLengthIsStatic(tree.symbolTable[firstVector].type);
392     for (int i=0; i<tree.mapping.length;++i) {
393         if (tree.rank[i]=='1') {
394             if (firstVector != i) {
395 //                if (known && arrayLengthIsStatic(tree.symbolTable[i].type)) {
396 //                    // both lengths are known at compile time - make it a
397 //                    // compile-time static assert
398 //                    result ~= "static ";
399 //                }
400                 result ~= "assert("
401                  ~ getDimensionLengthForSymbol(tree.mapping[i], tree, 1)
402                     ~ "==" ~ getDimensionLengthForSymbol(tree.mapping[firstVector], tree, 1)
403                     ~ ", `Vector length mismatch`);"\n;
404             }
405         }
406     }
407     return result;
408 }
409
410 char [] assertAllVectorsAlign128(RevisedExpression tree)
411 {
412     char [] result ="";
413     for (int i=0; i<tree.mapping.length;++i) {
414         if (tree.rank[i]=='1'){
415             result ~= "assert( (cast(size_t)(" ~ getValueForSymbol(tree.mapping[i], tree).invoker
416                     ~ ".ptr)& 0x0F) == 0, `SSE Vector misalignment: " ~ getValueForSymbol(tree.mapping[i], tree).invoker ~ "`);"\n;
417         }
418     }
419     return result;
420 }
421
422 // Return true if the type has a length which is known at compile time
423 bool arrayLengthIsStatic(char [] type)
424 {
425     if (type.length<3) return false;
426     // BUG: this returns true for AA's.
427     return type[$-1]==']' && type[$-2]!='[';
428 }
429
430 // Return a vector which contains the length of the expression.
431 // If possible, a vector with known (static) length will be chosen.
432 // If this is not possible, a normal dynamic array will be used.
433 // If all else fails, a sliced vector will be used.
434 int findVectorForLength(RevisedExpression tree)
435 {
436     int dynamic = -1; // last dynamic vector
437     int strided = 0; // last unstrided vector
438     for (int i = 0; i < tree.mapping.length; ++i) {
439         if (tree.rank[i]!='1') continue;
440         int x = tree.mapping[i]-'A';
441         strided = i;
442         if (x < tree.symbolTable.length) {
443             if (arrayLengthIsStatic(tree.symbolTable[x].type)) return i;
444             dynamic = i;
445         } else {
446             // Check for a stride.
447             if (tree.compounds[x-tree.symbolTable.length][$-1]==']') {
448                 if (!isStrided(tree.compounds[x-tree.symbolTable.length])) {
449                     dynamic = i;
450                 }
451             }
452         }
453     }
454     return dynamic>=0? dynamic : strided;
455 }
456
457 bool hasDollar(char [] s)
458 {
459     foreach(c; s) { if (c=='$') return true; }
460     return false;
461 }
462
463 char [] getDimensionLengthForSymbol(char c, RevisedExpression tree, int dimension)
464 {
465     int numSlicesRemaining = 1;
466     assert(dimension == 1);
467     char [] v = "";
468     // is it an original symbol?
469     if (c-'A'<tree.symbolTable.length) {
470         v = tree.symbolTable[c-'A'].value;
471         return v ~ ".length";
472     } else {  // else it's a compound or an indexed array
473         char [] comp = tree.compounds[c-'A'-tree.symbolTable.length];
474
475         if (comp[$-1]!=']') { // simple compound expression
476             foreach(d; comp) {
477                 if (d=='{') assert(0, "BLADE ICE");
478                 if (d>='A' && d<='Z') v ~= tree.symbolTable[d-'A'].value;
479                 else v ~= d;
480             }
481             return v ~ ".length";
482         } else {
483             // indexed array, possibly involving slicing
484             int numbrack=0;
485             bool hasSliced=false;
486             // it's easier if we go backwards
487             // Replace the last slice [a..b] operation with [a+firstIndexExpr]
488             // or if no slices exist, append [firstIndexExpr] to the end.
489             int numbracks = 0;
490             bool isSlice = false;
491             char [] nextIndex;
492             char [] sliceTo;
493
494             for (int k = comp.length-1;k>=1; --k) {
495                 char d = comp[k];
496                 if (d == ']') { ++numbracks; }
497                 if (d == '[') { --numbracks; }
498
499                 if (d == ']' && numbracks == 1) { nextIndex = ""; }
500                 else if (numbracks == 1 && comp[k-1..k+1]=="..") {
501                     isSlice = true;
502                     sliceTo = nextIndex;
503                     nextIndex = "";
504                     --k;
505                 } else if ((d == '[' && numbracks==0) || (d==',' && numbracks==1)) {
506                     if (isSlice && numSlicesRemaining>0) {
507                         if (numSlicesRemaining==1) {
508                             if (!hasDollar(nextIndex) && !hasDollar(sliceTo)) {
509                                 return "(" ~ sliceTo ~ "-" ~ nextIndex ~ ")";
510                             }
511                         }
512                         v = "[" ~ nextIndex ~ ".." ~ sliceTo
513                          ~ "].length";
514                         --numSlicesRemaining;
515                     } else {
516                         if (isSlice)
517                             v = "[" ~ nextIndex ~ ".." ~ sliceTo ~ "]" ~ v;
518                         else v = "[" ~ nextIndex ~ "]" ~ v;
519                     }
520                     nextIndex = "";
521                     isSlice = false;
522                 } else {
523                     if (d>='A' && d<='Z') nextIndex = tree.symbolTable[d-'A'].value ~ nextIndex;
524                     else nextIndex = d ~ nextIndex;
525                 }
526             }
527             if (numSlicesRemaining>0) v ~= ".length";
528             return tree.symbolTable[comp[0]-'A'].