Changeset 94
- Timestamp:
- 04/11/07 15:32:52 (2 years ago)
- Files:
-
- trunk/mathextra/Blade.d (modified) (24 diffs)
Legend:
- Unmodified
- Added
- Removed
- Modified
- Copied
- Moved
trunk/mathextra/Blade.d
r93 r94 13 13 * FEATURES: 14 14 * - Supports any mix of vector addition, subtraction, dot product, and multiplication 15 * by a real or pure imaginary scalar.15 * by a real, complex, or pure imaginary scalar. 16 16 * - Generates either x87 asm code or pure D, depending on the complexity of 17 17 * the expression, and the availability of inline asm. 18 18 * - 80-bit precision is used whenever possible. 19 19 * - Supports mixed-length operations (eg, real[] + double[] + float[]). 20 * - Supports both real and imaginary vectors, and detects type mismatches between them. 21 * (Complex numbers are not yet supported). 20 * - Supports real, complex and imaginary vectors, and detects type mismatches between them. 22 21 * - When static arrays are used, mismatches in array length are detected 23 22 * at compile time. … … 36 35 * is possible). 37 36 * - Doesn't use EBP register -- this would allow an extra vector in expressions. 38 * - Doesn't support complex numbers. 37 * (to do this, need naked asm with no stack frame). 38 * - There is no asm support for complex numbers. 39 39 * - Need an SSE2 version. 40 40 * 41 41 * THEORY: 42 * Expression templates are used to create an expression string of the form "( a+b*c)+d"43 * and a tuple, the entries of which correspond to a, b, c,d, ...42 * Expression templates are used to create an expression string of the form "(#a+#b*#c)+#d" 43 * and a tuple, the entries of which correspond to #a, #b, #c, #d, ... 44 44 * This string is converted to postfix. The postfix string is converted to 45 45 * a string containing x87 asm, which is then mixed into a function which accepts the tuple. … … 86 86 // 'abstract syntax tree' constructed from multiple templates. Instead, 87 87 // the expression is constructed as a normal, human-readable text string 88 // (for example, " a+(b*c-d)").88 // (for example, "#a+(#b*#c-#d)"). 89 89 // This string is a template parameter for a struct, which 90 90 // contains a tuple of all the arguments used in the expression. … … 98 98 // Find the highest used variable in the first expression... 99 99 char nextarg = 'a'; 100 for (int i= 0; i<first.length; ++i) {101 if (first[i ]>nextarg && first[i]<='z') nextarg=first[i];100 for (int i=1; i<first.length; ++i) { 101 if (first[i-1]=='#' && first[i]>nextarg) nextarg=first[i]; 102 102 } 103 103 // Add parentheses round the first expression if required. 104 104 char [] ret=""; 105 if (first.length> 1) {105 if (first.length>2) { 106 106 ret = "(" ~ first ~ ")"; 107 107 } else { … … 110 110 // ... and add it to all variables in the second expression, adding 111 111 // parentheses if required. 112 if (second==" a") return ret ~ op~ cast(char)(nextarg+1);112 if (second=="#a") return ret ~ op ~ "#" ~ cast(char)(nextarg+1); 113 113 ret ~= op ~ "("; 114 114 for (int i=0; i<second.length; ++i) { 115 if (second[i]>='a' && second[i]<='z') ret~= (second[i]+(nextarg-'a'+1)); 116 else ret ~=second[i]; 115 if (second[i]=='#') { 116 ret~= "#" ~ cast(char)(second[i+1]+(nextarg-'a'+1)); 117 ++i; 118 } else ret ~=second[i]; 117 119 } 118 120 return ret ~ ")"; … … 134 136 * Stores all the arguments as a tuple, and the operations 135 137 * as a character string. 136 * Note that if operations==" a", it's a single vector; otherwise, it's a temporary.138 * Note that if operations=="#a", it's a single vector; otherwise, it's a temporary. 137 139 * 'knownlength' is the length of vector, when known at compile time; 138 140 * if 'knownlength' == 0, it is unknown at compile time. … … 159 161 // All they do is update the expression string and the tuple, creating a new 160 162 // VectorExpr. The existing VectorExpr will not be used again. 161 static if (operations.length> 1&& operations[$-2]=='*') {163 static if (operations.length>2 && operations[$-2]=='*') { 162 164 // Optimisation: already a scalar multiply, so constant fold it. 163 165 VectorExpr!(BaseType, operations, knownlength, B) opMul(real x) { … … 168 170 } 169 171 } else { 170 JoinResult!(BaseType, "*", "a", real) opMul(real x) { 171 return JoinResult!(BaseType, "*", "a", real)(values, x); 172 } 173 JoinResult!(typeof(BaseType*1.0fi), "*", "a", real) opMul(ireal x) { 174 return JoinResult!(typeof(BaseType*1.0fi), "*", "a", real)(values, x.im); 172 // trick: typeof(C*C) converts imag to real, but leaves real & complex unchanged. 173 JoinResult!(typeof(BaseType*C), "*", "#a", typeof(C*C*1.0L)) opMul(C)(C x) { 174 static assert(is(C: real) || is(C:ireal) || is(C:creal), "Can only multiply by scalars"); 175 static if (is(C: ireal)) { 176 return JoinResult!(typeof(BaseType*C), "*", "#a", real)(values, x.im); 177 } else static if (is(C: real)) { 178 return JoinResult!(BaseType, "*", "#a", real)(values, x); 179 } else { 180 return JoinResult!(typeof(BaseType*C), "*", "#a", C)(values, x); 181 } 175 182 } 176 183 } … … 184 191 // The opAssign operations are only valid for single vectors, not for temporaries 185 192 // They actually perform the calculation. 186 static if (operations==" a") {193 static if (operations=="#a") { 187 194 void opAssign(A)(A expr) { 188 195 static assert(CompatibleVectors!(BaseType,A.BaseType), "Vector type mismatch in " ~ BaseType.stringof ~ "[] = " ~ A.BaseType.stringof ~ "[]"); … … 201 208 } 202 209 void opMulAssign(A)(A w) { // Use a template to avoid unnecessary code generation 203 static assert( is (A: real), "Vector type mismatch in " ~ BaseType.stringof ~ "[] *= real");204 performOperation!(void, " a", "*=", knownlength, real, B[0])(w, values);210 static assert((is (BaseType: creal) && is(A:ireal) || is(A:creal)) || is (A: real), "Vector type mismatch in " ~ BaseType.stringof ~ "[] *= " ~ A.stringof); 211 performOperation!(void, "#a", "*=", knownlength, A, B[0])(w, values); 205 212 } 206 213 } … … 208 215 209 216 // Convert static arrays to dynamic, but remember the length as a compile-time parameter. 210 VectorExpr!(X, " a", Q, X[]) Vec(X, int Q)(X[Q] vals) {211 VectorExpr!(X, " a", Q, X[]) a;217 VectorExpr!(X, "#a", Q, X[]) Vec(X, int Q)(X[Q] vals) { 218 VectorExpr!(X, "#a", Q, X[]) a; 212 219 a.values[0]=vals; 213 220 return a; 214 221 } 215 222 216 VectorExpr!(X, " a", 0, X[]) Vec(X)(X[] vals) {217 VectorExpr!(X, " a", 0, X[]) a;223 VectorExpr!(X, "#a", 0, X[]) Vec(X)(X[] vals) { 224 VectorExpr!(X, "#a", 0, X[]) a; 218 225 a.values[0]=vals; return a; 219 226 } … … 241 248 "\nPostfix: " ~ makePostfixForX87(operations, tupstr) ~ "\nTuple: " ~ tupstr); 242 249 static if (knownlength!=0) pragma(msg, "Length is known!"); 250 static if (isSSE2AsmPossible(tupstr, operations)) { 251 const char [] q1 = generateCodeForAsmSSE2(knownlength, tupstr, makePostfixForX87(operations, tupstr), finaloperation); 252 pragma(msg, q1); 253 } else static if (isX87AsmPossible(tupstr, operations)) { 243 254 244 255 const char [] qqq = generateCodeForAsmX87(knownlength, tupstr, makePostfixForX87(operations, tupstr), finaloperation); 245 256 pragma(msg, "Generated code:"\n ~ qqq); 257 } else pragma(msg, "Too complicated for x87 -- generating D code instead"); 246 258 } 247 259 248 260 // Decide which code generator to use, based on expression complexity and 249 261 // assembler availability. 250 if (isX87AsmPossible(tupstr, operations)) { 262 static if (isSSE2AsmPossible(tupstr, operations)) { 263 mixin(generateCodeForAsmSSE2(knownlength, tupstr, makePostfixForX87(operations, tupstr), finaloperation)); 264 } else static if (isX87AsmPossible(tupstr, operations)) { 251 265 mixin(generateCodeForAsmX87(knownlength, tupstr, makePostfixForX87(operations, tupstr), finaloperation)); 252 266 } else { 253 267 mixin(generateCodeForD!(ReturnType)(knownlength, tupstr, finaloperation,operations)); 254 268 } 255 256 269 } 257 270 … … 266 279 { 267 280 static if (is(A == real[]) || is(A==ireal[])) const char [] singleType = "R"; 281 else static if (is(A == creal[]) || is(A==cdouble[])||is(A==cfloat[])) const char [] singleType = "Z"; 268 282 else static if (is(A == double[])|| is(A == idouble[]))const char [] singleType = "D"; 269 283 else static if (is(A == float[]) || is(A==ifloat[])) const char [] singleType = "F"; 270 else static if (is(A == real)) const char [] singleType = "S"; 284 else static if (is(A == real) || is (A == ireal)) const char [] singleType = "S"; 285 else static if (is(A == creal)) const char [] singleType = "C"; 271 286 else const char [] singleType = "?"; 272 287 } … … 288 303 if (isVector(typelist[i])) return i; 289 304 } 305 assert(0, typelist); 290 306 } 291 307 292 308 bool isVector(char var) 293 309 { 294 return (var=='R' || var=='D' || var=='F' );310 return (var=='R' || var=='D' || var=='F' || var=='Z'); 295 311 } 296 312 … … 320 336 { 321 337 char [] iter=""; 322 foreach(int i, ch; operation) { 323 if (ch=='.') { iter~="*"; } else 324 if (ch>='a' && ch<='z') { 325 iter ~= "expr[" ~ itoa(ch-'a') ~ "]"; 326 if (isVector(typelist[ch-'a'])) iter~="[i]"; 327 } else iter ~= ch; 338 for(int i=0; i< operation.length; ++i) { 339 if (operation[i]=='.') iter~="*"; 340 else if (operation[i]=='#') { 341 int n = operation[i+1]-'a'; 342 iter ~= "expr[" ~ itoa(n) ~ "]"; 343 if (isVector(typelist[n])) iter~="[i]"; 344 ++i; 345 } else iter ~= operation[i]; 328 346 } 329 347 char [] result; … … 349 367 int exprLength(char [] s) 350 368 { 369 if (s[0]=='#') return 1; 351 370 int numParens=0; 352 371 for (int i=0; i<s.length; ++i) { … … 357 376 } 358 377 359 // Converts an infix string into postfix360 char [] makePostfix(char [] operations)361 {362 if (operations.length==1) return operations;363 364 int x = exprLength(operations);365 char [] first = operations[0..x+1];366 char [] second = operations[x+2..$];367 if (first[0]=='(') {368 first = makePostfix(first[1..$-1]);369 }370 if (second[0]=='(') {371 second = makePostfix(second[1..$-1]);372 }373 return first ~ second ~ operations[x+1..x+2];374 }375 376 378 // Converts an infix string into postfix. 377 379 // Apply x87-specific optimisations during the conversion. 378 380 char [] makePostfixForX87(char [] operations, char [] typelist) 379 381 { 380 if (operations.length==1) return operations; 382 // if (operations.length==1) return operations; 383 if (operations.length==2 && operations[0]=='#') return operations[1..$]; 381 384 382 385 int x = exprLength(operations); … … 385 388 if (first[0]=='(') { 386 389 first = makePostfixForX87(first[1..$-1], typelist); 387 } 390 } else if (first[0]=='#') first = operations[1..x+1]; 388 391 if (second[0]=='(') { 389 392 second = makePostfixForX87(second[1..$-1], typelist); 390 } 393 } else if (second[0]=='#') second = operations[x+3..$]; 394 391 395 // x87 OPTIMISATION #1 392 396 // On x87, fmul has a long latency, so we want to delay using the … … 431 435 int numVecs=0; 432 436 for (int i=0; i<typelist.length; ++i) { 433 if (typelist[i]=='R' || typelist[i]=='D' || typelist[i]=='F' ) ++numVecs;437 if (typelist[i]=='R' || typelist[i]=='D' || typelist[i]=='F' || typelist[i]=='Z') ++numVecs; 434 438 } 435 439 return numVecs; … … 440 444 int numVecs=0; 441 445 for (int i=0; i<var-'a'; ++i) { 442 if (typelist[i]=='R' || typelist[i]=='D' || typelist[i]=='F' ) ++numVecs;446 if (typelist[i]=='R' || typelist[i]=='D' || typelist[i]=='F' || typelist[i]=='Z') ++numVecs; 443 447 } 444 448 return numVecs; … … 492 496 const char [][5] vectorRegister = ["EAX", "ECX", "EDX", "EBX", "EDI"]; 493 497 494 // Is this expression simple enough for ourcode generator?498 // Is this expression simple enough for the x87 code generator? 495 499 bool isX87AsmPossible(char [] typelist, char [] operations) { 496 500 version (D_InlineAsm_X86) { 497 501 // Are there enough index registers? 498 if (countVectors(typelist) >= vectorRegister.length) return false; 502 if (countVectors(typelist) > vectorRegister.length) return false; 503 // Does it contain any types we can't deal with? 504 foreach(ch; typelist) { 505 // can only do float, double, and 80-bit vectors and scalars. 506 if (ch!='R' && ch!='D' && ch!='F' && ch!='S') return false; 507 } 499 508 // BUG: should also check if it will overflow the FPU stack 500 509 return true; … … 505 514 } 506 515 516 // Is this expression simple enough for the SSE2 code generator? 517 bool isSSE2AsmPossible(char [] typelist, char [] operations) 518 { 519 version (D_InlineAsm_X86) { 520 // Does it contain any types we can't deal with? 521 foreach(ch; typelist) { 522 // can only do double vectors and scalars. 523 if (ch!='D' && ch!='S') return false; 524 } 525 return false; // not yet implemented 526 } else { 527 // Without an assembler, there's no chance! 528 return false; 529 } 530 } 531 507 532 // Create code to push all used vector registors. 508 533 char [] pushRegisters(int numVectors) … … 538 563 } 539 564 540 /** Generate asm code which is optimal for x87 CPUs without SSE2. It is also 541 optimal for recent x86 CPUs where vector sizes are mixed. 542 (Pentium, PMMX, PII, PIII). 565 char [] generateCodeForAsmSSE2(int knownlength, char [] typelist, char [] operations, char [] finaloperation) 566 { 567 char [] result="asm {"\n 568 ~"L1: \n" 569 ~ " movapd XMM1, [ESI+EAX];"\n 570 ~ " mulpd XMM1, XMM2;"\n 571 ~ " addpd XMM1, [EDI+EAX];"\n 572 ~ " movapd [EDI+EAX], XMM1;"\n 573 ~ " add EAX, 16;"\n 574 ~ " js L1;"\n 575 ~ "}"\n; 576 return result; 577 } 578 579 /** Generate asm code which is optimal for x87 CPUs without SSE2. 580 (Pentium, PMMX, PII, PIII). It is also optimal for recent x86 CPUs 581 where vector sizes are mixed. 543 582 The key optimisation rules are: 544 583 1. keep the loop overhead to one clock cycle if possible. … … 754 793 { 755 794 auto p = Vec([1.0L, 2, 18]); 756 auto q = Vec([3.5L, 1.1, 3.8]); 795 796 auto q = Vec([3.5, 1.1, 3.8]); 757 797 auto r = Vec([17.0f, 28.25, 1]); 758 798 auto z = Vec([17.0i, 28.1i, 1i]); 799 auto w = Vec([2.0+17.0i, 0+28.1i, 8.1+1i]); 800 801 w*= (35.0 + 2.1i); 802 759 803 real d = dot(r, p+r+r); 760 804 assert(d==2267.625); 805 761 806 ireal e = dot(r, z); 762 807 writefln(d, " ", e); 763 808 764 q -= ((r+p)*18.0L*314.1L - (p-r))* 35;809 q -= ((r+p)*18.0L*314.1L - (p-r))*35.0; 765 810 d = dot(r, p+r+r); 766 811 writefln(d, " ", e); 767 812 assert(d==2267.625); 813 q*=2.1L; 814 768 815 /* 769 816 p = r - q*2.0;
