| 9 | | |
|---|
| 10 | | // -------------- |
|---|
| 11 | | // Ranklist functions |
|---|
| 12 | | |
|---|
| 13 | | // Count the number of vectors |
|---|
| 14 | | int countVectors(char[] ranklist) |
|---|
| 15 | | { |
|---|
| 16 | | int numVecs=0; |
|---|
| 17 | | for (int i=0; i<ranklist.length; ++i) { |
|---|
| 18 | | if (ranklist[i]=='1') ++numVecs; |
|---|
| 19 | | } |
|---|
| 20 | | return numVecs; |
|---|
| 21 | | } |
|---|
| 22 | | |
|---|
| 23 | | int vectorNum(char [] ranklist, char var) |
|---|
| 24 | | { |
|---|
| 25 | | int numVecs=0; |
|---|
| 26 | | for (int i=0; i<var-'A'; ++i) { |
|---|
| 27 | | if (ranklist[i]=='1') ++numVecs; |
|---|
| 28 | | } |
|---|
| 29 | | return numVecs; |
|---|
| 30 | | } |
|---|
| 31 | | |
|---|
| 32 | | int scalarNum(char [] ranklist, char var) |
|---|
| 33 | | { |
|---|
| 34 | | int k=0; |
|---|
| 35 | | for (int i=0; i<var-'A'; ++i) { |
|---|
| 36 | | if (ranklist[i]=='0') ++k; |
|---|
| 37 | | } |
|---|
| 38 | | return k; |
|---|
| 39 | | } |
|---|
| 40 | | |
|---|
| 41 | | int realScalarNum(char [][] typelist, char [] ranklist, char var) |
|---|
| 42 | | { |
|---|
| 43 | | int k=0; |
|---|
| 44 | | for (int i=0; i<var-'A'; ++i) { |
|---|
| 45 | | if (ranklist[i]=='0' && typelist[i]=="real") ++k; |
|---|
| 46 | | } |
|---|
| 47 | | return k; |
|---|
| 48 | | } |
|---|
| 49 | | |
|---|
| 50 | | /** Return the length of a sub-expression |
|---|
| 51 | | * The sub-expression must be |
|---|
| 52 | | * - a single character (eg "X"), OR |
|---|
| 53 | | * - a lower-case function (eg "a(B,(C*D))"), OR |
|---|
| 54 | | * - an expression in parenthesis, OR |
|---|
| 55 | | * - an array literal |
|---|
| 56 | | */ |
|---|
| 57 | | int exprLength(char [] s) |
|---|
| 58 | | { |
|---|
| 59 | | if ((s[0]>='A' && s[0]<='Z') || s[0]=='$') return 0; |
|---|
| 60 | | int i = 0; |
|---|
| 61 | | if (s[0]>='a' && s[0]<='z'){ // function call |
|---|
| 62 | | i=1; // next char is a parenthesis - so the code |
|---|
| 63 | | // below works |
|---|
| 64 | | } |
|---|
| 65 | | int numParens = 0; |
|---|
| 66 | | int numBrack = 0; |
|---|
| 67 | | for (; i<s.length; ++i) { |
|---|
| 68 | | if (s[i]=='(') ++numParens; |
|---|
| 69 | | if (s[i]==')') numParens--; |
|---|
| 70 | | if (s[i]=='[') ++numBrack; |
|---|
| 71 | | if (s[i]==']') --numBrack; |
|---|
| 72 | | if (numParens == 0 && numBrack == 0) { |
|---|
| 73 | | return i; |
|---|
| 74 | | } |
|---|
| 75 | | } |
|---|
| 76 | | assert(0, "BLADE ICE: " ~ s); |
|---|
| 77 | | } |
|---|
| 78 | | |
|---|
| 79 | | /** Determine the (tensor) rank of a sub-expression |
|---|
| 80 | | * The sub-expression must be a single character, or an expression in |
|---|
| 81 | | * parentheses. |
|---|
| 82 | | */ |
|---|
| 83 | | int subexprRank(char [] expr, char [] rank) |
|---|
| 84 | | { |
|---|
| 85 | | if (expr.length==1) { |
|---|
| 86 | | if (expr=="$") return 0; |
|---|
| 87 | | assert(expr[0]>='A' && expr[0]<='Z', "BLADE ICE: " ~ expr); |
|---|
| 88 | | return rank[expr[0]-'A']-'0'; |
|---|
| 89 | | } |
|---|
| 90 | | if (expr[0]=='d') return 0; |
|---|
| 91 | | assert(expr[0]=='(', "BLADE ICE:" ~ expr); |
|---|
| 92 | | // strip off the parentheses |
|---|
| 93 | | return exprRank(expr[1..$-1], rank); |
|---|
| 94 | | } |
|---|
| 95 | | |
|---|
| 96 | | enum RankError : int { |
|---|
| 97 | | UnsupportedOperation = -1, |
|---|
| 98 | | RankIncrement = -2, |
|---|
| 99 | | AttemptToIndexAScalar = -3, |
|---|
| 100 | | NonScalarIndex = -4, |
|---|
| 101 | | NonScalarSlice = -5, |
|---|
| 102 | | DotDotExpected = -6, |
|---|
| 103 | | CommaExpected = -7, |
|---|
| 104 | | RankMismatch = -8, |
|---|
| 105 | | RankMismatchConcatenation = -9, |
|---|
| 106 | | RankMismatchDotProduct = -10, |
|---|
| 107 | | ExtraCharsAfterArrayLiteral = -11, |
|---|
| 108 | | ArrayLiteralRankMismatch = -12 |
|---|
| 109 | | } |
|---|
| 110 | | |
|---|
| 111 | | char [] getRankErrorText(int err) |
|---|
| 112 | | { |
|---|
| 113 | | return ["Unsupported vector operation", |
|---|
| 114 | | "Can only use ++ and -- on scalars", |
|---|
| 115 | | "Cannot index a scalar", |
|---|
| 116 | | "Vector can only be indexed by a scalar", |
|---|
| 117 | | "Vector can only be sliced by a scalar", |
|---|
| 118 | | ".. expected", |
|---|
| 119 | | ", expected", |
|---|
| 120 | | "Dimensionality mismatch (addition or subtraction)", |
|---|
| 121 | | "Dimensionality mismatch in concatenation", |
|---|
| 122 | | "Dimenionality error in dot product" |
|---|
| 123 | | "Extra characters after array literal" |
|---|
| 124 | | "Rank mismatch in array literal" |
|---|
| 125 | | ][-err-1]; |
|---|
| 126 | | } |
|---|
| 127 | | |
|---|
| 128 | | /** Returns the (tensor) rank of the expression expr. |
|---|
| 129 | | * A negative number will be returned if an error is detected. |
|---|
| 130 | | * |
|---|
| 131 | | * Params: |
|---|
| 132 | | * expr Placeholder expression (A,B,... correspond to tuple[0],[1],...) |
|---|
| 133 | | * rank The rank of each tuple member A, B, C, ... |
|---|
| 134 | | */ |
|---|
| 135 | | int exprRank(char [] expr, char [] rank) |
|---|
| 136 | | { |
|---|
| 137 | | // BUG: also need to deal with comma, ?:, &&, ||, is, !is, in, |
|---|
| 138 | | // unary &, unary ! |
|---|
| 139 | | |
|---|
| 140 | | if (expr.length>3 && expr[0..2]=="d(") { // dot product |
|---|
| 141 | | int x = exprLength(expr[2..$-1]); |
|---|
| 142 | | if (expr[x+3]!=',') return RankError.CommaExpected; |
|---|
| 143 | | int lrank = subexprRank(expr[2..x+3], rank); |
|---|
| 144 | | if (lrank<0) return lrank; // propagate errors |
|---|
| 145 | | int rrank = subexprRank(expr[x+4..$-1], rank); |
|---|
| 146 | | if (rrank<0) return rrank; // propagate errors |
|---|
| 147 | | if (lrank!=1 || rrank!=1) return RankError.RankMismatchDotProduct; |
|---|
| 148 | | return 0; |
|---|
| 149 | | } |
|---|
| 150 | | |
|---|
| 151 | | // Deal with ++ and --. |
|---|
| 152 | | if (expr.length>2 && (expr[0..2]=="++" || expr[0..2]=="--")) { |
|---|
| 153 | | int r = subexprRank(expr[2..$], rank); |
|---|
| 154 | | if (r!=0) return RankError.RankIncrement; |
|---|
| 155 | | return r; |
|---|
| 156 | | } |
|---|
| 157 | | if (expr.length>2 && (expr[$-2..$]=="++" || expr[$-2..$]=="--")) { |
|---|
| 158 | | int r = subexprRank(expr[0..$-2], rank); |
|---|
| 159 | | if (r!=0) return RankError.RankIncrement; |
|---|
| 160 | | return r; |
|---|
| 161 | | } |
|---|
| 162 | | // Deal with unary operators |
|---|
| 163 | | if (expr[0]=='+' || expr[0]=='-') return subexprRank(expr[1..$], rank); |
|---|
| 164 | | |
|---|
| 165 | | int x = exprLength(expr); |
|---|
| 166 | | if (expr[0]=='[') { // array literal |
|---|
| 167 | | if (x!=expr.length-1) return RankError.ExtraCharsAfterArrayLiteral; |
|---|
| 168 | | expr = expr[1..$-1]; |
|---|
| 169 | | x = exprLength(expr); |
|---|
| 170 | | int lrank = subexprRank(expr[0..x+1], rank); |
|---|
| 171 | | while (x<expr.length-1) { |
|---|
| 172 | | if (expr[x+1]!=',') return RankError.CommaExpected; |
|---|
| 173 | | expr = expr[x+2.. $]; |
|---|
| 174 | | x = exprLength(expr); |
|---|
| 175 | | int rrank = subexprRank(expr[0..x+1], rank); |
|---|
| 176 | | if (lrank!=rrank) return RankError.ArrayLiteralRankMismatch; |
|---|
| 177 | | } |
|---|
| 178 | | return lrank+1; |
|---|
| 179 | | } |
|---|
| 180 | | int y = x+1; |
|---|
| 181 | | // Deal with shifts, op=, and NCEG operators |
|---|
| 182 | | while (expr[y+1]=='<' || expr[y+1]=='>' || expr[y+1]=='=') ++y; |
|---|
| 183 | | |
|---|
| 184 | | char [] op = expr[x+1..y+1]; |
|---|
| 185 | | char [] left = expr[0..x+1]; |
|---|
| 186 | | char [] right = expr[y+1..$]; |
|---|
| 187 | | if (expr[x+1]=='[') right = expr[y+1..$-1]; // drop off the ']'. |
|---|
| 188 | | int lrank = subexprRank(left, rank); |
|---|
| 189 | | if (lrank<0) return lrank; // propagate errors |
|---|
| 190 | | if (op=="[") { |
|---|
| 191 | | if (lrank==0) return RankError.AttemptToIndexAScalar; |
|---|
| 192 | | if (right.length==0) { |
|---|
| 193 | | return lrank; // was [], which doesn't change the length |
|---|
| 194 | | } |
|---|
| 195 | | int z = exprLength(right); |
|---|
| 196 | | if (z+1 == right.length) { |
|---|
| 197 | | // indexing -- reduces the rank by 1. |
|---|
| 198 | | int rrank = subexprRank(right, rank); |
|---|
| 199 | | if (rrank!=0) return RankError.NonScalarIndex; |
|---|
| 200 | | return lrank - 1; |
|---|
| 201 | | } else { |
|---|
| 202 | | int totrank = lrank; |
|---|
| 203 | | do { |
|---|
| 204 | | int rrank = subexprRank(right[0..z+1], rank); |
|---|
| 205 | | if (z==right.length-1 || right[z+1]==',') { |
|---|
| 206 | | // allow rank of 1 to be a slice operation |
|---|
| 207 | | // (so A[1,[2,$-1], $] is possible). |
|---|
| 208 | | if (rrank<0) return rrank; |
|---|
| 209 | | if (rrank>1) return RankError.NonScalarIndex; |
|---|
| 210 | | if (rrank==0) --totrank; |
|---|
| 211 | | if (z==right.length-1) return totrank; |
|---|
| 212 | | } else if (!(z+3 < right.length && right[z+1..z+3]=="..")) { |
|---|
| 213 | | return RankError.DotDotExpected; |
|---|
| 214 | | } else {// slice |
|---|
| 215 | | char [] start = right[0..z+1]; |
|---|
| 216 | | char [] end = right[z+3..$]; |
|---|
| 217 | | int startrank = subexprRank(start, rank); |
|---|
| 218 | | if (startrank<0) return startrank; |
|---|
| 219 | | z = exprLength(end); |
|---|
| 220 | | int endrank = subexprRank(end[0..z+1], rank); |
|---|
| 221 | | if (endrank<0) return endrank; |
|---|
| 222 | | if (startrank!=0 || endrank!=0) return RankError.NonScalarSlice; |
|---|
| 223 | | right = end; |
|---|
| 224 | | } |
|---|
| 225 | | if (z==right.length-1) return totrank; |
|---|
| 226 | | right = right[z+2..$]; |
|---|
| 227 | | z = exprLength(right); |
|---|
| 228 | | //assert(0, right[0..z+1]); |
|---|
| 229 | | }while (true); |
|---|
| 230 | | } |
|---|
| 231 | | } |
|---|
| 232 | | int rrank = subexprRank(right, rank); |
|---|
| 233 | | if (rrank<0) return rrank; // propagate errors |
|---|
| 234 | | if (op=="+" || op=="-" || op=="=" || op=="+=" || op=="-=") { |
|---|
| 235 | | if (lrank!=rrank) { |
|---|
| 236 | | return RankError.RankMismatch; |
|---|
| 237 | | } |
|---|
| 238 | | return lrank; |
|---|
| 239 | | } |
|---|
| 240 | | if (op=="~") { // concatentating scalars and vectors, or vectors and matrices, is permitted |
|---|
| 241 | | if (lrank==rrank || lrank==(rrank+1) || rrank==(lrank+1)) |
|---|
| 242 | | return (lrank>rrank)? lrank: rrank; |
|---|
| 243 | | else return RankError.RankMismatchConcatenation; |
|---|
| 244 | | } |
|---|
| 245 | | if (op=="~=") { // can do vector~=scalar, but not scalar~=vector. |
|---|
| 246 | | if (lrank==rrank || lrank==(rrank+1)) return lrank; |
|---|
| 247 | | else return RankError.RankMismatchConcatenation; |
|---|
| 248 | | } |
|---|
| 249 | | // For *, /, only scalar operations are permitted |
|---|
| 250 | | if ((op=="*=" || op=="/=") && rrank==0) return lrank; |
|---|
| 251 | | if (op=="*" || op=="/") { |
|---|
| 252 | | if (lrank==0) return rrank; |
|---|
| 253 | | if (rrank==0) return lrank; |
|---|
| 254 | | } |
|---|
| 255 | | // All other operations are only valid for scalars. |
|---|
| 256 | | if (lrank==0 && rrank==0) return 0; |
|---|
| 257 | | return RankError.UnsupportedOperation; |
|---|
| 258 | | } |
|---|
| 259 | | |
|---|
| 260 | | unittest { |
|---|
| 261 | | assert(exprRank("A+((((++B)+D)--)*C)", "1010")==1); |
|---|
| 262 | | assert(exprRank("A+(B*C)", "000")==0); |
|---|
| 263 | | assert(exprRank("A=(B*C)", "202")==2); |
|---|
| 264 | | assert(exprRank("B*=(C*A)", "010")==1); |
|---|
| 265 | | assert(exprRank("(A[])+B", "22")==2); |
|---|
| 266 | | assert(exprRank("D+=((A+C)*B)", "2022")==2); |
|---|
| 267 | | assert(exprRank("D+=((A&C)*B)", "0101")==1); |
|---|
| 268 | | assert(exprRank("A+=(A[B..C])", "300")==3); |
|---|
| 269 | | assert(exprRank("C+=(A[B])", "302")==2); |
|---|
| 270 | | assert(exprRank("C~=(((A[B])[B])~C)", "302")==2); |
|---|
| 271 | | assert(exprRank("((D[E])[E])+(-((C[B])[B..E]))", "202300")==1); |
|---|
| 272 | | assert(subexprRank("((A[B..C])[C])", "300")==2); |
|---|
| 273 | | assert(exprRank("d(A)", "1")==RankError.CommaExpected); |
|---|
| 274 | | assert(exprRank("d(A,B)", "10")==RankError.RankMismatchDotProduct); |
|---|
| 275 | | assert(exprRank("d(B,(A*(d(B,B))))", "11")==0); |
|---|
| 276 | | assert(exprRank("A[B,B,B]", "60")==3); |
|---|
| 277 | | assert(exprRank("A[B,B,C,B]", "600")==2); |
|---|
| 278 | | assert(exprRank("A[B,([B,C]),B]", "600")==4); |
|---|
| 279 | | assert(exprRank("A[B,(([B,C])[B]),B]", "600")==3); |
|---|
| 280 | | assert(exprRank("A+=(B[C..$])", "110")==1); |
|---|
| 281 | | assert(exprRank("A+=(B[C,D..$])", "2300")==2); |
|---|
| 282 | | } |
|---|
| 283 | | |
|---|
| 284 | | |
|---|
| 285 | | // Return true if the entire expression contains a multiplication by a scalar |
|---|
| 286 | | bool hasScalarMultiply(char [] expr, char [] rank) |
|---|
| 287 | | { |
|---|
| 288 | | if (expr.length>2 && (expr[0..2]=="++" || expr[0..2]=="--" || expr[$-2..$]=="++" || expr[$-2..$]=="--")) { |
|---|
| 289 | | return false; |
|---|
| 290 | | } |
|---|
| 291 | | if (expr[0]=='+' || expr[0]=='-') return hasScalarMultiply(expr[1..$], rank); |
|---|
| 292 | | |
|---|
| 293 | | int x = exprLength(expr); |
|---|
| 294 | | int y = x+1; |
|---|
| 295 | | assert(y < expr.length, "BLADE BUG:" ~ expr); |
|---|
| 296 | | // Deal with shifts, op=, and NCEG operators |
|---|
| 297 | | while (expr[y+1]=='<' || expr[y+1]=='>' || expr[y+1]=='=') ++y; |
|---|
| 298 | | |
|---|
| 299 | | char [] op = expr[x+1..y+1]; |
|---|
| 300 | | char [] left = expr[0..x+1]; |
|---|
| 301 | | char [] right = expr[y+1..$]; |
|---|
| 302 | | if (op=="[") { |
|---|
| 303 | | // (A)[C] can still have a multiply by scalar, if A contains a |
|---|
| 304 | | // multiply. |
|---|
| 305 | | if (left.length==1) return false; |
|---|
| 306 | | return hasScalarMultiply(left[1..$], rank); |
|---|
| 307 | | } |
|---|
| 308 | | if (op=="/") return true; |
|---|
| 309 | | if (op!="*" && op!="/") { |
|---|
| 310 | | if (left.length==1 || right.length==1) return false; |
|---|
| 311 | | // (A+B) could contain a multiply by scalar, if both A and B |
|---|
| 312 | | // contain multiplies. |
|---|
| 313 | | return hasScalarMultiply(left[1..$-1], rank) && hasScalarMultiply(right[1..$-1], rank); |
|---|
| 314 | | } |
|---|
| 315 | | // it's not true for matrix*matrix multiplies. |
|---|
| 316 | | if (subexprRank(left, rank)==0) return true; |
|---|
| 317 | | return subexprRank(right, rank) == 0; |
|---|
| 318 | | } |
|---|
| 319 | | |
|---|
| 320 | | unittest { |
|---|
| 321 | | assert(hasScalarMultiply("(A*B)+(B*C)","101")); |
|---|
| 322 | | assert(!hasScalarMultiply("(A*B)-(C*C)","101")); |
|---|
| 323 | | assert(!hasScalarMultiply("A+(B*C)","101")); |
|---|
| 324 | | assert(hasScalarMultiply("(A/B)-((A*B)+(C*B))","101")); |
|---|
| 325 | | assert(!hasScalarMultiply("A[B]","20")); |
|---|
| 326 | | assert(!hasScalarMultiply("(C[B])[B..A]","002") ); |
|---|
| 327 | | } |
|---|
| | 9 | private import blade.BladeVisitor; |
|---|
| | 62 | |
|---|
| | 63 | public: |
|---|
| | 64 | /** Returns the (tensor) rank of the expression expr. |
|---|
| | 65 | * A negative number will be returned if an error is detected. |
|---|
| | 66 | * |
|---|
| | 67 | * Params: |
|---|
| | 68 | * expr Placeholder expression (A,B,... correspond to tuple[0],[1],...) |
|---|
| | 69 | * rank The rank of each tuple member A, B, C, ... |
|---|
| | 70 | */ |
|---|
| | 71 | int exprRank(char [] expr, char [] ranks) |
|---|
| | 72 | { |
|---|
| | 73 | return beginVisit(RankVisitor(ranks), expr); |
|---|
| | 74 | } |
|---|
| | 75 | |
|---|
| | 76 | int subexprRank(char [] expr, char [] ranks) |
|---|
| | 77 | { |
|---|
| | 78 | return doVisit(RankVisitor(ranks), expr); |
|---|
| | 79 | } |
|---|
| | 80 | |
|---|
| | 81 | |
|---|
| | 82 | enum RankError : int { |
|---|
| | 83 | UnsupportedOperation = -1, |
|---|
| | 84 | RankIncrement = -2, |
|---|
| | 85 | AttemptToIndexAScalar = -3, |
|---|
| | 86 | NonScalarIndex = -4, |
|---|
| | 87 | NonScalarSlice = -5, |
|---|
| | 88 | DotDotExpected = -6, |
|---|
| | 89 | CommaExpected = -7, |
|---|
| | 90 | RankMismatch = -8, |
|---|
| | 91 | RankMismatchConcatenation = -9, |
|---|
| | 92 | RankMismatchDotProduct = -10, |
|---|
| | 93 | ExtraCharsAfterArrayLiteral = -11, |
|---|
| | 94 | ArrayLiteralRankMismatch = -12 |
|---|
| | 95 | } |
|---|
| | 96 | |
|---|
| | 97 | char [] getRankErrorText(int err) |
|---|
| | 98 | { |
|---|
| | 99 | return ["Unsupported vector operation", |
|---|
| | 100 | "Can only use ++ and -- on scalars", |
|---|
| | 101 | "Cannot index a scalar", |
|---|
| | 102 | "Vector can only be indexed by a scalar", |
|---|
| | 103 | "Vector can only be sliced by a scalar", |
|---|
| | 104 | ".. expected", |
|---|
| | 105 | ", expected", |
|---|
| | 106 | "Dimensionality mismatch (addition or subtraction)", |
|---|
| | 107 | "Dimensionality mismatch in concatenation", |
|---|
| | 108 | "Dimenionality error in dot product" |
|---|
| | 109 | "Extra characters after array literal" |
|---|
| | 110 | "Rank mismatch in array literal" |
|---|
| | 111 | ][-err-1]; |
|---|
| | 112 | } |
|---|
| | 113 | |
|---|
| | 114 | struct RankVisitor { |
|---|
| | 115 | alias typeof(*this) This; |
|---|
| | 116 | alias int ReturnType; |
|---|
| | 117 | char [] rank; |
|---|
| | 118 | static: |
|---|
| | 119 | ReturnType onVisitSymbol(This this_, char sym) { |
|---|
| | 120 | if (sym=='$') return 0; |
|---|
| | 121 | return this_.rank[sym-'A']-'0'; |
|---|
| | 122 | } |
|---|
| | 123 | ReturnType onVisitFunction(This this_, char [] func, char [][] args) { |
|---|
| | 124 | if (func=="d") { // dot product |
|---|
| | 125 | if (args.length!=2) return RankError.CommaExpected; |
|---|
| | 126 | auto lrank = doVisit(this_,args[0]); |
|---|
| | 127 | if (lrank<0) return lrank; // propagate errors |
|---|
| | 128 | auto rrank = doVisit(this_, args[1]); |
|---|
| | 129 | if (rrank<0) return rrank; // propagate errors |
|---|
| | 130 | if (lrank!=1 || rrank!=1) return RankError.RankMismatchDotProduct; |
|---|
| | 131 | return 0; |
|---|
| | 132 | } |
|---|
| | 133 | assert(0, "BLADE ICE: Unsupported function"); |
|---|
| | 134 | return 0; |
|---|
| | 135 | } |
|---|
| | 136 | ReturnType onVisitPrefix(This this_, char [] op, char [] expr) { |
|---|
| | 137 | if (op=="+" || op=="-") return doVisit(this_, expr); |
|---|
| | 138 | auto r = doVisit(this_, expr); |
|---|
| | 139 | if (r<=0) return r; |
|---|
| | 140 | return RankError.RankIncrement; |
|---|
| | 141 | } |
|---|
| | 142 | ReturnType onVisitPostfix(This this_, char [] op, char [] expr) { |
|---|
| | 143 | auto r = doVisit(this_, expr); |
|---|
| | 144 | if (r<=0) return r; |
|---|
| | 145 | return RankError.RankIncrement; |
|---|
| | 146 | } |
|---|
| | 147 | // Includes multi-dimensional slicing and indexing. |
|---|
| | 148 | ReturnType onVisitIndex(This this_, char [] base, char [][] startSlice, char [][] endSlice) { |
|---|
| | 149 | int totrank = doVisit(this_, base); |
|---|
| | 150 | for(int i=0; i<endSlice.length; ++i) { |
|---|
| | 151 | int r = doVisit(this_,startSlice[i]); |
|---|
| | 152 | if (r!=0) return (r<0)? r :RankError.NonScalarIndex; |
|---|
| | 153 | if (endSlice[i]==""){ |
|---|
| | 154 | --totrank; |
|---|
| | 155 | } else { |
|---|
| | 156 | r = doVisit(this_,endSlice[i]); |
|---|
| | 157 | if (r!=0) return (r<0)?r:RankError.NonScalarSlice; |
|---|
| | 158 | } |
|---|
| | 159 | } |
|---|
| | 160 | return totrank; |
|---|
| | 161 | } |
|---|
| | 162 | ReturnType onVisitBinaryOp(This this_, char [] op, char [] left, char [] right) { |
|---|
| | 163 | int lrank = doVisit(this_, left); |
|---|
| | 164 | int rrank = doVisit(this_, right); |
|---|
| | 165 | if (rrank<0) return rrank; // propagate errors |
|---|
| | 166 | if (op=="+" || op=="-" || op=="=" || op=="+=" || op=="-=") { |
|---|
| | 167 | if (lrank!=rrank) { |
|---|
| | 168 | return RankError.RankMismatch; |
|---|
| | 169 | } |
|---|
| | 170 | return lrank; |
|---|
| | 171 | } |
|---|
| | 172 | if (op=="~") { // concatentating scalars and vectors, or vectors and matrices, is permitted |
|---|
| | 173 | if (lrank==rrank || lrank==(rrank+1) || rrank==(lrank+1)) |
|---|
| | 174 | return (lrank>rrank)? lrank: rrank; |
|---|
| | 175 | else return RankError.RankMismatchConcatenation; |
|---|
| | 176 | } |
|---|
| | 177 | if (op=="~=") { // can do vector~=scalar, but not scalar~=vector. |
|---|
| | 178 | if (lrank==rrank || lrank==(rrank+1)) return lrank; |
|---|
| | 179 | else return RankError.RankMismatchConcatenation; |
|---|
| | 180 | } |
|---|
| | 181 | // For *, /, only scalar operations are permitted |
|---|
| | 182 | if ((op=="*=" || op=="/=") && rrank==0) return lrank; |
|---|
| | 183 | if (op=="*" || op=="/") { |
|---|
| | 184 | if (lrank==0) return rrank; |
|---|
| | 185 | if (rrank==0) return lrank; |
|---|
| | 186 | } |
|---|
| | 187 | // All other operations are only valid for scalars. |
|---|
| | 188 | if (lrank==0 && rrank==0) return 0; |
|---|
| | 189 | return RankError.UnsupportedOperation; |
|---|
| | 190 | |
|---|
| | 191 | } |
|---|
| | 192 | } |
|---|
| | 193 | |
|---|
| | 194 | unittest { |
|---|
| | 195 | assert(exprRank("(A[B..C])[C]", "300")==2); |
|---|
| | 196 | assert(exprRank("A+=(A[B..C])", "300")==3); |
|---|
| | 197 | |
|---|
| | 198 | assert(exprRank("A+(B*C)", "000")==0); |
|---|
| | 199 | assert(exprRank("A=(B*C)", "202")==2); |
|---|
| | 200 | assert(exprRank("B*=(C*A)", "010")==1); |
|---|
| | 201 | assert(exprRank("(A[])+B", "22")==2); |
|---|
| | 202 | assert(exprRank("D+=((A+C)*B)", "2022")==2); |
|---|
| | 203 | assert(exprRank("D+=((A&C)*B)", "0101")==1); |
|---|
| | 204 | |
|---|
| | 205 | assert(exprRank("C~=(((A[B])[B])~C)", "302")==2); |
|---|
| | 206 | assert(exprRank("((D[E])[E])+(-((C[B])[B..E]))", "202300")==1); |
|---|
| | 207 | |
|---|
| | 208 | assert(exprRank("A+((((++B)+D)--)*C)", "1010")==1); |
|---|
| | 209 | |
|---|
| | 210 | assert(exprRank("C+=(A[B])", "302")==2); |
|---|
| | 211 | assert(exprRank("d(A)", "1")==RankError.CommaExpected); |
|---|
| | 212 | assert(exprRank("d(A,B)", "10")==RankError.RankMismatchDotProduct); |
|---|
| | 213 | |
|---|
| | 214 | assert(exprRank("d(B,(A*(d(B,B))))", "11")==0); |
|---|
| | 215 | |
|---|
| | 216 | assert(exprRank("A[B,B,B]", "60")==3); |
|---|
| | 217 | assert(exprRank("A[B,B,C,B]", "600")==2); |
|---|
| | 218 | assert(exprRank("A+=(B[C..$])", "110")==1); |
|---|
| | 219 | assert(exprRank("A+=(B[C,D..$])", "2300")==2); |
|---|
| | 220 | |
|---|
| | 221 | // bug fixes: |
|---|
| | 222 | assert(exprRank("(A[B..$,C])+=D", "2001")==1); |
|---|
| | 223 | |
|---|
| | 224 | //NO LONGER SUPPORTED |
|---|
| | 225 | // assert(exprRank("A[B,([B,C]),B]", "600")==4); |
|---|
| | 226 | // assert(exprRank("A[B,(([B,C])[B]),B]", "600")==3); |
|---|
| | 227 | |
|---|
| | 228 | } |
|---|
| | 229 | |
|---|
| | 230 | |
|---|
| | 231 | // Return true if the entire expression contains a multiplication by a scalar |
|---|
| | 232 | bool hasScalarMultiply(char [] expr, char [] rank) |
|---|
| | 233 | { |
|---|
| | 234 | if (expr.length>2 && (expr[0..2]=="++" || expr[0..2]=="--" || expr[$-2..$]=="++" || expr[$-2..$]=="--")) { |
|---|
| | 235 | return false; |
|---|
| | 236 | } |
|---|
| | 237 | if (expr[0]=='+' || expr[0]=='-') return hasScalarMultiply(expr[1..$], rank); |
|---|
| | 238 | |
|---|
| | 239 | int x = exprLength(expr); |
|---|
| | 240 | int y = x+1; |
|---|
| | 241 | assert(y < expr.length, "BLADE BUG:" ~ expr); |
|---|
| | 242 | // Deal with shifts, op=, and NCEG operators |
|---|
| | 243 | while (expr[y+1]=='<' || expr[y+1]=='>' || expr[y+1]=='=') ++y; |
|---|
| | 244 | |
|---|
| | 245 | char [] op = expr[x+1..y+1]; |
|---|
| | 246 | char [] left = expr[0..x+1]; |
|---|
| | 247 | char [] right = expr[y+1..$]; |
|---|
| | 248 | if (op=="[") { |
|---|
| | 249 | // (A)[C] can still have a multiply by scalar, if A contains a |
|---|
| | 250 | // multiply. |
|---|
| | 251 | if (left.length==1) return false; |
|---|
| | 252 | return hasScalarMultiply(left[1..$], rank); |
|---|
| | 253 | } |
|---|
| | 254 | if (op=="/") return true; |
|---|
| | 255 | if (op!="*" && op!="/") { |
|---|
| | 256 | if (left.length==1 || right.length==1) return false; |
|---|
| | 257 | // (A+B) could contain a multiply by scalar, if both A and B |
|---|
| | 258 | // contain multiplies. |
|---|
| | 259 | return hasScalarMultiply(left[1..$-1], rank) && hasScalarMultiply(right[1..$-1], rank); |
|---|
| | 260 | } |
|---|
| | 261 | // it's not true for matrix*matrix multiplies. |
|---|
| | 262 | if (subexprRank(left, rank)==0) return true; |
|---|
| | 263 | return subexprRank(right, rank) == 0; |
|---|
| | 264 | } |
|---|
| | 265 | |
|---|
| | 266 | unittest { |
|---|
| | 267 | assert(hasScalarMultiply("(A*B)+(B*C)","101")); |
|---|
| | 268 | assert(!hasScalarMultiply("(A*B)-(C*C)","101")); |
|---|
| | 269 | assert(!hasScalarMultiply("A+(B*C)","101")); |
|---|
| | 270 | assert(hasScalarMultiply("(A/B)-((A*B)+(C*B))","101")); |
|---|
| | 271 | assert(!hasScalarMultiply("A[B]","20")); |
|---|
| | 272 | assert(!hasScalarMultiply("(C[B])[B..A]","002") ); |
|---|
| | 273 | } |
|---|