Changeset 161

Show
Ignore:
Timestamp:
12/14/07 03:25:52 (9 months ago)
Author:
Don Clugston
Message:

The index folding visitor is now part of the main line. A preliminary scalar folding visitor is implemented, not yet in the main line.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/blade/BladeDemo.d

    r159 r161  
    3434    real k=3.4; 
    3535 
    36 //    mixin(vectorize(` a += (d[2..$-1]*2.01*a[2]-another[][1])["abc".length-3..$]`)); 
    37 /+     
    38     mixin(vectorize(" a-= 2.01*(        3.04+k)*r")); 
     36    mixin(vectorize(` a += (d[2..$-1]*2.01*a[2]-another[][1])["abc".length-3..$]`)); 
     37    mixin(vectorize(" a-= 2.01*(        3.04+k)*r"));     
     38     
    3939    mixin(vectorize("q+= q*2.01")); 
     40    
    4041    // All of the next four are equivalent 
    4142    mixin(vectorize("a+=6*another[1,0..$]")); 
     
    4647    mixin(vectorize("another[0..$,1]+=6*a[0..2]")); 
    4748    mixin(vectorize("r-=another[0]")); 
    48 +/ 
    49     // Parses OK, but I don't think I'll support this
    50 //    mixin(vectorize("a+=6*another[1,[1,$]]")); 
     49 
     50    // I don't think I'll support this syntax long-term
     51    mixin(vectorize("a+=6*another[1,[0,$]]")); 
    5152 
    5253// Parses, and simplifies to A*A, where A = dot(q,q). No codegen yet. 
    53 //    mixin(vectorize("dot(q,q*dot(q,q))")); 
     54//   mixin(vectorize("dot(q,q*dot(q,q))")); 
    5455 
    5556    writefln("a=", a); 
  • trunk/blade/BladeRank.d

    r159 r161  
    164164        int rrank = doVisit(this_, right); 
    165165        if (rrank<0) return rrank; // propagate errors 
     166        if (lrank<0) return lrank; // propagate errors 
    166167        if (op=="+" || op=="-" || op=="=" || op=="+=" || op=="-=") { 
    167168            if (lrank!=rrank) { 
     
    179180            else return RankError.RankMismatchConcatenation; 
    180181        } 
    181         // For *, /, only scalar operations are permitted 
     182        // For / and /=, only scalar operations are permitted 
    182183        if ((op=="*=" || op=="/=") && rrank==0) return lrank; 
     184        if (op=="*=" && lrank==2 && rrank==2) return lrank; // mat *= mat 
     185        if (op=="*=" && lrank==1 && rrank==2) return lrank; // vec *= mat 
    183186        if (op=="*" || op=="/") { 
    184187            if (lrank==0) return rrank; 
    185188            if (rrank==0) return lrank; 
     189            if (lrank==2 && rrank==2) return lrank; 
     190            if (lrank==2 && rrank==2) return lrank; 
     191            if (lrank+rrank==3) return 1; // vec*mat or mat*vec 
    186192        } 
    187193        // All other operations are only valid for scalars. 
    188194        if (lrank==0 && rrank==0) return 0; 
    189195        return RankError.UnsupportedOperation; 
    190  
    191196    } 
    192197} 
  • trunk/blade/BladeSimplify.d

    r160 r161  
    114114} 
    115115 
     116// DEPRECATED 
    116117/// As for exprSimplify, but allows the whole thing to be wrapped in parentheses. 
    117118char [] subexprSimplify(char [] expr, char [] rank, char [] mulexpr, char [] indexexpr) 
     
    135136} 
    136137 
     138// DEPRECATED 
    137139// Simplify a scalar*tensor expression. 
    138140char [] simplifyScalarMul(char [] scalar, char [] tensor, char [] mulexpr, char [] rank, char [] indexexpr) 
     
    150152} 
    151153 
     154// DEPRECATED 
    152155char [] getCommonMultiplucation(char [] expr, char [] rank) 
    153156{ 
     
    187190} 
    188191 
    189  
     192// DEPRECATED 
    190193// Simplify the expression, assuming global scalar multiply has already been removed. 
    191194char [] simplifyWithoutMul(char [] rawExpr, char [] rank) 
     
    229232} 
    230233 
    231  
     234// DEPRECATED 
    232235/** 
    233236 * Rewrite the expression, taking advantage of distributivity of [] and 
     
    255258            else if (m.length==1) m= "*" ~ m; 
    256259            assert(indexexpr.length==0, "BLADE ICE: rank mismatch in dot product"); 
    257 //           assert(0, expr ~ "#" ~ left ~ "#" ~ right ~"#" ~ leftmul ~ "#"~ rightmul);// ~ subexprSimplify(right, rank, mulexpr,"")~"#"); 
    258260            return " {d(" ~ simplifyWithoutMul(left, rank) ~ "," ~ 
    259261                          simplifyWithoutMul(right, rank) ~ ")} " ~ m; 
     
    310312} 
    311313 
    312 // Allows [a,[b,c],d..e,f] syntax for indices, where [b,c] is a range. 
     314// Determine rank of a multidimensional index 
    313315int indexRank(char [] s) 
    314316{ 
     
    319321        if (s[i]=='(') ++paren; 
    320322        else if (s[i]==')') --paren; 
    321         if (paren==0 && s[i]==']') { numbrack--; if (s[i-1]=='[') --r;
     323        if (paren==0 && s[i]==']') { numbrack--;
    322324        if (paren==0 && s[i]=='[') { 
    323325            if (numbrack==0) ++r; 
     
    325327        } 
    326328        if (paren==0 && numbrack==1 && s[i]==',') ++r; // commas increase the rank 
    327         if (paren==0 && numbrack==2 && s[i]==',') --r; // slices commas don't increase rank. 
    328329        if (paren==0 && numbrack==1 && s[i]=='.' && s[i-1]=='.') { 
    329330            // if it's a slice, it does not increase the rank 
     
    336337RevisedExpression simplifyVectorExpression(char [] expr, char [] rank, Symbol[] symTable=[]) 
    337338{ 
    338 //    char [] s = exprSimplify(foldIndices(expr, rank), rank, "", ""); 
    339     char [] s = exprSimplify(expr, rank, "", ""); 
     339    char [] s = exprSimplify(foldIndices(expr, rank), rank, "", ""); 
     340//    char [] s = foldScalars(foldIndices(expr, rank), rank); 
     341//    char [] s = exprSimplify(expr, rank, "", ""); 
    340342    if (s.length>1 && s[0]=='(') s = s[1..$-1]; // strip off () 
    341343    char [][] comp; 
     
    417419    assert(e.expression == "A+=(B*C)"); 
    418420    assert(e.rank=="202"); 
    419     assert(e.compounds[0]=="D[B][B]"); 
     421    assert(e.compounds[0]=="D[B,B]"); 
    420422    assert(e.mapping=="ACE"); 
    421423} 
     
    523525        int lrank = subexprRank(left, this_.rank); 
    524526        int rrank = subexprRank(right, this_.rank); 
    525         char [] first
    526         char [] second
     527        char [] first=""
     528        char [] second=""
    527529        if ((op=="*" || op=="*=") && this_.slicing.length>0) { 
    528530            // If one of these is a matrix, the slicing gets interesting... 
    529531            // .. extremely so for slicing of matrix chain multiplication. 
    530             // Given U a row vector, V a column vector; A,B,C matrices; x, y scalars: 
    531             //  (U*V)[x][y]  = U[x]*V[y] 
    532             //  (U*V)[x] =  
    533             //  (U*A)[x]  = U*A[,x] 
    534             //  (A*U)[x]  = A[x]*U 
    535             //  (A*B)[x] = A[x]*B. 
    536             //  (A*B)[x..y] = A[x..y]*B[,x..y] 
    537             //  (A*B)[x][y] = A[x]*B[,y] 
    538             //  (A*B*C)[x][y] 
    539532            if (lrank==0) { 
    540533                // All dimensions apply to right operand 
     
    545538                first = wrapInParens(doVisit(this_, left)); 
    546539                second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right)); 
    547             } else {                
    548                 assert(lrank<=2 && rrank<=2, "BLADE ICE: Tensor*tensor is unsupported"); 
     540            } else { 
     541                assert(lrank>0 && rrank>0 && lrank<=2 && rrank<=2, "BLADE ICE: Tensor*tensor is unsupported"); 
    549542                bool isDotProduct = false; // was it reduced to a dot product? 
     543                 
     544                // In the case of chained matrix multiplies, we can end up with an empty slice. 
     545                if (this_.slicing.length>0 && this_.slicing[$-1][0]=="") { 
     546                    this_.slicing=this_.slicing[0..$-1]; 
     547                } 
    550548                if (lrank==2) { 
    551549                    // First dimension applies to rows of the left operand 
    552                     // If it's a slice, it will be a strided slice. 
     550                    // If it's a slice, it will be a strided slice -- unless 
     551                    // it comes from another matrix multiply, in which case the                     
     552                    // stride will drop out. (A[x]*B is strided). 
    553553                    char [][2][] newslice=[]; 
    554554                    newslice ~= this_.slicing[0]; 
     
    579579                    second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right)); 
    580580                    if (this_.slicing[0][1].length==0)  isDotProduct = true; 
    581                 } 
     581                } else assert(0, "BLADE ICE"); 
    582582                if (isDotProduct) { 
    583583                  return "d(" ~ first ~ "," ~ second ~ ")"; 
     
    585585            } 
    586586        } else { 
    587             first = wrapInParens(doVisit(this_, left)); 
    588             second = wrapInParens(doVisit(this_, right)); 
    589         }         
     587            // in DMD1.024, nasty compiler bug if you save the first & second results into a local variable 
     588            return wrapInParens(doVisit(this_, left)) ~ op ~ wrapInParens(doVisit(this_, right)); 
     589        } 
    590590        return first ~ op ~ second; 
    591591    } 
     
    597597} 
    598598 
    599 unittest { 
    600     assert(foldIndices("A+=(((D[B])*C)[B])", "2004")=="A+=((D[B,B])*C)"); 
     599unittest {    
    601600    assert(foldIndices("((A[C..D])+B)[($-E)]", "21000")=="(A[C+((D-C)-E)])+(B[($-E)])"); 
    602601    assert(foldIndices("(A[C])[D]", "3100")=="A[C,D]"); 
     
    615614    assert(foldIndices("(A*B)[C..D]", "1200")=="A*(B[C..D])"); 
    616615    assert(foldIndices("(A*B)[C]", "120")=="d(A,(B[C]))"); 
     616     
     617    assert(foldIndices("((A*B)*C)[D]", "2220")=="((A[D,])*B)*C"); 
     618    assert(foldIndices("((A+B)*C)[D]", "2220")=="((A[D,])+(B[D,]))*C"); 
    617619    assert(foldIndices("((D*A)*B)[C]", "2100")=="d((D*(A[C,])),B)"); 
    618 
     620    assert(foldIndices("(((A*B)*C)[D..E])[D]", "12200")=="d((A*B),(C[D+D]))");  
     621    assert(foldIndices("A+=(((D[B])*C)[B])", "2004")=="A+=((D[B,B])*C)"); 
     622    assert(foldIndices("d(A,(A*d(A,A)))","1")=="d(A,(A*(d(A,A))))"); 
     623
     624 
     625struct ScalarFold 
     626
     627    char [] expr; 
     628    char [] multiplier; // scalar multiply of the entire expression 
     629
     630 
     631// Fold all scalars together, extracting common multiplies. 
     632struct ScalarFoldingVisitor { 
     633    alias typeof(*this) This; 
     634    alias ScalarFold ReturnType; 
     635    char [] rank; 
     636static: 
     637    ReturnType onVisitSymbol(This this_, char[] sym) { 
     638        if (sym=="$" || this_.rank[sym[0]-'A']=='0') return ScalarFold("",sym); 
     639        else return ScalarFold(sym, ""); 
     640    } 
     641    ReturnType onVisitFunction(This this_, char [] func, char [][] args) {         
     642        if (func=="d") { // dot product. 
     643            ScalarFold left = doVisit(this_,args[0]); 
     644            ScalarFold right = doVisit(this_, args[1]); 
     645            char [] s = left.multiplier; 
     646            if (s.length>0 && right.multiplier.length>0) s~= "*"  ~ right.multiplier; 
     647            if (s.length>0) s ~="*"; 
     648            return ScalarFold("", combineMul(combineMul(left.multiplier, right.multiplier), func ~ "(" ~ left.expr ~ "," ~ right.expr ~ ")")); 
     649        } else { 
     650            assert(0, "BLADE: Unsupported function"); 
     651            return ScalarFold("",""); 
     652        } 
     653    }     
     654    ReturnType onVisitPrefix(This this_, char [] op, char [] expr) { 
     655        if (op=="-") { 
     656            // Fold unary minus into a multiply, if possible. 
     657            ScalarFold left = doVisit(this_, expr); 
     658            if (left.multiplier!="") return ScalarFold(left.expr, "-" ~ wrapInParens(left.multiplier)); 
     659            else return ScalarFold("-" ~ wrapInParens(left.expr), left.multiplier); 
     660        } else if (op=="+") { // just ignore unary plus 
     661            return doVisit(this_, expr);         
     662        } else { 
     663            ScalarFold f = doVisit(this_, expr); 
     664            return ScalarFold(op ~ wrapInParens(combineMul(f.expr, f.multiplier)),""); 
     665        } 
     666    } 
     667    ReturnType onVisitPostfix(This this_, char [] op, char [] expr) { 
     668        ScalarFold f = doVisit(this_, expr); 
     669        return ScalarFold(wrapInParens(combineMul(f.expr, f.multiplier))~ op,""); 
     670    } 
     671    ReturnType onVisitIndex(This this_, char [] base, char [][2][] slices) { 
     672        // Base is always a single symbol. 
     673        ScalarFold left = doVisit(this_, base); 
     674        // BUG: This whole thing could be a scalar. 
     675        return ScalarFold(" {" ~ left.expr ~ createMultiSlice(slices)~ "} ", left.multiplier); 
     676    } 
     677    ReturnType onVisitBinaryOp(This this_, char [] op, char [] left, char [] right) { 
     678        ScalarFold first = doVisit(this_, left); 
     679        ScalarFold second = doVisit(this_, right); 
     680        if (op=="*=") { 
     681            assert(first.multiplier=="" && second.expr=="", "BLADE ICE"); 
     682            if (second.multiplier.length>1)  return ScalarFold(wrapInParens(first.expr) ~ op ~ " {" ~ wrapInParens(second.multiplier) ~ "} ","");  
     683            else return ScalarFold(wrapInParens(first.expr) ~ op ~ wrapInParens(second.multiplier),"");  
     684        } 
     685        if (op=="*") { 
     686            return ScalarFold(combineMul(first.expr,second.expr), combineMul(first.multiplier,second.multiplier)); 
     687        } 
     688        if (first.expr=="" && second.expr=="") { // both are 100% scalars -- it remains a scalar. 
     689            return ScalarFold("", 
     690            wrapInParens(combineMulWithCompound(first.expr, first.multiplier)) ~ op ~ 
     691            wrapInParens(combineMulWithCompound(second.expr, second.multiplier))); 
     692        } 
     693        return ScalarFold(wrapInParens(combineMulWithCompound(first.expr, first.multiplier)) ~ op ~ 
     694            wrapInParens(combineMulWithCompound(second.expr, second.multiplier)), ""); 
     695    } 
     696
     697 
     698char [] combineMul(char [] left, char [] right) 
     699
     700    if (left.length==0) return right; 
     701    if (right.length==0) return left; 
     702    return wrapInParens(left) ~ "*" ~ wrapInParens(right); 
     703
     704 
     705// 'right' should become a new compound expression 
     706char [] combineMulWithCompound(char [] left, char [] right) 
     707
     708    assert(left.length>0); 
     709    if (right.length==0) return left; 
     710    if (right.length==1) return wrapInParens(left) ~ "*" ~ right; 
     711    return wrapInParens(left) ~ "* {" ~ wrapInParens(right) ~ "} ";     
     712
     713 
     714char [] foldScalars(char [] expr, char [] ranks) 
     715
     716    ScalarFold f = beginVisit(ScalarFoldingVisitor(ranks), expr); 
     717    if (f.multiplier=="") return f.expr; 
     718    else if (f.expr=="") return " {" ~ f.multiplier ~ "} "; 
     719    else return " {" ~ f.multiplier ~ "} *" ~ wrapInParens(f.expr); 
     720
     721 
     722unittest { 
     723    assert(foldScalars("A*=(B*C)", "100")== "A*= {(B*C)} "); 
     724    assert(foldScalars("d(A,(A*d(A,A)))", "1")==" {(d(A,A))*(d(A,A))} "); 
     725
  • trunk/blade/BladeVisitor.d

    r159 r161  
    7171                ++z; 
    7272            } else { 
    73                 int w = exprLength(right[z..$]); 
    74                 if (z+w+3 < right.length && right[z+w+1..z+w+3]=="..") { 
    75                     int q = z+w+3; 
    76                     int t = exprLength(right[q..$]); 
    77                     allslices ~= [ right[z..z+w+1], right[q..q+t+1]]; 
    78                     z = q+t+2; // no comma to skip 
    79                 } else { 
    80                     allslices ~= [ right[z..z+w+1], ""]; 
    81                     z = z+w+2; // skip the comma, if any 
    82                  } 
     73                    int w = exprLength(right[z..$]); 
     74                    if (z+w+3 < right.length && right[z+w+1..z+w+3]=="..") { 
     75                        int q = z+w+3; 
     76                        int t = exprLength(right[q..$]); 
     77                        allslices ~= [ right[z..z+w+1], right[q..q+t+1]]; 
     78                        z = q+t+2; // no comma to skip 
     79                    } else { 
     80  
     81                     if (right.length>z+1 && right[z+1]=='[') { // fake slice -- support this for now 
     82                            int w2 = exprLength(right[z+2..$-2]); 
     83                            int q = z+w2+2; // skip the , 
     84                            int t = exprLength(right[q+2..$]); 
     85                            allslices ~= [ right[z+2..q+1], right[q+2..q+2+t+1]]; 
     86                            z = z+w+2; 
     87                    } else { 
     88                        allslices ~= [ right[z..z+w+1], ""]; 
     89                        z = z+w+2; // skip the comma, if any 
     90                    } 
     91                } 
    8392             } 
    8493        }