Changeset 157
- Timestamp:
- 12/06/07 12:47:16 (9 months ago)
- Files:
-
- trunk/blade/Blade.d (modified) (18 diffs)
- trunk/blade/BladeDemo.d (modified) (2 diffs)
- trunk/blade/BladeRank.d (modified) (2 diffs)
- trunk/blade/BladeSimplify.d (modified) (10 diffs)
- trunk/blade/SyntaxTree.d (modified) (1 diff)
Legend:
- Unmodified
- Added
- Removed
- Modified
- Copied
- Moved
trunk/blade/Blade.d
r155 r157 72 72 RevisedExpression revised = simplifySyntaxTree(tree); 73 73 if (revised.errorMessage.length>0) return `static assert(0, "BLADE: ` ~ enquote(revised.errorMessage) ~ `");`; 74 VecExpressionType exprType = categorizeExpression( tree,revised);74 VecExpressionType exprType = categorizeExpression(revised); 75 75 if (exprType == VecExpressionType.SSE2Expression || exprType == VecExpressionType.SSE1Expression) { 76 return invokeSSE((exprType == VecExpressionType.SSE2Expression), tree,revised);76 return invokeSSE((exprType == VecExpressionType.SSE2Expression), revised); 77 77 } else if (exprType == VecExpressionType.X87Expression) { 78 return invokeX87( tree,revised);78 return invokeX87(revised); 79 79 } else { 80 return DCodeGenerator( tree,revised);80 return DCodeGenerator(revised); 81 81 } 82 82 } … … 122 122 enum VecExpressionType { SSE1Expression, SSE2Expression, X87Expression, DExpression }; 123 123 124 VecExpressionType categorizeExpression( AbstractSyntaxTree tree, RevisedExpression revised)124 VecExpressionType categorizeExpression(RevisedExpression tree) 125 125 { 126 126 bool SSE2 = true; … … 137 137 int numscalars = 0; 138 138 int numRealScalars = 0; // scalars other than float or double. 139 for (int i=0; i< revised.mapping.length;++i) {140 char r = revised.rank[i];141 int x = revised.mapping[i]-'A';139 for (int i=0; i<tree.mapping.length;++i) { 140 char r = tree.rank[i]; 141 int x = tree.mapping[i]-'A'; 142 142 if (r=='0') { 143 143 ++numscalars; … … 155 155 int y = x; // for compounds, get the original type 156 156 if (x>=tree.symbolTable.length) { 157 y = revised.compounds[x-tree.symbolTable.length][0]-'A';157 y = tree.compounds[x-tree.symbolTable.length][0]-'A'; 158 158 // Check for a stride.. 159 if ( revised.compounds[x-tree.symbolTable.length][$-1]==']') {160 strided |= isStrided( revised.compounds[x-tree.symbolTable.length]);159 if (tree.compounds[x-tree.symbolTable.length][$-1]==']') { 160 strided |= isStrided(tree.compounds[x-tree.symbolTable.length]); 161 161 } 162 162 } … … 194 194 195 195 /// Generate code which will call the X87 function 196 char [] invokeX87( AbstractSyntaxTree tree, RevisedExpression revised)197 { 198 char [] result = assertAllVectorLengthsEqual(tree , revised);199 result ~= `X87VECGEN!("` ~ enquote( revised.expression) ~ `"`;196 char [] invokeX87(RevisedExpression tree) 197 { 198 char [] result = assertAllVectorLengthsEqual(tree); 199 result ~= `X87VECGEN!("` ~ enquote(tree.expression) ~ `"`; 200 200 201 201 char [] vals; 202 for (int i=0; i< revised.mapping.length;++i) {203 char rnk = revised.rank[i];202 for (int i=0; i<tree.mapping.length;++i) { 203 char rnk = tree.rank[i]; 204 204 vals ~= ","; 205 205 if (rnk=='1') vals ~= "&"; 206 vals ~= getValueForSymbol( revised.mapping[i], tree, revised);207 int x = revised.mapping[i]-'A';206 vals ~= getValueForSymbol(tree.mapping[i], tree); 207 int x = tree.mapping[i]-'A'; 208 208 char [] t; 209 209 if (x<tree.symbolTable.length) { … … 219 219 // or float, it could use less FPU stack space. 220 220 } else { // for arrays, the type is the type of the original array 221 t = tree.symbolTable[ revised.compounds[x-tree.symbolTable.length][0]-'A'].element;221 t = tree.symbolTable[tree.compounds[x-tree.symbolTable.length][0]-'A'].element; 222 222 } 223 223 } … … 239 239 } 240 240 result ~= ")("; 241 int firstVector = findVectorForLength(tree , revised);242 return result ~ getValueForSymbol( revised.mapping[firstVector], tree, revised) ~ ".length" ~ vals ~ ");";241 int firstVector = findVectorForLength(tree); 242 return result ~ getValueForSymbol(tree.mapping[firstVector], tree) ~ ".length" ~ vals ~ ");"; 243 243 } 244 244 245 245 /// Generate code which will call the SSE/SSE2 code generation function 246 char [] invokeSSE(bool SSE2, AbstractSyntaxTree tree, RevisedExpression revised)247 { 248 char [] result = assertAllVectorLengthsEqual(tree , revised);249 result ~= assertAllVectorsAlign128(tree , revised);246 char [] invokeSSE(bool SSE2, RevisedExpression tree) 247 { 248 char [] result = assertAllVectorLengthsEqual(tree); 249 result ~= assertAllVectorsAlign128(tree); 250 250 251 251 252 result ~= "SSEVECGEN!(" ~ (SSE2?"2":"1") ~ `,"` ~ enquote( revised.expression) ~ `"`;252 result ~= "SSEVECGEN!(" ~ (SSE2?"2":"1") ~ `,"` ~ enquote(tree.expression) ~ `"`; 253 253 // For SSE2, everything must be implicitly convertible to double. 254 254 char [] vals; 255 for (int i=0; i< revised.mapping.length;++i) {256 char rnk = revised.rank[i];255 for (int i=0; i<tree.mapping.length;++i) { 256 char rnk = tree.rank[i]; 257 257 if (rnk=='0') result ~= SSE2? ",double" : ",float"; 258 258 else result ~= SSE2? ",double*" : ",float*"; 259 259 vals ~= ","; 260 260 if (rnk=='1') vals ~= "&"; 261 vals ~= getValueForSymbol( revised.mapping[i], tree, revised);261 vals ~= getValueForSymbol(tree.mapping[i], tree); 262 262 // for vectors, we only need the pointer, not the length 263 263 // if (rnk=='1') vals ~= ".ptr"; … … 266 266 267 267 result ~= ")("; 268 int firstVector = findVectorForLength(tree , revised);269 result ~= getValueForSymbol( revised.mapping[firstVector], tree, revised) ~ ".length";268 int firstVector = findVectorForLength(tree); 269 result ~= getValueForSymbol(tree.mapping[firstVector], tree) ~ ".length"; 270 270 // result ~= tree.symbolTable[firstVector].value ~ ".length"; 271 271 result ~= vals; … … 277 277 * If possible, the error will be detected at compile time. 278 278 */ 279 char [] assertAllVectorLengthsEqual( AbstractSyntaxTree tree, RevisedExpression revised)279 char [] assertAllVectorLengthsEqual(RevisedExpression tree) 280 280 { 281 281 char [] result =""; 282 int firstVector = findVectorForLength(tree , revised);282 int firstVector = findVectorForLength(tree); 283 283 // bool known = arrayLengthIsStatic(tree.symbolTable[firstVector].type); 284 for (int i=0; i< revised.mapping.length;++i) {285 if ( revised.rank[i]=='1') {284 for (int i=0; i<tree.mapping.length;++i) { 285 if (tree.rank[i]=='1') { 286 286 if (firstVector != i) { 287 287 // if (known && arrayLengthIsStatic(tree.symbolTable[i].type)) { … … 291 291 // } 292 292 result ~= "assert(" 293 ~ getDimensionLengthForSymbol( revised.mapping[i], tree, revised, 1)294 ~ "==" ~ getDimensionLengthForSymbol( revised.mapping[firstVector], tree, revised, 1)293 ~ getDimensionLengthForSymbol(tree.mapping[i], tree, 1) 294 ~ "==" ~ getDimensionLengthForSymbol(tree.mapping[firstVector], tree, 1) 295 295 ~ ", `Vector length mismatch`);"\n; 296 // ~ ".length==" ~ getValueForSymbol(revised.mapping[firstVector], tree, revised)297 // ~ ".length, `Vector length mismatch`);"\n;298 296 } 299 297 } … … 302 300 } 303 301 304 char [] assertAllVectorsAlign128( AbstractSyntaxTree tree, RevisedExpression revised)302 char [] assertAllVectorsAlign128(RevisedExpression tree) 305 303 { 306 304 char [] result =""; 307 for (int i=0; i< revised.mapping.length;++i) {308 if ( revised.rank[i]=='1'){309 result ~= "assert( (cast(size_t)(&" ~ getValueForSymbol( revised.mapping[i], tree, revised)310 ~ "[0])& 0x0F) == 0, `SSE Vector misalignment: " ~ getValueForSymbol( revised.mapping[i], tree, revised) ~ "`);"\n;305 for (int i=0; i<tree.mapping.length;++i) { 306 if (tree.rank[i]=='1'){ 307 result ~= "assert( (cast(size_t)(&" ~ getValueForSymbol(tree.mapping[i], tree) 308 ~ "[0])& 0x0F) == 0, `SSE Vector misalignment: " ~ getValueForSymbol(tree.mapping[i], tree) ~ "`);"\n; 311 309 } 312 310 } … … 326 324 // If this is not possible, a normal dynamic array will be used. 327 325 // If all else fails, a sliced vector will be used. 328 int findVectorForLength( AbstractSyntaxTree tree, RevisedExpression revised)326 int findVectorForLength(RevisedExpression tree) 329 327 { 330 328 int dynamic = -1; // last dynamic vector 331 329 int strided = 0; // last unstrided vector 332 for (int i = 0; i < revised.mapping.length; ++i) {333 if ( revised.rank[i]!='1') continue;334 int x = revised.mapping[i]-'A';330 for (int i = 0; i < tree.mapping.length; ++i) { 331 if (tree.rank[i]!='1') continue; 332 int x = tree.mapping[i]-'A'; 335 333 strided = i; 336 334 if (x < tree.symbolTable.length) { … … 339 337 } else { 340 338 // Check for a stride. 341 if ( revised.compounds[x-tree.symbolTable.length][$-1]==']') {342 if (!isStrided( revised.compounds[x-tree.symbolTable.length])) {339 if (tree.compounds[x-tree.symbolTable.length][$-1]==']') { 340 if (!isStrided(tree.compounds[x-tree.symbolTable.length])) { 343 341 dynamic = i; 344 342 } … … 355 353 } 356 354 357 char [] getDimensionLengthForSymbol(char c, AbstractSyntaxTree tree, RevisedExpression revised, int dimension)355 char [] getDimensionLengthForSymbol(char c, RevisedExpression tree, int dimension) 358 356 { 359 357 int numSlicesRemaining = 1; … … 365 363 return v ~ ".length"; 366 364 } else { // else it's a compound or an indexed array 367 char [] comp = revised.compounds[c-'A'-tree.symbolTable.length];365 char [] comp = tree.compounds[c-'A'-tree.symbolTable.length]; 368 366 369 367 if (comp[$-1]!=']') { // simple compound expression … … 424 422 } 425 423 426 char [] getValueForSymbol(char c, AbstractSyntaxTree tree, RevisedExpression revised, char [] firstIndexExpr="")424 char [] getValueForSymbol(char c, RevisedExpression tree, char [] firstIndexExpr="") 427 425 { 428 426 int numSlicesRemaining=1; … … 433 431 v = tree.symbolTable[c-'A'].value; 434 432 } else { // else it's a compound or an indexed array 435 char [] comp = revised.compounds[c-'A'-tree.symbolTable.length];433 char [] comp = tree.compounds[c-'A'-tree.symbolTable.length]; 436 434 437 435 if (comp[$-1]!=']') { // simple compound expression … … 510 508 511 509 // Generate inline D code for the expression 512 char [] DCodeGenerator( AbstractSyntaxTree tree, RevisedExpression revised)513 { 514 int lenvec = findVectorForLength(tree , revised);515 char [] result = assertAllVectorLengthsEqual(tree , revised);510 char [] DCodeGenerator(RevisedExpression tree) 511 { 512 int lenvec = findVectorForLength(tree); 513 char [] result = assertAllVectorLengthsEqual(tree); 516 514 result ~= "for (int blade_index=0; blade_index<" 517 ~ getDimensionLengthForSymbol( revised.mapping[lenvec], tree, revised, 1) ~515 ~ getDimensionLengthForSymbol(tree.mapping[lenvec], tree, 1) ~ 518 516 "; ++blade_index) {"\n; 519 foreach (c; revised.expression) {517 foreach (c; tree.expression) { 520 518 if (c>='A' && c<'Z') { 521 519 // restore all symbols into the expression 522 520 // If it's a vector, index it 523 if ( revised.rank[c-'A']=='1')524 result ~= getValueForSymbol( revised.mapping[c-'A'], tree, revised, "blade_index");525 else result ~= getValueForSymbol( revised.mapping[c-'A'], tree, revised);521 if (tree.rank[c-'A']=='1') 522 result ~= getValueForSymbol(tree.mapping[c-'A'], tree, "blade_index"); 523 else result ~= getValueForSymbol(tree.mapping[c-'A'], tree); 526 524 } else result ~= c; 527 525 } trunk/blade/BladeDemo.d
r156 r157 32 32 double [4][] another = [[33.1, 4543, 43, 878.7], [5.14, 455, 554, 2.43]]; 33 33 real k=3.4; 34 34 35 35 mixin(vectorize(` a += (d[2..$-1]*2.01*a[2]-another[][1])["abc".length-3..$]`)); 36 36 … … 45 45 mixin(vectorize("another[0..$,1]+=6*a[0..2]")); 46 46 mixin(vectorize("r-=another[0]")); 47 47 48 48 // Parses OK, but I don't think I'll support this. 49 49 // mixin(vectorize("a+=6*another[1,[1,$]]")); 50 50 51 51 52 // Parses, and rank checks OK. Doesn't simplify yet, no codegen.53 // mixin(vectorize("dot(q,q*dot(q,q))")); // BUG: should simplify to: dot(q.q) * dot(q,q) 52 // Parses, and simplifies to A*A, where A = dot(q,q). No codegen yet. 53 // mixin(vectorize("dot(q,q*dot(q,q))")); 54 54 55 55 writefln("a=", a); trunk/blade/BladeRank.d
r156 r157 51 51 * The sub-expression must be 52 52 * - a single character (eg "X"), OR 53 * - a lower-case function (eg "a(B,(C*D))"), OR 53 54 * - an expression in parenthesis, OR 54 55 * - an array literal … … 87 88 return rank[expr[0]-'A']-'0'; 88 89 } 90 if (expr[0]=='d') return 0; 89 91 assert(expr[0]=='(', "BLADE ICE:" ~ expr); 90 92 // strip off the parentheses trunk/blade/BladeSimplify.d
r156 r157 37 37 char [] expression; // the revised expression using new variable names 38 38 // (so, for example, B+=(D-F) becomes A+=(B-C) ). 39 Symbol[] symbolTable; // the original symbol table, with all the old names 39 40 char [][] compounds; // the compound variables, defined using the old names 40 41 char [] rank; // rank of all symbols (including original & compounds) … … 59 60 // Check for undefined symbols 60 61 if (err.length > 0) 61 return RevisedExpression(tree.expression, [""], "","", "Undefined symbols:" ~ err);62 return RevisedExpression(tree.expression, tree.symbolTable, [""], "","", "Undefined symbols:" ~ err); 62 63 else { 63 64 char [] expr2 = removeDuplicates(tree); … … 65 66 int wholerank = exprRank(expr2, ranks); 66 67 if (wholerank<0) 67 return RevisedExpression(expr2, [""], "","", getRankErrorText(wholerank));68 return simplifyVectorExpression(expr2, ranks );68 return RevisedExpression(expr2, tree.symbolTable, [""], "","", getRankErrorText(wholerank)); 69 return simplifyVectorExpression(expr2, ranks, tree.symbolTable); 69 70 } 70 71 } … … 179 180 char [] leftMul = ""; 180 181 char [] rightMul = ""; 181 if (leftrnk == 0 && rightrnk == 0) return expr;182 if (leftrnk == 0 && rightrnk == 0) { return expr; } 182 183 if (leftrnk == 0) leftMul = left; else leftMul = getCommonMultiplucation(left, rank); 183 184 if (rightrnk== 0) rightMul = right; else rightMul = getCommonMultiplucation(right, rank); … … 228 229 } 229 230 231 230 232 /** 231 233 * Rewrite the expression, taking advantage of distributivity of [] and … … 247 249 leftmul = getCommonMultiplucation(left, rank); 248 250 rightmul = getCommonMultiplucation(right, rank); 249 250 // assert(0, leftmul~"#" ~ rightmul);251 // if (hasScalarMultiply(left, rank)) {252 // pull the scalar mul out253 // }254 // if (hasScalarMultiply(right, rank)) {255 // } // ditto for right.256 251 char [] m = leftmul; 257 252 if (rightmul.length>0) m = m==""? rightmul : "(" ~ m ~ "*" ~rightmul~")"; 258 253 if (mulexpr.length>0) m = m=="" ? mulexpr : "(" ~ m ~ "*" ~mulexpr~")"; 259 if (m.length>1) m= "* {" ~ m ~ "} ";254 if (m.length>1) m= "* {" ~ m ~ "} "; 260 255 else if (m.length==1) m= "*" ~ m; 261 256 assert(indexexpr.length==0, "BLADE ICE: rank mismatch in dot product"); 262 //return "#" ~ subexprSimplify("A,B","01", "",""); 263 // assert(0, expr ~ "#" ~ left ~ "#" ~ right ~"#" ~ rank ~ "#");// ~ subexprSimplify(right, rank, mulexpr,"")~"#"); 264 return "d(" ~ simplifyWithoutMul(left, rank) ~ "," ~ 265 simplifyWithoutMul(right, rank) ~ ")" ~ m; 257 // assert(0, expr ~ "#" ~ left ~ "#" ~ right ~"#" ~ leftmul ~ "#"~ rightmul);// ~ subexprSimplify(right, rank, mulexpr,"")~"#"); 258 return " {d(" ~ simplifyWithoutMul(left, rank) ~ "," ~ 259 simplifyWithoutMul(right, rank) ~ ")} " ~ m; 266 260 } 267 261 // Deal with ++ and --. Only for scalars … … 340 334 } 341 335 342 RevisedExpression simplifyVectorExpression(char [] expr, char [] rank )336 RevisedExpression simplifyVectorExpression(char [] expr, char [] rank, Symbol[] symTable=[]) 343 337 { 344 338 char [] s = exprSimplify(expr, rank, "", ""); 345 if (s.length>1 ) s = s[1..$-1]; // strip off ()339 if (s.length>1 && s[0]=='(') s = s[1..$-1]; // strip off () 346 340 char [][] comp; 347 341 char [] used=""; // which of the old variables are used; gives the new mapping … … 355 349 int k; 356 350 for (k=i+1; s[k]!=' '; ++k) {} 357 comp ~= s[i+2..k-1]; 358 if (s[k-2]==']') { 359 // it's a vector/matrix of some kind, with rank reduced 360 // by indices. Can't just use exprRank, because the [] 361 // aren't wrapped by (). 362 r ~= (rank[s[i+2]-'A'] - indexRank(s[i+2..k-1])); 363 } else { 364 // it's a scalar expression. Note that it could involve 365 // a vector expression. 366 r~='0'; 367 } 368 e ~= next; 369 ++next; 351 char [] newexpr = s[i+2..k-1]; // strip off the {} 352 // Check for a duplicate 353 int z; 354 for (z=0; z<comp.length && comp[z]!=newexpr; ++z) {} 355 if (z==comp.length) { 356 e ~= next; 357 ++next; 358 comp ~= s[i+2..k-1]; // strip off the {} 359 if (s[k-2]==']') { 360 // it's a vector/matrix of some kind, with rank reduced 361 // by indices. Can't just use exprRank, because the [] 362 // aren't wrapped by (). 363 r ~= (rank[s[i+2]-'A'] - indexRank(s[i+2..k-1])); 364 } else { 365 // it's a scalar expression. Note that it could involve 366 // a vector expression. 367 r~='0'; 368 } 369 } else e ~= cast(char)('A'+z+rank.length); 370 370 i = k; 371 371 } else { … … 399 399 } else f ~= c; 400 400 } 401 return RevisedExpression(f, comp, old_ranks~r, mapping);401 return RevisedExpression(f, symTable, comp, old_ranks~r, mapping); 402 402 } 403 403 404 404 unittest { 405 // assert(0, exprSimplify("d(A,A*d(A,A))", "1", "","")); // == "d(A,A)*d(A,A)");405 assert(exprSimplify("d(A,(A*d(A,A)))", "1", "","")== " {d(A,A)} * {d(A,A)} "); 406 406 assert(exprSimplify("A+=(B*(C[D,D..$]))","1020","","")=="(A+=(B* {C[D,D..$]} ))"); 407 407 assert(exprSimplify("A+=(((D[E])*B)[E])", "103300","","")=="(A+=(B* {D[E][E]} ))"); … … 411 411 assert(exprSimplify("A=(((B*E)+(C*E))*D)", "11100","","")=="(A=(( {(D*E)} *B)+( {(D*E)} *C)))"); 412 412 assert(exprSimplify("A=(D*((B*E)+(C*E)))", "11100","","")=="(A=(( {(D*E)} *B)+( {(D*E)} *C)))"); 413 assert(exprSimplify("d((A*(B*C)),(B*A))","010","","")== " d(B,B)* {((A*C)*A)}");413 assert(exprSimplify("d((A*(B*C)),(B*A))","010","","")== " {d(B,B)} * {((A*C)*A)} "); 414 414 415 415 RevisedExpression e = simplifyVectorExpression("A+=(((D[B])*C)[B])", "2004"); trunk/blade/SyntaxTree.d
r151 r157 82 82 char [] expression; /// syntax tree in Placeholder format, eg A+=(B*C) 83 83 Symbol[] symbolTable; /// Textual form of the types and values of A,B,C,... 84 }85 86 struct TemplateSyntaxTree(T...) {87 AbstractSyntaxTree tree;88 84 } 89 85
