Changeset 158

Show
Ignore:
Timestamp:
12/07/07 14:51:35 (9 months ago)
Author:
Don Clugston
Message:

BLADE is now object-oriented! Now uses the Visitor pattern for parsing. ExpressionSimplify? still needs to be refactored to use the new scheme.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/blade/BladeDemo.d

    r157 r158  
    3333    real k=3.4; 
    3434 
    35     mixin(vectorize(` a += (d[2..$-1]*2.01*a[2]-another[][1])["abc".length-3..$]`)); 
     35//    mixin(vectorize(` a += (d[2..$-1]*2.01*a[2]-another[][1])["abc".length-3..$]`)); 
    3636     
    3737    mixin(vectorize(" a-= 2.01*(        3.04+k)*r")); 
     
    4242    mixin(vectorize("a+=6*another[1]")); 
    4343    mixin(vectorize("a+=6*another[1][]")); 
    44      
     44    
    4545    mixin(vectorize("another[0..$,1]+=6*a[0..2]")); 
    4646    mixin(vectorize("r-=another[0]")); 
  • trunk/blade/BladeRank.d

    r157 r158  
    77 
    88module blade.BladeRank; 
    9  
    10 // -------------- 
    11 // Ranklist functions 
    12  
    13 // Count the number of vectors 
    14 int countVectors(char[] ranklist) 
    15 
    16     int numVecs=0; 
    17     for (int i=0; i<ranklist.length; ++i) { 
    18         if (ranklist[i]=='1') ++numVecs; 
    19     } 
    20     return numVecs; 
    21 
    22  
    23 int vectorNum(char [] ranklist, char var) 
    24 
    25     int numVecs=0; 
    26     for (int i=0; i<var-'A'; ++i) { 
    27         if (ranklist[i]=='1') ++numVecs; 
    28     } 
    29     return numVecs; 
    30 
    31  
    32 int scalarNum(char [] ranklist, char var) 
    33 
    34     int k=0; 
    35     for (int i=0; i<var-'A'; ++i) { 
    36         if (ranklist[i]=='0') ++k; 
    37     } 
    38     return k; 
    39 
    40  
    41 int realScalarNum(char [][] typelist, char [] ranklist, char var) 
    42 
    43     int k=0; 
    44     for (int i=0; i<var-'A'; ++i) { 
    45         if (ranklist[i]=='0' && typelist[i]=="real") ++k; 
    46     } 
    47     return k; 
    48 
    49  
    50 /** Return the length of a sub-expression 
    51  * The sub-expression must be  
    52  *  - a single character (eg "X"), OR 
    53  *  - a lower-case function (eg "a(B,(C*D))"), OR 
    54  *  - an expression in parenthesis, OR 
    55  *  - an array literal  
    56  */ 
    57 int exprLength(char [] s) 
    58 
    59     if ((s[0]>='A' && s[0]<='Z') || s[0]=='$') return 0; 
    60     int i = 0;    
    61     if (s[0]>='a' && s[0]<='z'){ // function call 
    62          i=1; // next char is a parenthesis - so the code 
    63             // below works 
    64     } 
    65     int numParens = 0; 
    66     int numBrack = 0; 
    67     for (; i<s.length; ++i) { 
    68         if (s[i]=='(')  ++numParens; 
    69         if (s[i]==')') numParens--;        
    70         if (s[i]=='[') ++numBrack; 
    71         if (s[i]==']') --numBrack; 
    72         if (numParens == 0 && numBrack == 0) { 
    73             return i; 
    74         } 
    75     } 
    76     assert(0, "BLADE ICE: " ~ s);    
    77 
    78  
    79 /** Determine the (tensor) rank of a sub-expression 
    80 * The sub-expression must be a single character, or an expression in 
    81 * parentheses. 
    82 */ 
    83 int subexprRank(char [] expr, char [] rank) 
    84 
    85     if (expr.length==1) { 
    86         if (expr=="$") return 0; 
    87         assert(expr[0]>='A' && expr[0]<='Z', "BLADE ICE: " ~ expr); 
    88         return rank[expr[0]-'A']-'0'; 
    89     } 
    90     if (expr[0]=='d') return 0; 
    91     assert(expr[0]=='(', "BLADE ICE:" ~ expr); 
    92     // strip off the parentheses 
    93     return exprRank(expr[1..$-1], rank); 
    94 
    95  
    96 enum RankError : int { 
    97     UnsupportedOperation = -1, 
    98     RankIncrement = -2,  
    99     AttemptToIndexAScalar = -3, 
    100     NonScalarIndex = -4, 
    101     NonScalarSlice = -5, 
    102     DotDotExpected = -6, 
    103     CommaExpected = -7,  
    104     RankMismatch = -8, 
    105     RankMismatchConcatenation = -9, 
    106     RankMismatchDotProduct = -10, 
    107     ExtraCharsAfterArrayLiteral = -11, 
    108     ArrayLiteralRankMismatch = -12 
    109 
    110  
    111 char [] getRankErrorText(int err) 
    112 
    113     return ["Unsupported vector operation", 
    114             "Can only use ++ and -- on scalars", 
    115             "Cannot index a scalar", 
    116             "Vector can only be indexed by a scalar", 
    117             "Vector can only be sliced by a scalar", 
    118             ".. expected", 
    119             ", expected", 
    120             "Dimensionality mismatch (addition or subtraction)", 
    121             "Dimensionality mismatch in concatenation", 
    122             "Dimenionality error in dot product" 
    123             "Extra characters after array literal" 
    124             "Rank mismatch in array literal" 
    125             ][-err-1]; 
    126 
    127  
    128 /** Returns the (tensor) rank of the expression expr. 
    129  * A negative number will be returned if an error is detected. 
    130  * 
    131  * Params: 
    132  * expr   Placeholder expression (A,B,... correspond to tuple[0],[1],...) 
    133  * rank   The rank of each tuple member A, B, C, ... 
    134  */ 
    135 int exprRank(char [] expr, char [] rank) 
    136 
    137     // BUG: also need to deal with comma, ?:, &&, ||, is, !is, in,  
    138     // unary &, unary ! 
    139      
    140     if (expr.length>3 && expr[0..2]=="d(") { // dot product     
    141             int x = exprLength(expr[2..$-1]); 
    142             if (expr[x+3]!=',') return RankError.CommaExpected; 
    143             int lrank = subexprRank(expr[2..x+3], rank); 
    144             if (lrank<0) return lrank; // propagate errors 
    145             int rrank = subexprRank(expr[x+4..$-1], rank); 
    146             if (rrank<0) return rrank; // propagate errors 
    147             if (lrank!=1 || rrank!=1) return RankError.RankMismatchDotProduct;        
    148             return 0; 
    149     }         
    150          
    151     // Deal with ++ and --. 
    152     if (expr.length>2 && (expr[0..2]=="++" || expr[0..2]=="--")) { 
    153         int r = subexprRank(expr[2..$], rank); 
    154         if (r!=0) return RankError.RankIncrement; 
    155         return r; 
    156     } 
    157     if (expr.length>2 && (expr[$-2..$]=="++" || expr[$-2..$]=="--")) { 
    158         int r = subexprRank(expr[0..$-2], rank); 
    159         if (r!=0) return RankError.RankIncrement; 
    160         return r; 
    161     } 
    162     // Deal with unary operators 
    163     if (expr[0]=='+' || expr[0]=='-') return subexprRank(expr[1..$], rank); 
    164      
    165     int x = exprLength(expr); 
    166     if (expr[0]=='[') { // array literal 
    167         if (x!=expr.length-1) return RankError.ExtraCharsAfterArrayLiteral; 
    168         expr = expr[1..$-1]; 
    169         x = exprLength(expr); 
    170         int lrank = subexprRank(expr[0..x+1], rank); 
    171         while (x<expr.length-1) { 
    172             if (expr[x+1]!=',') return RankError.CommaExpected; 
    173             expr = expr[x+2.. $]; 
    174             x = exprLength(expr); 
    175             int rrank = subexprRank(expr[0..x+1], rank); 
    176             if (lrank!=rrank) return RankError.ArrayLiteralRankMismatch; 
    177         } 
    178         return lrank+1; 
    179     } 
    180     int y = x+1; 
    181     // Deal with shifts, op=, and NCEG operators 
    182     while (expr[y+1]=='<' || expr[y+1]=='>' || expr[y+1]=='=') ++y;     
    183  
    184     char [] op = expr[x+1..y+1]; 
    185     char [] left = expr[0..x+1]; 
    186     char [] right = expr[y+1..$]; 
    187     if (expr[x+1]=='[') right = expr[y+1..$-1]; // drop off the ']'. 
    188     int lrank = subexprRank(left, rank); 
    189     if (lrank<0) return lrank; // propagate errors 
    190     if (op=="[") { 
    191         if (lrank==0) return RankError.AttemptToIndexAScalar; 
    192         if (right.length==0) { 
    193             return lrank; // was [], which doesn't change the length 
    194         } 
    195         int z = exprLength(right); 
    196         if (z+1 == right.length) { 
    197             // indexing  --  reduces the rank by 1. 
    198             int rrank = subexprRank(right, rank); 
    199             if (rrank!=0) return RankError.NonScalarIndex; 
    200             return lrank - 1; 
    201         } else { 
    202             int totrank = lrank; 
    203             do { 
    204                 int rrank = subexprRank(right[0..z+1], rank); 
    205                 if (z==right.length-1 || right[z+1]==',') { 
    206                     // allow rank of 1 to be a slice operation 
    207                     // (so A[1,[2,$-1], $] is possible). 
    208                     if (rrank<0) return rrank; 
    209                     if (rrank>1) return RankError.NonScalarIndex; 
    210                     if (rrank==0) --totrank; 
    211                     if (z==right.length-1) return totrank; 
    212                 } else if (!(z+3 < right.length && right[z+1..z+3]=="..")) { 
    213                         return RankError.DotDotExpected; 
    214                 } else {// slice 
    215                     char [] start = right[0..z+1];  
    216                     char [] end = right[z+3..$]; 
    217                     int startrank = subexprRank(start, rank); 
    218                         if (startrank<0) return startrank; 
    219                     z = exprLength(end); 
    220                     int endrank = subexprRank(end[0..z+1], rank); 
    221                         if (endrank<0) return endrank; 
    222                     if (startrank!=0 || endrank!=0) return RankError.NonScalarSlice; 
    223                     right = end; 
    224                 } 
    225                 if (z==right.length-1) return totrank; 
    226                 right = right[z+2..$]; 
    227                 z = exprLength(right); 
    228                 //assert(0, right[0..z+1]); 
    229             }while (true); 
    230         } 
    231     } 
    232     int rrank = subexprRank(right, rank); 
    233     if (rrank<0) return rrank; // propagate errors 
    234     if (op=="+" || op=="-" || op=="=" || op=="+=" || op=="-=") { 
    235         if (lrank!=rrank) { 
    236             return RankError.RankMismatch; 
    237         } 
    238         return lrank; 
    239     } 
    240     if (op=="~") { // concatentating scalars and vectors, or vectors and matrices, is permitted 
    241         if (lrank==rrank || lrank==(rrank+1) || rrank==(lrank+1))  
    242             return (lrank>rrank)? lrank: rrank; 
    243         else return RankError.RankMismatchConcatenation; 
    244     } 
    245     if (op=="~=") { // can do vector~=scalar, but not scalar~=vector. 
    246         if (lrank==rrank || lrank==(rrank+1))  return lrank; 
    247         else return RankError.RankMismatchConcatenation; 
    248     } 
    249     // For *, /, only scalar operations are permitted 
    250     if ((op=="*=" || op=="/=") && rrank==0) return lrank; 
    251     if (op=="*" || op=="/") { 
    252         if (lrank==0) return rrank; 
    253         if (rrank==0) return lrank; 
    254     } 
    255     // All other operations are only valid for scalars. 
    256     if (lrank==0 && rrank==0) return 0; 
    257     return RankError.UnsupportedOperation; 
    258 
    259  
    260 unittest {     
    261     assert(exprRank("A+((((++B)+D)--)*C)", "1010")==1); 
    262     assert(exprRank("A+(B*C)", "000")==0); 
    263     assert(exprRank("A=(B*C)", "202")==2); 
    264     assert(exprRank("B*=(C*A)", "010")==1); 
    265     assert(exprRank("(A[])+B", "22")==2); 
    266     assert(exprRank("D+=((A+C)*B)", "2022")==2); 
    267     assert(exprRank("D+=((A&C)*B)", "0101")==1); 
    268     assert(exprRank("A+=(A[B..C])", "300")==3); 
    269     assert(exprRank("C+=(A[B])", "302")==2); 
    270     assert(exprRank("C~=(((A[B])[B])~C)", "302")==2); 
    271     assert(exprRank("((D[E])[E])+(-((C[B])[B..E]))", "202300")==1); 
    272     assert(subexprRank("((A[B..C])[C])", "300")==2); 
    273     assert(exprRank("d(A)", "1")==RankError.CommaExpected); 
    274     assert(exprRank("d(A,B)", "10")==RankError.RankMismatchDotProduct); 
    275     assert(exprRank("d(B,(A*(d(B,B))))", "11")==0);     
    276     assert(exprRank("A[B,B,B]", "60")==3); 
    277     assert(exprRank("A[B,B,C,B]", "600")==2); 
    278     assert(exprRank("A[B,([B,C]),B]", "600")==4); 
    279     assert(exprRank("A[B,(([B,C])[B]),B]", "600")==3); 
    280     assert(exprRank("A+=(B[C..$])", "110")==1); 
    281     assert(exprRank("A+=(B[C,D..$])", "2300")==2); 
    282 
    283  
    284  
    285 // Return true if the entire expression contains a multiplication by a scalar 
    286 bool hasScalarMultiply(char [] expr, char [] rank) 
    287 
    288     if (expr.length>2 && (expr[0..2]=="++" || expr[0..2]=="--" || expr[$-2..$]=="++" || expr[$-2..$]=="--")) { 
    289         return false; 
    290     } 
    291     if (expr[0]=='+' || expr[0]=='-') return hasScalarMultiply(expr[1..$], rank); 
    292      
    293     int x = exprLength(expr); 
    294     int y = x+1; 
    295     assert(y < expr.length, "BLADE BUG:" ~ expr); 
    296     // Deal with shifts, op=, and NCEG operators 
    297     while (expr[y+1]=='<' || expr[y+1]=='>' || expr[y+1]=='=') ++y;     
    298  
    299     char [] op = expr[x+1..y+1];     
    300     char [] left = expr[0..x+1]; 
    301     char [] right = expr[y+1..$]; 
    302     if (op=="[") { 
    303         // (A)[C] can still have a multiply by scalar, if A contains a 
    304         // multiply. 
    305         if (left.length==1) return false; 
    306         return hasScalarMultiply(left[1..$], rank); 
    307     } 
    308     if (op=="/") return true; 
    309     if (op!="*" && op!="/") { 
    310         if (left.length==1 || right.length==1) return false;  
    311         // (A+B) could contain a multiply by scalar, if both A and B 
    312         // contain multiplies. 
    313         return hasScalarMultiply(left[1..$-1], rank) && hasScalarMultiply(right[1..$-1], rank); 
    314     } 
    315     // it's not true for matrix*matrix multiplies. 
    316     if (subexprRank(left, rank)==0) return true; 
    317     return subexprRank(right, rank) == 0; 
    318 
    319  
    320 unittest { 
    321     assert(hasScalarMultiply("(A*B)+(B*C)","101")); 
    322     assert(!hasScalarMultiply("(A*B)-(C*C)","101")); 
    323     assert(!hasScalarMultiply("A+(B*C)","101")); 
    324     assert(hasScalarMultiply("(A/B)-((A*B)+(C*B))","101")); 
    325     assert(!hasScalarMultiply("A[B]","20")); 
    326     assert(!hasScalarMultiply("(C[B])[B..A]","002") ); 
    327 
     9private import blade.BladeVisitor; 
    32810 
    32911public: 
     
    37860    assert(!isStrided("A[7][B[[1,3],2]..6]")); 
    37961} 
     62 
     63public: 
     64/** Returns the (tensor) rank of the expression expr. 
     65 * A negative number will be returned if an error is detected. 
     66 * 
     67 * Params: 
     68 * expr   Placeholder expression (A,B,... correspond to tuple[0],[1],...) 
     69 * rank   The rank of each tuple member A, B, C, ... 
     70 */ 
     71int exprRank(char [] expr, char [] ranks) 
     72{ 
     73    return beginVisit(RankVisitor(ranks), expr); 
     74} 
     75 
     76int subexprRank(char [] expr, char [] ranks) 
     77{ 
     78    return doVisit(RankVisitor(ranks), expr); 
     79} 
     80 
     81 
     82enum RankError : int { 
     83    UnsupportedOperation = -1, 
     84    RankIncrement = -2,  
     85    AttemptToIndexAScalar = -3, 
     86    NonScalarIndex = -4, 
     87    NonScalarSlice = -5, 
     88    DotDotExpected = -6, 
     89    CommaExpected = -7,  
     90    RankMismatch = -8, 
     91    RankMismatchConcatenation = -9, 
     92    RankMismatchDotProduct = -10, 
     93    ExtraCharsAfterArrayLiteral = -11, 
     94    ArrayLiteralRankMismatch = -12 
     95} 
     96 
     97char [] getRankErrorText(int err) 
     98{ 
     99    return ["Unsupported vector operation", 
     100            "Can only use ++ and -- on scalars", 
     101            "Cannot index a scalar", 
     102            "Vector can only be indexed by a scalar", 
     103            "Vector can only be sliced by a scalar", 
     104            ".. expected", 
     105            ", expected", 
     106            "Dimensionality mismatch (addition or subtraction)", 
     107            "Dimensionality mismatch in concatenation", 
     108            "Dimenionality error in dot product" 
     109            "Extra characters after array literal" 
     110            "Rank mismatch in array literal" 
     111            ][-err-1]; 
     112} 
     113 
     114struct RankVisitor { 
     115    alias typeof(*this) This; 
     116    alias int ReturnType; 
     117    char [] rank; 
     118static: 
     119    ReturnType onVisitSymbol(This this_, char sym) { 
     120        if (sym=='$') return 0; 
     121        return this_.rank[sym-'A']-'0'; 
     122    } 
     123    ReturnType onVisitFunction(This this_, char [] func, char [][] args) { 
     124        if (func=="d") { // dot product 
     125            if (args.length!=2) return RankError.CommaExpected; 
     126            auto lrank = doVisit(this_,args[0]); 
     127            if (lrank<0) return lrank; // propagate errors 
     128            auto rrank = doVisit(this_, args[1]); 
     129            if (rrank<0) return rrank; // propagate errors 
     130            if (lrank!=1 || rrank!=1) return RankError.RankMismatchDotProduct;        
     131            return 0; 
     132        } 
     133        assert(0, "BLADE ICE: Unsupported function"); 
     134        return 0; 
     135    } 
     136    ReturnType onVisitPrefix(This this_, char [] op, char [] expr) { 
     137        if (op=="+" || op=="-") return doVisit(this_, expr); 
     138        auto r = doVisit(this_, expr); 
     139        if (r<=0) return r; 
     140        return RankError.RankIncrement; 
     141    } 
     142    ReturnType onVisitPostfix(This this_, char [] op, char [] expr) { 
     143        auto r = doVisit(this_, expr); 
     144        if (r<=0) return r; 
     145        return RankError.RankIncrement; 
     146    } 
     147    // Includes multi-dimensional slicing and indexing.     
     148    ReturnType onVisitIndex(This this_, char [] base, char [][] startSlice, char [][] endSlice) { 
     149        int totrank = doVisit(this_, base); 
     150        for(int i=0; i<endSlice.length; ++i) { 
     151            int r = doVisit(this_,startSlice[i]); 
     152            if (r!=0) return (r<0)? r :RankError.NonScalarIndex;  
     153            if (endSlice[i]==""){ 
     154                --totrank; 
     155            } else { 
     156                r = doVisit(this_,endSlice[i]); 
     157                if (r!=0) return (r<0)?r:RankError.NonScalarSlice;  
     158            } 
     159        } 
     160        return totrank; 
     161    } 
     162    ReturnType onVisitBinaryOp(This this_, char [] op, char [] left, char [] right) { 
     163        int lrank = doVisit(this_, left); 
     164        int rrank = doVisit(this_, right); 
     165        if (rrank<0) return rrank; // propagate errors 
     166        if (op=="+" || op=="-" || op=="=" || op=="+=" || op=="-=") { 
     167            if (lrank!=rrank) { 
     168                return RankError.RankMismatch; 
     169            } 
     170            return lrank; 
     171        } 
     172        if (op=="~") { // concatentating scalars and vectors, or vectors and matrices, is permitted 
     173            if (lrank==rrank || lrank==(rrank+1) || rrank==(lrank+1))  
     174                return (lrank>rrank)? lrank: rrank; 
     175            else return RankError.RankMismatchConcatenation; 
     176        } 
     177        if (op=="~=") { // can do vector~=scalar, but not scalar~=vector. 
     178            if (lrank==rrank || lrank==(rrank+1))  return lrank; 
     179            else return RankError.RankMismatchConcatenation; 
     180        } 
     181        // For *, /, only scalar operations are permitted 
     182        if ((op=="*=" || op=="/=") && rrank==0) return lrank; 
     183        if (op=="*" || op=="/") { 
     184            if (lrank==0) return rrank; 
     185            if (rrank==0) return lrank; 
     186        } 
     187        // All other operations are only valid for scalars. 
     188        if (lrank==0 && rrank==0) return 0; 
     189        return RankError.UnsupportedOperation; 
     190 
     191    } 
     192} 
     193 
     194unittest {     
     195    assert(exprRank("(A[B..C])[C]", "300")==2); 
     196    assert(exprRank("A+=(A[B..C])", "300")==3); 
     197     
     198    assert(exprRank("A+(B*C)", "000")==0);     
     199    assert(exprRank("A=(B*C)", "202")==2); 
     200    assert(exprRank("B*=(C*A)", "010")==1); 
     201    assert(exprRank("(A[])+B", "22")==2); 
     202    assert(exprRank("D+=((A+C)*B)", "2022")==2); 
     203    assert(exprRank("D+=((A&C)*B)", "0101")==1); 
     204     
     205    assert(exprRank("C~=(((A[B])[B])~C)", "302")==2); 
     206    assert(exprRank("((D[E])[E])+(-((C[B])[B..E]))", "202300")==1); 
     207 
     208    assert(exprRank("A+((((++B)+D)--)*C)", "1010")==1); 
     209     
     210    assert(exprRank("C+=(A[B])", "302")==2); 
     211    assert(exprRank("d(A)", "1")==RankError.CommaExpected); 
     212    assert(exprRank("d(A,B)", "10")==RankError.RankMismatchDotProduct); 
     213     
     214    assert(exprRank("d(B,(A*(d(B,B))))", "11")==0);     
     215     
     216    assert(exprRank("A[B,B,B]", "60")==3); 
     217    assert(exprRank("A[B,B,C,B]", "600")==2); 
     218    assert(exprRank("A+=(B[C..$])", "110")==1); 
     219    assert(exprRank("A+=(B[C,D..$])", "2300")==2); 
     220     
     221    // bug fixes: 
     222    assert(exprRank("(A[B..$,C])+=D", "2001")==1); 
     223 
     224    //NO LONGER SUPPORTED 
     225    //    assert(exprRank("A[B,([B,C]),B]", "600")==4); 
     226//    assert(exprRank("A[B,(([B,C])[B]),B]", "600")==3); 
     227 
     228} 
     229 
     230 
     231// Return true if the entire expression contains a multiplication by a scalar 
     232bool hasScalarMultiply(char [] expr, char [] rank) 
     233{ 
     234    if (expr.length>2 && (expr[0..2]=="++" || expr[0..2]=="--" || expr[$-2..$]=="++" || expr[$-2..$]=="--")) { 
     235        return false; 
     236    } 
     237    if (expr[0]=='+' || expr[0]=='-') return hasScalarMultiply(expr[1..$], rank); 
     238     
     239    int x = exprLength(expr); 
     240    int y = x+1; 
     241    assert(y < expr.length, "BLADE BUG:" ~ expr); 
     242    // Deal with shifts, op=, and NCEG operators 
     243    while (expr[y+1]=='<' || expr[y+1]=='>' || expr[y+1]=='=') ++y;     
     244 
     245    char [] op = expr[x+1..y+1];     
     246    char [] left = expr[0..x+1]; 
     247    char [] right = expr[y+1..$]; 
     248    if (op=="[") { 
     249        // (A)[C] can still have a multiply by scalar, if A contains a 
     250        // multiply. 
     251        if (left.length==1) return false; 
     252        return hasScalarMultiply(left[1..$], rank); 
     253    } 
     254    if (op=="/") return true; 
     255    if (op!="*" && op!="/") { 
     256        if (left.length==1 || right.length==1) return false;  
     257        // (A+B) could contain a multiply by scalar, if both A and B 
     258        // contain multiplies. 
     259        return hasScalarMultiply(left[1..$-1], rank) && hasScalarMultiply(right[1..$-1], rank); 
     260    } 
     261    // it's not true for matrix*matrix multiplies. 
     262    if (subexprRank(left, rank)==0) return true; 
     263    return subexprRank(right, rank) == 0; 
     264} 
     265 
     266unittest { 
     267    assert(hasScalarMultiply("(A*B)+(B*C)","101")); 
     268    assert(!hasScalarMultiply("(A*B)-(C*C)","101")); 
     269    assert(!hasScalarMultiply("A+(B*C)","101")); 
     270    assert(hasScalarMultiply("(A/B)-((A*B)+(C*B))","101")); 
     271    assert(!hasScalarMultiply("A[B]","20")); 
     272    assert(!hasScalarMultiply("(C[B])[B..A]","002") ); 
     273} 
  • trunk/blade/BladeSimplify.d

    r157 r158  
    2929 
    3030public import blade.SyntaxTree : AbstractSyntaxTree, Symbol; 
    31 private import blade.BladeRank : exprLength, exprRank, subexprRank,  
    32         hasScalarMultiply, getRankErrorText, isStrided; 
     31//private import blade.BladeVisitor; 
     32private import blade.BladeRank : exprLength, exprRank, subexprRank, 
     33        hasScalarMultiply, getRankErrorText, isStrided;  
    3334 
    3435 
     
    6364    else { 
    6465        char [] expr2 = removeDuplicates(tree); 
    65         // Check for rank errors 
     66        // Check for rank errors         
    6667        int wholerank = exprRank(expr2, ranks); 
    6768        if (wholerank<0) 
     
    108109 
    109110unittest { 
    110  
    111111    AbstractSyntaxTree t = AbstractSyntaxTree("A+(B*C)", [Symbol("int", "125", 0), 
    112112    Symbol("int", "7", 0), Symbol("int", "125", 0)]); 
    113113    assert(removeDuplicates(t)=="A+(B*A)"); 
    114      
    115114} 
    116115 
     
    187186    return rightMul; 
    188187} 
     188 
    189189 
    190190// Simplify the expression, assuming global scalar multiply has already been removed. 
     
    419419    assert(e.mapping=="ACE"); 
    420420} 
     421 
     422 
     423 
  • trunk/blade/CodegenX86.d

    r155 r158  
    4545module blade.CodegenX86; 
    4646private import blade.BladeUtil; 
    47 private import blade.BladeRank
     47private import blade.PostfixX86
    4848 
    4949private: 
    50 // ------------------------------------------------ 
    51 //   Convert infix string to postfix 
    52 // ------------------------------------------------ 
    53  
    54 /// Converts an infix string into postfix. 
    55 /// Apply x87-specific optimisations during the conversion. 
    56 char [] makePostfixForX87(char [] operations, char [][] typelist, char[] ranklist) 
    57 { 
    58     if (operations.length==1) return operations; 
    59     int x = exprLength(operations); 
    60      
    61     char [] op = operations[x+1..x+2];     
    62     char [] first = operations[0..x+1]; 
    63     char [] second = operations[x+2..$]; 
    64     if (operations[x+2]=='=') { // +=, -=, *=, /= 
    65         // Convert "A+=B" into "A=A+B" 
    66         second = makePostfixForX87(operations[0..x+2] ~ operations[x+3..$], typelist, ranklist); 
    67         return second ~ first ~ "="; 
    68     } 
    69     char [] oprvs = op; 
    70     if (op=="-") oprvs="_";  // We use _ to mean reversed subtraction. 
    71      
    72     if (first[0]=='(') { 
    73         first = makePostfixForX87(first[1..first.length-1], typelist, ranklist); 
    74     }else assert(first.length<2, "Missing () in expression: " ~ first); 
    75     if (second[0]=='(') { 
    76         second = makePostfixForX87(second[1..second.length-1], typelist, ranklist); 
    77     }else assert(second.length<2, "Missing () in expression: " ~ second); 
    78     if (op=="=") { 
    79         return second ~ first ~ "="; 
    80     } 
    81  
    82     // x87 OPTIMISATION #1 
    83     // On x87, fmul has a long latency, so we want to delay using the 
    84     // result of a multiply. Since + is commutative, we can achieve this 
    85     // by calculating the value with the multiply, before the other one. 
    86     // We can also do the same thing with -, but we'll need to use fsubr 
    87     // instead of fsub. We use _ to mean reversed subtraction. 
    88     if (op=="+" || op=="-") { 
    89         if (second[second.length-1]=='*'&& first[first.length-1]!='*') { 
    90            return second ~ first ~ oprvs; 
    91         } 
    92         // x87 OPTIMISATION #2 
    93         // When an operation is performed between a real[] and a non-real[], 
    94         // we want to have the real[] being the one which is loaded first. 
    95         if (second.length==1 && typelist[second[0]-'A']=="real" && ranklist[second[0]-'A']=='1') { 
    96                return second ~ first ~ oprvs; 
    97         } 
    98     } 
    99     return first ~ second ~ op; 
    100 } 
    101  
    102  
    103 unittest { 
    104 assert(makePostfixForX87("A=B", ["double", "double"],"11")=="BA="); 
    105 assert(makePostfixForX87("(B*C)+A", ["double", "float", "float"],"111")=="BC*A+"); 
    106 assert(makePostfixForX87("(B*C)+A", ["real", "float", "float"],"111")=="ABC*+"); 
    107 assert(makePostfixForX87("A-(B*C)", ["double", "float", "float"],"100")=="BC*A_"); 
    108 assert(makePostfixForX87("(B*C)-A", ["float", "float", "float"],"100")=="BC*A-"); 
    109 assert(makePostfixForX87("(B*C)-A", ["real", "float", "float"],"100")=="ABC*_"); 
    110 assert(makePostfixForX87("C+=((B*C)-A)", ["real", "float", "float"],"101")=="CABC*_+C="); 
    111 assert(makePostfixForX87("C-=((B*C)-A)", ["real", "float", "float"],"101")=="CABC*_-C="); 
    112 assert(makePostfixForX87("C-=(B*A)", ["real", "float", "float"],"101") =="BA*C_C="); 
    113 assert(makePostfixForX87("C-=(B*A)", ["real", "float", "real"],"101") =="BA*C_C="); 
    114 assert(makePostfixForX87("((A*B)+(C*D))+(E*F)", ["int", "int", "int"],"000")=="EF*AB*CD*++"); 
    115  
    116 } 
    11750 
    11851// num chars before we get a comma. 
     
    12558} 
    12659 
    127 /// Converts an infix string into postfix. 
    128 /// Apply SSE/SSE2-specific optimisations during the conversion. 
    129 char [] makePostfixForSSE(char [] operations, char [] ranklist) 
    130 
    131     if (operations.length==1) return operations; 
    132     int x = exprLength(operations); 
    133      
    134     char [] op = operations[x+1..x+2];     
    135     char [] first = operations[0..x+1]; 
    136     char [] second = operations[x+2..$]; 
    137     if (operations[x+2]=='=') { // +=, -=, *=, /= 
    138         // Convert "A+=B" into "A=A+B" 
    139         second = makePostfixForSSE(operations[0..x+2] ~ operations[x+3..$], ranklist); 
    140         return second ~ first ~ "="; 
    141     } 
    142      
    143     if (first[0]=='(') { 
    144         first = makePostfixForSSE(first[1..first.length-1], ranklist); 
    145     } else assert(first.length<2, "Missing () in expression: " ~ first); 
    146     if (second[0]=='(') { 
    147         second = makePostfixForSSE(second[1..second.length-1], ranklist); 
    148     }else assert(second.length<2, "Missing () in expression: " ~ second); 
    149     if (op=="=") { 
    150         return second ~ first ~ "="; 
    151     } 
    152  
    153     // Multiplies have a long latency, so we want to delay using the 
    154     // result of a multiply. Since + is commutative, we can achieve this 
    155     // by calculating the value with the multiply, before the other one. 
    156     if (op=="+") { 
    157         if (second[second.length-1]=='*'&& first[first.length-1]!='*') { 
    158            return second ~ first ~ op; 
    159         } 
    160     } 
    161     if (op=="*") { 
    162         // SSE OPTIMISATION #2 
    163         // When an operation is performed between a vector and a scalar 
    164         // we want to have the vector being the one which is loaded first. 
    165         if (first.length==1 && ranklist[first[0]-'A']=='0') { 
    166                return second ~ first ~ op; 
    167         } 
    168     } 
    169  
    170     return first ~ second ~ op; 
    171 
    172  
     60// -------------- 
     61// Ranklist functions 
     62 
     63// Count the number of vectors 
     64int countVectors(char[] ranklist) 
     65
     66    int numVecs=0; 
     67    for (int i=0; i<ranklist.length; ++i) { 
     68        if (ranklist[i]=='1') ++numVecs; 
     69    } 
     70    return numVecs; 
     71
     72 
     73int vectorNum(char [] ranklist, char var) 
     74
     75    int numVecs=0; 
     76    for (int i=0; i<var-'A'; ++i) { 
     77        if (ranklist[i]=='1') ++numVecs; 
     78    } 
     79    return numVecs; 
     80
     81 
     82int scalarNum(char [] ranklist, char var) 
     83
     84    int k=0; 
     85    for (int i=0; i<var-'A'; ++i) { 
     86        if (ranklist[i]=='0') ++k; 
     87    } 
     88    return k; 
     89
     90 
     91int realScalarNum(char [][] typelist, char [] ranklist, char var) 
     92
     93    int k=0; 
     94    for (int i=0; i<var-'A'; ++i) { 
     95        if (ranklist[i]=='0' && typelist[i]=="real") ++k; 
     96    } 
     97    return k; 
     98
    17399private: 
    174  
    175 unittest { 
    176 assert(makePostfixForSSE("A=B", "11")=="BA="); 
    177 assert(makePostfixForSSE("(A*B)+C", "101")=="AB*C+"); 
    178 assert(makePostfixForSSE("A=(B*C)", "110")=="BC*A="); 
    179 } 
    180  
    181100// ------------------------------- 
    182101//   Mixins to generate x87 ASM code