Changeset 170

Show
Ignore:
Timestamp:
01/07/08 09:21:45 (9 months ago)
Author:
Don Clugston
Message:

Now generates asserts for all nested sub-expressions.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/blade/Blade.d

    r169 r170  
    3737* 
    3838* FUTURE DIRECTIONS (in order of expected implementation): 
    39 * - sum(), trace() 
     39* - trace() 
    4040* - Loop unrolling for cumulative operations dot, sum, trace. 
    4141* - Dense matrix support. 
     
    9292    if (revised.errorMessage.length>0)  return `static assert(0, "BLADE: ` ~ enquote(revised.errorMessage) ~ `");`; 
    9393    VecExpressionType exprType = categorizeExpression(revised); 
    94     char [] result = generateAsserts(revised, (exprType == VecExpressionType.SSE2Expression || exprType == VecExpressionType.SSE1Expression)); 
    95     debug(BladeFrontEnd) { 
    96         result ~= "// Simplified: " ~ revised.expression ~ \n; 
    97     }     
     94    InvocationCode q;     
    9895    if (exprType == VecExpressionType.SSE2Expression || exprType == VecExpressionType.SSE1Expression) { 
    99         return result ~ invokeSSE((exprType == VecExpressionType.SSE2Expression), revised)~ ";"
     96        q = invokeSSE((exprType == VecExpressionType.SSE2Expression), revised)
    10097    } else if (exprType == VecExpressionType.X87Expression) { 
    101         return result ~ invokeX87(revised) ~ ";"
     98        q = invokeX87(revised)
    10299    } else { 
    103         return result ~ DCodeGenerator(revised); 
    104     }     
     100        q = DCodeGenerator(revised); 
     101    } 
     102    return q.assertions ~ q.invoker ~ ";"; 
    105103} 
    106104 
     
    113111} 
    114112 
    115 template SSERetType(int SSEVersion, char [] expr) { 
    116     static if (expr[0]!='0') alias void SSERetType; 
    117     else static if (SSEVersion==1) alias float SSERetType; 
    118     else alias double SSERetType; 
    119 } 
    120113template X87RetType(char [] expr) { 
    121    static if (expr[0]!='0') alias void X87RetType; 
    122    else alias real X87RetType; 
     114    static if (expr[0]!='0') alias void X87RetType; 
     115    else alias real X87RetType; 
    123116} 
    124117 
     
    132125 * Every member of the Values tuple must only be double or double *. 
    133126 */ 
    134 SSERetType!(SSEVersion, expr) SSEVECGEN(int SSEVersion, char [] expr, Values...)(int veclength, Values values) { 
     127RetType SSEVECGEN(RetType, char [] expr, Values...)(int veclength, Values values) { 
    135128    debug(BladeBackEnd) { 
    136        pragma(msg, generateCodeForSSE!(Values)(SSEVersion, expr)); 
    137     } 
    138     mixin(generateCodeForSSE!(Values)(SSEVersion, expr)); 
     129       pragma(msg, generateCodeForSSE!(Values)(expr)); 
     130    } 
     131    mixin(generateCodeForSSE!(Values)(expr)); 
    139132} 
    140133 
    141134/** Function to implement BLAS1 operations using X87 assembler. 
    142  * Every member of the Values tuple must only be real, float[], double [], or real[]. 
     135 * Every member of the Values tuple must only be real,  
     136 * float[], double [], or real[], or BladeStrided!(float), !(double), !(real) 
    143137 */ 
    144138X87RetType!(expr) X87VECGEN(char [] expr, int numStrides, Values...)(int veclength, Values values) { 
     
    247241// of the parameters. 
    248242 
    249 /// Generate code which will call the X87 function 
    250 char [] invokeX87(RevisedExpression tree) 
    251 
     243struct InvocationCode { 
     244    char [] invoker; // For mixin: code to invoke the functions. 
     245    char [] assertions; // For mixin: code to assert that everything is correct 
     246
     247 
     248/// Generate code which will call the X87 function. 
     249InvocationCode invokeX87(RevisedExpression tree) 
     250
     251    char [] assertions = assertAllVectorLengthsEqual(tree); 
    252252    char [] result = ""; 
    253253    char [] stridelist=""; 
     
    261261        char rnk = tree.rank[i]; 
    262262        vals ~= ","; 
    263         char [] v = getValueForSymbol(tree.mapping[i], tree); 
     263        InvocationCode q = getValueForSymbol(tree.mapping[i], tree); 
     264        char [] v = q.invoker; 
     265        assertions ~= q.assertions; 
    264266        int x = tree.mapping[i]-'A'; 
    265267        char [] t; 
     
    274276            if (rnk=='0') { 
    275277                t = "real"; // convert all compounds to real. 
    276                 // TODO: if the number is exactly representable as a double 
     278                // TODO (tricky): if the number is exactly representable as a double 
    277279                // or float, it could use less FPU stack space. 
    278280            } else { // for arrays, the type is the type of the original array 
     
    318320    result ~= ")("; 
    319321    int firstVector = findVectorForLength(tree); 
    320     return result ~ getValueForSymbol(tree.mapping[firstVector], tree) ~ ".length"  
    321         ~ vals ~ stridelist  ~ ")"
     322    return InvocationCode(result ~ getValueForSymbol(tree.mapping[firstVector], tree).invoker ~ ".length"  
     323        ~ vals ~ stridelist  ~ ")", assertions)
    322324} 
    323325 
     
    330332 
    331333/// Generate code which will call the SSE/SSE2 code generation function 
    332 char [] invokeSSE(bool SSE2, RevisedExpression tree) 
    333 {       
    334     char [] result = "SSEVECGEN!(" ~ (SSE2?"2":"1") ~ `,"` ~ enquote(makePostfixForSSE(tree.expression, tree.rank)) ~ `"`; 
     334InvocationCode invokeSSE(bool SSE2, RevisedExpression tree) 
     335
     336    char [] assertions = assertAllVectorLengthsEqual(tree) 
     337    ~ assertAllVectorsAlign128(tree); 
     338 
     339    char [] postfix = makePostfixForSSE(tree.expression, tree.rank); 
     340    char [] retType = "void"; 
     341    if (postfix[0]=='0') retType = (SSE2? "double" : "float"); 
     342 
     343    char [] result = "SSEVECGEN!(" ~ retType ~ `,"` ~ enquote(postfix) ~ `"`; 
    335344    // For SSE2, everything must be implicitly convertible to double. 
    336345    char [] vals; 
     
    341350        vals ~= ","; 
    342351        if (rnk=='1') vals ~= "&"; 
    343         vals ~= getValueForSymbol(tree.mapping[i], tree); 
     352        InvocationCode q = getValueForSymbol(tree.mapping[i], tree); 
     353        vals ~= q.invoker; 
     354        assertions ~= q.assertions; 
    344355        // for vectors, we only need the pointer, not the length 
    345 //        if (rnk=='1') vals ~= ".ptr"; 
    346356        if (rnk=='1') vals ~= "[0]"; 
    347357    } 
     
    349359    result ~= ")("; 
    350360    int firstVector = findVectorForLength(tree); 
    351     result ~= getValueForSymbol(tree.mapping[firstVector], tree) ~ ".length"; 
     361    result ~= getValueForSymbol(tree.mapping[firstVector], tree).invoker ~ ".length"; 
    352362//    result ~= tree.symbolTable[firstVector].value ~ ".length"; 
    353363    result ~= vals; 
    354364 
    355     return result ~ ")"
     365    return InvocationCode(result ~ ")", assertions)
    356366} 
    357367 
     
    387397    for (int i=0; i<tree.mapping.length;++i) { 
    388398        if (tree.rank[i]=='1'){ 
    389             result ~= "assert( (cast(size_t)(&" ~ getValueForSymbol(tree.mapping[i], tree) 
    390                     ~ "[0])& 0x0F) == 0, `SSE Vector misalignment: " ~ getValueForSymbol(tree.mapping[i], tree) ~ "`);"\n; 
     399            result ~= "assert( (cast(size_t)(&" ~ getValueForSymbol(tree.mapping[i], tree).invoker 
     400                    ~ "[0])& 0x0F) == 0, `SSE Vector misalignment: " ~ getValueForSymbol(tree.mapping[i], tree).invoker ~ "`);"\n; 
    391401        } 
    392402    } 
     
    449459        if (comp[$-1]!=']') { // simple compound expression 
    450460            foreach(d; comp) { 
    451                if (d=='{') assert(0, "BLADE ICE"); 
     461                if (d=='{') assert(0, "BLADE ICE"); 
    452462                if (d>='A' && d<='Z') v ~= tree.symbolTable[d-'A'].value; 
    453463                else v ~= d; 
     
    515525    // TODO: This only works for packed types. Doesn't work for jagged arrays, and 
    516526    // is probably very inefficient for user-defined types. 
    517     return "(&" ~ getValueForSymbol(c, tree, "1") ~ "-&" 
    518         ~ getValueForSymbol(c, tree, "0") ~ ")*" ~ tree.symbolTable[comp[0]-'A'].element ~ ".sizeof"; 
    519 } 
    520  
    521  
    522 char [] invokeNestedExpression(char [] expr, Symbol[] symbolTable) 
    523 { 
    524    char [] ranks; 
     527    return "(&" ~ getValueForSymbol(c, tree, "1").invoker ~ "-&" 
     528        ~ getValueForSymbol(c, tree, "0").invoker ~ ")*" ~ tree.symbolTable[comp[0]-'A'].element ~ ".sizeof"; 
     529} 
     530 
     531 
     532InvocationCode invokeNestedExpression(char [] expr, Symbol[] symbolTable) 
     533{ 
     534    char [] ranks; 
    525535    for (int i=0; i<symbolTable.length; ++i) { 
    526536        ranks ~= symbolTable[i].rank; 
    527    
    528    RevisedExpression revised = remapCompounds(expr, ranks, symbolTable); 
    529      
     537   
     538    RevisedExpression revised = remapCompounds(expr, ranks, symbolTable); 
     539     
    530540    VecExpressionType exprType = categorizeExpression(revised); 
    531541    if (exprType == VecExpressionType.SSE2Expression || exprType == VecExpressionType.SSE1Expression) { 
     
    534544        return invokeX87(revised); 
    535545    } else { 
    536        assert(0, "BLADE ICE: Nested D expressions are not yet supported"); 
     546        assert(0, "BLADE ICE: Nested D expressions are not yet supported"); 
    537547        return DCodeGenerator(revised); 
    538548    } 
    539549} 
    540550 
    541 char [] getValueForSymbol(char c, RevisedExpression tree, char [] firstIndexExpr="") 
     551InvocationCode getValueForSymbol(char c, RevisedExpression tree, char [] firstIndexExpr="") 
    542552{ 
    543553    int numSlicesRemaining=1; 
    544554    if (firstIndexExpr=="") numSlicesRemaining=0; 
    545555    char [] v = ""; 
     556    char [] assertions = ""; 
    546557    // is it an original symbol? 
    547558    if (c-'A'<tree.symbolTable.length) { 
    548559        v = tree.symbolTable[c-'A'].value; 
    549     } else {  // else it's a compound or an indexed array       
     560    } else {  // else it's a compound or an indexed array        
    550561        char [] comp = tree.compounds[c-'A'-tree.symbolTable.length]; 
    551562         
    552563        if (comp[$-1]!=']') { // compound expression (not an indexed array) 
    553            if (comp[0]>='a' && comp[0]<='z') { 
    554                // dot product is a nested expression 
    555                return invokeNestedExpression(comp, tree.symbolTable); 
    556            } 
     564            if (comp[0]>='a' && comp[0]<='z') { 
     565                // dot product is a nested expression 
     566                return invokeNestedExpression(comp, tree.symbolTable); 
     567            } 
    557568            for (int k=0; k<comp.length; ++k) { 
    558                 char d = comp[k]; 
    559                  if (d=='{') { 
    560                      int braceStart = k+1; 
    561                      for (; comp[k]!='}'; ++k) {}                     
    562                      v ~= invokeNestedExpression(comp[braceStart..k], tree.symbolTable); 
    563                  } else if (d>='A' && d<='Z') { 
    564                      v ~= tree.symbolTable[d-'A'].value; 
     569                char d = comp[k]; 
     570                 if (d=='{') { 
     571                     int braceStart = k+1; 
     572                     for (; comp[k]!='}'; ++k) {} 
     573                     InvocationCode q = invokeNestedExpression(comp[braceStart..k], tree.symbolTable); 
     574                     v ~= q.invoker; 
     575                     assertions ~= q.assertions; 
     576                 } else if (d>='A' && d<='Z') { 
     577                     v ~= tree.symbolTable[d-'A'].value; 
    565578                 } else v ~= d; 
    566579            } 
     
    614627                v ~= "[" ~ firstIndexExpr ~ "]"; 
    615628            } 
    616             return tree.symbolTable[comp[0]-'A'].value ~ v
     629            return InvocationCode(tree.symbolTable[comp[0]-'A'].value ~ v, assertions)
    617630        } 
    618631    } 
     
    620633        v ~= "[" ~ firstIndexExpr ~ "]"; 
    621634    } 
    622     return v
     635    return InvocationCode(v, assertions)
    623636} 
    624637 
    625638 
    626639// Generate inline D code for the expression 
    627 char [] DCodeGenerator(RevisedExpression tree) 
     640InvocationCode DCodeGenerator(RevisedExpression tree) 
    628641{ 
    629642    char [] result = "// D generate:" ~ tree.braceExpression ~ \n; 
     643    char [] assertions=""; 
    630644    int wholerank = exprRank(tree.expression, tree.rank); 
    631645    if (wholerank ==1) {   
     
    639653            // restore all symbols into the expression 
    640654            // If it's a vector, index it 
    641             if (tree.rank[c-'A']=='1') 
    642                 result ~= getValueForSymbol(tree.mapping[c-'A'], tree, "blade_index"); 
    643             else result ~= getValueForSymbol(tree.mapping[c-'A'], tree); 
     655            if (tree.rank[c-'A']=='1') { 
     656                InvocationCode q = getValueForSymbol(tree.mapping[c-'A'], tree, "blade_index"); 
     657                result ~= q.invoker; 
     658                assertions ~= q.assertions; 
     659            } 
     660            else { 
     661                InvocationCode q = getValueForSymbol(tree.mapping[c-'A'], tree); 
     662                result ~= q.invoker; 
     663                assertions ~= q.assertions; 
     664            } 
    644665        } else result ~= c; 
    645666    } 
    646     if (wholerank==0) return result ~ ";";    
    647     return result ~ "; }"
    648 } 
    649  
     667    if (wholerank==0) return InvocationCode(result, assertions);    
     668    return InvocationCode(result ~ "; }", assertions)
     669} 
     670 
  • trunk/blade/CodegenX86.d

    r169 r170  
    377377    int numOnStack = 0; // How much of the FP stack is being used? 
    378378 
    379     bool isDotProduct = (operations[0]=='0'); 
     379    bool isCumulative = (operations[0]=='0'); 
    380380    if (operations[0]=='0') { 
    381381        result ~= "  fldz;"\n; // dot product 
     
    457457 
    458458    // Discard any scalars that are left on the stack 
    459     if (isDotProduct && numScalarsOnStack>0) { 
     459    if (isCumulative && numScalarsOnStack>0) { 
    460460        // Preserve the result of the dot product 
    461461        result ~= "  fxch ST(" ~ itoa(numScalarsOnStack) ~ "), ST;"\n; 
     
    478478 * At entry, all vector parameters are aligned. 
    479479 */ 
    480 char [] generateCodeForSSE(Values...)(int SSEVer, char [] operations) 
     480char [] generateCodeForSSE(Values...)(char [] operations) 
    481481{ 
    482482    char [] ranklist; 
     483    bool usingDoubles=false; 
    483484    foreach(T; Values) { 
    484485        static if (is(typeof(T[0]))) ranklist~="1"; else ranklist~="0"; 
    485     } 
    486     return generateCodeForSSEImpl(SSEVer, ranklist, operations); 
     486        static if (is(T == double) || is(T == double *)) { 
     487            usingDoubles = true; 
     488        } //else assert(is(T==float)|| is(T==float*)); 
     489    } 
     490    return generateCodeForSSEImpl(usingDoubles, ranklist, operations); 
    487491//    makePostfixForSSE(infixOperations, ranklist)); 
    488492} 
     
    495499private: 
    496500// split off from the template to make code coverage work 
    497 char [] generateCodeForSSEImpl(int SSEVer, char [] ranklist, char [] operations, char cumulatingOp=0) 
     501char [] generateCodeForSSEImpl(bool usingDoubles, char [] ranklist, char [] operations, char cumulatingOp=0) 
    498502{ 
    499503    char [] result=""; 
     
    503507    int numvecs = countVectors(ranklist); 
    504508    int numScalarsOnStack=0; 
    505     bool isDotProduct = (operations[0]=='0'); 
    506     if (isDotProduct) result ~= ((SSEVer == 2)? "  double" : "  float") ~" sum;"\n; 
     509    bool isCumulative = (operations[0]=='0'); 
     510    if (isCumulative) result ~= (usingDoubles? "  double" : "  float") ~" sum;"\n; 
    507511 
    508512    result~= \n"asm {"\n ~ pushRegisters(numvecs); 
     
    511515    // Load all the vector pointers into registers 
    512516 
    513     char [] vectorsize = (SSEVer == 2) ? "8" :"4"; // size of a double 
    514     char [] suffix = (SSEVer == 2) ? "d " :"s "; 
     517    char [] vectorsize = usingDoubles? "8" :"4"; // size of a double 
     518    char [] suffix = usingDoubles? "d " :"s "; 
    515519     
    516520    int vecregnum = 0; 
     
    544548    int done=0; 
    545549    if (operations[0]=='0') { 
    546        result ~= "  pxor " ~ XMM(numOnStack) ~ "," ~ XMM(numOnStack) ~ ";  // 0\n"; 
    547        ++numOnStack; 
    548        ++done; 
     550        result ~= "  pxor " ~ XMM(numOnStack) ~ "," ~ XMM(numOnStack) ~ ";  // 0\n"; 
     551        ++numOnStack; 
     552        ++done; 
    549553    } 
    550554    result ~= "  xor EAX, EAX; "\n 
     
    575579            } else  
    576580            if (operations[done-1]==operations[done]) { 
    577                // operation on self, eg XX+ --> don't need to load it again. 
    578                int cumvector = (operations[done-1]=='0')? numScalarsOnStack : numOnStack-1; 
     581                // operation on self, eg XX+ --> don't need to load it again. 
     582                int cumvector = (operations[done-1]=='0')? numScalarsOnStack : numOnStack-1; 
    579583                mainbody ~= "  " ~ opToSSE[operations[done+1]] ~ suffix ~ " " ~ XMM(numOnStack-1) ~ ", " 
    580584                     ~ XMM(numOnStack-1) ~ comment; 
     
    600604        ~ "  align 16;\n"  
    601605        ~ "L1:\n" ~ mainbody; 
    602     if (SSEVer == 2) { 
     606    if (usingDoubles) { 
    603607        result ~= "  add EAX,2;\n  js L1;\n" 
    604608             ~ "L2:\n  sub EAX, 2;\n  jns L4;\n" 
     
    613617    } 
    614618    result ~= "L4:" \n; 
    615     if (isDotProduct) { 
    616        // Result is now in XMM(numScalarsOnStack). We need to do a horizontal 
    617        // add to get the final sum. 
    618        if (SSEVer==2) { 
    619            // For SSE3, use   haddpd XMM(numScalarsOnStack). 
    620            result ~= "  movhlps " ~ XMM(numScalarsOnStack+1) ~ "," ~ XMM(numScalarsOnStack) ~ ";"\n 
    621             ~ "  addsd "  ~ XMM(numScalarsOnStack) ~ "," ~  XMM(numScalarsOnStack+1) ~ ";\n";            
    622        } else { // floats 
    623            result ~= "  movhlps " ~ XMM(numScalarsOnStack+1) ~ "," ~ XMM(numScalarsOnStack) ~ ";"\n 
    624            ~ "  addps "  ~ XMM(numScalarsOnStack) ~ "," ~  XMM(numScalarsOnStack+1) ~ ";\n" 
    625            ~ "  pshufd " ~ XMM(numScalarsOnStack+1) ~ "," ~ XMM(numScalarsOnStack) ~ ",1;"\n 
    626            ~ "  addss "  ~ XMM(numScalarsOnStack) ~ "," ~  XMM(numScalarsOnStack+1) ~ ";\n"; 
    627        
    628        result ~= "  movs" ~ suffix ~ " sum," ~ XMM(numScalarsOnStack) ~ ";"\n; 
    629        //result ~= "// Move to ST(0)\n"; 
     619    if (isCumulative) { 
     620        // Result is now in XMM(numScalarsOnStack). We need to do a horizontal 
     621        // add to get the final sum. 
     622        if (usingDoubles) { 
     623            // For SSE3, use   haddpd XMM(numScalarsOnStack). 
     624            result ~= "  movhlps " ~ XMM(numScalarsOnStack+1) ~ "," ~ XMM(numScalarsOnStack) ~ ";"\n 
     625            ~ "  addsd "  ~ XMM(numScalarsOnStack) ~ "," ~  XMM(numScalarsOnStack+1) ~ ";\n";            
     626        } else { // floats 
     627            result ~= "  movhlps " ~ XMM(numScalarsOnStack+1) ~ "," ~ XMM(numScalarsOnStack) ~ ";"\n 
     628            ~ "  addps "  ~ XMM(numScalarsOnStack) ~ "," ~  XMM(numScalarsOnStack+1) ~ ";\n" 
     629            ~ "  pshufd " ~ XMM(numScalarsOnStack+1) ~ "," ~ XMM(numScalarsOnStack) ~ ",1;"\n 
     630            ~ "  addss "  ~ XMM(numScalarsOnStack) ~ "," ~  XMM(numScalarsOnStack+1) ~ ";\n"; 
     631       
     632        result ~= "  movs" ~ suffix ~ " sum," ~ XMM(numScalarsOnStack) ~ ";"\n; 
     633        //result ~= "// Move to ST(0)\n"; 
    630634    } 
    631635    result ~= popRegisters(numvecs) ~ "}\n"; 
    632     if (isDotProduct) result ~= "  return sum;"\n; 
     636    if (isCumulative) result ~= "  return sum;"\n; 
    633637    
    634638    return result; 
  • trunk/blade/PostfixX86.d

    r169 r170  
    77* values being'0'=scalar, '1'=vector, '2'=matrix. 
    88* This string is converted to postfix, applying simple X86-specific optimisations. 
     9* 
     10* The elements of the postfix string may be: 
     11*  ABC          a variable or constant, to be pushed onto the stack 
     12*  0            the literal zero (used to initialize a dot product, for example) 
     13*  *+-/         ST(0)*ST(1), ST(0)+ST(1), ST(0)-ST(1), ST(0)/ST(1) and pop stack 
     14*  _            ST(1)-ST(0) and pop stack 
     15*  =            store stack top and pop stack 
     16* 
     17*  NOT YET IMPLEMENTED: 
     18*  1            the literal one (used to initialize a product, for example) 
     19*  sc           ST(0) = sine(ST(0)) ST(0) = cos(ST(0)) 
     20*  q            ST(0) = sqrt(ST(0)) 
    921*/ 
    1022 
     
    4254    } 
    4355    ReturnType onVisitFunction(This this_, char [] func, char [][] args) { 
    44        if (func=="dot") { 
    45            return "0" ~ doVisit(this_,args[0]) ~ doVisit(this_, args[1]) ~ "*+"; 
    46        } 
     56        if (func=="dot") { 
     57            return "0" ~ doVisit(this_,args[0]) ~ doVisit(this_, args[1]) ~ "*+"; 
     58        } 
    4759        assert(0, "BLADE ICE: Unsupported"); 
    4860    } 
     
    117129        return sym; 
    118130    } 
    119     ReturnType onVisitFunction(This this_, char [] func, char [][] args) {      
    120        if (func=="dot") { 
    121            return "0" ~ doVisit(this_,args[0]) ~ doVisit(this_, args[1]) ~ "*+"; 
    122        } 
    123        if (func=="sum") return "0" ~ doVisit(this_, args[0]) ~ "+"; 
     131    ReturnType onVisitFunction(This this_, char [] func, char [][] args) {       
     132        if (func=="dot") { 
     133            return "0" ~ doVisit(this_,args[0]) ~ doVisit(this_, args[1]) ~ "*+"; 
     134        } 
     135        if (func=="sum") return "0" ~ doVisit(this_, args[0]) ~ "+"; 
    124136        assert(0, "BLADE ICE: Unsupported"); 
    125137    }