Changeset 187

Show
Ignore:
Timestamp:
04/30/08 16:05:32 (4 months ago)
Author:
Don Clugston
Message:

Added prod(). Use .ptr to get raw data, so it works with Bill Baxter's ArrayView?.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/blade/Blade.d

    r176 r187  
    11//  Written in the D programming language 1.0 
    22/** 
    3 * BLADE 0.4Alpha -- Basic Linear Algebra D Expressions 
     3* BLADE 0.6Alpha -- Basic Linear Algebra D Expressions 
    44* 
    55* Generate near-optimal x87/SSE2 asm code for BLAS1 basic vector operations at compile time. 
     
    1313* 
    1414* FEATURES: 
    15 *  - Supports any mix of vector addition, subtraction, dot product, unary minus, 
    16 *    multiplication by a scalar, sum(), abs(), and multidimensional slicing. 
     15*  - Supports any mix of vector addition, subtraction, unary minus, 
     16*    multiplication by a scalar, 
     17*    cumulation via dot product, sum() and prod(), and multidimensional slicing. 
    1718*  - Generates either x87 asm code, SSE or SSE2 asm code or pure D, depending on 
    1819*    the complexity of the expression, and the availability of inline asm. 
     
    2425* 
    2526* SPEED/ACCURACY TRADEOFF: 
    26 * IEEE floating point multiplication and addition are not associative. 
     27* Tradeoff arises because IEEE floating point multiplication and addition are not associative. 
    2728*  Assuming that overflow and underflow do not occur: 
    2829*  (a*b)*c may differ from a*(b*c) in the last bit. 
     
    3435* 
    3536* FUTURE DIRECTIONS (in order of expected implementation): 
    36 * - trace() 
    37 * - Loop unrolling for cumulative operations dot, sum, trace. 
     37* - nested D expressions 
     38* - cumulative operations min, max 
     39* - Loop unrolling for cumulative operations dot, sum, prod. 
    3840* - Dense matrix support. 
    3941* - Triangular, banded, symmetric, and sparse matrix support 
     
    4648* which accepts the tuple. 
    4749* 
     50* COMPILER BUGS/LIMITATIONS AFFECTING THIS LIBRARY 
     51* - Local arrays are not aligned to a 128-bit boundary, so use of aligned SSE is not 
     52*   always possible. 
     53* - Bugzilla #1125 -- structs in a tuple can't be used in asm. 
     54* - Bugzilla #1382 -- CTFE strings never get deleted --> SLOOOOOW compilation. KILLER BUG. 
     55* - Bugzilla #1768 -- in CTFE, arrays of arrays aren't initialized properly 
     56* 
    4857* HISTORY: 
    4958* 0.1 - Used classes to make expression templates. 
    5059* 0.2 - Support for a wider variety of expressions. Dot product, imaginary numbers, etc. 
    5160* 0.3 - Based on string mixins. Most of the new features of 0.2 are gone, but SSE2 is added. 
    52 * 0.4 - Added D code generator. Nice error messages. Optimal parameter passing.  
     61* 0.4 - Added D code generator. Nice error messages. Optimal parameter passing. 
    5362*       (passes pointers, not arrays). 
    5463* 0.5 - Expression simplification step. Slicing support. 
    55 * 0.6 - Dot product, nested expressions, intrinsics: abs, sqrt, sum. 
     64* 0.6 - Dot product, nested expressions (asm only), intrinsics: abs, sqrt, sum. 
     65* 0.7 - Intrinsics: prod 
    5666*/ 
    5767 
     
    7383// FOR MIXIN: Generate code to evaluate the given vector expression. 
    7484char [] vectorize(char [] expr) 
    75 {     
     85{ 
    7686    debug (BladeFrontEnd) { 
    7787    return `pragma(msg, \n ~ "// " __FILE__ ~ "(" ~__LINE__.stringof[0..$-1] ~ ") ` ~ enquote(expr) ~ `" ~ \n ~ ` ~ mixin_tupleAndSyntaxtreeof("makeVectorCode", expr) ~ "~\\n);" 
     
    8292} 
    8393 
    84 // Simplify the expression, categorise it,  
     94// Simplify the expression, categorise it, 
    8595// and dispatch to the appropriate code generator. 
    8696char [] makeVectorCode(Types...)(AbstractSyntaxTree tree) 
     
    8999    if (revised.errorMessage.length>0)  return `static assert(0, "BLADE: ` ~ enquote(revised.errorMessage) ~ `");`; 
    90100    VecExpressionType exprType = categorizeExpression(revised); 
    91     InvocationCode q;     
     101    InvocationCode q; 
    92102    if (exprType == VecExpressionType.SSE2Expression || exprType == VecExpressionType.SSE1Expression) { 
    93103        q = invokeSSE((exprType == VecExpressionType.SSE2Expression), revised); 
     
    105115{ 
    106116    // TODO: 
    107     return expression;     
     117    return expression; 
    108118} 
    109119 
    110120template X87RetType(char [] expr) { 
    111     static if (expr[0]!='0') alias void X87RetType; 
     121    static if (expr[0]!='0' && expr[0]!='1') alias void X87RetType; 
    112122    else alias real X87RetType; 
    113123} 
     
    130140 
    131141/** Function to implement BLAS1 operations using X87 assembler. 
    132  * Every member of the Values tuple must only be real,  
     142 * Every member of the Values tuple must only be real, 
    133143 * float[], double [], or real[], or BladeStrided!(float), !(double), !(real) 
    134144 */ 
    135145X87RetType!(expr) X87VECGEN(char [] expr, int numStrides, Values...)(int veclength, Values values) { 
    136     debug(BladeBackEnd) {     
     146    debug(BladeBackEnd) { 
    137147        pragma(msg, generateCodeForAsmX87!(numStrides, Values)(expr)); 
    138148    } 
     
    146156static ulong[2] SSE_SIGNBITpd = [0x8000_0000_0000_0000L, 0x8000_0000_0000_0000L]; 
    147157static uint[4] SSE_SIGNBITps = [0x8000_0000,0x8000_0000,0x8000_0000, 0x8000_0000]; 
     158// The value 1.0 for a parallel SSE register 
     159static ulong[2] SSE_ONEpd = [0x3FF0_0000_0000_0000L, 0x3FF0_0000_0000_0000L]; 
     160static uint[4] SSE_ONEps = [0x3F0_000, 0x3F0_000, 0x3F0_000, 0x3F0_000]; 
    148161 
    149162private: 
     
    156169// SSE1 is possible only if all vectors are floats. 
    157170// X87 is possible for any mix of real, double, and float vectors. 
    158 // BUG: for X87, should also check number of temporaries (don't overflow the FPU stack 
     171// BUG: for X87, should also check number of temporaries (don't overflow the FPU stack) 
    159172enum VecExpressionType { SSE1Expression, SSE2Expression, X87Expression, DExpression }; 
    160173 
     
    164177    bool SSE1 = true; 
    165178    bool X87 = true; 
    166     bool strided = false; // true if any strided vector or matrix operations exist     
     179    bool strided = false; // true if any strided vector or matrix operations exist 
    167180version (D_InlineAsm_X86) {} else { 
    168181    // Without an assembler, there's no chance! 
     
    199212            y = tree.compounds[x-tree.symbolTable.length][0]-'A'; 
    200213            // Check for a stride.. 
    201             if (tree.compounds[x-tree.symbolTable.length][$-1]==']') {                
     214            if (tree.compounds[x-tree.symbolTable.length][$-1]==']') { 
    202215                strided |= isStrided(tree.compounds[x-tree.symbolTable.length]); 
    203216            } 
    204217        } 
    205          
     218 
    206219        char [] t = tree.symbolTable[y].element; 
    207220        if (t == "double") { 
     
    218231        } 
    219232    } 
    220     // It's not worth doing strided operations with SSE.    
     233    // It's not worth doing strided operations with SSE. 
    221234    if (strided) { SSE1=false; SSE2=false; } 
    222235    if (numRealScalars > MAX_87_REALSCALARSPLUSTEMPORARIES) X87 = false; 
     
    224237    if (numvectors > MAX_SSE_VECTORS) { SSE1=false; SSE2=false; } 
    225238    if (SSE1) return VecExpressionType.SSE1Expression; 
    226     if (SSE2) return VecExpressionType.SSE2Expression;     
    227     return X87 ? VecExpressionType.X87Expression : VecExpressionType.DExpression;  
     239    if (SSE2) return VecExpressionType.SSE2Expression; 
     240    return X87 ? VecExpressionType.X87Expression : VecExpressionType.DExpression; 
    228241} 
    229242 
     
    256269    char [] stridelist=""; 
    257270    char [] alltypes=""; 
    258      
     271 
    259272    char [][] typelist; 
    260      
     273 
    261274    char [] vals; 
    262275    int numstrides=0; 
     
    283296            } else { // for arrays, the type is the type of the original array 
    284297                t = tree.symbolTable[tree.compounds[x-tree.symbolTable.length][0]-'A'].element; 
    285                         // Check for a stride..                 
     298                        // Check for a stride.. 
    286299                if (tree.compounds[x-tree.symbolTable.length][$-1]==']') { 
    287300                    strided = isStrided(tree.compounds[x-tree.symbolTable.length]); 
    288301                    if (strided) ++numstrides; 
    289302                } 
    290                  
     303 
    291304            } 
    292305        } 
     
    296309            // long, ulong, and real must become real. 
    297310            // We convert everything else to double, since that uses less 
    298             // FPU stack space.            
     311            // FPU stack space. 
    299312            if (t == "real" || t == "double" || t == "float") alltypes ~= t; 
    300313            else if (t == "long" || t == "ulong") result ~= "real"; 
     
    310323                alltypes ~= t ~ "*"; 
    311324                // for vectors, we only need the pointer, not the length 
    312                 vals ~= "&" ~  v ~ "[0]"; 
     325                //vals ~= "&" ~  v ~ "[0]"; 
     326                vals ~= v ~ ".ptr"; 
    313327            } 
    314328        } 
     
    323337    result ~= ")("; 
    324338    int firstVector = findVectorForLength(tree); 
    325     return InvocationCode(result ~ getValueForSymbol(tree.mapping[firstVector], tree).invoker ~ ".length"  
     339    return InvocationCode(result ~ getValueForSymbol(tree.mapping[firstVector], tree).invoker ~ ".length" 
    326340        ~ vals ~ stridelist  ~ ")", assertions); 
    327341} 
     
    342356    char [] postfix = makePostfixForSSE(tree.expression, tree.rank); 
    343357    char [] retType = "void"; 
    344     if (postfix[0]=='0') retType = (SSE2? "double" : "float"); 
     358    if (postfix[0]=='0' || postfix[0]=='1') retType = (SSE2? "double" : "float"); 
    345359 
    346360    char [] result = "SSEVECGEN!(" ~ retType ~ `,"` ~ enquote(postfix) ~ `"`; 
     
    352366        else result ~= SSE2? ",double*" : ",float*"; 
    353367        vals ~= ","; 
    354         if (rnk=='1') vals ~= "&"; 
     368//        if (rnk=='1') vals ~= "&"; 
    355369        InvocationCode q = getValueForSymbol(tree.mapping[i], tree); 
    356370        vals ~= q.invoker; 
    357371        assertions ~= q.assertions; 
    358372        // for vectors, we only need the pointer, not the length 
    359         if (rnk=='1') vals ~= "[0]"; 
    360     } 
    361              
     373        if (rnk=='1') vals ~= ".ptr"; 
     374    } 
     375 
    362376    result ~= ")("; 
    363377    int firstVector = findVectorForLength(tree); 
     
    384398//                    result ~= "static "; 
    385399//                } 
    386                 result ~= "assert("  
     400                result ~= "assert(" 
    387401                 ~ getDimensionLengthForSymbol(tree.mapping[i], tree, 1) 
    388402                    ~ "==" ~ getDimensionLengthForSymbol(tree.mapping[firstVector], tree, 1) 
     
    399413    for (int i=0; i<tree.mapping.length;++i) { 
    400414        if (tree.rank[i]=='1'){ 
    401             result ~= "assert( (cast(size_t)(&" ~ getValueForSymbol(tree.mapping[i], tree).invoker 
    402                     ~ "[0])& 0x0F) == 0, `SSE Vector misalignment: " ~ getValueForSymbol(tree.mapping[i], tree).invoker ~ "`);"\n; 
     415            result ~= "assert( (cast(size_t)(" ~ getValueForSymbol(tree.mapping[i], tree).invoker 
     416                    ~ ".ptr)& 0x0F) == 0, `SSE Vector misalignment: " ~ getValueForSymbol(tree.mapping[i], tree).invoker ~ "`);"\n; 
    403417        } 
    404418    } 
     
    438452        } 
    439453    } 
    440     return dynamic>=0? dynamic : strided;     
     454    return dynamic>=0? dynamic : strided; 
    441455} 
    442456 
     
    458472    } else {  // else it's a compound or an indexed array 
    459473        char [] comp = tree.compounds[c-'A'-tree.symbolTable.length]; 
    460          
     474 
    461475        if (comp[$-1]!=']') { // simple compound expression 
    462476            foreach(d; comp) { 
     
    477491            char [] nextIndex; 
    478492            char [] sliceTo; 
    479      
    480             for (int k = comp.length-1;k>=1; --k) {             
     493 
     494            for (int k = comp.length-1;k>=1; --k) { 
    481495                char d = comp[k]; 
    482496                if (d == ']') { ++numbracks; } 
    483497                if (d == '[') { --numbracks; } 
    484                  
     498 
    485499                if (d == ']' && numbracks == 1) { nextIndex = ""; } 
    486500                else if (numbracks == 1 && comp[k-1..k+1]=="..") { 
    487501                    isSlice = true; 
    488502                    sliceTo = nextIndex; 
    489                     nextIndex = "";                      
     503                    nextIndex = ""; 
    490504                    --k; 
    491505                } else if ((d == '[' && numbracks==0) || (d==',' && numbracks==1)) { 
     
    539553    } 
    540554    RevisedExpression revised = remapCompounds(expr, ranks, symbolTable); 
    541      
     555 
    542556    VecExpressionType exprType = categorizeExpression(revised); 
    543557    if (exprType == VecExpressionType.SSE2Expression || exprType == VecExpressionType.SSE1Expression) { 
     
    560574    if (c-'A'<tree.symbolTable.length) { 
    561575        v = tree.symbolTable[c-'A'].value; 
    562     } else {  // else it's a compound or an indexed array        
     576    } else {  // else it's a compound or an indexed array 
    563577        char [] comp = tree.compounds[c-'A'-tree.symbolTable.length]; 
    564          
     578 
    565579        if (comp[$-1]!=']') { // compound expression (not an indexed array) 
    566580            if (comp[0]>='a' && comp[0]<='z') { 
     
    590604            bool isSlice = false; 
    591605            char [] newSlice; 
    592      
    593             for (int k = comp.length-1;k>=1; --k) {             
     606 
     607            for (int k = comp.length-1;k>=1; --k) { 
    594608                char d = comp[k]; 
    595609                if (d == ']') { ++numbracks; } 
    596610                if (d == '[') { --numbracks; } 
    597                 
     611 
    598612                if (d==']' && numbracks == 1) { newSlice = ""; } 
    599                 else if (d=='.' && numbracks == 1) { isSlice = true;  
    600                     if(numSlicesRemaining>0){ newSlice = ""; }  
     613                else if (d=='.' && numbracks == 1) { isSlice = true; 
     614                    if(numSlicesRemaining>0){ newSlice = ""; } 
    601615                    else newSlice = "." ~ newSlice; 
    602616                } 
     
    645659    char [] assertions=""; 
    646660    int wholerank = exprRank(tree.expression, tree.rank); 
    647     if (wholerank ==1) {   
     661    if (wholerank ==1) { 
    648662        int lenvec = findVectorForLength(tree); 
    649         result = "for (int blade_index=0; blade_index<"  
     663        result = "for (int blade_index=0; blade_index<" 
    650664               ~ getDimensionLengthForSymbol(tree.mapping[lenvec], tree, 1) 
    651665               ~ "; ++blade_index) {"\n; 
    652666    } 
    653667    foreach (c; tree.expression) { 
    654         if (c>='A' && c<'Z') {             
     668        if (c>='A' && c<'Z') { 
    655669            // restore all symbols into the expression 
    656670            // If it's a vector, index it 
     
    667681        } else result ~= c; 
    668682    } 
    669     if (wholerank==0) return InvocationCode(result, assertions);    
     683    if (wholerank==0) return InvocationCode(result, assertions); 
    670684    return InvocationCode(result ~ "; }", assertions); 
    671685} 
  • trunk/blade/BladeDemo.d

    r172 r187  
    11//  Written in the D programming language 1.0. 
    22/** 
    3 * BLADE 0.3Alpha -- Basic Linear Algebra D Expressions 
     3* BLADE Alpha -- Basic Linear Algebra D Expressions 
    44* 
    55*/ 
     
    99import std.stdio; 
    1010 
    11 // Local arrays in D aren't aligned to 128-bit boundaries. 
     11// Local arrays in D aren't currently aligned to 128-bit boundaries. 
    1212// In such cases, the library generates an 'SSE misalignment' assert error, 
    1313// to avoid segfaults. 
    1414// Use heap-allocated arrays, or static arrays (DMD 1.023 or later) 
    1515// cdouble[] always remains aligned, even when sliced. 
    16    
     16 
    1717void main() 
    18 {  
     18{ 
    1919    static z = [3.4, 565, 31.3, 0]; 
    2020    double [] a = new double[4]; 
     
    2929        r[i]= q[i]*2213.3L; 
    3030    } 
    31     double [4][] another = [[33.1, 4543, 43, 878.7],  
     31    double [4][] another = [[33.1, 4543, 43, 878.7], 
    3232                            [5.14, 455, 554, 2.43]]; 
    3333    real k=3.4; 
    34      
    3534    mixin(vectorize(` a += (d[2..$-1]*2.01*a[2]-another[][1])["abc".length-3..$]`)); 
    36     mixin(vectorize(" a-= 2.01*(        3.04+k)*r"));     
    37     
     35    mixin(vectorize(" a-= 2.01*(        3.04+k)*r")); 
     36 
    3837    mixin(vectorize("q+= q*2.01")); 
    3938    mixin(vectorize("another[0]+=r-=another[0]+another[0]")); 
    40     
     39 
    4140    // All of the next four are equivalent 
    4241    mixin(vectorize("a+=6*another[1,0..$]")); 
    4342    mixin(vectorize("a+=6*(another[1,0..$]+another[1,0..$])")); 
    44    
     43 
    4544 
    4645    mixin(vectorize("a+=6*another[1][0..$]")); 
     
    4948    // I don't think I'll support this syntax long-term. 
    5049    mixin(vectorize("a+=6*another[1,[0,$]]")); 
    51      
     50 
    5251    // Strided vector 
    5352    mixin(vectorize("another[0..$,1]=6*a[0..2]")); 
     
    6160   mixin(vectorize("a = -a")); 
    6261   mixin(vectorize("u = sum(sqrt(abs(p))) + sum(sqrt(abs(q)))")); 
     62   mixin(vectorize("u = prod(q)")); 
     63    writefln("a=", a); 
    6364 
    64     writefln("a=", a); 
    6565} 
  • trunk/blade/BladeRank.d

    r172 r187  
    2727        else if (s[i]==')') --paren; 
    2828        if (paren==0 && s[i]==']') { 
    29             if (startIndex && hasSliced) return true;    
     29            if (startIndex && hasSliced) return true; 
    3030            numbrack--; 
    31             if (s[i-1]=='[') { startIndex=false; }  
     31            if (s[i-1]=='[') { startIndex=false; } 
    3232        } 
    3333        if (paren==0 && s[i]=='[') { 
     
    3636            numbrack++; 
    3737        } 
    38         if (paren==0 && numbrack==1 && s[i]==',') {            
     38        if (paren==0 && numbrack==1 && s[i]==',') { 
    3939            if (hasSliced && startIndex) return true; 
    4040            if (maybeSlice) hasSliced = true; 
     
    8383enum RankError : int { 
    8484    UnsupportedOperation = -1, 
    85     RankIncrement = -2,  
     85    RankIncrement = -2, 
    8686    AttemptToIndexAScalar = -3, 
    8787    NonScalarIndex = -4, 
    8888    NonScalarSlice = -5, 
    8989    DotDotExpected = -6, 
    90     CommaExpected = -7,  
     90    CommaExpected = -7, 
    9191    RankMismatch = -8, 
    9292    RankMismatchConcatenation = -9, 
     
    132132            auto rrank = doVisit(this_, args[1]); 
    133133            if (rrank<0) return rrank; // propagate errors 
    134             if (lrank!=1 || rrank!=1) return RankError.RankMismatchDotProduct;        
     134            if (lrank!=1 || rrank!=1) return RankError.RankMismatchDotProduct; 
    135135            return 0; 
    136136        case "sum": 
     137        case "prod": 
    137138            auto lrank = doVisit(this_,args[0]); 
    138139            if (lrank<0) return lrank; // propagate errors 
     
    146147            assert(0, "BLADE ICE: Unsupported function:" ~ func); 
    147148            return 0; 
    148         }         
     149        } 
    149150    } 
    150151    ReturnType onVisitPrefix(This this_, char [] op, char [] expr) { 
     
    159160        return RankError.RankIncrement; 
    160161    } 
    161     // Includes multi-dimensional slicing and indexing.     
     162    // Includes multi-dimensional slicing and indexing. 
    162163    ReturnType onVisitIndex(This this_, char [] base, char [][2][] slices) { 
    163164        int totrank = doVisit(this_, base); 
    164165        for(int i=0; i<slices.length; ++i) { 
    165166            int r = doVisit(this_,slices[i][0]); 
    166             if (r!=0) return (r<0)? r :RankError.NonScalarIndex;  
     167            if (r!=0) return (r<0)? r :RankError.NonScalarIndex; 
    167168            if (slices[i][1]==""){ 
    168169                --totrank; 
    169170            } else { 
    170171                r = doVisit(this_,slices[i][1]); 
    171                 if (r!=0) return (r<0)?r:RankError.NonScalarSlice;  
     172                if (r!=0) return (r<0)?r:RankError.NonScalarSlice; 
    172173            } 
    173174        } 
     
    186187        } 
    187188        if (op=="~") { // concatentating scalars and vectors, or vectors and matrices, is permitted 
    188             if (lrank==rrank || lrank==(rrank+1) || rrank==(lrank+1))  
     189            if (lrank==rrank || lrank==(rrank+1) || rrank==(lrank+1)) 
    189190                return (lrank>rrank)? lrank: rrank; 
    190191            else return RankError.RankMismatchConcatenation; 
     
    211212} 
    212213 
    213 unittest {     
     214unittest { 
    214215    assert(exprRank("(A[B..C])[C]", "300")==2); 
    215216    assert(exprRank("A+=(A[B..C])", "300")==3); 
    216      
    217     assert(exprRank("A+(B*C)", "000")==0);     
     217 
     218    assert(exprRank("A+(B*C)", "000")==0); 
    218219    assert(exprRank("A=(B*C)", "202")==2); 
    219220    assert(exprRank("B*=(C*A)", "010")==1); 
     
    221222    assert(exprRank("D+=((A+C)*B)", "2022")==2); 
    222223    assert(exprRank("D+=((A&C)*B)", "0101")==1); 
    223      
     224 
    224225    assert(exprRank("C~=(((A[B])[B])~C)", "302")==2); 
    225226    assert(exprRank("((D[E])[E])+(-((C[B])[B..E]))", "202300")==1); 
    226227 
    227228    assert(exprRank("A+((((++B)+D)--)*C)", "1010")==1); 
    228      
     229 
    229230    assert(exprRank("C+=(A[B])", "302")==2); 
    230231    assert(exprRank("dot(A)", "1")==RankError.CommaExpected); 
    231232    assert(exprRank("dot(A,B)", "10")==RankError.RankMismatchDotProduct); 
    232      
    233     assert(exprRank("dot(B,(A*(dot(B,B))))", "11")==0);     
    234      
     233 
     234    assert(exprRank("dot(B,(A*(dot(B,B))))", "11")==0); 
     235 
     236    assert(exprRank("prod(A*B)", "10")==0); 
     237 
    235238    assert(exprRank("A[B,B,B]", "60")==3); 
    236239    assert(exprRank("A[B,B,C,B]", "600")==2); 
    237240    assert(exprRank("A+=(B[C..$])", "110")==1); 
    238241    assert(exprRank("A+=(B[C,D..$])", "2300")==2); 
    239      
     242 
    240243    // bug fixes: 
    241244    assert(exprRank("(A[B..$,C])+=D", "2001")==1); 
  • trunk/blade/BladeSimplify.d

    r175 r187  
    1414*      be moved to every vector inside A. 
    1515*    - Use associativity of *: A*(B*C[]) == (A*B)*C[] (Not strictly true for 
    16 *      floating point; results may differ by 1ulp,  
     16*      floating point; results may differ by 1ulp, 
    1717*       eg (1.3L*3.1L)*4.7L < 1.3L*(3.1L*4.7L) 
    1818*      Note that floating point addition is not associative at all). 
    1919*    - Remove unary minus where possible, eg A-(-B) => A+B, abs(-A) => abs(A). 
    20 *    - Use associativity of * in intrinsics:  
     20*    - Use associativity of * in intrinsics: 
    2121*         sum(A*V) => A*sum(V), abs(A*B) => abs(A)*abs(B) 
    22 * (D) Expression standardisation  
     22* (D) Expression standardisation 
    2323*    - Move multiplies to left: Convert A[]*B into B*A[] (assumes * is commutative, 
    2424*      not valid for quaternions). 
     
    5454{ 
    5555    return str=="dot" || str=="sum" || str=="max" || str=="min" 
    56            || str=="abs" || str=="sqrt"
     56           || str=="abs" || str=="sqrt" || str=="prod"
    5757} 
    5858 
     
    6868    } 
    6969    // Check for undefined symbols 
    70     if (err.length > 0)  
     70    if (err.length > 0) 
    7171        return RevisedExpression(tree.expression, "", tree.symbolTable, [""], "","", "Undefined symbols:" ~ err); 
    7272    else { 
     
    119119        } else e~=c; 
    120120    } 
    121     return e;     
     121    return e; 
    122122} 
    123123 
     
    171171            } 
    172172            --k; 
    173             char [] newexpr = expr[i+1..k]; // strip off the {}             
     173            char [] newexpr = expr[i+1..k]; // strip off the {} 
    174174            int newi = k; 
    175175            if (i>0 && k<expr.length-1 && expr[i-1]=='(' && expr[k+1]==')') { 
     
    184184                ++next; 
    185185                comp ~= expr[i+1..k]; // strip off the {} 
    186                 if (expr[k-1]==']') {                 
     186                if (expr[k-1]==']') { 
    187187                    // it's a vector/matrix of some kind, with rank reduced 
    188188                    // by indices. Can't just use exprRank, because the [] 
     
    192192                    // it's a scalar expression. Note that it could involve 
    193193                    // a vector expression. 
    194                     r~='0';  
    195                 }                 
     194                    r~='0'; 
     195                } 
    196196            } else e ~= cast(char)('A'+z+rank.length); 
    197197            i = newi; 
     
    202202    } 
    203203    // Create a mapping from old to new variable names 
    204          
     204 
    205205    char [] old_ranks = ""; 
    206206    char [] mapping=""; 
     
    235235} 
    236236 
    237 unittest {     
     237unittest { 
    238238    RevisedExpression e = simplifyVectorExpression("A+=(((D[B])*C)[B])", "2004",[]); 
    239239    assert(e.rank=="202"); 
     
    281281           assert(sym!="$" && this_.rank[sym[0]-'A']>'0', "Rank error " ~ sym); 
    282282           // Note: Later, we'll want this to be a new terminal. 
    283            return sym ~ createMultiSlice(this_.slicing);            
     283           return sym ~ createMultiSlice(this_.slicing); 
    284284       } 
    285285    } 
     
    301301        return wrapInParens(doVisit(this_, expr)) ~ op; 
    302302    } 
    303     // Includes multi-dimensional slicing and indexing.     
     303    // Includes multi-dimensional slicing and indexing. 
    304304    ReturnType onVisitIndex(This this_, char [] base, char [][2][] slices) { 
    305305        if (slices.length==0) { // []  -- has no effect. 
     
    311311            // with the earliest existing dimension. 
    312312            // * If the existing dimension is an index, 
    313             //   it might contain a dollar, which we need to replace.  
     313            //   it might contain a dollar, which we need to replace. 
    314314            // * If the existing dimension is a slice, the two slices will combine. 
    315315            // 
     
    331331                newslice ~= [a ~ "+" ~ c, ""]; 
    332332            } 
    333             if (slices.length>1) {                 
     333            if (slices.length>1) { 
    334334                // append other slices, if any. 
    335335                return doVisit(IndexFoldingVisitor(this_.rank, "$", slices[0..$-1] ~ newslice ~ this_.slicing[1..$]), base); 
     
    360360                assert(lrank>0 && rrank>0 && lrank<=2 && rrank<=2, "BLADE ICE: Tensor*tensor is unsupported"); 
    361361                bool isDotProduct = false; // was it reduced to a dot product? 
    362                  
     362 
    363363                // In the case of chained matrix multiplies, we can end up with an empty slice. 
    364364                if (this_.slicing.length>0 && this_.slicing[$-1][0]=="") { 
     
    368368                    // First dimension applies to rows of the left operand 
    369369                    // If it's a slice, it will be a strided slice -- unless 
    370                     // it comes from another matrix multiply, in which case the                     
     370                    // it comes from another matrix multiply, in which case the 
    371371                    // stride will drop out. (A[x]*B is strided). 
    372372                    char [][2][] newslice=[]; 
     
    390390                        second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right)); 
    391391                    } 
    392                 } else if (lrank==1 && rrank==2) {                     
     392                } else if (lrank==1 && rrank==2) { 
    393393                    // vector * matrix, Matrix uses all the slicing 
    394394                    second = wrapInParens(doVisit(this_, right)); 
     
    403403                } 
    404404            } 
    405         } else { 
    406             // in DMD1.024, nasty compiler bug if you save the first & second results into a local variable 
     405        } else { // not a multiplication 
    407406            return wrapInParens(doVisit(this_, left)) ~ op ~ wrapInParens(doVisit(this_, right)); 
    408407        } 
     
    416415} 
    417416 
    418 unittest {    
     417unittest { 
    419418    assert(foldIndices("((A[C..D])+B)[($-E)]", "21000")=="(A[C+((D-C)-E)])+(B[($-E)])"); 
    420419    assert(foldIndices("(A[C])[D]", "3100")=="A[C,D]"); 
     
    427426    assert(foldIndices("A[,B..$,C]", "300")=="A[,B..$,C]"); 
    428427    // Multidimensional slicing 
    429     assert(foldIndices("(C*((A*B)[C]))[D]", "2200")=="C*dot((A[C,]),(B[D]))");     
     428    assert(foldIndices("(C*((A*B)[C]))[D]", "2200")=="C*dot((A[C,]),(B[D]))"); 
    430429    assert(foldIndices("(A*B)[C..D,D]", "2200")=="(A[C..D,])*(B[D])"); 
    431430    assert(foldIndices("(A*B)[C..D]", "2200")=="(A[C..D,])*B"); 
     
    433432    assert(foldIndices("(A*B)[C..D]", "1200")=="A*(B[C..D])"); 
    434433    assert(foldIndices("(A*B)[C]", "120")=="dot(A,(B[C]))"); 
    435      
     434 
    436435    assert(foldIndices("((A*B)*C)[D]", "2220")=="((A[D,])*B)*C"); 
    437436    assert(foldIndices("((A+B)*C)[D]", "2220")=="((A[D,])+(B[D,]))*C"); 
    438437    assert(foldIndices("((D*A)*B)[C]", "2100")=="dot((D*(A[C,])),B)"); 
    439     assert(foldIndices("(((A*B)*C)[D..E])[D]", "12200")=="dot((A*B),(C[D+D]))");  
     438    assert(foldIndices("(((A*B)*C)[D..E])[D]", "12200")=="dot((A*B),(C[D+D]))"); 
    440439    assert(foldIndices("A+=(((D[B])*C)[B])", "2004")=="A+=((D[B,B])*C)"); 
    441440    assert(foldIndices("dot(A,(A*dot(A,A)))","1")=="dot(A,(A*dot(A,A)))"); 
     
    466465            ScalarFold right = doVisit(this_, args[1]); 
    467466            return ScalarFold("", combineMul(combineMul(left.multiplier, right.multiplier), "{" ~ func ~ "(" ~ wrapInParens(left.expr) ~ "," ~ wrapInParens(right.expr) ~ ")}")); 
    468         case "sum":  
     467        case "sum": 
     468        case "prod": 
    469469            //  sum(A*V) = A*sum(V) is a scalar. 
     470            //  prod(A*V) = A*prod(V) is a scalar. 
    470471            return ScalarFold("", combineMul(left.multiplier, "{" ~ func ~ "(" ~ wrapInParens(left.expr) ~ ")}")); 
    471472        case "abs": 
     
    483484        case "max": 
    484485        case "min": // max(A*B) can't be simplified unless we know that they are not negative. 
    485             return ScalarFold("", "{" ~ func ~ "(" ~ combineMulWithCompound(left.expr, left.multiplier) ~ ")}");  
    486 //            return ScalarFold("", "{" ~ func ~ "(@>"  ~ left.expr ~ "@" ~ left.multiplier ~ "<@)}");  
     486            return ScalarFold("", "{" ~ func ~ "(" ~ combineMulWithCompound(left.expr, left.multiplier) ~ ")}"); 
     487//            return ScalarFold("", "{" ~ func ~ "(@>"  ~ left.expr ~ "@" ~ left.multiplier ~ "<@)}"); 
    487488        default: 
    488489            assert(0, "BLADE: Unsupported function"); 
    489490            return ScalarFold("",""); 
    490491        } 
    491     }     
     492    } 
    492493    ReturnType onVisitPrefix(This this_, char [] op, char [] expr) { 
    493494        if (op=="-") { 
     
    498499            else return ScalarFold(left.expr, "-"); 
    499500        } else if (op=="+") { // just ignore unary plus 
    500             return doVisit(this_, expr);         
     501            return doVisit(this_, expr); 
    501502        } else { 
    502503            ScalarFold f = doVisit(this_, expr); 
     
    530531            assert(first.multiplier=="" && second.expr=="", "BLADE ICE"); 
    531532            assert(second.multiplier!="-", "BLADE ICE"); // this would be a*=-b, where b is a vector 
    532             if (second.multiplier.length>1)  return ScalarFold(wrapInParens(first.expr) ~ op ~ "{" ~ wrapInParens(second.multiplier) ~ "}","");  
    533             else return ScalarFold(wrapInParens(first.expr) ~ op ~ wrapInParens(second.multiplier),"");  
     533            if (second.multiplier.length>1)  return ScalarFold(wrapInParens(first.expr) ~ op ~ "{" ~ wrapInParens(second.multiplier) ~ "}",""); 
     534            else return ScalarFold(wrapInParens(first.expr) ~ op ~ wrapInParens(second.multiplier),""); 
    534535        } 
    535536        if (op=="*") { 
     
    588589    assert(left.length>0); 
    589590    if (right.length==0) return left; 
    590     if (right=="-") return "-" ~ wrapInParens(left);     
     591    if (right=="-") return "-" ~ wrapInParens(left); 
    591592    if (right.length==1) return wrapInParens(left) ~ "*" ~ right; 
    592     return wrapInParens(left) ~ "*{" ~ wrapInParens(right) ~ "}";     
     593    return wrapInParens(left) ~ "*{" ~ wrapInParens(right) ~ "}"; 
    593594} 
    594595 
  • trunk/blade/CodegenX86.d

    r172 r187  
    11//  Written in the D programming language 1.0 
    22/** 
    3 * BLADE 0.4Alpha -- Basic Linear Algebra D Expressions 
     3* BLADE Alpha -- Basic Linear Algebra D Expressions 
    44* 
    55* Generate near-optimal x87/SSE/SSE2 asm code for BLAS1 basic vector operations 
     
    2121* 
    2222* BUGS/ FUTURE DIRECTIONS: 
    23 *  None of these support dot product, or matrix operations. 
     23*  None of these support matrix operations. 
    2424* X87: 
    2525*  - Not optimal for the case of multiple real vectors (they could share a counter). 
     
    3030*   (to do this, need naked asm with no stack frame). 
    3131* SSE/SSE2: 
    32 *  - SSE functions don't support unaligned data. Need to generate seperate code  
     32*  - SSE functions don't support unaligned data. Need to generate seperate code 
    3333*    for that case (NOTE: probably only worth doing for small expressions). 
    3434* 
     
    191191// (max # temporaries + max # real scalars) must be <=8, otherwise FPU stack 
    192192// will overflow). 
    193 const int MAX_87_REALSCALARSPLUSTEMPORARIES = 8;  
     193const int MAX_87_REALSCALARSPLUSTEMPORARIES = 8; 
    194194 
    195195private: 
     
    213213// indexed by i. 
    214214char [] indexedVector(char [][] typelist, char [] ranklist, char [] stridelist, char var) 
    215 {     
     215{ 
    216216    if (typelist[var-'A']=="real") return " real ptr [" ~ vectorRegister[vectorNum(ranklist, var)] ~ "]"; 
    217217    else if (stridelist[var-'A']=='1') return operandSize(typelist[var-'A']) ~ "[" ~ vectorRegister[vectorNum(ranklist, var)] ~ "]"; 
     
    264264 (Pentium, PMMX, PII, PIII). It is also optimal for recent x86 CPUs 
    265265 where vector sizes are mixed. 
    266   
     266 
    267267 There are two cases: 
    268268 (A) DAXPY-style loops, where every element is independent of the other indices; 
     
    271271 For cumulative loops, best performance is achieved with loop unrolling and 
    272272 multiple accumulators, in order to break dependency chains. 
    273    
     273 
    274274The key optimisation rules for DAXPY loops are: 
    275275 1. keep the loop overhead to one clock cycle if possible. 
     
    304304            ranklist~="1"; 
    305305            typelist ~= typeof(T[0]).stringof; 
    306         } else static if (is(typeof(T.data))) {             
     306        } else static if (is(typeof(T.data))) { 
    307307            stridelist~="1"; 
    308308            ranklist~="1"; 
     
    323323    char [] result=""; 
    324324    char [] incrementRealVectors=""; 
    325      
     325 
    326326    result ~= "// Operation : " ~  operations ~ \n; 
    327     
     327 
    328328    // Create local variables for pointers to vectors (avoid bug #1125) 
    329329 
     
    361361              ~ "  add " ~ vectorRegister[numvecs] ~ ", values[" ~ itoa(i) ~ "];"; 
    362362         } 
    363          result ~= "  //" ~ cast(char)('A'+i) ~ \n;  
     363         result ~= "  //" ~ cast(char)('A'+i) ~ \n; 
    364364        ++numvecs; 
    365365      } else if (typelist[i]=="real") { 
     
    367367          ++numconsts; 
    368368          ++numScalarsOnStack; 
    369          result ~= "  //" ~ cast(char)('A'+i) ~ \n;  
     369         result ~= "  //" ~ cast(char)('A'+i) ~ \n; 
    370370      } 
    371371    } 
     
    376376    int numOnStack = 0; // How much of the FP stack is being used? 
    377377 
    378     bool isCumulative = (operations[0]=='0'); 
     378    bool isCumulative = (operations[0]=='0' || operations[0]=='1'); 
    379379    if (operations[0]=='0') { 
    380380        result ~= "  fldz;"\n; // dot product 
    381381        ++numOnStack; 
    382382        done = 1; 
     383    } else if (operations[0]=='1') { 
     384        result ~= "  fld1;"\n; // prod 
     385        ++numOnStack; 
     386        done = 1; 
    383387    } 
    384388    result ~= "  xor EAX, EAX; "\n 
     
    389393    // the final storage instruction, because of the FST latency). 
    390394    char [] mainbody = ""; 
    391              
     395 
    392396    while(done<operations.length) { 
    393397        char [] next; 
     
    420424        } else if (operations[done]==',') { 
    421425            mainbody ~= "  " ~ opToX87[operations[done+1]] ~ " ST, ST(0);    // dup " ~ operations[done+1] ~ \n; 
    422             done+=2;           
     426            done+=2; 
    423427        } else if (ranklist[operations[done]-'A']=='1') { 
    424428             // An operation will be performed between the stack top and a vector. 
     
    430434                // it chains. 
    431435                next = ((done+2 == operations.length)? "  fstp " : "  fst ") 
    432                     ~ indexedVector(typelist, ranklist, stridelist, operations[$-2] ) ~ comment;             
     436                    ~ indexedVector(typelist, ranklist, stridelist, operations[$-2] ) ~ comment; 
    433437            } else if (typelist[operations[done]-'A']=="real") { 
    434438                 // 80-bit vectors must be loaded onto the FPU stack first 
     
    445449            // Multiply by real scalar, which is already on the stack. 
    446450            next = "  fmul ST, ST(" ~ itoa(numOnStack + numScalarsOnStack - realScalarNum(typelist, ranklist, operations[done]-'A')-1) ~ "); // * " ~ operations[done] ~ \n; 
    447             mainbody ~= next;             
     451            mainbody ~= next; 
    448452          } else { 
    449453            // For scalar float or double values, we can multiply directly, saving one slot on the FP stack. 
     
    452456          } 
    453457            done +=2; 
    454         }       
    455     } 
    456          
    457     result ~= \n  
    458         ~ "  align 4;\n"  
     458        } 
     459    } 
     460 
     461    result ~= \n 
     462        ~ "  align 4;\n" 
    459463        ~ "L1:\n" ~ mainbody; 
    460          
     464 
    461465//    if (cumulatingOp) result ~= "  " ~ opToX87[cumulatingOp] ~ "p ST(2), ST;"\n; 
    462466 
     
    472476 
    473477    result~= "L3:" \n ~ popRegisters(vecnum) ~ "}\r\n"; 
    474