root/trunk/blade/BladeRank.d

Revision 187, 8.3 kB (checked in by Don Clugston, 5 months ago)

Added prod(). Use .ptr to get raw data, so it works with Bill Baxter's ArrayView?.

  • Property svn:eol-style set to native
Line 
1 //  Written in the D programming language 1.0
2 //  Part of BLADE : Basic Linear Algebra D Expressions
3 /**
4 *  Determine the tensor rank of an expression.
5 *
6 */
7
8 module blade.BladeRank;
9 private import blade.BladeVisitor;
10
11 public:
12 // return true if the given symbol has a non-zero stride.
13 // This happens whenever there is a slice (including a complete slice)
14 // where there is a later incomplete slice or index.
15 //
16 bool isStrided(char [] s)
17 {
18    assert(s[$-1]==']', "BLADE ICE");
19    if (s[$-2..$]=="[]") return false;
20    int numbrack=0;
21    int paren = 0;
22    bool hasSliced=false;
23    bool maybeSlice = false;
24    bool startIndex=false;
25    for(int i=1; i<s.length; ++i) {
26         if (s[i]=='(') ++paren;
27         else if (s[i]==')') --paren;
28         if (paren==0 && s[i]==']') {
29             if (startIndex && hasSliced) return true;
30             numbrack--;
31             if (s[i-1]=='[') { startIndex=false; }
32         }
33         if (paren==0 && s[i]=='[') {
34             startIndex = true;
35             maybeSlice = false;
36             numbrack++;
37         }
38         if (paren==0 && numbrack==1 && s[i]==',') {
39             if (hasSliced && startIndex) return true;
40             if (maybeSlice) hasSliced = true;
41             startIndex = true;
42         }
43         if (paren==0 && numbrack==2 && s[i]==',') { startIndex=false; hasSliced=true; }
44         if (paren==0 && numbrack==1 && s[i]=='.' && s[i-1]=='.') {
45              startIndex = false; // [..] slices don't move the index to the next value.
46              maybeSlice = true;
47         }
48    }
49    return false;
50 }
51
52 private:
53 unittest {
54     assert(!isStrided("A[3..5][]"));
55     assert(!isStrided("A[2..7][3]"));
56     assert(!isStrided("A[][2]"));
57     assert(!isStrided("A[2,[2,7]]"));
58     assert(isStrided("A[[2,7],3]"));
59     assert(isStrided("C[D..$,D]"));
60     assert(!isStrided("A[7][B[[1,3],2]..6]"));
61 }
62
63 public:
64 /** Returns the (tensor) rank of the expression expr.
65  * A negative number will be returned if an error is detected.
66  * TODO: Should warn of expressions with no effect (ie, with no =).
67  *
68  * Params:
69  * expr   Placeholder expression (A,B,... correspond to tuple[0],[1],...)
70  * rank   The rank of each tuple member A, B, C, ...
71  */
72 int exprRank(char [] expr, char [] ranks)
73 {
74     return beginVisit(RankVisitor(ranks), expr);
75 }
76
77 int subexprRank(char [] expr, char [] ranks)
78 {
79     return doVisit(RankVisitor(ranks), expr);
80 }
81
82
83 enum RankError : int {
84     UnsupportedOperation = -1,
85     RankIncrement = -2,
86     AttemptToIndexAScalar = -3,
87     NonScalarIndex = -4,
88     NonScalarSlice = -5,
89     DotDotExpected = -6,
90     CommaExpected = -7,
91     RankMismatch = -8,
92     RankMismatchConcatenation = -9,
93     RankMismatchDotProduct = -10,
94     ExtraCharsAfterArrayLiteral = -11,
95     ArrayLiteralRankMismatch = -12,
96     AbsDimensionality = -13
97 }
98
99 char [] getRankErrorText(int err)
100 {
101     return ["Unsupported vector operation",
102             "Can only use ++ and -- on scalars",
103             "Cannot index a scalar",
104             "Vector can only be indexed by a scalar",
105             "Vector can only be sliced by a scalar",
106             ".. expected",
107             ", expected",
108             "Dimensionality mismatch (addition or subtraction)",
109             "Dimensionality mismatch in concatenation",
110             "Dimenionality error in dot product",
111             "Extra characters after array literal",
112             "Rank mismatch in array literal",
113             "Can only use abs, sqrt with scalar or vector"
114             ][-err-1];
115 }
116
117 struct RankVisitor {
118     alias typeof(*this) This;
119     alias int ReturnType;
120     char [] rank;
121 static:
122     ReturnType onVisitSymbol(This this_, char [] sym) {
123         if (sym=="$") return 0;
124         return this_.rank[sym[0]-'A']-'0';
125     }
126     ReturnType onVisitFunction(This this_, char [] func, char [][] args) {
127         switch(func) {
128         case "dot":
129             if (args.length!=2) return RankError.CommaExpected;
130             auto lrank = doVisit(this_,args[0]);
131             if (lrank<0) return lrank; // propagate errors
132             auto rrank = doVisit(this_, args[1]);
133             if (rrank<0) return rrank; // propagate errors
134             if (lrank!=1 || rrank!=1) return RankError.RankMismatchDotProduct;
135             return 0;
136         case "sum":
137         case "prod":
138             auto lrank = doVisit(this_,args[0]);
139             if (lrank<0) return lrank; // propagate errors
140             return 0;
141         case "abs":
142         case "sqrt":
143             auto lrank = doVisit(this_,args[0]);
144             if (lrank>1) return RankError.AbsDimensionality;
145             return 0;
146         default:
147             assert(0, "BLADE ICE: Unsupported function:" ~ func);
148             return 0;
149         }
150     }
151     ReturnType onVisitPrefix(This this_, char [] op, char [] expr) {
152         if (op=="+" || op=="-") return doVisit(this_, expr);
153         auto r = doVisit(this_, expr);
154         if (r<=0) return r;
155         return RankError.RankIncrement;
156     }
157     ReturnType onVisitPostfix(This this_, char [] op, char [] expr) {
158         auto r = doVisit(this_, expr);
159         if (r<=0) return r;
160         return RankError.RankIncrement;
161     }
162     // Includes multi-dimensional slicing and indexing.
163     ReturnType onVisitIndex(This this_, char [] base, char [][2][] slices) {
164         int totrank = doVisit(this_, base);
165         for(int i=0; i<slices.length; ++i) {
166             int r = doVisit(this_,slices[i][0]);
167             if (r!=0) return (r<0)? r :RankError.NonScalarIndex;
168             if (slices[i][1]==""){
169                 --totrank;
170             } else {
171                 r = doVisit(this_,slices[i][1]);
172                 if (r!=0) return (r<0)?r:RankError.NonScalarSlice;
173             }
174         }
175         return totrank;
176     }
177     ReturnType onVisitBinaryOp(This this_, char [] op, char [] left, char [] right) {
178         int lrank = doVisit(this_, left);
179         int rrank = doVisit(this_, right);
180         if (rrank<0) return rrank; // propagate errors
181         if (lrank<0) return lrank; // propagate errors
182         if (op=="+" || op=="-" || op=="=" || op=="+=" || op=="-=") {
183             if (lrank!=rrank) {
184                 return RankError.RankMismatch;
185             }
186             return lrank;
187         }
188         if (op=="~") { // concatentating scalars and vectors, or vectors and matrices, is permitted
189             if (lrank==rrank || lrank==(rrank+1) || rrank==(lrank+1))
190                 return (lrank>rrank)? lrank: rrank;
191             else return RankError.RankMismatchConcatenation;
192         }
193         if (op=="~=") { // can do vector~=scalar, but not scalar~=vector.
194             if (lrank==rrank || lrank==(rrank+1))  return lrank;
195             else return RankError.RankMismatchConcatenation;
196         }
197         // For / and /=, only scalar operations are permitted
198         if ((op=="*=" || op=="/=") && rrank==0) return lrank;
199         if (op=="*=" && lrank==2 && rrank==2) return lrank; // mat *= mat
200         if (op=="*=" && lrank==1 && rrank==2) return lrank; // vec *= mat
201         if (op=="*" || op=="/") {
202             if (lrank==0) return rrank;
203             if (rrank==0) return lrank;
204             if (lrank==2 && rrank==2) return lrank;
205             if (lrank==2 && rrank==2) return lrank;
206             if (lrank+rrank==3) return 1; // vec*mat or mat*vec
207         }
208         // All other operations are only valid for scalars.
209         if (lrank==0 && rrank==0) return 0;
210         return RankError.UnsupportedOperation;
211     }
212 }
213
214 unittest {
215     assert(exprRank("(A[B..C])[C]", "300")==2);
216     assert(exprRank("A+=(A[B..C])", "300")==3);
217
218     assert(exprRank("A+(B*C)", "000")==0);
219     assert(exprRank("A=(B*C)", "202")==2);
220     assert(exprRank("B*=(C*A)", "010")==1);
221     assert(exprRank("(A[])+B", "22")==2);
222     assert(exprRank("D+=((A+C)*B)", "2022")==2);
223     assert(exprRank("D+=((A&C)*B)", "0101")==1);
224
225     assert(exprRank("C~=(((A[B])[B])~C)", "302")==2);
226     assert(exprRank("((D[E])[E])+(-((C[B])[B..E]))", "202300")==1);
227
228     assert(exprRank("A+((((++B)+D)--)*C)", "1010")==1);
229
230     assert(exprRank("C+=(A[B])", "302")==2);
231     assert(exprRank("dot(A)", "1")==RankError.CommaExpected);
232     assert(exprRank("dot(A,B)", "10")==RankError.RankMismatchDotProduct);
233
234     assert(exprRank("dot(B,(A*(dot(B,B))))", "11")==0);
235
236     assert(exprRank("prod(A*B)", "10")==0);
237
238     assert(exprRank("A[B,B,B]", "60")==3);
239     assert(exprRank("A[B,B,C,B]", "600")==2);
240     assert(exprRank("A+=(B[C..$])", "110")==1);
241     assert(exprRank("A+=(B[C,D..$])", "2300")==2);
242
243     // bug fixes:
244     assert(exprRank("(A[B..$,C])+=D", "2001")==1);
245 }
Note: See TracBrowser for help on using the browser.