| 1 |
// Written in the D programming language 1.0 |
|---|
| 2 |
/** |
|---|
| 3 |
* Simplify a vector expression. |
|---|
| 4 |
* Part of BLADE : Basic Linear Algebra D Expressions |
|---|
| 5 |
* |
|---|
| 6 |
* The following transformations are performed: |
|---|
| 7 |
* (A) Remove all duplicate symbols. |
|---|
| 8 |
* (B) Create 'compound' variables, consisting of multiple symbols: |
|---|
| 9 |
* - Combine all scalars into a single scalar. |
|---|
| 10 |
* - Rank reduction: Combine slicing operations into vectors |
|---|
| 11 |
* (C) Arithmetic transformations. |
|---|
| 12 |
* - Use slicing distributive law: Given A[B..C] for expressions A,B,C |
|---|
| 13 |
* where B and C are both rank 0, and A is rank 1 or more, the slice can |
|---|
| 14 |
* be moved to every vector inside A. |
|---|
| 15 |
* - Use associativity of *: A*(B*C[]) == (A*B)*C[] (Not strictly true for |
|---|
| 16 |
* floating point; results may differ by 1ulp, |
|---|
| 17 |
* eg (1.3L*3.1L)*4.7L < 1.3L*(3.1L*4.7L) |
|---|
| 18 |
* Note that floating point addition is not associative at all). |
|---|
| 19 |
* - Remove unary minus where possible, eg A-(-B) => A+B, abs(-A) => abs(A). |
|---|
| 20 |
* - Use associativity of * in intrinsics: |
|---|
| 21 |
* sum(A*V) => A*sum(V), abs(A*B) => abs(A)*abs(B) |
|---|
| 22 |
* (D) Expression standardisation |
|---|
| 23 |
* - Move multiplies to left: Convert A[]*B into B*A[] (assumes * is commutative, |
|---|
| 24 |
* not valid for quaternions). |
|---|
| 25 |
* - Convert C-A*B into C+(-A)*B whenever possible. |
|---|
| 26 |
* |
|---|
| 27 |
* Author: |
|---|
| 28 |
* Don Clugston. |
|---|
| 29 |
* License: |
|---|
| 30 |
* Public domain. |
|---|
| 31 |
*/ |
|---|
| 32 |
|
|---|
| 33 |
module blade.BladeSimplify; |
|---|
| 34 |
|
|---|
| 35 |
public import blade.SyntaxTree : AbstractSyntaxTree, Symbol; |
|---|
| 36 |
private import blade.BladeVisitor; |
|---|
| 37 |
private import blade.BladeRank : exprRank, subexprRank, getRankErrorText; |
|---|
| 38 |
|
|---|
| 39 |
|
|---|
| 40 |
/// A simplified vector expression |
|---|
| 41 |
struct RevisedExpression { |
|---|
| 42 |
char [] braceExpression; // the expression with compounds in braces |
|---|
| 43 |
char [] expression; // the revised expression using new variable names |
|---|
| 44 |
// (so, for example, B+=(D-F) becomes A+=(B-C) ). |
|---|
| 45 |
Symbol[] symbolTable; // the original symbol table, with all the old names |
|---|
| 46 |
char [][] compounds; // the compound variables, defined using the old names |
|---|
| 47 |
char [] rank; // rank of all symbols (including original & compounds) |
|---|
| 48 |
char [] mapping; // mapping from new names onto old names. |
|---|
| 49 |
char [] errorMessage; // or null if no errors. |
|---|
| 50 |
} |
|---|
| 51 |
|
|---|
| 52 |
// TODO: ".re", ".im" should join this list. |
|---|
| 53 |
bool isBladeIntrinsic(char [] str) |
|---|
| 54 |
{ |
|---|
| 55 |
return str=="dot" || str=="sum" || str=="max" || str=="min" |
|---|
| 56 |
|| str=="abs" || str=="sqrt" || str=="prod"; |
|---|
| 57 |
} |
|---|
| 58 |
|
|---|
| 59 |
|
|---|
| 60 |
// Given an abstract syntax tree, returns a RevisedExpression. |
|---|
| 61 |
RevisedExpression simplifySyntaxTree(AbstractSyntaxTree tree) |
|---|
| 62 |
{ |
|---|
| 63 |
char [] ranks; |
|---|
| 64 |
char [] err=""; |
|---|
| 65 |
for (int i=0; i<tree.symbolTable.length; ++i) { |
|---|
| 66 |
ranks~=tree.symbolTable[i].rank; |
|---|
| 67 |
if (tree.symbolTable[i].type=="" && !isBladeIntrinsic(tree.symbolTable[i].value)) err ~= " " ~ tree.symbolTable[i].value; |
|---|
| 68 |
} |
|---|
| 69 |
// Check for undefined symbols |
|---|
| 70 |
if (err.length > 0) |
|---|
| 71 |
return RevisedExpression(tree.expression, "", tree.symbolTable, [""], "","", "Undefined symbols:" ~ err); |
|---|
| 72 |
else { |
|---|
| 73 |
// Remove duplicate symbols, convert intrinsics |
|---|
| 74 |
char [] expr2 = removeDuplicates(tree); |
|---|
| 75 |
// Check for rank errors |
|---|
| 76 |
int wholerank = exprRank(expr2, ranks); |
|---|
| 77 |
if (wholerank<0) |
|---|
| 78 |
return RevisedExpression(expr2, "", tree.symbolTable, [""], "","", getRankErrorText(wholerank)); |
|---|
| 79 |
// Perform scalar foldings and dimension reduction |
|---|
| 80 |
char [] expr3 = foldScalars(foldIndices(expr2, ranks), ranks); |
|---|
| 81 |
return remapCompounds(expr3, ranks, tree.symbolTable); |
|---|
| 82 |
} |
|---|
| 83 |
} |
|---|
| 84 |
|
|---|
| 85 |
private: |
|---|
| 86 |
|
|---|
| 87 |
/* Adjust the expression to remove all references to duplicated symbol table |
|---|
| 88 |
* entries. (Duplicates can occur as a result of resolving aliases or constants). |
|---|
| 89 |
* Also move intrinsics in-line. |
|---|
| 90 |
*/ |
|---|
| 91 |
char [] removeDuplicates(AbstractSyntaxTree tree) |
|---|
| 92 |
{ |
|---|
| 93 |
int numdups=0; |
|---|
| 94 |
char [] mapping = ""; // The new letter which this symbol should become, |
|---|
| 95 |
// or '!' if it is an intrinsic |
|---|
| 96 |
for (int i=0; i<tree.symbolTable.length; ++i) { |
|---|
| 97 |
char c = 'A'+i; |
|---|
| 98 |
if (isBladeIntrinsic(tree.symbolTable[i].value)) { |
|---|
| 99 |
++numdups; |
|---|
| 100 |
c = '!'; |
|---|
| 101 |
} else { |
|---|
| 102 |
for (int j=0; j<i; ++j) { |
|---|
| 103 |
if (tree.symbolTable[i].value==tree.symbolTable[j].value) { |
|---|
| 104 |
++numdups; |
|---|
| 105 |
c = ('A'+j); |
|---|
| 106 |
break; |
|---|
| 107 |
} |
|---|
| 108 |
} |
|---|
| 109 |
} |
|---|
| 110 |
mapping ~= c; |
|---|
| 111 |
} |
|---|
| 112 |
if (numdups==0) return tree.expression; |
|---|
| 113 |
char [] e = ""; |
|---|
| 114 |
for (int i=0; i<tree.expression.length;++i) { |
|---|
| 115 |
char c = tree.expression[i]; |
|---|
| 116 |
if (c>='A' && c<='Z') { |
|---|
| 117 |
if (mapping[c-'A']=='!') e~=tree.symbolTable[c-'A'].value; |
|---|
| 118 |
else e~= mapping[c-'A']; |
|---|
| 119 |
} else e~=c; |
|---|
| 120 |
} |
|---|
| 121 |
return e; |
|---|
| 122 |
} |
|---|
| 123 |
|
|---|
| 124 |
unittest { |
|---|
| 125 |
AbstractSyntaxTree t = AbstractSyntaxTree("A+(B*C)", [Symbol("int", "125", 0), |
|---|
| 126 |
Symbol("int", "7", 0), Symbol("int", "125", 0)]); |
|---|
| 127 |
assert(removeDuplicates(t)=="A+(B*A)"); |
|---|
| 128 |
} |
|---|
| 129 |
|
|---|
| 130 |
// Determine rank of a multidimensional index |
|---|
| 131 |
int indexRank(char [] s) |
|---|
| 132 |
{ |
|---|
| 133 |
int r=0; |
|---|
| 134 |
int numbrack=0; |
|---|
| 135 |
int paren = 0; |
|---|
| 136 |
for(int i=1; i<s.length; ++i) { |
|---|
| 137 |
if (s[i]=='(') ++paren; |
|---|
| 138 |
else if (s[i]==')') --paren; |
|---|
| 139 |
if (paren==0 && s[i]==']') { numbrack--; } |
|---|
| 140 |
if (paren==0 && s[i]=='[') { |
|---|
| 141 |
if (numbrack==0) ++r; |
|---|
| 142 |
numbrack++; |
|---|
| 143 |
} |
|---|
| 144 |
if (paren==0 && numbrack==1 && s[i]==',') ++r; // commas increase the rank |
|---|
| 145 |
if (paren==0 && numbrack==1 && s[i]=='.' && s[i-1]=='.') { |
|---|
| 146 |
// if it's a slice, it does not increase the rank |
|---|
| 147 |
r--; |
|---|
| 148 |
} |
|---|
| 149 |
} |
|---|
| 150 |
return r; |
|---|
| 151 |
} |
|---|
| 152 |
|
|---|
| 153 |
public: |
|---|
| 154 |
// Remove everything inside {} from expr, and create new variables for it. |
|---|
| 155 |
RevisedExpression remapCompounds(char [] expr, char [] rank, Symbol[] symTable) |
|---|
| 156 |
{ |
|---|
| 157 |
char [][] comp; |
|---|
| 158 |
char [] used = ""; // which of the old variables are used; gives the new mapping |
|---|
| 159 |
// of the name, or - if not used. |
|---|
| 160 |
for (int i=0; i<rank.length; ++i) used ~= "-"; |
|---|
| 161 |
char [] r; |
|---|
| 162 |
char next = cast(char)('A' + rank.length); |
|---|
| 163 |
char [] e = ""; |
|---|
| 164 |
for (int i=0; i<expr.length; ++i) { |
|---|
| 165 |
if (expr[i]=='{') { |
|---|
| 166 |
int k; |
|---|
| 167 |
int bracecount=1; |
|---|
| 168 |
for (k=i+1; bracecount>0; ++k) { |
|---|
| 169 |
if (expr[k]=='{') ++bracecount; |
|---|
| 170 |
if (expr[k]=='}') --bracecount; |
|---|
| 171 |
} |
|---|
| 172 |
--k; |
|---|
| 173 |
char [] newexpr = expr[i+1..k]; // strip off the {} |
|---|
| 174 |
int newi = k; |
|---|
| 175 |
if (i>0 && k<expr.length-1 && expr[i-1]=='(' && expr[k+1]==')') { |
|---|
| 176 |
e = e[0..$-1]; // remove the last '(' |
|---|
| 177 |
newi=k+1; // don't add ')' |
|---|
| 178 |
} |
|---|
| 179 |
// Check for a duplicate |
|---|
| 180 |
int z; |
|---|
| 181 |
for (z=0; z<comp.length && comp[z]!=newexpr; ++z) {} |
|---|
| 182 |
if (z==comp.length) { |
|---|
| 183 |
e ~= next; |
|---|
| 184 |
++next; |
|---|
| 185 |
comp ~= expr[i+1..k]; // strip off the {} |
|---|
| 186 |
if (expr[k-1]==']') { |
|---|
| 187 |
// it's a vector/matrix of some kind, with rank reduced |
|---|
| 188 |
// by indices. Can't just use exprRank, because the [] |
|---|
| 189 |
// aren't wrapped by (). |
|---|
| 190 |
r ~= (rank[expr[i+1]-'A'] - indexRank(expr[i+1..k])); |
|---|
| 191 |
} else { |
|---|
| 192 |
// it's a scalar expression. Note that it could involve |
|---|
| 193 |
// a vector expression. |
|---|
| 194 |
r~='0'; |
|---|
| 195 |
} |
|---|
| 196 |
} else e ~= cast(char)('A'+z+rank.length); |
|---|
| 197 |
i = newi; |
|---|
| 198 |
} else { |
|---|
| 199 |
e ~= expr[i]; |
|---|
| 200 |
if (expr[i]>='A' && expr[i]<='Z') used[expr[i]-'A']=expr[i]; |
|---|
| 201 |
} |
|---|
| 202 |
} |
|---|
| 203 |
// Create a mapping from old to new variable names |
|---|
| 204 |
|
|---|
| 205 |
char [] old_ranks = ""; |
|---|
| 206 |
char [] mapping=""; |
|---|
| 207 |
char knt = 'A'; |
|---|
| 208 |
for (int i=0; i<used.length; ++i) { |
|---|
| 209 |
if (used[i]!='-') { |
|---|
| 210 |
mapping ~= ('A'+i); |
|---|
| 211 |
old_ranks ~= rank[i]; |
|---|
| 212 |
used[i] = knt; |
|---|
| 213 |
++knt; |
|---|
| 214 |
} |
|---|
| 215 |
} |
|---|
| 216 |
for (int i=0; i<r.length; ++i) { |
|---|
| 217 |
mapping ~= ('A'+used.length+i); |
|---|
| 218 |
} |
|---|
| 219 |
// and set the expression to use the new names |
|---|
| 220 |
char [] f = ""; |
|---|
| 221 |
for (int i=0; i<e.length; ++i) { |
|---|
| 222 |
char c = e[i]; |
|---|
| 223 |
if (c>='A' && c<='Z') { |
|---|
| 224 |
if ((c-'A')<rank.length) f ~= used[c-'A']; |
|---|
| 225 |
else f~= (c-'A') - rank.length + knt; |
|---|
| 226 |
} else f ~= c; |
|---|
| 227 |
} |
|---|
| 228 |
return RevisedExpression(expr, f, symTable, comp, old_ranks~r, mapping); |
|---|
| 229 |
} |
|---|
| 230 |
|
|---|
| 231 |
private: |
|---|
| 232 |
RevisedExpression simplifyVectorExpression(char [] expr, char [] rank, Symbol[] symTable) |
|---|
| 233 |
{ |
|---|
| 234 |
return remapCompounds(foldScalars(foldIndices(expr, rank), rank), rank, symTable); |
|---|
| 235 |
} |
|---|
| 236 |
|
|---|
| 237 |
unittest { |
|---|
| 238 |
RevisedExpression e = simplifyVectorExpression("A+=(((D[B])*C)[B])", "2004",[]); |
|---|
| 239 |
assert(e.rank=="202"); |
|---|
| 240 |
assert(e.compounds[0]=="D[B,B]"); |
|---|
| 241 |
assert(e.mapping=="ACE"); |
|---|
| 242 |
assert(e.expression== "A+=(C*B)"); |
|---|
| 243 |
} |
|---|
| 244 |
|
|---|
| 245 |
// ----------------------------------------------------------- |
|---|
| 246 |
|
|---|
| 247 |
// Given an array of slices/indices where |
|---|
| 248 |
// slicing[0]= start of slice, or index |
|---|
| 249 |
// slicing[1]= end of slice, or "" if it's an index, |
|---|
| 250 |
// combine everything into a single slicing expression. |
|---|
| 251 |
char [] createMultiSlice(char [][2][] slicing) |
|---|
| 252 |
{ |
|---|
| 253 |
char [] s="["; |
|---|
| 254 |
for (int i=0; i<slicing.length;++i) { |
|---|
| 255 |
if (i>0) s~= ","; |
|---|
| 256 |
s ~= slicing[i][0]; |
|---|
| 257 |
if (slicing[i][1].length>0) s~=".." ~ slicing[i][1]; |
|---|
| 258 |
} |
|---|
| 259 |
return s~"]"; |
|---|
| 260 |
} |
|---|
| 261 |
|
|---|
| 262 |
// Combines all the indexing and slicing operations together (dimension reduction). |
|---|
| 263 |
// Multiplication of sliced matrices and/or vectors is dimensionally |
|---|
| 264 |
// reduced where possible (may even be converted into dot product). |
|---|
| 265 |
// Returns the new expression. This eliminates all unnecessary slice operations. |
|---|
| 266 |
// Furthermore, *any* value followed by '[' should be used as a new compound. |
|---|
| 267 |
struct IndexFoldingVisitor { |
|---|
| 268 |
alias typeof(*this) This; |
|---|
| 269 |
alias char [] ReturnType; |
|---|
| 270 |
char [] rank; |
|---|
| 271 |
char [] dollar; |
|---|
| 272 |
char [][2][] slicing; // the indexing which applies to this complete expr. |
|---|
| 273 |
static: |
|---|
| 274 |
ReturnType onVisitSymbol(This this_, char[] sym) { |
|---|
| 275 |
if (this_.slicing.length==0) { |
|---|
| 276 |
if (sym=="$") { |
|---|
| 277 |
return this_.dollar; |
|---|
| 278 |
} |
|---|
| 279 |
else return sym; |
|---|
| 280 |
} else { |
|---|
| 281 |
assert(sym!="$" && this_.rank[sym[0]-'A']>'0', "Rank error " ~ sym); |
|---|
| 282 |
// Note: Later, we'll want this to be a new terminal. |
|---|
| 283 |
return sym ~ createMultiSlice(this_.slicing); |
|---|
| 284 |
} |
|---|
| 285 |
} |
|---|
| 286 |
ReturnType onVisitFunction(This this_, char [] func, char [][] args) { |
|---|
| 287 |
// Intrinsics have no effect on slicing |
|---|
| 288 |
char [] result = ""; |
|---|
| 289 |
for (int i=0; i<args.length; ++i) { |
|---|
| 290 |
if (i>0) result ~= ","; |
|---|
| 291 |
result ~= wrapInParens(doVisit(this_,args[i])); |
|---|
| 292 |
} |
|---|
| 293 |
return func ~ "(" ~ result ~ ")"; |
|---|
| 294 |
} |
|---|
| 295 |
ReturnType onVisitPrefix(This this_, char [] op, char [] expr) { |
|---|
| 296 |
assert(this_.slicing.length==0, "BLADE ICE"); |
|---|
| 297 |
return op ~ wrapInParens(doVisit(this_, expr)); |
|---|
| 298 |
} |
|---|
| 299 |
ReturnType onVisitPostfix(This this_, char [] op, char [] expr) { |
|---|
| 300 |
assert(this_.slicing.length==0, "BLADE ICE"); |
|---|
| 301 |
return wrapInParens(doVisit(this_, expr)) ~ op; |
|---|
| 302 |
} |
|---|
| 303 |
// Includes multi-dimensional slicing and indexing. |
|---|
| 304 |
ReturnType onVisitIndex(This this_, char [] base, char [][2][] slices) { |
|---|
| 305 |
if (slices.length==0) { // [] -- has no effect. |
|---|
| 306 |
return doVisit(this_, base); |
|---|
| 307 |
} |
|---|
| 308 |
// printf(" %.*s --> %.*s %.*s\n", base, slices[0][0], slices[0][1]); |
|---|
| 309 |
if (this_.slicing.length >0 && slices[$-1][1].length>0) { |
|---|
| 310 |
// the new dimension block ends with a slice. This needs to be combined |
|---|
| 311 |
// with the earliest existing dimension. |
|---|
| 312 |
// * If the existing dimension is an index, |
|---|
| 313 |
// it might contain a dollar, which we need to replace. |
|---|
| 314 |
// * If the existing dimension is a slice, the two slices will combine. |
|---|
| 315 |
// |
|---|
| 316 |
// The items inside the slice are top-level, ie have no slice or dollar. |
|---|
| 317 |
char [] a = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank,"$",[]), slices[$-1][0])); |
|---|
| 318 |
char [] b = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank,"$",[]), slices[$-1][1])); |
|---|
| 319 |
char [] dollr = b ~ "-" ~ a; |
|---|
| 320 |
char [] c = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank,dollr,[]), this_.slicing[0][0])); |
|---|
| 321 |
char [][2][] newslice=[]; |
|---|
| 322 |
if (this_.slicing[0][1].length>0) { // slicing a slice |
|---|
| 323 |
if (b=="$" && this_.slicing[0][1]=="$") { |
|---|
| 324 |
// very common special case, where both are sliced from end |
|---|
| 325 |
newslice ~= ["(" ~ a ~ "+" ~ c ~")","$"]; |
|---|
| 326 |
} else { |
|---|
| 327 |
char [] d = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank,dollr,[]), this_.slicing[0][1])); |
|---|
| 328 |
newslice ~= ["(" ~ a ~ "+" ~ c~ ")", "(" ~ a ~ "+" ~ d~ ")"]; |
|---|
| 329 |
} |
|---|
| 330 |
} else { |
|---|
| 331 |
newslice ~= [a ~ "+" ~ c, ""]; |
|---|
| 332 |
} |
|---|
| 333 |
if (slices.length>1) { |
|---|
| 334 |
// append other slices, if any. |
|---|
| 335 |
return doVisit(IndexFoldingVisitor(this_.rank, "$", slices[0..$-1] ~ newslice ~ this_.slicing[1..$]), base); |
|---|
| 336 |
} else { |
|---|
| 337 |
return doVisit(IndexFoldingVisitor(this_.rank, "$",newslice ~ this_.slicing[1..$]), base); |
|---|
| 338 |
} |
|---|
| 339 |
} else { // just append them. |
|---|
| 340 |
return doVisit(IndexFoldingVisitor(this_.rank, "$", slices ~ this_.slicing), base); |
|---|
| 341 |
} |
|---|
| 342 |
} |
|---|
| 343 |
ReturnType onVisitBinaryOp(This this_, char [] op, char [] left, char [] right) { |
|---|
| 344 |
int lrank = subexprRank(left, this_.rank); |
|---|
| 345 |
int rrank = subexprRank(right, this_.rank); |
|---|
| 346 |
char [] first=""; |
|---|
| 347 |
char [] second=""; |
|---|
| 348 |
if ((op=="*" || op=="*=") && this_.slicing.length>0) { |
|---|
| 349 |
// If one of these is a matrix, the slicing gets interesting... |
|---|
| 350 |
// .. extremely so for slicing of matrix chain multiplication. |
|---|
| 351 |
if (lrank==0) { |
|---|
| 352 |
// All dimensions apply to right operand |
|---|
| 353 |
first = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), left)); |
|---|
| 354 |
second = wrapInParens(doVisit(this_, right)); |
|---|
| 355 |
} else if (rrank==0) { |
|---|
| 356 |
// All dimensions apply to the left operand |
|---|
| 357 |
first = wrapInParens(doVisit(this_, left)); |
|---|
| 358 |
second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right)); |
|---|
| 359 |
} else { |
|---|
| 360 |
assert(lrank>0 && rrank>0 && lrank<=2 && rrank<=2, "BLADE ICE: Tensor*tensor is unsupported"); |
|---|
| 361 |
bool isDotProduct = false; // was it reduced to a dot product? |
|---|
| 362 |
|
|---|
| 363 |
// In the case of chained matrix multiplies, we can end up with an empty slice. |
|---|
| 364 |
if (this_.slicing.length>0 && this_.slicing[$-1][0]=="") { |
|---|
| 365 |
this_.slicing=this_.slicing[0..$-1]; |
|---|
| 366 |
} |
|---|
| 367 |
if (lrank==2) { |
|---|
| 368 |
// First dimension applies to rows of the left operand |
|---|
| 369 |
// If it's a slice, it will be a strided slice -- unless |
|---|
| 370 |
// it comes from another matrix multiply, in which case the |
|---|
| 371 |
// stride will drop out. (A[x]*B is strided). |
|---|
| 372 |
char [][2][] newslice=[]; |
|---|
| 373 |
newslice ~= this_.slicing[0]; |
|---|
| 374 |
newslice ~= ["",""]; |
|---|
| 375 |
first = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,newslice), left)); |
|---|
| 376 |
} else { |
|---|
| 377 |
assert(this_.slicing.length==1, "BLADE ICE: Rank error"); |
|---|
| 378 |
first = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), left)); |
|---|
| 379 |
} |
|---|
| 380 |
if (lrank==2 && rrank==2) { |
|---|
| 381 |
// Matrix * matrix |
|---|
| 382 |
if (this_.slicing.length>1) { |
|---|
| 383 |
// Second dimension applies to the right operand. |
|---|
| 384 |
second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,this_.slicing[1..$]), right)); |
|---|
| 385 |
if (this_.slicing[0][1].length==0 && this_.slicing[1][1].length==0) { |
|---|
| 386 |
// It's indices in both cases -- so it's a dot product. |
|---|
| 387 |
isDotProduct = true; |
|---|
| 388 |
} |
|---|
| 389 |
} else { |
|---|
| 390 |
second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right)); |
|---|
| 391 |
} |
|---|
| 392 |
} else if (lrank==1 && rrank==2) { |
|---|
| 393 |
// vector * matrix, Matrix uses all the slicing |
|---|
| 394 |
second = wrapInParens(doVisit(this_, right)); |
|---|
| 395 |
if (this_.slicing[0][1].length==0) isDotProduct = true; |
|---|
| 396 |
} else if (lrank==2 && rrank==1) { |
|---|
| 397 |
// matrix * vector, vector is unsliced. |
|---|
| 398 |
second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right)); |
|---|
| 399 |
if (this_.slicing[0][1].length==0) isDotProduct = true; |
|---|
| 400 |
} else assert(0, "BLADE ICE"); |
|---|
| 401 |
if (isDotProduct) { |
|---|
| 402 |
return "dot(" ~ first ~ "," ~ second ~ ")"; |
|---|
| 403 |
} |
|---|
| 404 |
} |
|---|
| 405 |
} else { // not a multiplication |
|---|
| 406 |
return wrapInParens(doVisit(this_, left)) ~ op ~ wrapInParens(doVisit(this_, right)); |
|---|
| 407 |
} |
|---|
| 408 |
return first ~ op ~ second; |
|---|
| 409 |
} |
|---|
| 410 |
} |
|---|
| 411 |
|
|---|
| 412 |
char [] foldIndices(char [] expr, char [] ranks) |
|---|
| 413 |
{ |
|---|
| 414 |
return beginVisit(IndexFoldingVisitor(ranks,"$",[]), expr); |
|---|
| 415 |
} |
|---|
| 416 |
|
|---|
| 417 |
unittest { |
|---|
| 418 |
assert(foldIndices("((A[C..D])+B)[($-E)]", "21000")=="(A[C+((D-C)-E)])+(B[($-E)])"); |
|---|
| 419 |
assert(foldIndices("(A[C])[D]", "3100")=="A[C,D]"); |
|---|
| 420 |
assert(foldIndices("(A[B..C])[D]", "3000")=="A[B+D]"); |
|---|
| 421 |
assert(foldIndices("(A[B])[C..D]", "3000")=="A[B,C..D]"); |
|---|
| 422 |
assert(foldIndices("((A[$])[(C-$)])[D]", "3000")=="A[$,(C-$),D]"); |
|---|
| 423 |
assert(foldIndices("(A[B..$])[C..$]", "3000")=="A[(B+C)..$]"); |
|---|
| 424 |
assert(foldIndices("((A[])[C..$])[]", "3000")=="A[C..$]"); |
|---|
| 425 |
assert(foldIndices("((A[])[(B[C])..$])[]", "3100")=="A[(B[C])..$]"); |
|---|
| 426 |
assert(foldIndices("A[,B..$,C]", "300")=="A[,B..$,C]"); |
|---|
| 427 |
// Multidimensional slicing |
|---|
| 428 |
assert(foldIndices("(C*((A*B)[C]))[D]", "2200")=="C*dot((A[C,]),(B[D]))"); |
|---|
| 429 |
assert(foldIndices("(A*B)[C..D,D]", "2200")=="(A[C..D,])*(B[D])"); |
|---|
| 430 |
assert(foldIndices("(A*B)[C..D]", "2200")=="(A[C..D,])*B"); |
|---|
| 431 |
assert(foldIndices("(A*B)[C..D]", "2100")=="(A[C..D,])*B"); |
|---|
| 432 |
assert(foldIndices("(A*B)[C..D]", "1200")=="A*(B[C..D])"); |
|---|
| 433 |
assert(foldIndices("(A*B)[C]", "120")=="dot(A,(B[C]))"); |
|---|
| 434 |
|
|---|
| 435 |
assert(foldIndices("((A*B)*C)[D]", "2220")=="((A[D,])*B)*C"); |
|---|
| 436 |
assert(foldIndices("((A+B)*C)[D]", "2220")=="((A[D,])+(B[D,]))*C"); |
|---|
| 437 |
assert(foldIndices("((D*A)*B)[C]", "2100")=="dot((D*(A[C,])),B)"); |
|---|
| 438 |
assert(foldIndices("(((A*B)*C)[D..E])[D]", "12200")=="dot((A*B),(C[D+D]))"); |
|---|
| 439 |
assert(foldIndices("A+=(((D[B])*C)[B])", "2004")=="A+=((D[B,B])*C)"); |
|---|
| 440 |
assert(foldIndices("dot(A,(A*dot(A,A)))","1")=="dot(A,(A*dot(A,A)))"); |
|---|
| 441 |
} |
|---|
| 442 |
|
|---|
| 443 |
struct ScalarFold |
|---|
| 444 |
{ |
|---|
| 445 |
char [] expr; // vector or matrix expression; empty for a pure scalar expression |
|---|
| 446 |
char [] multiplier; // scalar multiply of the entire expression. or "-" for unary minus |
|---|
| 447 |
} |
|---|
| 4 |
|---|