| 1 |
// Written in the D programming language 1.0 |
|---|
| 2 |
// Part of BLADE : Basic Linear Algebra D Expressions |
|---|
| 3 |
/** |
|---|
| 4 |
* Determine the tensor rank of an expression. |
|---|
| 5 |
* |
|---|
| 6 |
*/ |
|---|
| 7 |
|
|---|
| 8 |
module blade.BladeRank; |
|---|
| 9 |
private import blade.BladeVisitor; |
|---|
| 10 |
|
|---|
| 11 |
public: |
|---|
| 12 |
// return true if the given symbol has a non-zero stride. |
|---|
| 13 |
// This happens whenever there is a slice (including a complete slice) |
|---|
| 14 |
// where there is a later incomplete slice or index. |
|---|
| 15 |
// |
|---|
| 16 |
bool isStrided(char [] s) |
|---|
| 17 |
{ |
|---|
| 18 |
assert(s[$-1]==']', "BLADE ICE"); |
|---|
| 19 |
if (s[$-2..$]=="[]") return false; |
|---|
| 20 |
int numbrack=0; |
|---|
| 21 |
int paren = 0; |
|---|
| 22 |
bool hasSliced=false; |
|---|
| 23 |
bool maybeSlice = false; |
|---|
| 24 |
bool startIndex=false; |
|---|
| 25 |
for(int i=1; i<s.length; ++i) { |
|---|
| 26 |
if (s[i]=='(') ++paren; |
|---|
| 27 |
else if (s[i]==')') --paren; |
|---|
| 28 |
if (paren==0 && s[i]==']') { |
|---|
| 29 |
if (startIndex && hasSliced) return true; |
|---|
| 30 |
numbrack--; |
|---|
| 31 |
if (s[i-1]=='[') { startIndex=false; } |
|---|
| 32 |
} |
|---|
| 33 |
if (paren==0 && s[i]=='[') { |
|---|
| 34 |
startIndex = true; |
|---|
| 35 |
maybeSlice = false; |
|---|
| 36 |
numbrack++; |
|---|
| 37 |
} |
|---|
| 38 |
if (paren==0 && numbrack==1 && s[i]==',') { |
|---|
| 39 |
if (hasSliced && startIndex) return true; |
|---|
| 40 |
if (maybeSlice) hasSliced = true; |
|---|
| 41 |
startIndex = true; |
|---|
| 42 |
} |
|---|
| 43 |
if (paren==0 && numbrack==2 && s[i]==',') { startIndex=false; hasSliced=true; } |
|---|
| 44 |
if (paren==0 && numbrack==1 && s[i]=='.' && s[i-1]=='.') { |
|---|
| 45 |
startIndex = false; // [..] slices don't move the index to the next value. |
|---|
| 46 |
maybeSlice = true; |
|---|
| 47 |
} |
|---|
| 48 |
} |
|---|
| 49 |
return false; |
|---|
| 50 |
} |
|---|
| 51 |
|
|---|
| 52 |
private: |
|---|
| 53 |
unittest { |
|---|
| 54 |
assert(!isStrided("A[3..5][]")); |
|---|
| 55 |
assert(!isStrided("A[2..7][3]")); |
|---|
| 56 |
assert(!isStrided("A[][2]")); |
|---|
| 57 |
assert(!isStrided("A[2,[2,7]]")); |
|---|
| 58 |
assert(isStrided("A[[2,7],3]")); |
|---|
| 59 |
assert(isStrided("C[D..$,D]")); |
|---|
| 60 |
assert(!isStrided("A[7][B[[1,3],2]..6]")); |
|---|
| 61 |
} |
|---|
| 62 |
|
|---|
| 63 |
public: |
|---|
| 64 |
/** Returns the (tensor) rank of the expression expr. |
|---|
| 65 |
* A negative number will be returned if an error is detected. |
|---|
| 66 |
* TODO: Should warn of expressions with no effect (ie, with no =). |
|---|
| 67 |
* |
|---|
| 68 |
* Params: |
|---|
| 69 |
* expr Placeholder expression (A,B,... correspond to tuple[0],[1],...) |
|---|
| 70 |
* rank The rank of each tuple member A, B, C, ... |
|---|
| 71 |
*/ |
|---|
| 72 |
int exprRank(char [] expr, char [] ranks) |
|---|
| 73 |
{ |
|---|
| 74 |
return beginVisit(RankVisitor(ranks), expr); |
|---|
| 75 |
} |
|---|
| 76 |
|
|---|
| 77 |
int subexprRank(char [] expr, char [] ranks) |
|---|
| 78 |
{ |
|---|
| 79 |
return doVisit(RankVisitor(ranks), expr); |
|---|
| 80 |
} |
|---|
| 81 |
|
|---|
| 82 |
|
|---|
| 83 |
enum RankError : int { |
|---|
| 84 |
UnsupportedOperation = -1, |
|---|
| 85 |
RankIncrement = -2, |
|---|
| 86 |
AttemptToIndexAScalar = -3, |
|---|
| 87 |
NonScalarIndex = -4, |
|---|
| 88 |
NonScalarSlice = -5, |
|---|
| 89 |
DotDotExpected = -6, |
|---|
| 90 |
CommaExpected = -7, |
|---|
| 91 |
RankMismatch = -8, |
|---|
| 92 |
RankMismatchConcatenation = -9, |
|---|
| 93 |
RankMismatchDotProduct = -10, |
|---|
| 94 |
ExtraCharsAfterArrayLiteral = -11, |
|---|
| 95 |
ArrayLiteralRankMismatch = -12, |
|---|
| 96 |
AbsDimensionality = -13 |
|---|
| 97 |
} |
|---|
| 98 |
|
|---|
| 99 |
char [] getRankErrorText(int err) |
|---|
| 100 |
{ |
|---|
| 101 |
return ["Unsupported vector operation", |
|---|
| 102 |
"Can only use ++ and -- on scalars", |
|---|
| 103 |
"Cannot index a scalar", |
|---|
| 104 |
"Vector can only be indexed by a scalar", |
|---|
| 105 |
"Vector can only be sliced by a scalar", |
|---|
| 106 |
".. expected", |
|---|
| 107 |
", expected", |
|---|
| 108 |
"Dimensionality mismatch (addition or subtraction)", |
|---|
| 109 |
"Dimensionality mismatch in concatenation", |
|---|
| 110 |
"Dimenionality error in dot product", |
|---|
| 111 |
"Extra characters after array literal", |
|---|
| 112 |
"Rank mismatch in array literal", |
|---|
| 113 |
"Can only use abs, sqrt with scalar or vector" |
|---|
| 114 |
][-err-1]; |
|---|
| 115 |
} |
|---|
| 116 |
|
|---|
| 117 |
struct RankVisitor { |
|---|
| 118 |
alias typeof(*this) This; |
|---|
| 119 |
alias int ReturnType; |
|---|
| 120 |
char [] rank; |
|---|
| 121 |
static: |
|---|
| 122 |
ReturnType onVisitSymbol(This this_, char [] sym) { |
|---|
| 123 |
if (sym=="$") return 0; |
|---|
| 124 |
return this_.rank[sym[0]-'A']-'0'; |
|---|
| 125 |
} |
|---|
| 126 |
ReturnType onVisitFunction(This this_, char [] func, char [][] args) { |
|---|
| 127 |
switch(func) { |
|---|
| 128 |
case "dot": |
|---|
| 129 |
if (args.length!=2) return RankError.CommaExpected; |
|---|
| 130 |
auto lrank = doVisit(this_,args[0]); |
|---|
| 131 |
if (lrank<0) return lrank; // propagate errors |
|---|
| 132 |
auto rrank = doVisit(this_, args[1]); |
|---|
| 133 |
if (rrank<0) return rrank; // propagate errors |
|---|
| 134 |
if (lrank!=1 || rrank!=1) return RankError.RankMismatchDotProduct; |
|---|
| 135 |
return 0; |
|---|
| 136 |
case "sum": |
|---|
| 137 |
case "prod": |
|---|
| 138 |
auto lrank = doVisit(this_,args[0]); |
|---|
| 139 |
if (lrank<0) return lrank; // propagate errors |
|---|
| 140 |
return 0; |
|---|
| 141 |
case "abs": |
|---|
| 142 |
case "sqrt": |
|---|
| 143 |
auto lrank = doVisit(this_,args[0]); |
|---|
| 144 |
if (lrank>1) return RankError.AbsDimensionality; |
|---|
| 145 |
return 0; |
|---|
| 146 |
default: |
|---|
| 147 |
assert(0, "BLADE ICE: Unsupported function:" ~ func); |
|---|
| 148 |
return 0; |
|---|
| 149 |
} |
|---|
| 150 |
} |
|---|
| 151 |
ReturnType onVisitPrefix(This this_, char [] op, char [] expr) { |
|---|
| 152 |
if (op=="+" || op=="-") return doVisit(this_, expr); |
|---|
| 153 |
auto r = doVisit(this_, expr); |
|---|
| 154 |
if (r<=0) return r; |
|---|
| 155 |
return RankError.RankIncrement; |
|---|
| 156 |
} |
|---|
| 157 |
ReturnType onVisitPostfix(This this_, char [] op, char [] expr) { |
|---|
| 158 |
auto r = doVisit(this_, expr); |
|---|
| 159 |
if (r<=0) return r; |
|---|
| 160 |
return RankError.RankIncrement; |
|---|
| 161 |
} |
|---|
| 162 |
// Includes multi-dimensional slicing and indexing. |
|---|
| 163 |
ReturnType onVisitIndex(This this_, char [] base, char [][2][] slices) { |
|---|
| 164 |
int totrank = doVisit(this_, base); |
|---|
| 165 |
for(int i=0; i<slices.length; ++i) { |
|---|
| 166 |
int r = doVisit(this_,slices[i][0]); |
|---|
| 167 |
if (r!=0) return (r<0)? r :RankError.NonScalarIndex; |
|---|
| 168 |
if (slices[i][1]==""){ |
|---|
| 169 |
--totrank; |
|---|
| 170 |
} else { |
|---|
| 171 |
r = doVisit(this_,slices[i][1]); |
|---|
| 172 |
if (r!=0) return (r<0)?r:RankError.NonScalarSlice; |
|---|
| 173 |
} |
|---|
| 174 |
} |
|---|
| 175 |
return totrank; |
|---|
| 176 |
} |
|---|
| 177 |
ReturnType onVisitBinaryOp(This this_, char [] op, char [] left, char [] right) { |
|---|
| 178 |
int lrank = doVisit(this_, left); |
|---|
| 179 |
int rrank = doVisit(this_, right); |
|---|
| 180 |
if (rrank<0) return rrank; // propagate errors |
|---|
| 181 |
if (lrank<0) return lrank; // propagate errors |
|---|
| 182 |
if (op=="+" || op=="-" || op=="=" || op=="+=" || op=="-=") { |
|---|
| 183 |
if (lrank!=rrank) { |
|---|
| 184 |
return RankError.RankMismatch; |
|---|
| 185 |
} |
|---|
| 186 |
return lrank; |
|---|
| 187 |
} |
|---|
| 188 |
if (op=="~") { // concatentating scalars and vectors, or vectors and matrices, is permitted |
|---|
| 189 |
if (lrank==rrank || lrank==(rrank+1) || rrank==(lrank+1)) |
|---|
| 190 |
return (lrank>rrank)? lrank: rrank; |
|---|
| 191 |
else return RankError.RankMismatchConcatenation; |
|---|
| 192 |
} |
|---|
| 193 |
if (op=="~=") { // can do vector~=scalar, but not scalar~=vector. |
|---|
| 194 |
if (lrank==rrank || lrank==(rrank+1)) return lrank; |
|---|
| 195 |
else return RankError.RankMismatchConcatenation; |
|---|
| 196 |
} |
|---|
| 197 |
// For / and /=, only scalar operations are permitted |
|---|
| 198 |
if ((op=="*=" || op=="/=") && rrank==0) return lrank; |
|---|
| 199 |
if (op=="*=" && lrank==2 && rrank==2) return lrank; // mat *= mat |
|---|
| 200 |
if (op=="*=" && lrank==1 && rrank==2) return lrank; // vec *= mat |
|---|
| 201 |
if (op=="*" || op=="/") { |
|---|
| 202 |
if (lrank==0) return rrank; |
|---|
| 203 |
if (rrank==0) return lrank; |
|---|
| 204 |
if (lrank==2 && rrank==2) return lrank; |
|---|
| 205 |
if (lrank==2 && rrank==2) return lrank; |
|---|
| 206 |
if (lrank+rrank==3) return 1; // vec*mat or mat*vec |
|---|
| 207 |
} |
|---|
| 208 |
// All other operations are only valid for scalars. |
|---|
| 209 |
if (lrank==0 && rrank==0) return 0; |
|---|
| 210 |
return RankError.UnsupportedOperation; |
|---|
| 211 |
} |
|---|
| 212 |
} |
|---|
| 213 |
|
|---|
| 214 |
unittest { |
|---|
| 215 |
assert(exprRank("(A[B..C])[C]", "300")==2); |
|---|
| 216 |
assert(exprRank("A+=(A[B..C])", "300")==3); |
|---|
| 217 |
|
|---|
| 218 |
assert(exprRank("A+(B*C)", "000")==0); |
|---|
| 219 |
assert(exprRank("A=(B*C)", "202")==2); |
|---|
| 220 |
assert(exprRank("B*=(C*A)", "010")==1); |
|---|
| 221 |
assert(exprRank("(A[])+B", "22")==2); |
|---|
| 222 |
assert(exprRank("D+=((A+C)*B)", "2022")==2); |
|---|
| 223 |
assert(exprRank("D+=((A&C)*B)", "0101")==1); |
|---|
| 224 |
|
|---|
| 225 |
assert(exprRank("C~=(((A[B])[B])~C)", "302")==2); |
|---|
| 226 |
assert(exprRank("((D[E])[E])+(-((C[B])[B..E]))", "202300")==1); |
|---|
| 227 |
|
|---|
| 228 |
assert(exprRank("A+((((++B)+D)--)*C)", "1010")==1); |
|---|
| 229 |
|
|---|
| 230 |
assert(exprRank("C+=(A[B])", "302")==2); |
|---|
| 231 |
assert(exprRank("dot(A)", "1")==RankError.CommaExpected); |
|---|
| 232 |
assert(exprRank("dot(A,B)", "10")==RankError.RankMismatchDotProduct); |
|---|
| 233 |
|
|---|
| 234 |
assert(exprRank("dot(B,(A*(dot(B,B))))", "11")==0); |
|---|
| 235 |
|
|---|
| 236 |
assert(exprRank("prod(A*B)", "10")==0); |
|---|
| 237 |
|
|---|
| 238 |
assert(exprRank("A[B,B,B]", "60")==3); |
|---|
| 239 |
assert(exprRank("A[B,B,C,B]", "600")==2); |
|---|
| 240 |
assert(exprRank("A+=(B[C..$])", "110")==1); |
|---|
| 241 |
assert(exprRank("A+=(B[C,D..$])", "2300")==2); |
|---|
| 242 |
|
|---|
| 243 |
// bug fixes: |
|---|
| 244 |
assert(exprRank("(A[B..$,C])+=D", "2001")==1); |
|---|
| 245 |
} |
|---|