Changeset 161
- Timestamp:
- 12/14/07 03:25:52 (9 months ago)
- Files:
-
- trunk/blade/BladeDemo.d (modified) (2 diffs)
- trunk/blade/BladeRank.d (modified) (2 diffs)
- trunk/blade/BladeSimplify.d (modified) (17 diffs)
- trunk/blade/BladeVisitor.d (modified) (1 diff)
Legend:
- Unmodified
- Added
- Removed
- Modified
- Copied
- Moved
trunk/blade/BladeDemo.d
r159 r161 34 34 real k=3.4; 35 35 36 //mixin(vectorize(` a += (d[2..$-1]*2.01*a[2]-another[][1])["abc".length-3..$]`));37 /+38 mixin(vectorize(" a-= 2.01*( 3.04+k)*r"));36 mixin(vectorize(` a += (d[2..$-1]*2.01*a[2]-another[][1])["abc".length-3..$]`)); 37 mixin(vectorize(" a-= 2.01*( 3.04+k)*r")); 38 39 39 mixin(vectorize("q+= q*2.01")); 40 40 41 // All of the next four are equivalent 41 42 mixin(vectorize("a+=6*another[1,0..$]")); … … 46 47 mixin(vectorize("another[0..$,1]+=6*a[0..2]")); 47 48 mixin(vectorize("r-=another[0]")); 48 +/ 49 // Parses OK, but I don't think I'll support this.50 // mixin(vectorize("a+=6*another[1,[1,$]]"));49 50 // I don't think I'll support this syntax long-term. 51 mixin(vectorize("a+=6*another[1,[0,$]]")); 51 52 52 53 // Parses, and simplifies to A*A, where A = dot(q,q). No codegen yet. 53 // mixin(vectorize("dot(q,q*dot(q,q))"));54 // mixin(vectorize("dot(q,q*dot(q,q))")); 54 55 55 56 writefln("a=", a); trunk/blade/BladeRank.d
r159 r161 164 164 int rrank = doVisit(this_, right); 165 165 if (rrank<0) return rrank; // propagate errors 166 if (lrank<0) return lrank; // propagate errors 166 167 if (op=="+" || op=="-" || op=="=" || op=="+=" || op=="-=") { 167 168 if (lrank!=rrank) { … … 179 180 else return RankError.RankMismatchConcatenation; 180 181 } 181 // For *, /, only scalar operations are permitted182 // For / and /=, only scalar operations are permitted 182 183 if ((op=="*=" || op=="/=") && rrank==0) return lrank; 184 if (op=="*=" && lrank==2 && rrank==2) return lrank; // mat *= mat 185 if (op=="*=" && lrank==1 && rrank==2) return lrank; // vec *= mat 183 186 if (op=="*" || op=="/") { 184 187 if (lrank==0) return rrank; 185 188 if (rrank==0) return lrank; 189 if (lrank==2 && rrank==2) return lrank; 190 if (lrank==2 && rrank==2) return lrank; 191 if (lrank+rrank==3) return 1; // vec*mat or mat*vec 186 192 } 187 193 // All other operations are only valid for scalars. 188 194 if (lrank==0 && rrank==0) return 0; 189 195 return RankError.UnsupportedOperation; 190 191 196 } 192 197 } trunk/blade/BladeSimplify.d
r160 r161 114 114 } 115 115 116 // DEPRECATED 116 117 /// As for exprSimplify, but allows the whole thing to be wrapped in parentheses. 117 118 char [] subexprSimplify(char [] expr, char [] rank, char [] mulexpr, char [] indexexpr) … … 135 136 } 136 137 138 // DEPRECATED 137 139 // Simplify a scalar*tensor expression. 138 140 char [] simplifyScalarMul(char [] scalar, char [] tensor, char [] mulexpr, char [] rank, char [] indexexpr) … … 150 152 } 151 153 154 // DEPRECATED 152 155 char [] getCommonMultiplucation(char [] expr, char [] rank) 153 156 { … … 187 190 } 188 191 189 192 // DEPRECATED 190 193 // Simplify the expression, assuming global scalar multiply has already been removed. 191 194 char [] simplifyWithoutMul(char [] rawExpr, char [] rank) … … 229 232 } 230 233 231 234 // DEPRECATED 232 235 /** 233 236 * Rewrite the expression, taking advantage of distributivity of [] and … … 255 258 else if (m.length==1) m= "*" ~ m; 256 259 assert(indexexpr.length==0, "BLADE ICE: rank mismatch in dot product"); 257 // assert(0, expr ~ "#" ~ left ~ "#" ~ right ~"#" ~ leftmul ~ "#"~ rightmul);// ~ subexprSimplify(right, rank, mulexpr,"")~"#");258 260 return " {d(" ~ simplifyWithoutMul(left, rank) ~ "," ~ 259 261 simplifyWithoutMul(right, rank) ~ ")} " ~ m; … … 310 312 } 311 313 312 // Allows [a,[b,c],d..e,f] syntax for indices, where [b,c] is a range.314 // Determine rank of a multidimensional index 313 315 int indexRank(char [] s) 314 316 { … … 319 321 if (s[i]=='(') ++paren; 320 322 else if (s[i]==')') --paren; 321 if (paren==0 && s[i]==']') { numbrack--; if (s[i-1]=='[') --r;}323 if (paren==0 && s[i]==']') { numbrack--; } 322 324 if (paren==0 && s[i]=='[') { 323 325 if (numbrack==0) ++r; … … 325 327 } 326 328 if (paren==0 && numbrack==1 && s[i]==',') ++r; // commas increase the rank 327 if (paren==0 && numbrack==2 && s[i]==',') --r; // slices commas don't increase rank.328 329 if (paren==0 && numbrack==1 && s[i]=='.' && s[i-1]=='.') { 329 330 // if it's a slice, it does not increase the rank … … 336 337 RevisedExpression simplifyVectorExpression(char [] expr, char [] rank, Symbol[] symTable=[]) 337 338 { 338 // char [] s = exprSimplify(foldIndices(expr, rank), rank, "", ""); 339 char [] s = exprSimplify(expr, rank, "", ""); 339 char [] s = exprSimplify(foldIndices(expr, rank), rank, "", ""); 340 // char [] s = foldScalars(foldIndices(expr, rank), rank); 341 // char [] s = exprSimplify(expr, rank, "", ""); 340 342 if (s.length>1 && s[0]=='(') s = s[1..$-1]; // strip off () 341 343 char [][] comp; … … 417 419 assert(e.expression == "A+=(B*C)"); 418 420 assert(e.rank=="202"); 419 assert(e.compounds[0]=="D[B ][B]");421 assert(e.compounds[0]=="D[B,B]"); 420 422 assert(e.mapping=="ACE"); 421 423 } … … 523 525 int lrank = subexprRank(left, this_.rank); 524 526 int rrank = subexprRank(right, this_.rank); 525 char [] first ;526 char [] second ;527 char [] first=""; 528 char [] second=""; 527 529 if ((op=="*" || op=="*=") && this_.slicing.length>0) { 528 530 // If one of these is a matrix, the slicing gets interesting... 529 531 // .. extremely so for slicing of matrix chain multiplication. 530 // Given U a row vector, V a column vector; A,B,C matrices; x, y scalars:531 // (U*V)[x][y] = U[x]*V[y]532 // (U*V)[x] =533 // (U*A)[x] = U*A[,x]534 // (A*U)[x] = A[x]*U535 // (A*B)[x] = A[x]*B.536 // (A*B)[x..y] = A[x..y]*B[,x..y]537 // (A*B)[x][y] = A[x]*B[,y]538 // (A*B*C)[x][y]539 532 if (lrank==0) { 540 533 // All dimensions apply to right operand … … 545 538 first = wrapInParens(doVisit(this_, left)); 546 539 second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right)); 547 } else { 548 assert(lrank <=2 && rrank<=2, "BLADE ICE: Tensor*tensor is unsupported");540 } else { 541 assert(lrank>0 && rrank>0 && lrank<=2 && rrank<=2, "BLADE ICE: Tensor*tensor is unsupported"); 549 542 bool isDotProduct = false; // was it reduced to a dot product? 543 544 // In the case of chained matrix multiplies, we can end up with an empty slice. 545 if (this_.slicing.length>0 && this_.slicing[$-1][0]=="") { 546 this_.slicing=this_.slicing[0..$-1]; 547 } 550 548 if (lrank==2) { 551 549 // First dimension applies to rows of the left operand 552 // If it's a slice, it will be a strided slice. 550 // If it's a slice, it will be a strided slice -- unless 551 // it comes from another matrix multiply, in which case the 552 // stride will drop out. (A[x]*B is strided). 553 553 char [][2][] newslice=[]; 554 554 newslice ~= this_.slicing[0]; … … 579 579 second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right)); 580 580 if (this_.slicing[0][1].length==0) isDotProduct = true; 581 } 581 } else assert(0, "BLADE ICE"); 582 582 if (isDotProduct) { 583 583 return "d(" ~ first ~ "," ~ second ~ ")"; … … 585 585 } 586 586 } else { 587 first = wrapInParens(doVisit(this_, left));588 second =wrapInParens(doVisit(this_, right));589 } 587 // in DMD1.024, nasty compiler bug if you save the first & second results into a local variable 588 return wrapInParens(doVisit(this_, left)) ~ op ~ wrapInParens(doVisit(this_, right)); 589 } 590 590 return first ~ op ~ second; 591 591 } … … 597 597 } 598 598 599 unittest { 600 assert(foldIndices("A+=(((D[B])*C)[B])", "2004")=="A+=((D[B,B])*C)"); 599 unittest { 601 600 assert(foldIndices("((A[C..D])+B)[($-E)]", "21000")=="(A[C+((D-C)-E)])+(B[($-E)])"); 602 601 assert(foldIndices("(A[C])[D]", "3100")=="A[C,D]"); … … 615 614 assert(foldIndices("(A*B)[C..D]", "1200")=="A*(B[C..D])"); 616 615 assert(foldIndices("(A*B)[C]", "120")=="d(A,(B[C]))"); 616 617 assert(foldIndices("((A*B)*C)[D]", "2220")=="((A[D,])*B)*C"); 618 assert(foldIndices("((A+B)*C)[D]", "2220")=="((A[D,])+(B[D,]))*C"); 617 619 assert(foldIndices("((D*A)*B)[C]", "2100")=="d((D*(A[C,])),B)"); 618 } 620 assert(foldIndices("(((A*B)*C)[D..E])[D]", "12200")=="d((A*B),(C[D+D]))"); 621 assert(foldIndices("A+=(((D[B])*C)[B])", "2004")=="A+=((D[B,B])*C)"); 622 assert(foldIndices("d(A,(A*d(A,A)))","1")=="d(A,(A*(d(A,A))))"); 623 } 624 625 struct ScalarFold 626 { 627 char [] expr; 628 char [] multiplier; // scalar multiply of the entire expression 629 } 630 631 // Fold all scalars together, extracting common multiplies. 632 struct ScalarFoldingVisitor { 633 alias typeof(*this) This; 634 alias ScalarFold ReturnType; 635 char [] rank; 636 static: 637 ReturnType onVisitSymbol(This this_, char[] sym) { 638 if (sym=="$" || this_.rank[sym[0]-'A']=='0') return ScalarFold("",sym); 639 else return ScalarFold(sym, ""); 640 } 641 ReturnType onVisitFunction(This this_, char [] func, char [][] args) { 642 if (func=="d") { // dot product. 643 ScalarFold left = doVisit(this_,args[0]); 644 ScalarFold right = doVisit(this_, args[1]); 645 char [] s = left.multiplier; 646 if (s.length>0 && right.multiplier.length>0) s~= "*" ~ right.multiplier; 647 if (s.length>0) s ~="*"; 648 return ScalarFold("", combineMul(combineMul(left.multiplier, right.multiplier), func ~ "(" ~ left.expr ~ "," ~ right.expr ~ ")")); 649 } else { 650 assert(0, "BLADE: Unsupported function"); 651 return ScalarFold("",""); 652 } 653 } 654 ReturnType onVisitPrefix(This this_, char [] op, char [] expr) { 655 if (op=="-") { 656 // Fold unary minus into a multiply, if possible. 657 ScalarFold left = doVisit(this_, expr); 658 if (left.multiplier!="") return ScalarFold(left.expr, "-" ~ wrapInParens(left.multiplier)); 659 else return ScalarFold("-" ~ wrapInParens(left.expr), left.multiplier); 660 } else if (op=="+") { // just ignore unary plus 661 return doVisit(this_, expr); 662 } else { 663 ScalarFold f = doVisit(this_, expr); 664 return ScalarFold(op ~ wrapInParens(combineMul(f.expr, f.multiplier)),""); 665 } 666 } 667 ReturnType onVisitPostfix(This this_, char [] op, char [] expr) { 668 ScalarFold f = doVisit(this_, expr); 669 return ScalarFold(wrapInParens(combineMul(f.expr, f.multiplier))~ op,""); 670 } 671 ReturnType onVisitIndex(This this_, char [] base, char [][2][] slices) { 672 // Base is always a single symbol. 673 ScalarFold left = doVisit(this_, base); 674 // BUG: This whole thing could be a scalar. 675 return ScalarFold(" {" ~ left.expr ~ createMultiSlice(slices)~ "} ", left.multiplier); 676 } 677 ReturnType onVisitBinaryOp(This this_, char [] op, char [] left, char [] right) { 678 ScalarFold first = doVisit(this_, left); 679 ScalarFold second = doVisit(this_, right); 680 if (op=="*=") { 681 assert(first.multiplier=="" && second.expr=="", "BLADE ICE"); 682 if (second.multiplier.length>1) return ScalarFold(wrapInParens(first.expr) ~ op ~ " {" ~ wrapInParens(second.multiplier) ~ "} ",""); 683 else return ScalarFold(wrapInParens(first.expr) ~ op ~ wrapInParens(second.multiplier),""); 684 } 685 if (op=="*") { 686 return ScalarFold(combineMul(first.expr,second.expr), combineMul(first.multiplier,second.multiplier)); 687 } 688 if (first.expr=="" && second.expr=="") { // both are 100% scalars -- it remains a scalar. 689 return ScalarFold("", 690 wrapInParens(combineMulWithCompound(first.expr, first.multiplier)) ~ op ~ 691 wrapInParens(combineMulWithCompound(second.expr, second.multiplier))); 692 } 693 return ScalarFold(wrapInParens(combineMulWithCompound(first.expr, first.multiplier)) ~ op ~ 694 wrapInParens(combineMulWithCompound(second.expr, second.multiplier)), ""); 695 } 696 } 697 698 char [] combineMul(char [] left, char [] right) 699 { 700 if (left.length==0) return right; 701 if (right.length==0) return left; 702 return wrapInParens(left) ~ "*" ~ wrapInParens(right); 703 } 704 705 // 'right' should become a new compound expression 706 char [] combineMulWithCompound(char [] left, char [] right) 707 { 708 assert(left.length>0); 709 if (right.length==0) return left; 710 if (right.length==1) return wrapInParens(left) ~ "*" ~ right; 711 return wrapInParens(left) ~ "* {" ~ wrapInParens(right) ~ "} "; 712 } 713 714 char [] foldScalars(char [] expr, char [] ranks) 715 { 716 ScalarFold f = beginVisit(ScalarFoldingVisitor(ranks), expr); 717 if (f.multiplier=="") return f.expr; 718 else if (f.expr=="") return " {" ~ f.multiplier ~ "} "; 719 else return " {" ~ f.multiplier ~ "} *" ~ wrapInParens(f.expr); 720 } 721 722 unittest { 723 assert(foldScalars("A*=(B*C)", "100")== "A*= {(B*C)} "); 724 assert(foldScalars("d(A,(A*d(A,A)))", "1")==" {(d(A,A))*(d(A,A))} "); 725 } trunk/blade/BladeVisitor.d
r159 r161 71 71 ++z; 72 72 } else { 73 int w = exprLength(right[z..$]); 74 if (z+w+3 < right.length && right[z+w+1..z+w+3]=="..") { 75 int q = z+w+3; 76 int t = exprLength(right[q..$]); 77 allslices ~= [ right[z..z+w+1], right[q..q+t+1]]; 78 z = q+t+2; // no comma to skip 79 } else { 80 allslices ~= [ right[z..z+w+1], ""]; 81 z = z+w+2; // skip the comma, if any 82 } 73 int w = exprLength(right[z..$]); 74 if (z+w+3 < right.length && right[z+w+1..z+w+3]=="..") { 75 int q = z+w+3; 76 int t = exprLength(right[q..$]); 77 allslices ~= [ right[z..z+w+1], right[q..q+t+1]]; 78 z = q+t+2; // no comma to skip 79 } else { 80 81 if (right.length>z+1 && right[z+1]=='[') { // fake slice -- support this for now 82 int w2 = exprLength(right[z+2..$-2]); 83 int q = z+w2+2; // skip the , 84 int t = exprLength(right[q+2..$]); 85 allslices ~= [ right[z+2..q+1], right[q+2..q+2+t+1]]; 86 z = z+w+2; 87 } else { 88 allslices ~= [ right[z..z+w+1], ""]; 89 z = z+w+2; // skip the comma, if any 90 } 91 } 83 92 } 84 93 }
