Changeset 187 for trunk/blade/BladeSimplify.d
- Timestamp:
- 04/30/08 16:05:32 (5 months ago)
- Files:
-
- trunk/blade/BladeSimplify.d (modified) (25 diffs)
Legend:
- Unmodified
- Added
- Removed
- Modified
- Copied
- Moved
trunk/blade/BladeSimplify.d
r175 r187 14 14 * be moved to every vector inside A. 15 15 * - Use associativity of *: A*(B*C[]) == (A*B)*C[] (Not strictly true for 16 * floating point; results may differ by 1ulp, 16 * floating point; results may differ by 1ulp, 17 17 * eg (1.3L*3.1L)*4.7L < 1.3L*(3.1L*4.7L) 18 18 * Note that floating point addition is not associative at all). 19 19 * - Remove unary minus where possible, eg A-(-B) => A+B, abs(-A) => abs(A). 20 * - Use associativity of * in intrinsics: 20 * - Use associativity of * in intrinsics: 21 21 * sum(A*V) => A*sum(V), abs(A*B) => abs(A)*abs(B) 22 * (D) Expression standardisation 22 * (D) Expression standardisation 23 23 * - Move multiplies to left: Convert A[]*B into B*A[] (assumes * is commutative, 24 24 * not valid for quaternions). … … 54 54 { 55 55 return str=="dot" || str=="sum" || str=="max" || str=="min" 56 || str=="abs" || str=="sqrt" ;56 || str=="abs" || str=="sqrt" || str=="prod"; 57 57 } 58 58 … … 68 68 } 69 69 // Check for undefined symbols 70 if (err.length > 0) 70 if (err.length > 0) 71 71 return RevisedExpression(tree.expression, "", tree.symbolTable, [""], "","", "Undefined symbols:" ~ err); 72 72 else { … … 119 119 } else e~=c; 120 120 } 121 return e; 121 return e; 122 122 } 123 123 … … 171 171 } 172 172 --k; 173 char [] newexpr = expr[i+1..k]; // strip off the {} 173 char [] newexpr = expr[i+1..k]; // strip off the {} 174 174 int newi = k; 175 175 if (i>0 && k<expr.length-1 && expr[i-1]=='(' && expr[k+1]==')') { … … 184 184 ++next; 185 185 comp ~= expr[i+1..k]; // strip off the {} 186 if (expr[k-1]==']') { 186 if (expr[k-1]==']') { 187 187 // it's a vector/matrix of some kind, with rank reduced 188 188 // by indices. Can't just use exprRank, because the [] … … 192 192 // it's a scalar expression. Note that it could involve 193 193 // a vector expression. 194 r~='0'; 195 } 194 r~='0'; 195 } 196 196 } else e ~= cast(char)('A'+z+rank.length); 197 197 i = newi; … … 202 202 } 203 203 // Create a mapping from old to new variable names 204 204 205 205 char [] old_ranks = ""; 206 206 char [] mapping=""; … … 235 235 } 236 236 237 unittest { 237 unittest { 238 238 RevisedExpression e = simplifyVectorExpression("A+=(((D[B])*C)[B])", "2004",[]); 239 239 assert(e.rank=="202"); … … 281 281 assert(sym!="$" && this_.rank[sym[0]-'A']>'0', "Rank error " ~ sym); 282 282 // Note: Later, we'll want this to be a new terminal. 283 return sym ~ createMultiSlice(this_.slicing); 283 return sym ~ createMultiSlice(this_.slicing); 284 284 } 285 285 } … … 301 301 return wrapInParens(doVisit(this_, expr)) ~ op; 302 302 } 303 // Includes multi-dimensional slicing and indexing. 303 // Includes multi-dimensional slicing and indexing. 304 304 ReturnType onVisitIndex(This this_, char [] base, char [][2][] slices) { 305 305 if (slices.length==0) { // [] -- has no effect. … … 311 311 // with the earliest existing dimension. 312 312 // * If the existing dimension is an index, 313 // it might contain a dollar, which we need to replace. 313 // it might contain a dollar, which we need to replace. 314 314 // * If the existing dimension is a slice, the two slices will combine. 315 315 // … … 331 331 newslice ~= [a ~ "+" ~ c, ""]; 332 332 } 333 if (slices.length>1) { 333 if (slices.length>1) { 334 334 // append other slices, if any. 335 335 return doVisit(IndexFoldingVisitor(this_.rank, "$", slices[0..$-1] ~ newslice ~ this_.slicing[1..$]), base); … … 360 360 assert(lrank>0 && rrank>0 && lrank<=2 && rrank<=2, "BLADE ICE: Tensor*tensor is unsupported"); 361 361 bool isDotProduct = false; // was it reduced to a dot product? 362 362 363 363 // In the case of chained matrix multiplies, we can end up with an empty slice. 364 364 if (this_.slicing.length>0 && this_.slicing[$-1][0]=="") { … … 368 368 // First dimension applies to rows of the left operand 369 369 // If it's a slice, it will be a strided slice -- unless 370 // it comes from another matrix multiply, in which case the 370 // it comes from another matrix multiply, in which case the 371 371 // stride will drop out. (A[x]*B is strided). 372 372 char [][2][] newslice=[]; … … 390 390 second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right)); 391 391 } 392 } else if (lrank==1 && rrank==2) { 392 } else if (lrank==1 && rrank==2) { 393 393 // vector * matrix, Matrix uses all the slicing 394 394 second = wrapInParens(doVisit(this_, right)); … … 403 403 } 404 404 } 405 } else { 406 // in DMD1.024, nasty compiler bug if you save the first & second results into a local variable 405 } else { // not a multiplication 407 406 return wrapInParens(doVisit(this_, left)) ~ op ~ wrapInParens(doVisit(this_, right)); 408 407 } … … 416 415 } 417 416 418 unittest { 417 unittest { 419 418 assert(foldIndices("((A[C..D])+B)[($-E)]", "21000")=="(A[C+((D-C)-E)])+(B[($-E)])"); 420 419 assert(foldIndices("(A[C])[D]", "3100")=="A[C,D]"); … … 427 426 assert(foldIndices("A[,B..$,C]", "300")=="A[,B..$,C]"); 428 427 // Multidimensional slicing 429 assert(foldIndices("(C*((A*B)[C]))[D]", "2200")=="C*dot((A[C,]),(B[D]))"); 428 assert(foldIndices("(C*((A*B)[C]))[D]", "2200")=="C*dot((A[C,]),(B[D]))"); 430 429 assert(foldIndices("(A*B)[C..D,D]", "2200")=="(A[C..D,])*(B[D])"); 431 430 assert(foldIndices("(A*B)[C..D]", "2200")=="(A[C..D,])*B"); … … 433 432 assert(foldIndices("(A*B)[C..D]", "1200")=="A*(B[C..D])"); 434 433 assert(foldIndices("(A*B)[C]", "120")=="dot(A,(B[C]))"); 435 434 436 435 assert(foldIndices("((A*B)*C)[D]", "2220")=="((A[D,])*B)*C"); 437 436 assert(foldIndices("((A+B)*C)[D]", "2220")=="((A[D,])+(B[D,]))*C"); 438 437 assert(foldIndices("((D*A)*B)[C]", "2100")=="dot((D*(A[C,])),B)"); 439 assert(foldIndices("(((A*B)*C)[D..E])[D]", "12200")=="dot((A*B),(C[D+D]))"); 438 assert(foldIndices("(((A*B)*C)[D..E])[D]", "12200")=="dot((A*B),(C[D+D]))"); 440 439 assert(foldIndices("A+=(((D[B])*C)[B])", "2004")=="A+=((D[B,B])*C)"); 441 440 assert(foldIndices("dot(A,(A*dot(A,A)))","1")=="dot(A,(A*dot(A,A)))"); … … 466 465 ScalarFold right = doVisit(this_, args[1]); 467 466 return ScalarFold("", combineMul(combineMul(left.multiplier, right.multiplier), "{" ~ func ~ "(" ~ wrapInParens(left.expr) ~ "," ~ wrapInParens(right.expr) ~ ")}")); 468 case "sum": 467 case "sum": 468 case "prod": 469 469 // sum(A*V) = A*sum(V) is a scalar. 470 // prod(A*V) = A*prod(V) is a scalar. 470 471 return ScalarFold("", combineMul(left.multiplier, "{" ~ func ~ "(" ~ wrapInParens(left.expr) ~ ")}")); 471 472 case "abs": … … 483 484 case "max": 484 485 case "min": // max(A*B) can't be simplified unless we know that they are not negative. 485 return ScalarFold("", "{" ~ func ~ "(" ~ combineMulWithCompound(left.expr, left.multiplier) ~ ")}"); 486 // return ScalarFold("", "{" ~ func ~ "(@>" ~ left.expr ~ "@" ~ left.multiplier ~ "<@)}"); 486 return ScalarFold("", "{" ~ func ~ "(" ~ combineMulWithCompound(left.expr, left.multiplier) ~ ")}"); 487 // return ScalarFold("", "{" ~ func ~ "(@>" ~ left.expr ~ "@" ~ left.multiplier ~ "<@)}"); 487 488 default: 488 489 assert(0, "BLADE: Unsupported function"); 489 490 return ScalarFold("",""); 490 491 } 491 } 492 } 492 493 ReturnType onVisitPrefix(This this_, char [] op, char [] expr) { 493 494 if (op=="-") { … … 498 499 else return ScalarFold(left.expr, "-"); 499 500 } else if (op=="+") { // just ignore unary plus 500 return doVisit(this_, expr); 501 return doVisit(this_, expr); 501 502 } else { 502 503 ScalarFold f = doVisit(this_, expr); … … 530 531 assert(first.multiplier=="" && second.expr=="", "BLADE ICE"); 531 532 assert(second.multiplier!="-", "BLADE ICE"); // this would be a*=-b, where b is a vector 532 if (second.multiplier.length>1) return ScalarFold(wrapInParens(first.expr) ~ op ~ "{" ~ wrapInParens(second.multiplier) ~ "}",""); 533 else return ScalarFold(wrapInParens(first.expr) ~ op ~ wrapInParens(second.multiplier),""); 533 if (second.multiplier.length>1) return ScalarFold(wrapInParens(first.expr) ~ op ~ "{" ~ wrapInParens(second.multiplier) ~ "}",""); 534 else return ScalarFold(wrapInParens(first.expr) ~ op ~ wrapInParens(second.multiplier),""); 534 535 } 535 536 if (op=="*") { … … 588 589 assert(left.length>0); 589 590 if (right.length==0) return left; 590 if (right=="-") return "-" ~ wrapInParens(left); 591 if (right=="-") return "-" ~ wrapInParens(left); 591 592 if (right.length==1) return wrapInParens(left) ~ "*" ~ right; 592 return wrapInParens(left) ~ "*{" ~ wrapInParens(right) ~ "}"; 593 return wrapInParens(left) ~ "*{" ~ wrapInParens(right) ~ "}"; 593 594 } 594 595
