root/trunk/blade/CodegenX86.d

Revision 187, 26.0 kB (checked in by Don Clugston, 2 months ago)

Added prod(). Use .ptr to get raw data, so it works with Bill Baxter's ArrayView?.

  • Property svn:eol-style set to native
Line 
1 //  Written in the D programming language 1.0
2 /**
3 * BLADE Alpha -- Basic Linear Algebra D Expressions
4 *
5 * Generate near-optimal x87/SSE/SSE2 asm code for BLAS1 basic vector operations
6 * at compile time.
7 * 32, 64 and 80 bit vectors are all supported.
8 * Uses techniques described in Agner Fog's superb Pentium optimisation manual (www.agner.org).
9 *
10 * Author:
11 *   Don Clugston.
12 * License:
13 *   Public domain.
14 *
15 * FEATURES:
16 *  - Supports any mix of vector addition, subtraction, dot product, and multiplication
17 *    by a scalar, with strided vector access.
18 *  - Generates either x87, SSE, or SSE2 asm code.
19 *  - If x87 code is generated, 80-bit precision is used whenever possible.
20 *  - Supports mixed-length operations (eg, real[] + double[] + float[]).
21 *
22 * BUGS/ FUTURE DIRECTIONS:
23 *  None of these support matrix operations.
24 * X87:
25 *  - Not optimal for the case of multiple real vectors (they could share a counter).
26 *  - Not optimal for the case where all vectors are 80-bit (two counters are used, but only is required).
27 *  - Doesn't take full advantage of length being known at compile time (loop unrolling
28 *     is possible).
29 *  - Doesn't use EBP register -- this would allow an extra vector in expressions.
30 *   (to do this, need naked asm with no stack frame).
31 * SSE/SSE2:
32 *  - SSE functions don't support unaligned data. Need to generate seperate code
33 *    for that case (NOTE: probably only worth doing for small expressions).
34 *
35 * THEORY:
36 * An expression string of the form "A+(B*C)*D" is used, together with a tuple, the
37 * entries of which correspond to A, B, C, ...
38 * This string is converted to postfix. The postfix string is converted to
39 * a string containing x87 asm, which is then mixed into a function which accepts the tuple.
40 */
41
42 /*
43  * POTENTIAL FROM RECENT INSTRUCTION SETS:
44  * SSE5(AMD): fmaddpd can dramatically improve both performance and accuracy.
45  * SSE4(Intel): dppd has some limited use.
46  */
47
48 module blade.CodegenX86;
49 private import blade.BladeUtil;
50 private import blade.PostfixX86;
51
52 private:
53
54 // --------------
55 // Ranklist functions
56
57 // Count the number of vectors
58 int countVectors(char[] ranklist)
59 {
60     int numVecs=0;
61     for (int i=0; i<ranklist.length; ++i) {
62         if (ranklist[i]=='1') ++numVecs;
63     }
64     return numVecs;
65 }
66
67 int vectorNum(char [] ranklist, char var)
68 {
69     int numVecs=0;
70     for (int i=0; i<var-'A'; ++i) {
71         if (ranklist[i]=='1') ++numVecs;
72     }
73     return numVecs;
74 }
75
76 int strideVectorNum(char [] ranklist, char [] stridelist, char var)
77 {
78     int numVecs=0;
79     for (int i=0; i<var-'A'; ++i) {
80         if (ranklist[i]=='1' && stridelist[i]=='1') ++numVecs;
81     }
82     return numVecs;
83 }
84
85 int scalarNum(char [] ranklist, char var)
86 {
87     int k=0;
88     for (int i=0; i<var-'A'; ++i) {
89         if (ranklist[i]=='0') ++k;
90     }
91     return k;
92 }
93
94 int realScalarNum(char [][] typelist, char [] ranklist, char var)
95 {
96     int k=0;
97     for (int i=0; i<var-'A'; ++i) {
98         if (ranklist[i]=='0' && typelist[i]=="real") ++k;
99     }
100     return k;
101 }
102 private:
103 // -------------------------------
104 //   Mixins to generate x87 ASM code
105 // -------------------------------
106
107 /// True if the character is an operation (everything else is an operand)
108 bool isInstruction(char op)
109 {
110     return (op=='+' || op=='*' || op=='-'|| op=='_' || op=='=');
111 }
112
113 /// Count the number of temporaries which occur in the postfix expression.
114 int countTemporaries(char [] postfix)
115 {
116 // A temporary occurs whenever we load two values without an operation performed on the
117 // first one.
118     int numTemps=0;
119     for (int i=1; i<postfix.length; ++i) {
120         if (!isInstruction(postfix[i-1]) && !isInstruction(postfix[i])) numTemps++;
121     }
122     return numTemps;
123 }
124
125
126 /// The maximum number of simultaneous temporary values in the postfix expression.
127 int maxActiveTemporaries(char [] postfix)
128 {
129     int maxTemps=0;
130     int numTemps=0;
131     for (int i=1; i<postfix.length; ++i) {
132         if (!isInstruction(postfix[i-1]) && !isInstruction(postfix[i])) numTemps++;
133         if (isInstruction(postfix[i-1]) && isInstruction(postfix[i])) numTemps--;
134         if (maxTemps<numTemps) maxTemps=numTemps;
135     }
136     return maxTemps;
137
138 }
139
140 unittest {
141     assert(countTemporaries("AB*BC*+DE*+")==3);
142     assert(maxActiveTemporaries("AB*BC*+DE*+")==2);
143 }
144
145 char [] operandSize(char [] typestr)
146 {
147     switch(typestr) {
148         case "real":   return "real ptr ";
149         case "double": return "double ptr ";
150         case "float":  return "float ptr ";
151         default:
152         assert(0, typestr);
153     }
154 }
155
156 char [][char] opToX87() {
157     return ['*':"fmul"[], '+': "fadd", '-': "fsub", '_': "fsubr"]; }
158
159 char [][char] opToSSE() {
160     return ['*':"mulp"[], '+': "addp", '-': "subp", '/': "divp"]; }
161
162 char [][char] opToSSESingle() {
163     return ['*':"muls"[], '+': "adds", '-': "subs", '/': "divs"]; }
164
165 static if (real.sizeof==10)      const char [] REALSIZE = "10";
166 else static if (real.sizeof==12) const char [] REALSIZE = "12";
167 else static if (real.sizeof==16) const char [] REALSIZE = "16";
168
169 char [] vectorSize(char [] typestr)
170 {
171     switch (typestr) {
172         case "double": return "8";
173         case "float": return "4";
174         case "real": return REALSIZE;
175     }
176 }
177
178
179 // First, use the scratch registers (EAX, ECX, EDX). EAX is always used as
180 // an index register. If there are more than 2 vectors, use EBX, ESI, and EDI,
181 // which need to be pushed and popped.
182 // TODO: Finally, use the frame register EBP.
183 const char [][5] vectorRegister = ["ECX", "EDX", "EBX", "ESI", "EDI"];
184
185 public:
186
187 // Maximum number of vectors allowable in a expression for this code generator
188 const int MAX_X87_VECTORS = vectorRegister.length;
189 const int MAX_SSE_VECTORS = vectorRegister.length;
190 // Maximum number of real scalars allowable in an expression (
191 // (max # temporaries + max # real scalars) must be <=8, otherwise FPU stack
192 // will overflow).
193 const int MAX_87_REALSCALARSPLUSTEMPORARIES = 8;
194
195 private:
196
197 // Create code to push all used vector registors.
198 char [] pushRegisters(int numVectors)
199 {
200     char [] result = "";
201     for (int i=2; i<numVectors; ++i) result~= " push " ~ vectorRegister[i] ~ ";";
202     return result ~ \n;
203 }
204
205 // Create code to pop all used vector registors.
206 char [] popRegisters(int numVectors)
207 {
208     char [] result = ";  ";
209     for (int i=numVectors-1; i>=2; --i) result~= "pop " ~ vectorRegister[i] ~ "; ";
210     return result ~ \n;
211 }
212
213 // indexed by i.
214 char [] indexedVector(char [][] typelist, char [] ranklist, char [] stridelist, char var)
215 {
216     if (typelist[var-'A']=="real") return " real ptr [" ~ vectorRegister[vectorNum(ranklist, var)] ~ "]";
217     else if (stridelist[var-'A']=='1') return operandSize(typelist[var-'A']) ~ "[" ~ vectorRegister[vectorNum(ranklist, var)] ~ "]";
218     return operandSize(typelist[var-'A']) ~ "[" ~
219             vectorRegister[vectorNum(ranklist, var)] ~ " + " ~ vectorSize(typelist[var-'A']) ~ "*EAX]";
220 }
221
222 // indexed by i-1
223 char [] indexedVectorPrev(char [][] typelist, char [] ranklist, char var)
224 {
225     char [] stride = " - " ~ vectorSize(typelist[var-'A']);
226     if (typelist[var-'A'] == "real") return " real ptr [" ~ vectorRegister[vectorNum(ranklist, var)] ~ stride ~ "]";
227     return operandSize(typelist[var-'A']) ~ "[" ~
228             vectorRegister[vectorNum(ranklist, var)] ~ " + " ~ vectorSize(typelist[var-'A']) ~ "*EAX" ~ stride ~ "]";
229 }
230
231 char [] indexedSSEVector(char [] ranklist, char var, char [] vecsize)
232 {
233     return "[" ~ vectorRegister[vectorNum(ranklist, var)] ~ " + " ~ vecsize ~"*EAX]";
234 }
235
236 char [] indexedSSENext(char [] ranklist, char var, char [] vecsize)
237 {
238     return "[" ~ vectorRegister[vectorNum(ranklist, var)] ~ " + " ~ vecsize ~"*EAX+16]";
239 }
240
241 char [] indexedVectorWithStride(char [][] typelist, char [] ranklist, char var, int stride)
242 {
243     char [] stridestr = " - " ~ vectorSize(typelist[var-'A']) ~ "*" ~ itoa(stride);
244     if (typelist[var-'A'] == "real") return " real ptr [" ~ vectorRegister[vectorNum(ranklist, var)] ~ stridestr ~ "]";
245     return operandSize(typelist[var-'A']) ~ "[" ~
246             vectorRegister[vectorNum(ranklist, var)] ~ " + " ~ vectorSize(typelist[var-'A']) ~ "*EAX" ~ stridestr ~ "]";
247 }
248
249 // Pop N values from the FPU stack
250 char [] discardFromStack(int n)
251 {
252     char [] result="";
253     while (n>1) {
254         result~= "  fcompp ST(0), ST;"\n; // pop two values at once
255         n-=2;
256     }
257     if (n==1) result~= "  fstp ST(0), ST;"\n;
258     return result;
259 }
260
261 public:
262
263 /** Generate asm code which is optimal for x87 CPUs without SSE2.
264  (Pentium, PMMX, PII, PIII). It is also optimal for recent x86 CPUs
265  where vector sizes are mixed.
266
267  There are two cases:
268  (A) DAXPY-style loops, where every element is independent of the other indices;
269  (B) DDOT-style loops, where the result for every element is accumulated.
270
271  For cumulative loops, best performance is achieved with loop unrolling and
272  multiple accumulators, in order to break dependency chains.
273
274 The key optimisation rules for DAXPY loops are:
275  1. keep the loop overhead to one clock cycle if possible.
276  2. (FMUL latency) don't use the result of a multiply immediately
277 Techniques to address these are:
278  1. Use EAX as a counter and index variable, which begins negative and counts UP to zero.
279     Combine counters for all packed doubles and floats into this single counter.
280  2. The latency of fmul is avoided by swapping fadd/fsub with fmul whenever possible.
281
282 The generated code is of the form:
283 ----
284  load scalars onto FPU stack
285  load vector pointers into EAX, EBX, ...
286 L1:
287  calculate result into ST(0)
288  increment pointers
289  goto L1 if not done
290  pop scalars off FPU stack
291 ----
292
293 */
294 char [] generateCodeForAsmX87(int numStrides, Values...)(char [] postfixOperations)
295 {
296 // Because of compiler bug #1125, no structs can be stored in the 'values' tuple.
297 // Thus, lengths and strides must be stored seperately from vector pointers.
298     char [] ranklist;
299     char [][] typelist;
300     char [] stridelist; // "1" = strided, "0"=unstrided
301     foreach(T; Values[0..$-numStrides]) {
302         static if (is(typeof(T[0]))) {
303             stridelist~="0";
304             ranklist~="1";
305             typelist ~= typeof(T[0]).stringof;
306         } else static if (is(typeof(T.data))) {
307             stridelist~="1";
308             ranklist~="1";
309             typelist ~= typeof(T.data[0]).stringof;
310         } else {
311             stridelist~="0";
312             ranklist~="0";
313             typelist ~= T.stringof;
314         }
315     }
316     return generateCodeForAsmX87Impl(ranklist, typelist, stridelist, postfixOperations);
317 }
318
319 private:
320 // This is split off from the template to make code coverage easier.
321 char [] generateCodeForAsmX87Impl(char [] ranklist, char [][] typelist, char [] stridelist, char [] operations)
322 {
323     char [] result="";
324     char [] incrementRealVectors="";
325
326     result ~= "// Operation : " ~  operations ~ \n;
327
328     // Create local variables for pointers to vectors (avoid bug #1125)
329
330     int vecnum = 0;
331     int stridecount = 0;
332     for (int i=0; i< ranklist.length;++i) {
333         if (ranklist[i]=='1'){
334             if (stridelist[i]=='1') {
335                 incrementRealVectors ~= "  add " ~ vectorRegister[vecnum] ~ ", values[" ~ itoa(ranklist.length+stridecount) ~ "];\n";
336                 ++stridecount;
337             } else if (typelist[i]=="real") {
338                 incrementRealVectors ~= "  add " ~ vectorRegister[vecnum] ~ ", " ~ REALSIZE ~ ";\n";
339             }
340             ++vecnum;
341         }
342     }
343
344     int numScalarsOnStack=0;
345
346     result~= \n"asm {"\n ~ pushRegisters(vecnum);
347     // EAX will be the counter
348     result ~= "  mov EAX, veclength;"\n;
349
350     // Load all the vector pointers into registers, and push all the scalars onto the stack
351
352     int numvecs=0;
353     int numconsts=0;
354     for (int i=0; i<ranklist.length; ++i) {
355       if (ranklist[i]=='1') {
356           if (typelist[i]=="real" || stridelist[i]=='1') {
357               result ~= "  mov " ~ vectorRegister[numvecs] ~ ", values[" ~ itoa(i) ~ "];";
358           } else  {
359             result ~= "  lea " ~ vectorRegister[numvecs]
360               ~ ", [" ~ vectorSize(typelist[i]) ~ "*EAX];   "
361               ~ "  add " ~ vectorRegister[numvecs] ~ ", values[" ~ itoa(i) ~ "];";
362          }
363          result ~= "  //" ~ cast(char)('A'+i) ~ \n;
364         ++numvecs;
365       } else if (typelist[i]=="real") {
366           result ~= "  fld real ptr values["~ itoa(i) ~"];";
367           ++numconsts;
368           ++numScalarsOnStack;
369          result ~= "  //" ~ cast(char)('A'+i) ~ \n;
370       }
371     }
372     int done=0;
373
374     // We need to keep track of how many things are on the FPU stack.
375     // Every time something is pushed, the indices of our variables change!
376     int numOnStack = 0; // How much of the FP stack is being used?
377
378     bool isCumulative = (operations[0]=='0' || operations[0]=='1');
379     if (operations[0]=='0') {
380         result ~= "  fldz;"\n; // dot product
381         ++numOnStack;
382         done = 1;
383     } else if (operations[0]=='1') {
384         result ~= "  fld1;"\n; // prod
385         ++numOnStack;
386         done = 1;
387     }
388     result ~= "  xor EAX, EAX; "\n
389         "  sub EAX, veclength; // counter=-length"\n
390         "  jz short L3; // test for length==0"\n;
391
392     // Construct the main body of the loop (the main body does not include
393     // the final storage instruction, because of the FST latency).
394     char [] mainbody = "";
395
396     while(done<operations.length) {
397         char [] next;
398         if (isInstruction(operations[done])) {
399             // Perform an arithmetic operation on the top two FPU stack items.
400             next = "  " ~ opToX87[operations[done]] ~ "p ST(1), ST;  //" ~ operations[done] ~ \n;
401             mainbody ~= next;
402             ++done;
403             numOnStack--;
404         } else if (operations[done]=='a') {
405             mainbody ~= "  fabs;"\n;
406             ++done;
407         } else if (operations[done]=='n') {
408             mainbody ~= "  fchs;"\n;
409             ++done;
410         } else if (operations[done]=='q') {
411             mainbody ~= "  fsqrt;"\n;
412             ++done;
413         } else if (!isInstruction(operations[done+1])){
414             // load a vector onto the FPU stack, to begin a new subexpression.
415             int u  = operations[done]-'A';
416             if (ranklist[operations[done]-'A']=='1') {
417                 next = "  fld "  ~ indexedVector(typelist, ranklist, stridelist, operations[done] ) ~ ";  //" ~ operations[done] ~\n;
418             } else { // load constant. Will never be a real
419                 next = "  fld " ~ operandSize(typelist[operations[done]-'A']) ~ "values[" ~ itoa(operations[done]-'A') ~"]; // * " ~ operations[done..done+1] ~ "\n";
420             }
421             mainbody ~= next;
422             ++done;
423             numOnStack++;
424         } else if (operations[done]==',') {
425             mainbody ~= "  " ~ opToX87[operations[done+1]] ~ " ST, ST(0);    // dup " ~ operations[done+1] ~ \n;
426             done+=2;
427         } else if (ranklist[operations[done]-'A']=='1') {
428              // An operation will be performed between the stack top and a vector.
429              // If it's a float or double, we can combine the load+arithmetic op
430              // into a single instruction. Stores of reals can also be done in one instr.
431             char [] comment = ";  // " ~ operations[done..done+2] ~ \n;
432             if (operations[done+1]=='=') {
433                 // If it's the last operation, pop it from the stack; otherwise,
434                 // it chains.
435                 next = ((done+2 == operations.length)? "  fstp " : "  fst ")
436                     ~ indexedVector(typelist, ranklist, stridelist, operations[$-2] ) ~ comment;
437             } else if (typelist[operations[done]-'A']=="real") {
438                  // 80-bit vectors must be loaded onto the FPU stack first
439                 next = "  fld real ptr ["  ~ vectorRegister[vectorNum(ranklist, operations[done])] ~ "]; //" ~ operations[done] ~ \n
440                     ~ "  " ~ opToX87[operations[done+1]] ~ "p ST(1), ST; //" ~ operations[done+1] ~\n;
441             } else { // floats and doubles can be used directly
442                 next = "  " ~ opToX87[operations[done+1]] ~ " "
443                   ~ indexedVector(typelist, ranklist, stridelist, operations[done] ) ~ comment;
444             }
445             mainbody ~= next;
446             done+=2;
447         } else { // multiply by scalar.
448           if (typelist[operations[done]-'A']=="real") {
449             // Multiply by real scalar, which is already on the stack.
450             next = "  fmul ST, ST(" ~ itoa(numOnStack + numScalarsOnStack - realScalarNum(typelist, ranklist, operations[done]-'A')-1) ~ "); // * " ~ operations[done] ~ \n;
451             mainbody ~= next;
452           } else {
453             // For scalar float or double values, we can multiply directly, saving one slot on the FP stack.
454             next = "  fmul " ~ operandSize(typelist[operations[done]-'A']) ~ "values[" ~ itoa(operations[done]-'A') ~"]; // * " ~ operations[done..done+1] ~ "\n";
455             mainbody ~= next; //firstbody ~= next;
456           }
457             done +=2;
458         }
459     }
460
461     result ~= \n
462         ~ "  align 4;\n"
463         ~ "L1:\n" ~ mainbody;
464
465 //    if (cumulatingOp) result ~= "  " ~ opToX87[cumulatingOp] ~ "p ST(2), ST;"\n;
466
467     result ~= incrementRealVectors // Update the counters
468            ~ "  inc EAX;\n  jnz L1;\n";
469
470     // Discard any scalars that are left on the stack
471     if (isCumulative && numScalarsOnStack>0) {
472         // Preserve the result of the dot product
473         result ~= "  fxch ST(" ~ itoa(numScalarsOnStack) ~ "), ST;"\n;
474     }
475     result ~= discardFromStack(numScalarsOnStack);
476
477     result~= "L3:" \n ~ popRegisters(vecnum) ~ "}\r\n";
478
479     return result;
480 }
481
482 //-----------------------------
483
484 char [] XMM(int k) { return "XMM"~ itoa(k); }
485
486 public:
487
488 /** Generate BLAS1 asm code using SSE or SSE2.
489  * For SSE2, all scalars are double, vectors are double*; for SSE1, all are float.
490  * At entry, all vector parameters are aligned.
491  */
492 char [] generateCodeForSSE(Values...)(char [] operations)
493 {
494     char [] ranklist;
495     bool usingDoubles=false;
496     foreach(T; Values) {
497         static if (is(typeof(T[0]))) ranklist~="1"; else ranklist~="0";
498         static if (is(T == double) || is(T == double *)) {
499             usingDoubles = true;
500         } //else assert(is(T==float)|| is(T==float*));
501     }
502     return generateCodeForSSEImpl(usingDoubles, ranklist, operations);
503 //    makePostfixForSSE(infixOperations, ranklist));
504 }
505
506 // Note: If SSE4 is available, the dppd and dpps instructions could be
507 // used to replace the final *+ in dot-product operations. This would allow
508 // an in-order dot product to be performed, but otherwise doesn't give much speed
509 // improvement.
510
511 private:
512
513 // split off from the template to make code coverage work
514 char [] generateCodeForSSEImpl(bool usingDoubles, char [] ranklist, char [] operations, char cumulatingOp=0)
515 {
516     char [] result="";
517
518     result ~= "// Operation : " ~  operations ~ \n;
519
520     int numvecs = countVectors(ranklist);
521     int numScalarsOnStack=0;
522     bool isCumulative = (operations[0]=='0') || (operations[0]=='1');
523     if (isCumulative) result ~= (usingDoubles? "  double" : "  float") ~" sum;"\n;
524
525     result~= \n"asm {"\n ~ pushRegisters(numvecs);
526     // EAX will be the counter
527     result ~= "  mov EAX, veclength;"\n;
528     // Load all the vector pointers into registers
529
530     char [] vectorsize = usingDoubles? "8" :"4"; // size of a double
531     char [] suffix = usingDoubles? "d " :"s ";
532
533     int vecregnum = 0;
534     int numconsts=0;
535     for (int i=0; i<ranklist.length; ++i) {
536       if (ranklist[i]=='1') {
537         result ~= "  lea " ~ vectorRegister[vecregnum]
538           ~ ", [" ~ vectorsize ~ "*EAX];   "