root/trunk/blade/BladeSimplify.d

Revision 187, 27.0 kB (checked in by Don Clugston, 2 months ago)

Added prod(). Use .ptr to get raw data, so it works with Bill Baxter's ArrayView?.

  • Property svn:mime-type set to text/x-dsrc
  • Property svn:eol-style set to native
Line 
1 //  Written in the D programming language 1.0
2 /**
3 * Simplify a vector expression.
4 * Part of BLADE : Basic Linear Algebra D Expressions
5 *
6 *  The following transformations are performed:
7 * (A) Remove all duplicate symbols.
8 * (B) Create 'compound' variables, consisting of multiple symbols:
9 *    - Combine all scalars into a single scalar.
10 *    - Rank reduction: Combine slicing operations into vectors
11 * (C) Arithmetic transformations.
12 *    - Use slicing distributive law: Given A[B..C] for expressions A,B,C
13 *      where B and C are both rank 0, and A is rank 1 or more, the slice can
14 *      be moved to every vector inside A.
15 *    - Use associativity of *: A*(B*C[]) == (A*B)*C[] (Not strictly true for
16 *      floating point; results may differ by 1ulp,
17 *       eg (1.3L*3.1L)*4.7L < 1.3L*(3.1L*4.7L)
18 *      Note that floating point addition is not associative at all).
19 *    - Remove unary minus where possible, eg A-(-B) => A+B, abs(-A) => abs(A).
20 *    - Use associativity of * in intrinsics:
21 *         sum(A*V) => A*sum(V), abs(A*B) => abs(A)*abs(B)
22 * (D) Expression standardisation
23 *    - Move multiplies to left: Convert A[]*B into B*A[] (assumes * is commutative,
24 *      not valid for quaternions).
25 *    - Convert C-A*B into C+(-A)*B whenever possible.
26 *
27 * Author:
28 *   Don Clugston.
29 * License:
30 *   Public domain.
31 */
32
33 module blade.BladeSimplify;
34
35 public import blade.SyntaxTree : AbstractSyntaxTree, Symbol;
36 private import blade.BladeVisitor;
37 private import blade.BladeRank : exprRank, subexprRank, getRankErrorText;
38
39
40 /// A simplified vector expression
41 struct RevisedExpression {
42     char [] braceExpression; // the expression with compounds in braces
43     char [] expression; // the revised expression using new variable names
44                         // (so, for example, B+=(D-F) becomes A+=(B-C) ).
45     Symbol[] symbolTable; // the original symbol table, with all the old names
46     char [][] compounds; // the compound variables, defined using the old names
47     char [] rank;   // rank of all symbols (including original & compounds)
48     char [] mapping; // mapping from new names onto old names.
49     char [] errorMessage; // or null if no errors.
50 }
51
52 // TODO: ".re", ".im" should join this list.
53 bool isBladeIntrinsic(char [] str)
54 {
55     return str=="dot" || str=="sum" || str=="max" || str=="min"
56            || str=="abs" || str=="sqrt" || str=="prod";
57 }
58
59
60 // Given an abstract syntax tree, returns a RevisedExpression.
61 RevisedExpression simplifySyntaxTree(AbstractSyntaxTree tree)
62 {
63     char [] ranks;
64     char [] err="";
65     for (int i=0; i<tree.symbolTable.length; ++i) {
66         ranks~=tree.symbolTable[i].rank;
67         if (tree.symbolTable[i].type=="" && !isBladeIntrinsic(tree.symbolTable[i].value)) err ~= " " ~ tree.symbolTable[i].value;
68     }
69     // Check for undefined symbols
70     if (err.length > 0)
71         return RevisedExpression(tree.expression, "", tree.symbolTable, [""], "","", "Undefined symbols:" ~ err);
72     else {
73         // Remove duplicate symbols, convert intrinsics
74         char [] expr2 = removeDuplicates(tree);
75         // Check for rank errors
76         int wholerank = exprRank(expr2, ranks);
77         if (wholerank<0)
78             return RevisedExpression(expr2, "", tree.symbolTable, [""], "","", getRankErrorText(wholerank));
79         // Perform scalar foldings and dimension reduction
80         char [] expr3 = foldScalars(foldIndices(expr2, ranks), ranks);
81         return remapCompounds(expr3, ranks, tree.symbolTable);
82     }
83 }
84
85 private:
86
87 /* Adjust the expression to remove all references to duplicated symbol table
88  * entries. (Duplicates can occur as a result of resolving aliases or constants).
89  * Also move intrinsics in-line.
90 */
91 char [] removeDuplicates(AbstractSyntaxTree tree)
92 {
93     int numdups=0;
94     char [] mapping = ""; // The new letter which this symbol should become,
95                           // or '!' if it is an intrinsic
96     for (int i=0; i<tree.symbolTable.length; ++i) {
97         char c = 'A'+i;
98         if (isBladeIntrinsic(tree.symbolTable[i].value)) {
99             ++numdups;
100             c = '!';
101         } else {
102             for (int j=0; j<i; ++j) {
103                 if (tree.symbolTable[i].value==tree.symbolTable[j].value) {
104                     ++numdups;
105                     c = ('A'+j);
106                     break;
107                 }
108             }
109         }
110         mapping ~= c;
111     }
112     if (numdups==0) return tree.expression;
113     char [] e = "";
114     for (int i=0; i<tree.expression.length;++i) {
115         char c = tree.expression[i];
116         if (c>='A' && c<='Z') {
117             if (mapping[c-'A']=='!') e~=tree.symbolTable[c-'A'].value;
118             else e~= mapping[c-'A'];
119         } else e~=c;
120     }
121     return e;
122 }
123
124 unittest {
125     AbstractSyntaxTree t = AbstractSyntaxTree("A+(B*C)", [Symbol("int", "125", 0),
126     Symbol("int", "7", 0), Symbol("int", "125", 0)]);
127     assert(removeDuplicates(t)=="A+(B*A)");
128 }
129
130 // Determine rank of a multidimensional index
131 int indexRank(char [] s)
132 {
133    int r=0;
134    int numbrack=0;
135    int paren = 0;
136    for(int i=1; i<s.length; ++i) {
137         if (s[i]=='(') ++paren;
138         else if (s[i]==')') --paren;
139         if (paren==0 && s[i]==']') { numbrack--; }
140         if (paren==0 && s[i]=='[') {
141             if (numbrack==0) ++r;
142             numbrack++;
143         }
144         if (paren==0 && numbrack==1 && s[i]==',') ++r; // commas increase the rank
145         if (paren==0 && numbrack==1 && s[i]=='.' && s[i-1]=='.') {
146             // if it's a slice, it does not increase the rank
147              r--;
148         }
149    }
150    return r;
151 }
152
153 public:
154 // Remove everything inside {} from expr, and create new variables for it.
155 RevisedExpression remapCompounds(char [] expr, char [] rank, Symbol[] symTable)
156 {
157     char [][] comp;
158     char [] used = ""; // which of the old variables are used; gives the new mapping
159                   // of the name, or - if not used.
160     for (int i=0; i<rank.length; ++i) used ~= "-";
161     char [] r;
162     char next = cast(char)('A' + rank.length);
163     char [] e = "";
164     for (int i=0; i<expr.length; ++i) {
165         if (expr[i]=='{') {
166             int k;
167             int bracecount=1;
168             for (k=i+1; bracecount>0; ++k) {
169                 if (expr[k]=='{') ++bracecount;
170                 if (expr[k]=='}') --bracecount;
171             }
172             --k;
173             char [] newexpr = expr[i+1..k]; // strip off the {}
174             int newi = k;
175             if (i>0 && k<expr.length-1 && expr[i-1]=='(' && expr[k+1]==')') {
176                 e = e[0..$-1]; // remove the last '('
177                 newi=k+1; // don't add ')'
178             }
179             // Check for a duplicate
180             int z;
181             for (z=0; z<comp.length && comp[z]!=newexpr; ++z) {}
182             if (z==comp.length) {
183                 e ~= next;
184                 ++next;
185                 comp ~= expr[i+1..k]; // strip off the {}
186                 if (expr[k-1]==']') {
187                     // it's a vector/matrix of some kind, with rank reduced
188                     // by indices. Can't just use exprRank, because the []
189                     // aren't wrapped by ().
190                     r ~= (rank[expr[i+1]-'A'] - indexRank(expr[i+1..k]));
191                 } else {
192                     // it's a scalar expression. Note that it could involve
193                     // a vector expression.
194                     r~='0';
195                 }
196             } else e ~= cast(char)('A'+z+rank.length);
197             i = newi;
198         } else {
199             e ~= expr[i];
200             if (expr[i]>='A' && expr[i]<='Z') used[expr[i]-'A']=expr[i];
201         }
202     }
203     // Create a mapping from old to new variable names
204
205     char [] old_ranks = "";
206     char [] mapping="";
207     char knt = 'A';
208     for (int i=0; i<used.length; ++i) {
209         if (used[i]!='-') {
210             mapping ~= ('A'+i);
211             old_ranks ~= rank[i];
212             used[i] = knt;
213             ++knt;
214         }
215     }
216     for (int i=0; i<r.length; ++i) {
217         mapping ~= ('A'+used.length+i);
218     }
219    // and set the expression to use the new names
220     char [] f = "";
221     for (int i=0; i<e.length; ++i) {
222         char c = e[i];
223         if (c>='A' && c<='Z') {
224             if ((c-'A')<rank.length) f ~= used[c-'A'];
225             else f~= (c-'A') - rank.length + knt;
226         } else f ~= c;
227     }
228     return RevisedExpression(expr, f, symTable, comp, old_ranks~r, mapping);
229 }
230
231 private:
232 RevisedExpression simplifyVectorExpression(char [] expr, char [] rank, Symbol[] symTable)
233 {
234     return remapCompounds(foldScalars(foldIndices(expr, rank), rank), rank, symTable);
235 }
236
237 unittest {
238     RevisedExpression e = simplifyVectorExpression("A+=(((D[B])*C)[B])", "2004",[]);
239     assert(e.rank=="202");
240     assert(e.compounds[0]=="D[B,B]");
241     assert(e.mapping=="ACE");
242     assert(e.expression== "A+=(C*B)");
243 }
244
245 // -----------------------------------------------------------
246
247 // Given an array of slices/indices where
248 //  slicing[0]= start of slice, or index
249 //  slicing[1]= end of slice, or "" if it's an index,
250 // combine everything into a single slicing expression.
251 char [] createMultiSlice(char [][2][] slicing)
252 {
253     char [] s="[";
254     for (int i=0; i<slicing.length;++i) {
255         if (i>0) s~= ",";
256         s ~= slicing[i][0];
257         if (slicing[i][1].length>0) s~=".." ~ slicing[i][1];
258     }
259     return s~"]";
260 }
261
262 // Combines all the indexing and slicing operations together (dimension reduction).
263 // Multiplication of sliced matrices and/or vectors is dimensionally
264 // reduced where possible (may even be converted into dot product).
265 // Returns the new expression. This eliminates all unnecessary slice operations.
266 // Furthermore, *any* value followed by '[' should be used as a new compound.
267 struct IndexFoldingVisitor {
268     alias typeof(*this) This;
269     alias char [] ReturnType;
270     char [] rank;
271     char [] dollar;
272     char [][2][] slicing; // the indexing which applies to this complete expr.
273 static:
274     ReturnType onVisitSymbol(This this_, char[] sym) {
275        if (this_.slicing.length==0) {
276            if (sym=="$") {
277                return this_.dollar;
278            }
279            else return sym;
280        } else {
281            assert(sym!="$" && this_.rank[sym[0]-'A']>'0', "Rank error " ~ sym);
282            // Note: Later, we'll want this to be a new terminal.
283            return sym ~ createMultiSlice(this_.slicing);
284        }
285     }
286     ReturnType onVisitFunction(This this_, char [] func, char [][] args) {
287         // Intrinsics have no effect on slicing
288         char [] result = "";
289         for (int i=0; i<args.length; ++i) {
290             if (i>0) result ~= ",";
291             result ~= wrapInParens(doVisit(this_,args[i]));
292         }
293         return func ~ "(" ~ result ~ ")";
294     }
295     ReturnType onVisitPrefix(This this_, char [] op, char [] expr) {
296         assert(this_.slicing.length==0, "BLADE ICE");
297         return op ~ wrapInParens(doVisit(this_, expr));
298     }
299     ReturnType onVisitPostfix(This this_, char [] op, char [] expr) {
300         assert(this_.slicing.length==0, "BLADE ICE");
301         return wrapInParens(doVisit(this_, expr)) ~ op;
302     }
303     // Includes multi-dimensional slicing and indexing.
304     ReturnType onVisitIndex(This this_, char [] base, char [][2][] slices) {
305         if (slices.length==0) { // []  -- has no effect.
306              return doVisit(this_, base);
307          }
308 //        printf("  %.*s --> %.*s %.*s\n", base, slices[0][0], slices[0][1]);
309         if (this_.slicing.length >0 && slices[$-1][1].length>0) {
310             // the new dimension block ends with a slice. This needs to be combined
311             // with the earliest existing dimension.
312             // * If the existing dimension is an index,
313             //   it might contain a dollar, which we need to replace.
314             // * If the existing dimension is a slice, the two slices will combine.
315             //
316             // The items inside the slice are top-level, ie have no slice or dollar.
317             char [] a = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank,"$",[]), slices[$-1][0]));
318             char [] b = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank,"$",[]), slices[$-1][1]));
319             char [] dollr = b ~ "-" ~ a;
320             char [] c = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank,dollr,[]), this_.slicing[0][0]));
321             char [][2][] newslice=[];
322              if (this_.slicing[0][1].length>0) { // slicing a slice
323                if (b=="$" && this_.slicing[0][1]=="$") {
324                    // very common special case, where both are sliced from end
325                     newslice ~= ["(" ~ a ~ "+" ~ c ~")","$"];
326                 } else {
327                    char [] d = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank,dollr,[]), this_.slicing[0][1]));
328                     newslice ~= ["(" ~ a ~ "+" ~ c~ ")", "(" ~ a ~ "+" ~ d~ ")"];
329                 }
330             } else {
331                 newslice ~= [a ~ "+" ~ c, ""];
332             }
333             if (slices.length>1) {
334                 // append other slices, if any.
335                 return doVisit(IndexFoldingVisitor(this_.rank, "$", slices[0..$-1] ~ newslice ~ this_.slicing[1..$]), base);
336             } else {
337                 return doVisit(IndexFoldingVisitor(this_.rank, "$",newslice ~ this_.slicing[1..$]), base);
338             }
339         } else { // just append them.
340             return doVisit(IndexFoldingVisitor(this_.rank, "$", slices ~ this_.slicing), base);
341         }
342     }
343     ReturnType onVisitBinaryOp(This this_, char [] op, char [] left, char [] right) {
344         int lrank = subexprRank(left, this_.rank);
345         int rrank = subexprRank(right, this_.rank);
346         char [] first="";
347         char [] second="";
348         if ((op=="*" || op=="*=") && this_.slicing.length>0) {
349             // If one of these is a matrix, the slicing gets interesting...
350             // .. extremely so for slicing of matrix chain multiplication.
351             if (lrank==0) {
352                 // All dimensions apply to right operand
353                 first = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), left));
354                 second = wrapInParens(doVisit(this_, right));
355             } else if (rrank==0) {
356                 // All dimensions apply to the left operand
357                 first = wrapInParens(doVisit(this_, left));
358                 second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right));
359             } else {
360                 assert(lrank>0 && rrank>0 && lrank<=2 && rrank<=2, "BLADE ICE: Tensor*tensor is unsupported");
361                 bool isDotProduct = false; // was it reduced to a dot product?
362
363                 // In the case of chained matrix multiplies, we can end up with an empty slice.
364                 if (this_.slicing.length>0 && this_.slicing[$-1][0]=="") {
365                     this_.slicing=this_.slicing[0..$-1];
366                 }
367                 if (lrank==2) {
368                     // First dimension applies to rows of the left operand
369                     // If it's a slice, it will be a strided slice -- unless
370                     // it comes from another matrix multiply, in which case the
371                     // stride will drop out. (A[x]*B is strided).
372                     char [][2][] newslice=[];
373                     newslice ~= this_.slicing[0];
374                     newslice ~= ["",""];
375                     first = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,newslice), left));
376                 } else {
377                     assert(this_.slicing.length==1, "BLADE ICE: Rank error");
378                     first = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), left));
379                 }
380                 if (lrank==2 && rrank==2) {
381                     // Matrix * matrix
382                     if (this_.slicing.length>1) {
383                         // Second dimension applies to the right operand.
384                         second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,this_.slicing[1..$]), right));
385                         if (this_.slicing[0][1].length==0 && this_.slicing[1][1].length==0) {
386                             // It's indices in both cases -- so it's a dot product.
387                             isDotProduct = true;
388                         }
389                     } else {
390                         second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right));
391                     }
392                 } else if (lrank==1 && rrank==2) {
393                     // vector * matrix, Matrix uses all the slicing
394                     second = wrapInParens(doVisit(this_, right));
395                     if (this_.slicing[0][1].length==0)  isDotProduct = true;
396                 } else if (lrank==2 && rrank==1) {
397                     // matrix * vector, vector is unsliced.
398                     second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right));
399                     if (this_.slicing[0][1].length==0)  isDotProduct = true;
400                 } else assert(0, "BLADE ICE");
401                 if (isDotProduct) {
402                   return "dot(" ~ first ~ "," ~ second ~ ")";
403                 }
404             }
405         } else { // not a multiplication
406             return wrapInParens(doVisit(this_, left)) ~ op ~ wrapInParens(doVisit(this_, right));
407         }
408         return first ~ op ~ second;
409     }
410 }
411
412 char [] foldIndices(char [] expr, char [] ranks)
413 {
414     return beginVisit(IndexFoldingVisitor(ranks,"$",[]), expr);
415 }
416
417 unittest {
418     assert(foldIndices("((A[C..D])+B)[($-E)]", "21000")=="(A[C+((D-C)-E)])+(B[($-E)])");
419     assert(foldIndices("(A[C])[D]", "3100")=="A[C,D]");
420     assert(foldIndices("(A[B..C])[D]", "3000")=="A[B+D]");
421     assert(foldIndices("(A[B])[C..D]", "3000")=="A[B,C..D]");
422     assert(foldIndices("((A[$])[(C-$)])[D]", "3000")=="A[$,(C-$),D]");
423     assert(foldIndices("(A[B..$])[C..$]", "3000")=="A[(B+C)..$]");
424     assert(foldIndices("((A[])[C..$])[]", "3000")=="A[C..$]");
425     assert(foldIndices("((A[])[(B[C])..$])[]", "3100")=="A[(B[C])..$]");
426     assert(foldIndices("A[,B..$,C]", "300")=="A[,B..$,C]");
427     // Multidimensional slicing
428     assert(foldIndices("(C*((A*B)[C]))[D]", "2200")=="C*dot((A[C,]),(B[D]))");
429     assert(foldIndices("(A*B)[C..D,D]", "2200")=="(A[C..D,])*(B[D])");
430     assert(foldIndices("(A*B)[C..D]", "2200")=="(A[C..D,])*B");
431     assert(foldIndices("(A*B)[C..D]", "2100")=="(A[C..D,])*B");
432     assert(foldIndices("(A*B)[C..D]", "1200")=="A*(B[C..D])");
433     assert(foldIndices("(A*B)[C]", "120")=="dot(A,(B[C]))");
434
435     assert(foldIndices("((A*B)*C)[D]", "2220")=="((A[D,])*B)*C");
436     assert(foldIndices("((A+B)*C)[D]", "2220")=="((A[D,])+(B[D,]))*C");
437     assert(foldIndices("((D*A)*B)[C]", "2100")=="dot((D*(A[C,])),B)");
438     assert(foldIndices("(((A*B)*C)[D..E])[D]", "12200")=="dot((A*B),(C[D+D]))");
439     assert(foldIndices("A+=(((D[B])*C)[B])", "2004")=="A+=((D[B,B])*C)");
440     assert(foldIndices("dot(A,(A*dot(A,A)))","1")=="dot(A,(A*dot(A,A)))");
441 }
442
443 struct ScalarFold
444 {
445     char [] expr;       // vector or matrix expression; empty for a pure scalar expression
446     char [] multiplier; // scalar multiply of the entire expression. or "-" for unary minus
447 }
4