Changeset 170
- Timestamp:
- 01/07/08 09:21:45 (9 months ago)
- Files:
-
- trunk/blade/Blade.d (modified) (18 diffs)
- trunk/blade/CodegenX86.d (modified) (10 diffs)
- trunk/blade/PostfixX86.d (modified) (3 diffs)
Legend:
- Unmodified
- Added
- Removed
- Modified
- Copied
- Moved
trunk/blade/Blade.d
r169 r170 37 37 * 38 38 * FUTURE DIRECTIONS (in order of expected implementation): 39 * - sum(),trace()39 * - trace() 40 40 * - Loop unrolling for cumulative operations dot, sum, trace. 41 41 * - Dense matrix support. … … 92 92 if (revised.errorMessage.length>0) return `static assert(0, "BLADE: ` ~ enquote(revised.errorMessage) ~ `");`; 93 93 VecExpressionType exprType = categorizeExpression(revised); 94 char [] result = generateAsserts(revised, (exprType == VecExpressionType.SSE2Expression || exprType == VecExpressionType.SSE1Expression)); 95 debug(BladeFrontEnd) { 96 result ~= "// Simplified: " ~ revised.expression ~ \n; 97 } 94 InvocationCode q; 98 95 if (exprType == VecExpressionType.SSE2Expression || exprType == VecExpressionType.SSE1Expression) { 99 return result ~ invokeSSE((exprType == VecExpressionType.SSE2Expression), revised)~ ";";96 q = invokeSSE((exprType == VecExpressionType.SSE2Expression), revised); 100 97 } else if (exprType == VecExpressionType.X87Expression) { 101 return result ~ invokeX87(revised) ~ ";";98 q = invokeX87(revised); 102 99 } else { 103 return result ~ DCodeGenerator(revised); 104 } 100 q = DCodeGenerator(revised); 101 } 102 return q.assertions ~ q.invoker ~ ";"; 105 103 } 106 104 … … 113 111 } 114 112 115 template SSERetType(int SSEVersion, char [] expr) {116 static if (expr[0]!='0') alias void SSERetType;117 else static if (SSEVersion==1) alias float SSERetType;118 else alias double SSERetType;119 }120 113 template X87RetType(char [] expr) { 121 static if (expr[0]!='0') alias void X87RetType;122 else alias real X87RetType;114 static if (expr[0]!='0') alias void X87RetType; 115 else alias real X87RetType; 123 116 } 124 117 … … 132 125 * Every member of the Values tuple must only be double or double *. 133 126 */ 134 SSERetType!(SSEVersion, expr) SSEVECGEN(int SSEVersion, char [] expr, Values...)(int veclength, Values values) {127 RetType SSEVECGEN(RetType, char [] expr, Values...)(int veclength, Values values) { 135 128 debug(BladeBackEnd) { 136 pragma(msg, generateCodeForSSE!(Values)( SSEVersion,expr));137 } 138 mixin(generateCodeForSSE!(Values)( SSEVersion,expr));129 pragma(msg, generateCodeForSSE!(Values)(expr)); 130 } 131 mixin(generateCodeForSSE!(Values)(expr)); 139 132 } 140 133 141 134 /** Function to implement BLAS1 operations using X87 assembler. 142 * Every member of the Values tuple must only be real, float[], double [], or real[]. 135 * Every member of the Values tuple must only be real, 136 * float[], double [], or real[], or BladeStrided!(float), !(double), !(real) 143 137 */ 144 138 X87RetType!(expr) X87VECGEN(char [] expr, int numStrides, Values...)(int veclength, Values values) { … … 247 241 // of the parameters. 248 242 249 /// Generate code which will call the X87 function 250 char [] invokeX87(RevisedExpression tree) 251 { 243 struct InvocationCode { 244 char [] invoker; // For mixin: code to invoke the functions. 245 char [] assertions; // For mixin: code to assert that everything is correct 246 } 247 248 /// Generate code which will call the X87 function. 249 InvocationCode invokeX87(RevisedExpression tree) 250 { 251 char [] assertions = assertAllVectorLengthsEqual(tree); 252 252 char [] result = ""; 253 253 char [] stridelist=""; … … 261 261 char rnk = tree.rank[i]; 262 262 vals ~= ","; 263 char [] v = getValueForSymbol(tree.mapping[i], tree); 263 InvocationCode q = getValueForSymbol(tree.mapping[i], tree); 264 char [] v = q.invoker; 265 assertions ~= q.assertions; 264 266 int x = tree.mapping[i]-'A'; 265 267 char [] t; … … 274 276 if (rnk=='0') { 275 277 t = "real"; // convert all compounds to real. 276 // TODO : if the number is exactly representable as a double278 // TODO (tricky): if the number is exactly representable as a double 277 279 // or float, it could use less FPU stack space. 278 280 } else { // for arrays, the type is the type of the original array … … 318 320 result ~= ")("; 319 321 int firstVector = findVectorForLength(tree); 320 return result ~ getValueForSymbol(tree.mapping[firstVector], tree)~ ".length"321 ~ vals ~ stridelist ~ ")" ;322 return InvocationCode(result ~ getValueForSymbol(tree.mapping[firstVector], tree).invoker ~ ".length" 323 ~ vals ~ stridelist ~ ")", assertions); 322 324 } 323 325 … … 330 332 331 333 /// Generate code which will call the SSE/SSE2 code generation function 332 char [] invokeSSE(bool SSE2, RevisedExpression tree) 333 { 334 char [] result = "SSEVECGEN!(" ~ (SSE2?"2":"1") ~ `,"` ~ enquote(makePostfixForSSE(tree.expression, tree.rank)) ~ `"`; 334 InvocationCode invokeSSE(bool SSE2, RevisedExpression tree) 335 { 336 char [] assertions = assertAllVectorLengthsEqual(tree) 337 ~ assertAllVectorsAlign128(tree); 338 339 char [] postfix = makePostfixForSSE(tree.expression, tree.rank); 340 char [] retType = "void"; 341 if (postfix[0]=='0') retType = (SSE2? "double" : "float"); 342 343 char [] result = "SSEVECGEN!(" ~ retType ~ `,"` ~ enquote(postfix) ~ `"`; 335 344 // For SSE2, everything must be implicitly convertible to double. 336 345 char [] vals; … … 341 350 vals ~= ","; 342 351 if (rnk=='1') vals ~= "&"; 343 vals ~= getValueForSymbol(tree.mapping[i], tree); 352 InvocationCode q = getValueForSymbol(tree.mapping[i], tree); 353 vals ~= q.invoker; 354 assertions ~= q.assertions; 344 355 // for vectors, we only need the pointer, not the length 345 // if (rnk=='1') vals ~= ".ptr";346 356 if (rnk=='1') vals ~= "[0]"; 347 357 } … … 349 359 result ~= ")("; 350 360 int firstVector = findVectorForLength(tree); 351 result ~= getValueForSymbol(tree.mapping[firstVector], tree) ~ ".length";361 result ~= getValueForSymbol(tree.mapping[firstVector], tree).invoker ~ ".length"; 352 362 // result ~= tree.symbolTable[firstVector].value ~ ".length"; 353 363 result ~= vals; 354 364 355 return result ~ ")";365 return InvocationCode(result ~ ")", assertions); 356 366 } 357 367 … … 387 397 for (int i=0; i<tree.mapping.length;++i) { 388 398 if (tree.rank[i]=='1'){ 389 result ~= "assert( (cast(size_t)(&" ~ getValueForSymbol(tree.mapping[i], tree) 390 ~ "[0])& 0x0F) == 0, `SSE Vector misalignment: " ~ getValueForSymbol(tree.mapping[i], tree) ~ "`);"\n;399 result ~= "assert( (cast(size_t)(&" ~ getValueForSymbol(tree.mapping[i], tree).invoker 400 ~ "[0])& 0x0F) == 0, `SSE Vector misalignment: " ~ getValueForSymbol(tree.mapping[i], tree).invoker ~ "`);"\n; 391 401 } 392 402 } … … 449 459 if (comp[$-1]!=']') { // simple compound expression 450 460 foreach(d; comp) { 451 if (d=='{') assert(0, "BLADE ICE");461 if (d=='{') assert(0, "BLADE ICE"); 452 462 if (d>='A' && d<='Z') v ~= tree.symbolTable[d-'A'].value; 453 463 else v ~= d; … … 515 525 // TODO: This only works for packed types. Doesn't work for jagged arrays, and 516 526 // is probably very inefficient for user-defined types. 517 return "(&" ~ getValueForSymbol(c, tree, "1") ~ "-&"518 ~ getValueForSymbol(c, tree, "0") ~ ")*" ~ tree.symbolTable[comp[0]-'A'].element ~ ".sizeof";519 } 520 521 522 char []invokeNestedExpression(char [] expr, Symbol[] symbolTable)523 { 524 char [] ranks;527 return "(&" ~ getValueForSymbol(c, tree, "1").invoker ~ "-&" 528 ~ getValueForSymbol(c, tree, "0").invoker ~ ")*" ~ tree.symbolTable[comp[0]-'A'].element ~ ".sizeof"; 529 } 530 531 532 InvocationCode invokeNestedExpression(char [] expr, Symbol[] symbolTable) 533 { 534 char [] ranks; 525 535 for (int i=0; i<symbolTable.length; ++i) { 526 536 ranks ~= symbolTable[i].rank; 527 }528 RevisedExpression revised = remapCompounds(expr, ranks, symbolTable);529 537 } 538 RevisedExpression revised = remapCompounds(expr, ranks, symbolTable); 539 530 540 VecExpressionType exprType = categorizeExpression(revised); 531 541 if (exprType == VecExpressionType.SSE2Expression || exprType == VecExpressionType.SSE1Expression) { … … 534 544 return invokeX87(revised); 535 545 } else { 536 assert(0, "BLADE ICE: Nested D expressions are not yet supported");546 assert(0, "BLADE ICE: Nested D expressions are not yet supported"); 537 547 return DCodeGenerator(revised); 538 548 } 539 549 } 540 550 541 char []getValueForSymbol(char c, RevisedExpression tree, char [] firstIndexExpr="")551 InvocationCode getValueForSymbol(char c, RevisedExpression tree, char [] firstIndexExpr="") 542 552 { 543 553 int numSlicesRemaining=1; 544 554 if (firstIndexExpr=="") numSlicesRemaining=0; 545 555 char [] v = ""; 556 char [] assertions = ""; 546 557 // is it an original symbol? 547 558 if (c-'A'<tree.symbolTable.length) { 548 559 v = tree.symbolTable[c-'A'].value; 549 } else { // else it's a compound or an indexed array 560 } else { // else it's a compound or an indexed array 550 561 char [] comp = tree.compounds[c-'A'-tree.symbolTable.length]; 551 562 552 563 if (comp[$-1]!=']') { // compound expression (not an indexed array) 553 if (comp[0]>='a' && comp[0]<='z') {554 // dot product is a nested expression555 return invokeNestedExpression(comp, tree.symbolTable);556 }564 if (comp[0]>='a' && comp[0]<='z') { 565 // dot product is a nested expression 566 return invokeNestedExpression(comp, tree.symbolTable); 567 } 557 568 for (int k=0; k<comp.length; ++k) { 558 char d = comp[k]; 559 if (d=='{') { 560 int braceStart = k+1; 561 for (; comp[k]!='}'; ++k) {} 562 v ~= invokeNestedExpression(comp[braceStart..k], tree.symbolTable); 563 } else if (d>='A' && d<='Z') { 564 v ~= tree.symbolTable[d-'A'].value; 569 char d = comp[k]; 570 if (d=='{') { 571 int braceStart = k+1; 572 for (; comp[k]!='}'; ++k) {} 573 InvocationCode q = invokeNestedExpression(comp[braceStart..k], tree.symbolTable); 574 v ~= q.invoker; 575 assertions ~= q.assertions; 576 } else if (d>='A' && d<='Z') { 577 v ~= tree.symbolTable[d-'A'].value; 565 578 } else v ~= d; 566 579 } … … 614 627 v ~= "[" ~ firstIndexExpr ~ "]"; 615 628 } 616 return tree.symbolTable[comp[0]-'A'].value ~ v;629 return InvocationCode(tree.symbolTable[comp[0]-'A'].value ~ v, assertions); 617 630 } 618 631 } … … 620 633 v ~= "[" ~ firstIndexExpr ~ "]"; 621 634 } 622 return v;635 return InvocationCode(v, assertions); 623 636 } 624 637 625 638 626 639 // Generate inline D code for the expression 627 char []DCodeGenerator(RevisedExpression tree)640 InvocationCode DCodeGenerator(RevisedExpression tree) 628 641 { 629 642 char [] result = "// D generate:" ~ tree.braceExpression ~ \n; 643 char [] assertions=""; 630 644 int wholerank = exprRank(tree.expression, tree.rank); 631 645 if (wholerank ==1) { … … 639 653 // restore all symbols into the expression 640 654 // If it's a vector, index it 641 if (tree.rank[c-'A']=='1') 642 result ~= getValueForSymbol(tree.mapping[c-'A'], tree, "blade_index"); 643 else result ~= getValueForSymbol(tree.mapping[c-'A'], tree); 655 if (tree.rank[c-'A']=='1') { 656 InvocationCode q = getValueForSymbol(tree.mapping[c-'A'], tree, "blade_index"); 657 result ~= q.invoker; 658 assertions ~= q.assertions; 659 } 660 else { 661 InvocationCode q = getValueForSymbol(tree.mapping[c-'A'], tree); 662 result ~= q.invoker; 663 assertions ~= q.assertions; 664 } 644 665 } else result ~= c; 645 666 } 646 if (wholerank==0) return result ~ ";";647 return result ~ "; }";648 } 649 667 if (wholerank==0) return InvocationCode(result, assertions); 668 return InvocationCode(result ~ "; }", assertions); 669 } 670 trunk/blade/CodegenX86.d
r169 r170 377 377 int numOnStack = 0; // How much of the FP stack is being used? 378 378 379 bool is DotProduct= (operations[0]=='0');379 bool isCumulative = (operations[0]=='0'); 380 380 if (operations[0]=='0') { 381 381 result ~= " fldz;"\n; // dot product … … 457 457 458 458 // Discard any scalars that are left on the stack 459 if (is DotProduct&& numScalarsOnStack>0) {459 if (isCumulative && numScalarsOnStack>0) { 460 460 // Preserve the result of the dot product 461 461 result ~= " fxch ST(" ~ itoa(numScalarsOnStack) ~ "), ST;"\n; … … 478 478 * At entry, all vector parameters are aligned. 479 479 */ 480 char [] generateCodeForSSE(Values...)( int SSEVer,char [] operations)480 char [] generateCodeForSSE(Values...)(char [] operations) 481 481 { 482 482 char [] ranklist; 483 bool usingDoubles=false; 483 484 foreach(T; Values) { 484 485 static if (is(typeof(T[0]))) ranklist~="1"; else ranklist~="0"; 485 } 486 return generateCodeForSSEImpl(SSEVer, ranklist, operations); 486 static if (is(T == double) || is(T == double *)) { 487 usingDoubles = true; 488 } //else assert(is(T==float)|| is(T==float*)); 489 } 490 return generateCodeForSSEImpl(usingDoubles, ranklist, operations); 487 491 // makePostfixForSSE(infixOperations, ranklist)); 488 492 } … … 495 499 private: 496 500 // split off from the template to make code coverage work 497 char [] generateCodeForSSEImpl( int SSEVer, char [] ranklist, char [] operations, char cumulatingOp=0)501 char [] generateCodeForSSEImpl(bool usingDoubles, char [] ranklist, char [] operations, char cumulatingOp=0) 498 502 { 499 503 char [] result=""; … … 503 507 int numvecs = countVectors(ranklist); 504 508 int numScalarsOnStack=0; 505 bool is DotProduct= (operations[0]=='0');506 if (is DotProduct) result ~= ((SSEVer == 2)? " double" : " float") ~" sum;"\n;509 bool isCumulative = (operations[0]=='0'); 510 if (isCumulative) result ~= (usingDoubles? " double" : " float") ~" sum;"\n; 507 511 508 512 result~= \n"asm {"\n ~ pushRegisters(numvecs); … … 511 515 // Load all the vector pointers into registers 512 516 513 char [] vectorsize = (SSEVer == 2)? "8" :"4"; // size of a double514 char [] suffix = (SSEVer == 2)? "d " :"s ";517 char [] vectorsize = usingDoubles? "8" :"4"; // size of a double 518 char [] suffix = usingDoubles? "d " :"s "; 515 519 516 520 int vecregnum = 0; … … 544 548 int done=0; 545 549 if (operations[0]=='0') { 546 result ~= " pxor " ~ XMM(numOnStack) ~ "," ~ XMM(numOnStack) ~ "; // 0\n";547 ++numOnStack;548 ++done;550 result ~= " pxor " ~ XMM(numOnStack) ~ "," ~ XMM(numOnStack) ~ "; // 0\n"; 551 ++numOnStack; 552 ++done; 549 553 } 550 554 result ~= " xor EAX, EAX; "\n … … 575 579 } else 576 580 if (operations[done-1]==operations[done]) { 577 // operation on self, eg XX+ --> don't need to load it again.578 int cumvector = (operations[done-1]=='0')? numScalarsOnStack : numOnStack-1;581 // operation on self, eg XX+ --> don't need to load it again. 582 int cumvector = (operations[done-1]=='0')? numScalarsOnStack : numOnStack-1; 579 583 mainbody ~= " " ~ opToSSE[operations[done+1]] ~ suffix ~ " " ~ XMM(numOnStack-1) ~ ", " 580 584 ~ XMM(numOnStack-1) ~ comment; … … 600 604 ~ " align 16;\n" 601 605 ~ "L1:\n" ~ mainbody; 602 if ( SSEVer == 2) {606 if (usingDoubles) { 603 607 result ~= " add EAX,2;\n js L1;\n" 604 608 ~ "L2:\n sub EAX, 2;\n jns L4;\n" … … 613 617 } 614 618 result ~= "L4:" \n; 615 if (is DotProduct) {616 // Result is now in XMM(numScalarsOnStack). We need to do a horizontal617 // add to get the final sum.618 if (SSEVer==2) {619 // For SSE3, use haddpd XMM(numScalarsOnStack).620 result ~= " movhlps " ~ XMM(numScalarsOnStack+1) ~ "," ~ XMM(numScalarsOnStack) ~ ";"\n621 ~ " addsd " ~ XMM(numScalarsOnStack) ~ "," ~ XMM(numScalarsOnStack+1) ~ ";\n"; 622 } else { // floats623 result ~= " movhlps " ~ XMM(numScalarsOnStack+1) ~ "," ~ XMM(numScalarsOnStack) ~ ";"\n624 ~ " addps " ~ XMM(numScalarsOnStack) ~ "," ~ XMM(numScalarsOnStack+1) ~ ";\n"625 ~ " pshufd " ~ XMM(numScalarsOnStack+1) ~ "," ~ XMM(numScalarsOnStack) ~ ",1;"\n626 ~ " addss " ~ XMM(numScalarsOnStack) ~ "," ~ XMM(numScalarsOnStack+1) ~ ";\n";627 }628 result ~= " movs" ~ suffix ~ " sum," ~ XMM(numScalarsOnStack) ~ ";"\n;629 //result ~= "// Move to ST(0)\n";619 if (isCumulative) { 620 // Result is now in XMM(numScalarsOnStack). We need to do a horizontal 621 // add to get the final sum. 622 if (usingDoubles) { 623 // For SSE3, use haddpd XMM(numScalarsOnStack). 624 result ~= " movhlps " ~ XMM(numScalarsOnStack+1) ~ "," ~ XMM(numScalarsOnStack) ~ ";"\n 625 ~ " addsd " ~ XMM(numScalarsOnStack) ~ "," ~ XMM(numScalarsOnStack+1) ~ ";\n"; 626 } else { // floats 627 result ~= " movhlps " ~ XMM(numScalarsOnStack+1) ~ "," ~ XMM(numScalarsOnStack) ~ ";"\n 628 ~ " addps " ~ XMM(numScalarsOnStack) ~ "," ~ XMM(numScalarsOnStack+1) ~ ";\n" 629 ~ " pshufd " ~ XMM(numScalarsOnStack+1) ~ "," ~ XMM(numScalarsOnStack) ~ ",1;"\n 630 ~ " addss " ~ XMM(numScalarsOnStack) ~ "," ~ XMM(numScalarsOnStack+1) ~ ";\n"; 631 } 632 result ~= " movs" ~ suffix ~ " sum," ~ XMM(numScalarsOnStack) ~ ";"\n; 633 //result ~= "// Move to ST(0)\n"; 630 634 } 631 635 result ~= popRegisters(numvecs) ~ "}\n"; 632 if (is DotProduct) result ~= " return sum;"\n;636 if (isCumulative) result ~= " return sum;"\n; 633 637 634 638 return result; trunk/blade/PostfixX86.d
r169 r170 7 7 * values being'0'=scalar, '1'=vector, '2'=matrix. 8 8 * This string is converted to postfix, applying simple X86-specific optimisations. 9 * 10 * The elements of the postfix string may be: 11 * ABC a variable or constant, to be pushed onto the stack 12 * 0 the literal zero (used to initialize a dot product, for example) 13 * *+-/ ST(0)*ST(1), ST(0)+ST(1), ST(0)-ST(1), ST(0)/ST(1) and pop stack 14 * _ ST(1)-ST(0) and pop stack 15 * = store stack top and pop stack 16 * 17 * NOT YET IMPLEMENTED: 18 * 1 the literal one (used to initialize a product, for example) 19 * sc ST(0) = sine(ST(0)) ST(0) = cos(ST(0)) 20 * q ST(0) = sqrt(ST(0)) 9 21 */ 10 22 … … 42 54 } 43 55 ReturnType onVisitFunction(This this_, char [] func, char [][] args) { 44 if (func=="dot") {45 return "0" ~ doVisit(this_,args[0]) ~ doVisit(this_, args[1]) ~ "*+";46 }56 if (func=="dot") { 57 return "0" ~ doVisit(this_,args[0]) ~ doVisit(this_, args[1]) ~ "*+"; 58 } 47 59 assert(0, "BLADE ICE: Unsupported"); 48 60 } … … 117 129 return sym; 118 130 } 119 ReturnType onVisitFunction(This this_, char [] func, char [][] args) { 120 if (func=="dot") {121 return "0" ~ doVisit(this_,args[0]) ~ doVisit(this_, args[1]) ~ "*+";122 }123 if (func=="sum") return "0" ~ doVisit(this_, args[0]) ~ "+";131 ReturnType onVisitFunction(This this_, char [] func, char [][] args) { 132 if (func=="dot") { 133 return "0" ~ doVisit(this_,args[0]) ~ doVisit(this_, args[1]) ~ "*+"; 134 } 135 if (func=="sum") return "0" ~ doVisit(this_, args[0]) ~ "+"; 124 136 assert(0, "BLADE ICE: Unsupported"); 125 137 }
