Changeset 157

Show
Ignore:
Timestamp:
12/06/07 12:47:16 (9 months ago)
Author:
Don Clugston
Message:

Simplify drop product.
Refactor

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/blade/Blade.d

    r155 r157  
    7272    RevisedExpression revised = simplifySyntaxTree(tree); 
    7373    if (revised.errorMessage.length>0)  return `static assert(0, "BLADE: ` ~ enquote(revised.errorMessage) ~ `");`; 
    74     VecExpressionType exprType = categorizeExpression(tree, revised); 
     74    VecExpressionType exprType = categorizeExpression(revised); 
    7575    if (exprType == VecExpressionType.SSE2Expression || exprType == VecExpressionType.SSE1Expression) { 
    76         return invokeSSE((exprType == VecExpressionType.SSE2Expression), tree, revised); 
     76        return invokeSSE((exprType == VecExpressionType.SSE2Expression), revised); 
    7777    } else if (exprType == VecExpressionType.X87Expression) { 
    78         return invokeX87(tree, revised); 
     78        return invokeX87(revised); 
    7979    } else { 
    80         return DCodeGenerator(tree, revised); 
     80        return DCodeGenerator(revised); 
    8181    }     
    8282} 
     
    122122enum VecExpressionType { SSE1Expression, SSE2Expression, X87Expression, DExpression }; 
    123123 
    124 VecExpressionType categorizeExpression(AbstractSyntaxTree tree, RevisedExpression revised
     124VecExpressionType categorizeExpression(RevisedExpression tree
    125125{ 
    126126    bool SSE2 = true; 
     
    137137    int numscalars = 0; 
    138138    int numRealScalars = 0; // scalars other than float or double. 
    139     for (int i=0; i<revised.mapping.length;++i) { 
    140         char r = revised.rank[i]; 
    141         int x = revised.mapping[i]-'A'; 
     139    for (int i=0; i<tree.mapping.length;++i) { 
     140        char r = tree.rank[i]; 
     141        int x = tree.mapping[i]-'A'; 
    142142        if (r=='0') { 
    143143            ++numscalars; 
     
    155155        int y = x; // for compounds, get the original type 
    156156        if (x>=tree.symbolTable.length) { 
    157             y = revised.compounds[x-tree.symbolTable.length][0]-'A'; 
     157            y = tree.compounds[x-tree.symbolTable.length][0]-'A'; 
    158158            // Check for a stride.. 
    159             if (revised.compounds[x-tree.symbolTable.length][$-1]==']') {                
    160                 strided |= isStrided(revised.compounds[x-tree.symbolTable.length]); 
     159            if (tree.compounds[x-tree.symbolTable.length][$-1]==']') {                
     160                strided |= isStrided(tree.compounds[x-tree.symbolTable.length]); 
    161161            } 
    162162        } 
     
    194194 
    195195/// Generate code which will call the X87 function 
    196 char [] invokeX87(AbstractSyntaxTree tree, RevisedExpression revised
    197 { 
    198     char [] result = assertAllVectorLengthsEqual(tree, revised); 
    199     result ~= `X87VECGEN!("` ~ enquote(revised.expression) ~ `"`; 
     196char [] invokeX87(RevisedExpression tree
     197{ 
     198    char [] result = assertAllVectorLengthsEqual(tree); 
     199    result ~= `X87VECGEN!("` ~ enquote(tree.expression) ~ `"`; 
    200200     
    201201    char [] vals; 
    202     for (int i=0; i<revised.mapping.length;++i) { 
    203         char rnk = revised.rank[i]; 
     202    for (int i=0; i<tree.mapping.length;++i) { 
     203        char rnk = tree.rank[i]; 
    204204        vals ~= ","; 
    205205        if (rnk=='1') vals ~= "&"; 
    206         vals ~= getValueForSymbol(revised.mapping[i], tree, revised); 
    207         int x = revised.mapping[i]-'A'; 
     206        vals ~= getValueForSymbol(tree.mapping[i], tree); 
     207        int x = tree.mapping[i]-'A'; 
    208208        char [] t; 
    209209        if (x<tree.symbolTable.length) { 
     
    219219                // or float, it could use less FPU stack space. 
    220220            } else { // for arrays, the type is the type of the original array 
    221                 t = tree.symbolTable[revised.compounds[x-tree.symbolTable.length][0]-'A'].element; 
     221                t = tree.symbolTable[tree.compounds[x-tree.symbolTable.length][0]-'A'].element; 
    222222            } 
    223223        } 
     
    239239    } 
    240240    result ~= ")("; 
    241     int firstVector = findVectorForLength(tree, revised); 
    242     return result ~ getValueForSymbol(revised.mapping[firstVector], tree, revised) ~ ".length" ~ vals ~ ");"; 
     241    int firstVector = findVectorForLength(tree); 
     242    return result ~ getValueForSymbol(tree.mapping[firstVector], tree) ~ ".length" ~ vals ~ ");"; 
    243243} 
    244244 
    245245/// Generate code which will call the SSE/SSE2 code generation function 
    246 char [] invokeSSE(bool SSE2, AbstractSyntaxTree tree, RevisedExpression revised
    247 { 
    248     char [] result = assertAllVectorLengthsEqual(tree, revised); 
    249     result ~= assertAllVectorsAlign128(tree, revised); 
     246char [] invokeSSE(bool SSE2, RevisedExpression tree
     247{ 
     248    char [] result = assertAllVectorLengthsEqual(tree); 
     249    result ~= assertAllVectorsAlign128(tree); 
    250250 
    251251        
    252     result ~= "SSEVECGEN!(" ~ (SSE2?"2":"1") ~ `,"` ~ enquote(revised.expression) ~ `"`; 
     252    result ~= "SSEVECGEN!(" ~ (SSE2?"2":"1") ~ `,"` ~ enquote(tree.expression) ~ `"`; 
    253253    // For SSE2, everything must be implicitly convertible to double. 
    254254    char [] vals; 
    255     for (int i=0; i<revised.mapping.length;++i) { 
    256         char rnk = revised.rank[i]; 
     255    for (int i=0; i<tree.mapping.length;++i) { 
     256        char rnk = tree.rank[i]; 
    257257        if (rnk=='0') result ~= SSE2? ",double" : ",float"; 
    258258        else result ~= SSE2? ",double*" : ",float*"; 
    259259        vals ~= ","; 
    260260        if (rnk=='1') vals ~= "&"; 
    261         vals ~= getValueForSymbol(revised.mapping[i], tree, revised); 
     261        vals ~= getValueForSymbol(tree.mapping[i], tree); 
    262262        // for vectors, we only need the pointer, not the length 
    263263//        if (rnk=='1') vals ~= ".ptr"; 
     
    266266             
    267267    result ~= ")("; 
    268     int firstVector = findVectorForLength(tree, revised); 
    269     result ~= getValueForSymbol(revised.mapping[firstVector], tree, revised) ~ ".length"; 
     268    int firstVector = findVectorForLength(tree); 
     269    result ~= getValueForSymbol(tree.mapping[firstVector], tree) ~ ".length"; 
    270270//    result ~= tree.symbolTable[firstVector].value ~ ".length"; 
    271271    result ~= vals; 
     
    277277 * If possible, the error will be detected at compile time. 
    278278 */ 
    279 char [] assertAllVectorLengthsEqual(AbstractSyntaxTree tree, RevisedExpression revised
     279char [] assertAllVectorLengthsEqual(RevisedExpression tree
    280280{ 
    281281    char [] result =""; 
    282     int firstVector = findVectorForLength(tree, revised); 
     282    int firstVector = findVectorForLength(tree); 
    283283//    bool known = arrayLengthIsStatic(tree.symbolTable[firstVector].type); 
    284     for (int i=0; i<revised.mapping.length;++i) { 
    285         if (revised.rank[i]=='1') { 
     284    for (int i=0; i<tree.mapping.length;++i) { 
     285        if (tree.rank[i]=='1') { 
    286286            if (firstVector != i) { 
    287287//                if (known && arrayLengthIsStatic(tree.symbolTable[i].type)) { 
     
    291291//                } 
    292292                result ~= "assert("  
    293                  ~ getDimensionLengthForSymbol(revised.mapping[i], tree, revised, 1) 
    294                     ~ "==" ~ getDimensionLengthForSymbol(revised.mapping[firstVector], tree, revised, 1) 
     293                 ~ getDimensionLengthForSymbol(tree.mapping[i], tree, 1) 
     294                    ~ "==" ~ getDimensionLengthForSymbol(tree.mapping[firstVector], tree, 1) 
    295295                    ~ ", `Vector length mismatch`);"\n; 
    296 //                    ~ ".length==" ~ getValueForSymbol(revised.mapping[firstVector], tree, revised) 
    297 //                    ~ ".length, `Vector length mismatch`);"\n; 
    298296            } 
    299297        } 
     
    302300} 
    303301 
    304 char [] assertAllVectorsAlign128(AbstractSyntaxTree tree, RevisedExpression revised
     302char [] assertAllVectorsAlign128(RevisedExpression tree
    305303{ 
    306304    char [] result =""; 
    307     for (int i=0; i<revised.mapping.length;++i) { 
    308         if (revised.rank[i]=='1'){ 
    309             result ~= "assert( (cast(size_t)(&" ~ getValueForSymbol(revised.mapping[i], tree, revised
    310                     ~ "[0])& 0x0F) == 0, `SSE Vector misalignment: " ~ getValueForSymbol(revised.mapping[i], tree, revised) ~ "`);"\n; 
     305    for (int i=0; i<tree.mapping.length;++i) { 
     306        if (tree.rank[i]=='1'){ 
     307            result ~= "assert( (cast(size_t)(&" ~ getValueForSymbol(tree.mapping[i], tree
     308                    ~ "[0])& 0x0F) == 0, `SSE Vector misalignment: " ~ getValueForSymbol(tree.mapping[i], tree) ~ "`);"\n; 
    311309        } 
    312310    } 
     
    326324// If this is not possible, a normal dynamic array will be used. 
    327325// If all else fails, a sliced vector will be used. 
    328 int findVectorForLength(AbstractSyntaxTree tree, RevisedExpression revised
     326int findVectorForLength(RevisedExpression tree
    329327{ 
    330328    int dynamic = -1; // last dynamic vector 
    331329    int strided = 0; // last unstrided vector 
    332     for (int i = 0; i < revised.mapping.length; ++i) { 
    333         if (revised.rank[i]!='1') continue; 
    334         int x = revised.mapping[i]-'A'; 
     330    for (int i = 0; i < tree.mapping.length; ++i) { 
     331        if (tree.rank[i]!='1') continue; 
     332        int x = tree.mapping[i]-'A'; 
    335333        strided = i; 
    336334        if (x < tree.symbolTable.length) { 
     
    339337        } else { 
    340338            // Check for a stride. 
    341             if (revised.compounds[x-tree.symbolTable.length][$-1]==']') { 
    342                 if (!isStrided(revised.compounds[x-tree.symbolTable.length])) { 
     339            if (tree.compounds[x-tree.symbolTable.length][$-1]==']') { 
     340                if (!isStrided(tree.compounds[x-tree.symbolTable.length])) { 
    343341                    dynamic = i; 
    344342                } 
     
    355353} 
    356354 
    357 char [] getDimensionLengthForSymbol(char c, AbstractSyntaxTree tree, RevisedExpression revised, int dimension) 
     355char [] getDimensionLengthForSymbol(char c, RevisedExpression tree, int dimension) 
    358356{ 
    359357    int numSlicesRemaining = 1; 
     
    365363        return v ~ ".length"; 
    366364    } else {  // else it's a compound or an indexed array 
    367         char [] comp = revised.compounds[c-'A'-tree.symbolTable.length]; 
     365        char [] comp = tree.compounds[c-'A'-tree.symbolTable.length]; 
    368366         
    369367        if (comp[$-1]!=']') { // simple compound expression 
     
    424422} 
    425423 
    426 char [] getValueForSymbol(char c, AbstractSyntaxTree tree, RevisedExpression revised, char [] firstIndexExpr="") 
     424char [] getValueForSymbol(char c, RevisedExpression tree, char [] firstIndexExpr="") 
    427425{ 
    428426    int numSlicesRemaining=1; 
     
    433431        v = tree.symbolTable[c-'A'].value; 
    434432    } else {  // else it's a compound or an indexed array 
    435         char [] comp = revised.compounds[c-'A'-tree.symbolTable.length]; 
     433        char [] comp = tree.compounds[c-'A'-tree.symbolTable.length]; 
    436434         
    437435        if (comp[$-1]!=']') { // simple compound expression 
     
    510508 
    511509// Generate inline D code for the expression 
    512 char [] DCodeGenerator(AbstractSyntaxTree tree, RevisedExpression revised
    513 { 
    514     int lenvec = findVectorForLength(tree, revised); 
    515     char [] result = assertAllVectorLengthsEqual(tree, revised); 
     510char [] DCodeGenerator(RevisedExpression tree
     511{ 
     512    int lenvec = findVectorForLength(tree); 
     513    char [] result = assertAllVectorLengthsEqual(tree); 
    516514    result ~= "for (int blade_index=0; blade_index<"  
    517     ~ getDimensionLengthForSymbol(revised.mapping[lenvec], tree, revised, 1) ~ 
     515    ~ getDimensionLengthForSymbol(tree.mapping[lenvec], tree, 1) ~ 
    518516        "; ++blade_index) {"\n; 
    519     foreach (c; revised.expression) { 
     517    foreach (c; tree.expression) { 
    520518        if (c>='A' && c<'Z') { 
    521519            // restore all symbols into the expression 
    522520            // If it's a vector, index it 
    523             if (revised.rank[c-'A']=='1') 
    524                 result ~= getValueForSymbol(revised.mapping[c-'A'], tree, revised, "blade_index"); 
    525             else result ~= getValueForSymbol(revised.mapping[c-'A'], tree, revised); 
     521            if (tree.rank[c-'A']=='1') 
     522                result ~= getValueForSymbol(tree.mapping[c-'A'], tree, "blade_index"); 
     523            else result ~= getValueForSymbol(tree.mapping[c-'A'], tree); 
    526524        } else result ~= c; 
    527525    } 
  • trunk/blade/BladeDemo.d

    r156 r157  
    3232    double [4][] another = [[33.1, 4543, 43, 878.7], [5.14, 455, 554, 2.43]]; 
    3333    real k=3.4; 
    34     
     34 
    3535    mixin(vectorize(` a += (d[2..$-1]*2.01*a[2]-another[][1])["abc".length-3..$]`)); 
    3636     
     
    4545    mixin(vectorize("another[0..$,1]+=6*a[0..2]")); 
    4646    mixin(vectorize("r-=another[0]")); 
    47    
     47 
    4848    // Parses OK, but I don't think I'll support this. 
    4949//    mixin(vectorize("a+=6*another[1,[1,$]]")); 
    5050 
    5151 
    52 // Parses, and rank checks OK. Doesn't simplify yet, no codegen
    53  //   mixin(vectorize("dot(q,q*dot(q,q))")); // BUG: should simplify to: dot(q.q) * dot(q,q) 
     52// Parses, and simplifies to A*A, where A = dot(q,q). No codegen yet
     53//    mixin(vectorize("dot(q,q*dot(q,q))")); 
    5454 
    5555    writefln("a=", a); 
  • trunk/blade/BladeRank.d

    r156 r157  
    5151 * The sub-expression must be  
    5252 *  - a single character (eg "X"), OR 
     53 *  - a lower-case function (eg "a(B,(C*D))"), OR 
    5354 *  - an expression in parenthesis, OR 
    5455 *  - an array literal  
     
    8788        return rank[expr[0]-'A']-'0'; 
    8889    } 
     90    if (expr[0]=='d') return 0; 
    8991    assert(expr[0]=='(', "BLADE ICE:" ~ expr); 
    9092    // strip off the parentheses 
  • trunk/blade/BladeSimplify.d

    r156 r157  
    3737    char [] expression; // the revised expression using new variable names 
    3838                        // (so, for example, B+=(D-F) becomes A+=(B-C) ). 
     39    Symbol[] symbolTable; // the original symbol table, with all the old names 
    3940    char [][] compounds; // the compound variables, defined using the old names 
    4041    char [] rank;   // rank of all symbols (including original & compounds) 
     
    5960    // Check for undefined symbols 
    6061    if (err.length > 0)  
    61         return RevisedExpression(tree.expression, [""], "","", "Undefined symbols:" ~ err); 
     62        return RevisedExpression(tree.expression, tree.symbolTable, [""], "","", "Undefined symbols:" ~ err); 
    6263    else { 
    6364        char [] expr2 = removeDuplicates(tree); 
     
    6566        int wholerank = exprRank(expr2, ranks); 
    6667        if (wholerank<0) 
    67             return RevisedExpression(expr2, [""], "","", getRankErrorText(wholerank)); 
    68       return simplifyVectorExpression(expr2, ranks); 
     68            return RevisedExpression(expr2, tree.symbolTable, [""], "","", getRankErrorText(wholerank)); 
     69      return simplifyVectorExpression(expr2, ranks, tree.symbolTable); 
    6970    } 
    7071} 
     
    179180    char [] leftMul = ""; 
    180181    char [] rightMul = ""; 
    181     if (leftrnk == 0 && rightrnk == 0) return expr; 
     182    if (leftrnk == 0 && rightrnk == 0) { return expr; } 
    182183    if (leftrnk == 0) leftMul = left; else leftMul = getCommonMultiplucation(left, rank); 
    183184    if (rightrnk== 0) rightMul = right; else rightMul = getCommonMultiplucation(right, rank); 
     
    228229} 
    229230 
     231 
    230232/** 
    231233 * Rewrite the expression, taking advantage of distributivity of [] and 
     
    247249            leftmul = getCommonMultiplucation(left, rank); 
    248250            rightmul = getCommonMultiplucation(right, rank); 
    249              
    250 //            assert(0, leftmul~"#" ~ rightmul); 
    251 //            if (hasScalarMultiply(left, rank)) { 
    252                 // pull the scalar mul out 
    253 //            } 
    254 //            if (hasScalarMultiply(right, rank)) { 
    255 //            } // ditto for right. 
    256251            char [] m = leftmul; 
    257252            if (rightmul.length>0) m = m==""? rightmul : "(" ~ m ~ "*" ~rightmul~")"; 
    258253            if (mulexpr.length>0) m = m=="" ? mulexpr : "(" ~ m ~ "*" ~mulexpr~")"; 
    259             if (m.length>1) m= "* {" ~ m ~ "}"; 
     254            if (m.length>1) m= "* {" ~ m ~ "} "; 
    260255            else if (m.length==1) m= "*" ~ m; 
    261256            assert(indexexpr.length==0, "BLADE ICE: rank mismatch in dot product"); 
    262             //return "#" ~ subexprSimplify("A,B","01", "",""); 
    263 //           assert(0, expr ~ "#" ~ left ~ "#" ~ right ~"#" ~ rank ~ "#");// ~ subexprSimplify(right, rank, mulexpr,"")~"#"); 
    264             return "d(" ~ simplifyWithoutMul(left, rank) ~ "," ~ 
    265                           simplifyWithoutMul(right, rank) ~ ")" ~ m; 
     257//           assert(0, expr ~ "#" ~ left ~ "#" ~ right ~"#" ~ leftmul ~ "#"~ rightmul);// ~ subexprSimplify(right, rank, mulexpr,"")~"#"); 
     258            return " {d(" ~ simplifyWithoutMul(left, rank) ~ "," ~ 
     259                          simplifyWithoutMul(right, rank) ~ ")} " ~ m; 
    266260    } 
    267261    // Deal with ++ and --. Only for scalars 
     
    340334} 
    341335 
    342 RevisedExpression simplifyVectorExpression(char [] expr, char [] rank
     336RevisedExpression simplifyVectorExpression(char [] expr, char [] rank, Symbol[] symTable=[]
    343337{ 
    344338    char [] s = exprSimplify(expr, rank, "", ""); 
    345     if (s.length>1) s = s[1..$-1]; // strip off () 
     339    if (s.length>1 && s[0]=='(') s = s[1..$-1]; // strip off () 
    346340    char [][] comp; 
    347341    char [] used=""; // which of the old variables are used; gives the new mapping 
     
    355349            int k; 
    356350            for (k=i+1; s[k]!=' '; ++k) {} 
    357             comp ~= s[i+2..k-1]; 
    358             if (s[k-2]==']') {                 
    359                 // it's a vector/matrix of some kind, with rank reduced 
    360                 // by indices. Can't just use exprRank, because the [] 
    361                 // aren't wrapped by (). 
    362                 r ~= (rank[s[i+2]-'A'] - indexRank(s[i+2..k-1])); 
    363             } else { 
    364                 // it's a scalar expression. Note that it could involve 
    365                 // a vector expression. 
    366                 r~='0';  
    367             } 
    368             e ~= next; 
    369             ++next; 
     351            char [] newexpr = s[i+2..k-1]; // strip off the {} 
     352            // Check for a duplicate 
     353            int z; 
     354            for (z=0; z<comp.length && comp[z]!=newexpr; ++z) {} 
     355            if (z==comp.length) { 
     356                e ~= next; 
     357                ++next; 
     358                comp ~= s[i+2..k-1]; // strip off the {} 
     359                if (s[k-2]==']') {                 
     360                    // it's a vector/matrix of some kind, with rank reduced 
     361                    // by indices. Can't just use exprRank, because the [] 
     362                    // aren't wrapped by (). 
     363                    r ~= (rank[s[i+2]-'A'] - indexRank(s[i+2..k-1])); 
     364                } else { 
     365                    // it's a scalar expression. Note that it could involve 
     366                    // a vector expression. 
     367                    r~='0';  
     368                }                 
     369            } else e ~= cast(char)('A'+z+rank.length); 
    370370            i = k; 
    371371        } else { 
     
    399399        } else f ~= c; 
    400400    } 
    401     return RevisedExpression(f, comp, old_ranks~r, mapping); 
     401    return RevisedExpression(f, symTable, comp, old_ranks~r, mapping); 
    402402} 
    403403 
    404404unittest { 
    405 //    assert(0, exprSimplify("d(A,A*d(A,A))", "1", "","")); // == "d(A,A)*d(A,A)"); 
     405    assert(exprSimplify("d(A,(A*d(A,A)))", "1", "","")== " {d(A,A)} * {d(A,A)} "); 
    406406    assert(exprSimplify("A+=(B*(C[D,D..$]))","1020","","")=="(A+=(B* {C[D,D..$]} ))"); 
    407407    assert(exprSimplify("A+=(((D[E])*B)[E])", "103300","","")=="(A+=(B* {D[E][E]} ))"); 
     
    411411    assert(exprSimplify("A=(((B*E)+(C*E))*D)", "11100","","")=="(A=(( {(D*E)} *B)+( {(D*E)} *C)))"); 
    412412    assert(exprSimplify("A=(D*((B*E)+(C*E)))", "11100","","")=="(A=(( {(D*E)} *B)+( {(D*E)} *C)))"); 
    413     assert(exprSimplify("d((A*(B*C)),(B*A))","010","","")== "d(B,B)* {((A*C)*A)}"); 
     413    assert(exprSimplify("d((A*(B*C)),(B*A))","010","","")== " {d(B,B)} * {((A*C)*A)} "); 
    414414 
    415415    RevisedExpression e = simplifyVectorExpression("A+=(((D[B])*C)[B])", "2004"); 
  • trunk/blade/SyntaxTree.d

    r151 r157  
    8282    char [] expression; /// syntax tree in Placeholder format, eg A+=(B*C) 
    8383    Symbol[] symbolTable; /// Textual form of the types and values of A,B,C,... 
    84 } 
    85  
    86 struct TemplateSyntaxTree(T...) { 
    87     AbstractSyntaxTree tree; 
    8884} 
    8985