Changeset 168

Show
Ignore:
Timestamp:
01/04/08 07:56:21 (8 months ago)
Author:
Don Clugston
Message:

Added back-end code generation for dot product for X87 and SSE/SSE2.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/blade/Blade.d

    r167 r168  
    2323*  - A 'const folding' (actually vector/scalar folding) step is performed. 
    2424* 
     25* SPEED/ACCURACY TRADEOFF: 
     26* IEEE floating point multiplication and addition are not associative. 
     27*  Assuming that overflow and underflow do not occur: 
     28*  (a*b)*c may differ from a*(b*c) in the last bit. 
     29*  (a+b)+c may differ from a+(b+c) by a factor a million or more. 
     30* 
     31* - Multiplication is assumed to be associative. 
     32* - Addition and subtraction are not treated as associative. 
     33*    Except: Addition inside a dot product or vector sum is treated as associative. 
     34* 
     35* BUGS: 
     36*  Need to create asserts for nested expressions as well as for the primary one. 
     37* 
    2538* FUTURE DIRECTIONS (in order of expected implementation): 
    26 * - Dot product (which was present in BLADE 0.2). 
     39* - sum(), trace() 
     40* - Loop unrolling for cumulative operations dot, sum, trace. 
    2741* - Dense matrix support. 
    2842* - Triangular, banded, symmetric, and sparse matrix support 
     
    3953* 0.2 - Support for a wider variety of expressions. Dot product, imaginary numbers, etc. 
    4054* 0.3 - Based on string mixins. Most of the new features of 0.2 are gone, but SSE2 is added. 
    41 * 0.4 - Added D code generator. Nice error messages. Optimal parameter passing  
     55* 0.4 - Added D code generator. Nice error messages. Optimal parameter passing.  
    4256*       (passes pointers, not arrays). 
    43 * 0.5 - Expression simplification step. 
     57* 0.5 - Expression simplification step. Slicing support. 
     58* 0.6 - Dot product, nested expressions. 
    4459*/ 
    4560 
     
    5570private import blade.BladeVisitor: expressionContainsAssignment; 
    5671 
    57 private import blade.PostfixX86 : makePostfixForX87
     72private import blade.PostfixX86 : makePostfixForX87, makePostfixForSSE
    5873 
    5974public: 
     
    98113} 
    99114 
     115template SSERetType(int SSEVersion, char [] expr) { 
     116    static if (expr[0]!='0') alias void SSERetType; 
     117    else static if (SSEVersion==1) alias float SSERetType; 
     118    else alias double SSERetType; 
     119} 
     120template X87RetType(char [] expr) { 
     121    static if (expr[0]!='0') alias void X87RetType; 
     122    else alias real X87RetType; 
     123} 
     124 
    100125// These functions have the complete expression encoded in the template type. 
    101126// One of these functions is instantiated for each expression. 
     
    107132 * Every member of the Values tuple must only be double or double *. 
    108133 */ 
    109 void SSEVECGEN(int SSEVersion, char [] expr, Values...)(int veclength, Values values) { 
     134SSERetType!(SSEVersion, expr) SSEVECGEN(int SSEVersion, char [] expr, Values...)(int veclength, Values values) { 
    110135    debug(BladeBackEnd) { 
    111136       pragma(msg, generateCodeForSSE!(Values)(SSEVersion, expr)); 
     
    117142 * Every member of the Values tuple must only be real, float[], double [], or real[]. 
    118143 */ 
    119 void X87VECGEN(char [] expr, int numStrides, Values...)(int veclength, Values values) { 
     144X87RetType!(expr) X87VECGEN(char [] expr, int numStrides, Values...)(int veclength, Values values) { 
    120145    debug(BladeBackEnd) {     
    121146        pragma(msg, generateCodeForAsmX87!(numStrides, Values)(expr)); 
     
    307332char [] invokeSSE(bool SSE2, RevisedExpression tree) 
    308333{       
    309     char [] result = "SSEVECGEN!(" ~ (SSE2?"2":"1") ~ `,"` ~ enquote(tree.expression) ~ `"`; 
     334   char [] result = "SSEVECGEN!(" ~ (SSE2?"2":"1") ~ `,"` ~ enquote(makePostfixForSSE(tree.expression, tree.rank)) ~ `"`; 
    310335    // For SSE2, everything must be implicitly convertible to double. 
    311336    char [] vals; 
  • trunk/blade/BladeDemo.d

    r167 r168  
    1414// Use heap-allocated arrays, or static arrays (DMD 1.023 or later) 
    1515// cdouble[] always remains aligned, even when sliced. 
    16  
    17 float dot_product(float[] a, float[] b) 
    18 { 
    19     return 0; 
    20 } 
    2116   
    2217void main() 
     
    3025    q[0..$]= [17.0f, 28.25, 1, 0]; 
    3126    float [4] r; 
    32     idouble [] p = [2.3i, 254i, 0.1i, 1.2i]; 
     27    real [] p = [2.3, 254, 0.1, 1.2]; 
    3328    for(int i=0; i<r.length;++i) { 
    3429        r[i]= q[i]*2213.3L; 
     
    5752    mixin(vectorize("another[0..$,1]=6*a[0..2]")); 
    5853 
    59 // Parses, and simplifies to A*A, where A = dot(q,q). No asm codegen yet. 
     54// Simplifies to q*= 2*dot(q,q)*dot(q*q). 
     55   mixin(vectorize("q *=dot(q,q*dot(2*q,q))")); 
    6056   double u; 
    61 //   mixin(vectorize("u = dot(q,q*dot(q,q))")); 
    62 //   mixin(vectorize("q *=dot(q,q*dot(q+q,q))")); 
     57   mixin(vectorize("u = dot(q,q*dot(q,q))")); 
     58   mixin(vectorize("u = dot(a, q)")); 
    6359 
    6460    writefln("a=", a); 
  • trunk/blade/BladeSimplify.d

    r167 r168  
    455455            ScalarFold left = doVisit(this_,args[0]); 
    456456            ScalarFold right = doVisit(this_, args[1]); 
    457             return ScalarFold("", combineMul(combineMul(left.multiplier, right.multiplier), "{" ~ func ~ "(" ~ left.expr ~ "," ~ right.expr ~ ")}")); 
     457            return ScalarFold("", combineMul(combineMul(left.multiplier, right.multiplier), "{" ~ func ~ "(" ~ wrapInParens(left.expr) ~ "," ~ wrapInParens(right.expr) ~ ")}")); 
    458458        } else { 
    459459            assert(0, "BLADE: Unsupported function"); 
  • trunk/blade/CodegenX86.d

    r167 r168  
    297297        static if (is(typeof(T[0]))) { 
    298298            stridelist~="0"; 
    299             ranklist~="1";  
     299            ranklist~="1"; 
    300300            typelist ~= typeof(T[0]).stringof; 
    301301        } else static if (is(typeof(T.data))) {             
     
    314314private: 
    315315// This is split off from the template to make code coverage easier. 
    316 char [] generateCodeForAsmX87Impl(char [] ranklist, char [][] typelist, char [] stridelist, char [] operations, char cumulatingOp=0
     316char [] generateCodeForAsmX87Impl(char [] ranklist, char [][] typelist, char [] stridelist, char [] operations
    317317{ 
    318318    char [] result=""; 
     
    365365      } 
    366366    } 
    367     if (cumulatingOp=='+') { 
     367    int done=0; 
     368 
     369    // We need to keep track of how many things are on the FPU stack. 
     370    // Every time something is pushed, the indices of our variables change! 
     371    int numOnStack = 0; // How much of the FP stack is being used? 
     372 
     373    bool isDotProduct = (operations[0]=='0'); 
     374    if (operations[0]=='0') { 
    368375        result ~= "  fldz;"\n; // dot product 
    369     } else if (cumulatingOp=='*') { // trace 
    370         result ~= "fld1;"\n
     376        ++numOnStack; 
     377        done = 1
    371378    } 
    372379    result ~= "  xor EAX, EAX; "\n 
    373380        "  sub EAX, veclength; // counter=-length"\n 
    374381        "  jz short L3; // test for length==0"\n; 
    375     int done=0; 
    376382 
    377383    // Construct the main body of the loop (the main body does not include 
    378384    // the final storage instruction, because of the FST latency). 
    379385    char [] mainbody = ""; 
    380  
    381     // We need to keep track of how many things are on the FPU stack. 
    382     // Every time something is pushed, the indices of our variables change! 
    383     int numOnStack = 0; // How much of the FP stack is being used? 
    384      
     386             
    385387    while(done<operations.length) { 
    386388        char [] next; 
     
    443445        ~ "L1:\n" ~ mainbody; 
    444446         
    445     if (cumulatingOp) result ~= "  " ~ opToX87[cumulatingOp] ~ "p ST(2), ST;"\n; 
     447//    if (cumulatingOp) result ~= "  " ~ opToX87[cumulatingOp] ~ "p ST(2), ST;"\n; 
    446448 
    447449    result ~= incrementRealVectors // Update the counters 
     
    449451 
    450452    // Discard any scalars that are left on the stack 
    451     if (cumulatingOp!=0 && numScalarsOnStack>0) { 
     453    if (isDotProduct && numScalarsOnStack>0) { 
    452454        // Preserve the result of the dot product 
    453455        result ~= "  fxch ST(" ~ itoa(numScalarsOnStack) ~ "), ST;"\n; 
     
    470472 * At entry, all vector parameters are aligned. 
    471473 */ 
    472 char [] generateCodeForSSE(Values...)(int SSEVer, char [] infixOperations) 
     474char [] generateCodeForSSE(Values...)(int SSEVer, char [] operations) 
    473475{ 
    474476    char [] ranklist; 
     
    476478        static if (is(typeof(T[0]))) ranklist~="1"; else ranklist~="0"; 
    477479    } 
    478     return generateCodeForSSEImpl(SSEVer, ranklist, makePostfixForSSE(infixOperations, ranklist)); 
     480    return generateCodeForSSEImpl(SSEVer, ranklist, operations); 
     481//    makePostfixForSSE(infixOperations, ranklist)); 
    479482} 
    480483 
     
    489492    int numvecs = countVectors(ranklist); 
    490493    int numScalarsOnStack=0; 
     494    bool isDotProduct = (operations[0]=='0'); 
     495    if (isDotProduct) result ~= ((SSEVer == 2)? "  double" : "  float") ~" sum;"\n; 
    491496 
    492497    result~= \n"asm {"\n ~ pushRegisters(numvecs); 
     
    516521      } 
    517522    } 
    518     result ~= "  xor EAX, EAX; "\n 
    519         "  sub EAX, veclength; // counter=-length"\n 
    520         "  jz short L2; // test for length==0"\n; 
    521     int done=0; 
    522523 
    523524    char [] mainbody = ""; 
     
    530531     
    531532    int numOnStack = numScalarsOnStack; // How much of the FP stack is being used? 
     533    int done=0; 
     534    if (operations[0]=='0') { 
     535        result ~= "  pxor " ~ XMM(numOnStack) ~ "," ~ XMM(numOnStack) ~ ";  // 0\n"; 
     536        ++numOnStack; 
     537        ++done; 
     538    } 
     539    result ~= "  xor EAX, EAX; "\n 
     540        "  sub EAX, veclength; // counter=-length"\n 
     541        "  jz short L2; // test for length==0"\n; 
    532542    while(done<operations.length) { 
    533543      char [] comment; 
     
    552562                 mainbody ~= "  movap" ~ suffix ~ indexedSSEVector(ranklist, operations[$-2], vectorsize) ~ ", XMM" ~ itoa(numOnStack-1) ~ comment; 
    553563                 extra ~= "  movs" ~ suffix ~ indexedSSENext(ranklist, operations[$-2], vectorsize) ~ ", XMM" ~ itoa(numOnStack-1) ~ comment; 
     564            } else  
     565            if (operations[done-1]==operations[done]) { 
     566                // operation on self, eg XX+ --> don't need to load it again. 
     567                int cumvector = (operations[done-1]=='0')? numScalarsOnStack : numOnStack-1; 
     568                mainbody ~= "  " ~ opToSSE[operations[done+1]] ~ suffix ~ " " ~ XMM(numOnStack-1) ~ ", " 
     569                     ~ XMM(numOnStack-1) ~ comment; 
     570                extra ~= "  " ~ opToSSESingle[operations[done+1]] ~ suffix ~ " " ~ XMM(numOnStack-1) ~ ", " 
     571                    ~ XMM(numOnStack-1) ~ comment;             
    554572            } else { 
    555573                mainbody ~= "  " ~ opToSSE[operations[done+1]] ~ suffix ~ " " ~ XMM(numOnStack-1) ~ ", " 
     
    579597        result ~= "  add EAX,4;\n" ~ "  js L1;\n" 
    580598            ~ "L2:\n  sub EAX, 4;\n  jns L4;\n" 
    581         // Now the extra calculations for the 0-3 float, or 0-1 double
     599        // Now the extra calculations for the 0-3 float
    582600            ~ "L3:"\n ~ extra 
    583601            ~ "  add EAX,1;\n  js L3;\n"; 
    584602    } 
    585     result~= "L4:" \n ~ popRegisters(numvecs) ~ "}\n"; 
     603    result ~= "L4:" \n; 
     604    if (isDotProduct) { 
     605        // Result is now in XMM(numScalarsOnStack). We need to do a horizontal 
     606        // add to get the final sum. 
     607        if (SSEVer==2) { 
     608            // For SSE3, use   haddpd XMM(numScalarsOnStack). 
     609            result ~= "  movhlps " ~ XMM(numScalarsOnStack+1) ~ "," ~ XMM(numScalarsOnStack) ~ ";"\n 
     610            ~ "  addsd "  ~ XMM(numScalarsOnStack) ~ "," ~  XMM(numScalarsOnStack+1) ~ ";\n";            
     611        } else { // floats 
     612            result ~= "  movhlps " ~ XMM(numScalarsOnStack+1) ~ "," ~ XMM(numScalarsOnStack) ~ ";"\n 
     613            ~ "  addps "  ~ XMM(numScalarsOnStack) ~ "," ~  XMM(numScalarsOnStack+1) ~ ";\n" 
     614            ~ "  pshufd " ~ XMM(numScalarsOnStack+1) ~ "," ~ XMM(numScalarsOnStack) ~ ",1;"\n 
     615            ~ "  addss "  ~ XMM(numScalarsOnStack) ~ "," ~  XMM(numScalarsOnStack+1) ~ ";\n"; 
     616        } 
     617        result ~= "  movs" ~ suffix ~ " sum," ~ XMM(numScalarsOnStack) ~ ";"\n; 
     618        //result ~= "// Move to ST(0)\n"; 
     619    } 
     620    result ~= popRegisters(numvecs) ~ "}\n"; 
     621    if (isDotProduct) result ~= "  return sum;"\n; 
    586622    
    587623    return result; 
  • trunk/blade/PostfixX86.d

    r166 r168  
    4242    } 
    4343    ReturnType onVisitFunction(This this_, char [] func, char [][] args) { 
     44        if (func=="d") { 
     45            return "0" ~ doVisit(this_,args[0]) ~ doVisit(this_, args[1]) ~ "*+"; 
     46        } 
    4447        assert(0, "BLADE ICE: Unsupported"); 
    4548    } 
     
    114117        return sym; 
    115118    } 
    116     ReturnType onVisitFunction(This this_, char [] func, char [][] args) { 
     119    ReturnType onVisitFunction(This this_, char [] func, char [][] args) {       
     120        if (func=="d") { 
     121            return "0" ~ doVisit(this_,args[0]) ~ doVisit(this_, args[1]) ~ "*+"; 
     122        } 
    117123        assert(0, "BLADE ICE: Unsupported"); 
    118124    }