Changeset 162

Show
Ignore:
Timestamp:
12/14/07 07:40:53 (9 months ago)
Author:
Don Clugston
Message:

Scalar folding now done by a visitor -- all deprecated code removed.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/blade/Blade.d

    r157 r162  
    512512    int lenvec = findVectorForLength(tree); 
    513513    char [] result = assertAllVectorLengthsEqual(tree); 
     514    result ~= "// " ~ tree.expression ~ \n; 
    514515    result ~= "for (int blade_index=0; blade_index<"  
    515516    ~ getDimensionLengthForSymbol(tree.mapping[lenvec], tree, 1) ~ 
  • trunk/blade/BladeDemo.d

    r161 r162  
    3333                            [5.14, 455, 554, 2.43]]; 
    3434    real k=3.4; 
    35  
     35     
    3636    mixin(vectorize(` a += (d[2..$-1]*2.01*a[2]-another[][1])["abc".length-3..$]`)); 
    3737    mixin(vectorize(" a-= 2.01*(        3.04+k)*r"));     
    38      
     38    
    3939    mixin(vectorize("q+= q*2.01")); 
    4040    
     
    4444    mixin(vectorize("a+=6*another[1]")); 
    4545    mixin(vectorize("a+=6*another[1][]")); 
    46     
    4746    mixin(vectorize("another[0..$,1]+=6*a[0..2]")); 
    4847    mixin(vectorize("r-=another[0]")); 
  • trunk/blade/BladeSimplify.d

    r161 r162  
    114114} 
    115115 
    116 // DEPRECATED 
    117 /// As for exprSimplify, but allows the whole thing to be wrapped in parentheses. 
    118 char [] subexprSimplify(char [] expr, char [] rank, char [] mulexpr, char [] indexexpr) 
    119 { 
    120     if (expr.length==1) { 
    121         int r = subexprRank(expr, rank); 
    122         char [] e = expr; 
    123         if(indexexpr.length>0) { 
    124             assert(r>0, "BLADE BUG: MISMATCHED INDEX " ~ expr ~ " " ~ indexexpr ~ " " ~ mulexpr); 
    125             e = " {" ~ expr ~ indexexpr ~ "} "; 
    126         } 
    127         if (mulexpr.length>1) { 
    128             // in this case, it's worth creating a new variable 
    129             return "( {" ~ mulexpr ~ "} *" ~ e ~ ")"; 
    130         } 
    131         if (mulexpr.length>0) return "(" ~ mulexpr ~ "*" ~ e ~ ")"; 
    132         return e; 
    133     } 
    134     // strip off the parentheses before simplifying 
    135     return exprSimplify(expr[1..$-1], rank, mulexpr, indexexpr); 
    136 } 
    137  
    138 // DEPRECATED 
    139 // Simplify a scalar*tensor expression. 
    140 char [] simplifyScalarMul(char [] scalar, char [] tensor, char [] mulexpr, char [] rank, char [] indexexpr) 
    141 { 
    142     char [] m = scalar; 
    143     if (mulexpr.length>0) m = "(" ~ mulexpr ~ "*" ~ scalar ~ ")"; 
    144     // BUG: It's also worth scalar folding A*(B*U[]-V[]) into (A*B)*U[]+(-A)*V[], 
    145     if (tensor.length > 1 && hasScalarMultiply(tensor[1..$-1], rank)) { 
    146         // opportunity for scalar folding 
    147         return subexprSimplify(tensor, rank, m, indexexpr); 
    148     } else { 
    149         if (m.length>1) m = " {" ~ m ~ "} "; 
    150         return "(" ~ m ~ "*" ~ subexprSimplify(tensor, rank, "", indexexpr) ~ ")"; 
    151     } 
    152 } 
    153  
    154 // DEPRECATED 
    155 char [] getCommonMultiplucation(char [] expr, char [] rank) 
    156 { 
    157     if (expr.length>1 && expr[0]=='(') expr = expr[1..$-1]; 
    158     else if (expr.length==1 && expr[0]>='A' && expr[0]<='Z'){ 
    159         return rank[expr[0]-'A']=='0'? expr : ""; 
    160     } else assert(0, "BLADE ICE: " ~ expr); 
    161     if (expr.length>2 && (expr[0..2]=="++" || expr[0..2]=="--" || expr[$-2..$]=="++" || expr[$-2..$]=="--")) { 
    162         return ""; 
    163     } 
    164     if (expr[0]=='+' || expr[0]=='-') return getCommonMultiplucation(expr[1..$], rank); 
    165     int x = exprLength(expr); 
    166     int y = x+1; 
    167     assert(y < expr.length, "BLADE ICE:" ~ expr); 
    168     // Deal with shifts, op=, and NCEG operators 
    169     while (expr[y+1]=='<' || expr[y+1]=='>' || expr[y+1]=='=') ++y; 
    170     char [] op = expr[x+1..y+1];     
    171     char [] left = expr[0..x+1]; 
    172     char [] right = expr[y+1..$]; 
    173     if (op=="[") { 
    174         // (A)[C] can still have a multiply by scalar, if A contains a 
    175         // multiply. 
    176         if (left.length==1) return ""; 
    177         return getCommonMultiplucation(left[1..$], rank); 
    178     } 
    179     if (op!="*") return "";     
    180     int leftrnk = subexprRank(left, rank); 
    181     int rightrnk = subexprRank(right, rank); 
    182     char [] leftMul = ""; 
    183     char [] rightMul = ""; 
    184     if (leftrnk == 0 && rightrnk == 0) { return expr; } 
    185     if (leftrnk == 0) leftMul = left; else leftMul = getCommonMultiplucation(left, rank); 
    186     if (rightrnk== 0) rightMul = right; else rightMul = getCommonMultiplucation(right, rank); 
    187     if (leftMul!="" && rightMul!="") return "(" ~ leftMul ~ "*" ~ rightMul ~ ")"; 
    188     if (leftMul!="") return leftMul; 
    189     return rightMul; 
    190 } 
    191  
    192 // DEPRECATED 
    193 // Simplify the expression, assuming global scalar multiply has already been removed. 
    194 char [] simplifyWithoutMul(char [] rawExpr, char [] rank) 
    195 { 
    196     char [] expr = rawExpr; 
    197     if (expr.length==1 && expr[0]>='A' && expr[0]<='Z') { 
    198         return rank[expr[0]-'A']=='0'? "":expr; 
    199     } else if (expr.length>1 && expr[0]=='(') expr = rawExpr[1..$-1]; 
    200  
    201     if (expr.length>2 && (expr[0..2]=="++" || expr[0..2]=="--" || expr[$-2..$]=="++" || expr[$-2..$]=="--")) { 
    202         return expr; 
    203     } 
    204     int x = exprLength(expr); 
    205     int y = x+1; 
    206     assert(y < expr.length, "BLADE ICE:" ~ expr); 
    207     // Deal with shifts, op=, and NCEG operators 
    208     while (expr[y+1]=='<' || expr[y+1]=='>' || expr[y+1]=='=') ++y; 
    209     char [] op = expr[x+1..y+1];     
    210     char [] left = expr[0..x+1]; 
    211     char [] right = expr[y+1..$]; 
    212 /*     
    213     if (op=="[") { 
    214         // (A)[C] can still have a multiply by scalar, if A contains a 
    215         // multiply. 
    216         if (left.length==1) return left; 
    217         return simplifyWithoutMul(left[1..$], rank); 
    218     } 
    219 */ 
    220     if (op!="*") return subexprSimplify(rawExpr, rank, "", ""); 
    221      
    222     int leftrnk = subexprRank(left, rank); 
    223     int rightrnk = subexprRank(right, rank); 
    224     char [] leftMul = ""; 
    225     char [] rightMul = ""; 
    226     if (leftrnk == 0 && rightrnk == 0) return ""; 
    227     if (leftrnk == 0) leftMul = ""; else leftMul = simplifyWithoutMul(left, rank); 
    228     if (rightrnk== 0) rightMul = ""; else rightMul = simplifyWithoutMul(right, rank); 
    229     if (leftMul!="" && rightMul!="") return "(" ~ leftMul ~ "*" ~ rightMul ~ ")"; 
    230     if (leftMul!="") return leftMul; 
    231     return rightMul; 
    232 } 
    233  
    234 // DEPRECATED 
    235 /** 
    236  * Rewrite the expression, taking advantage of distributivity of [] and 
    237  * associativity of *. Additionally, we group all scalars together, whenever 
    238  * possible. 
    239  * 
    240  * This process creates compound scalars and vectors, delineated by " {" and "} ". 
    241  * They will be removed in a subsequent step. 
    242  */ 
    243 char [] exprSimplify(char [] expr, char [] rank, char [] mulexpr, char [] indexexpr) 
    244 { 
    245     if (expr.length>3 && expr[0..2]=="d(") { // dot product 
    246         assert(indexexpr=="", "BLADE ICE"); 
    247         int x = exprLength(expr[2..$-1]); 
    248             char [] left = expr[2..x+3]; 
    249             char [] right = expr[x+4..$-1]; 
    250             char [] leftmul = ""; 
    251             char [] rightmul = "";             
    252             leftmul = getCommonMultiplucation(left, rank); 
    253             rightmul = getCommonMultiplucation(right, rank); 
    254             char [] m = leftmul; 
    255             if (rightmul.length>0) m = m==""? rightmul : "(" ~ m ~ "*" ~rightmul~")"; 
    256             if (mulexpr.length>0) m = m=="" ? mulexpr : "(" ~ m ~ "*" ~mulexpr~")"; 
    257             if (m.length>1) m= "* {" ~ m ~ "} "; 
    258             else if (m.length==1) m= "*" ~ m; 
    259             assert(indexexpr.length==0, "BLADE ICE: rank mismatch in dot product"); 
    260             return " {d(" ~ simplifyWithoutMul(left, rank) ~ "," ~ 
    261                           simplifyWithoutMul(right, rank) ~ ")} " ~ m; 
    262     } 
    263     // Deal with ++ and --. Only for scalars 
    264     if (expr.length>2 && (expr[0..2]=="++" || expr[0..2]=="--")) { 
    265         return expr[0..2] ~ subexprSimplify(expr[2..$], rank, mulexpr, indexexpr); 
    266     } 
    267     if (expr.length>2 && (expr[$-2..$]=="++" || expr[$-2..$]=="--")) { 
    268         return subexprSimplify(expr[0..$-2], rank, mulexpr, indexexpr) ~ expr[$-2..$]; 
    269     } 
    270     // Deal with unary operators 
    271     if (expr[0]=='-') { 
    272         // Fold unary minus into a multiply, if possible. 
    273         if (mulexpr.length>0) return subexprSimplify(expr[1..$], rank, "-" ~ mulexpr, indexexpr); 
    274         return "-" ~ subexprSimplify(expr[1..$], rank, mulexpr, indexexpr); 
    275     } 
    276     // Just remove unary plus. 
    277     if (expr[0]=='+') return subexprSimplify(expr[1..$], rank, mulexpr, indexexpr); 
    278     
    279     int x = exprLength(expr); 
    280     int y = x+1; 
    281     assert(y < expr.length, expr); 
    282     // Deal with shifts, op=, and NCEG operators 
    283     while (expr[y+1]=='<' || expr[y+1]=='>' || expr[y+1]=='=') ++y;     
    284  
    285     char [] op = expr[x+1..y+1];     
    286     char [] left = expr[0..x+1]; 
    287     char [] right = expr[y+1..$]; 
    288     if (expr[x+1]=='[') right = expr[y+1..$-1]; // drop off the ']'. 
    289     if (op=="[") { 
    290         // accumulate indexing and slicing operations 
    291         return subexprSimplify(left, rank, mulexpr, "[" ~ right ~ "]" ~ indexexpr); 
    292     } 
    293     int lrank = subexprRank(left, rank); 
    294     int rrank = subexprRank(right, rank); 
    295     // Fold scalars together 
    296     if (op=="*") { 
    297         if (lrank==0) 
    298             return simplifyScalarMul(left, right, mulexpr, rank, indexexpr); 
    299         else if (rrank==0)  
    300             return simplifyScalarMul(right, left, mulexpr, rank, indexexpr); 
    301         // else it's matrix * matrix 
    302     } 
    303     if (op=="*=") { 
    304         if (rrank==0) { 
    305             char [] m = right; 
    306             if (mulexpr.length>0) m ~= "*" ~ mulexpr; 
    307             if (m.length>1)  m= " {" ~ m ~ "} "; 
    308             return "(" ~ subexprSimplify(left, rank, "", indexexpr)~ "*=" ~ m ~ ")"; 
    309         } // else it's matrix *= matrix 
    310     } 
    311     return "(" ~ subexprSimplify(left, rank, mulexpr, indexexpr) ~ op ~ subexprSimplify(right, rank, mulexpr, indexexpr) ~ ")"; 
    312 } 
    313  
    314116// Determine rank of a multidimensional index 
    315117int indexRank(char [] s) 
     
    337139RevisedExpression simplifyVectorExpression(char [] expr, char [] rank, Symbol[] symTable=[]) 
    338140{ 
    339     char [] s = exprSimplify(foldIndices(expr, rank), rank, "", ""); 
    340 //    char [] s = foldScalars(foldIndices(expr, rank), rank); 
    341 //    char [] s = exprSimplify(expr, rank, "", ""); 
    342     if (s.length>1 && s[0]=='(') s = s[1..$-1]; // strip off () 
     141    char [] s = foldScalars(foldIndices(expr, rank), rank); 
    343142    char [][] comp; 
    344143    char [] used = ""; // which of the old variables are used; gives the new mapping 
     
    353152            for (k=i+1; s[k]!=' '; ++k) {} 
    354153            char [] newexpr = s[i+2..k-1]; // strip off the {} 
     154             
     155            int newi = k; 
     156            if (i>0 && k<s.length && s[i-1]=='(' && s[k+1]==')') { 
     157                e = e[0..$-1]; // remove the last '(' 
     158                newi=k+1; // don't add ')' 
     159            } 
    355160            // Check for a duplicate 
    356161            int z; 
     
    371176                }                 
    372177            } else e ~= cast(char)('A'+z+rank.length); 
    373             i = k
     178            i = newi
    374179        } else { 
    375180            e ~= s[i]; 
     
    405210} 
    406211 
    407 unittest { 
    408     assert(exprSimplify("d(A,(A*d(A,A)))", "1", "","")== " {d(A,A)} * {d(A,A)} "); 
    409     assert(exprSimplify("A+=(B*(C[D,D..$]))","1020","","")=="(A+=(B* {C[D,D..$]} ))"); 
    410     assert(exprSimplify("A+=(((D[E])*B)[E])", "103300","","")=="(A+=(B* {D[E][E]} ))"); 
    411     assert(exprSimplify("A+=(B*((C[B])[B..E]))", "103300","","")=="(A+=(B* {C[B][B..E]} ))"); 
    412     assert(exprSimplify("A*=(B*C)", "100","","")== "(A*= {(B*C)} )"); 
    413     assert(exprSimplify("A=((B*C)-D)", "1101","","")=="(A=((C*B)-D))"); 
    414     assert(exprSimplify("A=(((B*E)+(C*E))*D)", "11100","","")=="(A=(( {(D*E)} *B)+( {(D*E)} *C)))"); 
    415     assert(exprSimplify("A=(D*((B*E)+(C*E)))", "11100","","")=="(A=(( {(D*E)} *B)+( {(D*E)} *C)))"); 
    416     assert(exprSimplify("d((A*(B*C)),(B*A))","010","","")== " {d(B,B)} * {((A*C)*A)} "); 
    417  
     212unittest {     
    418213    RevisedExpression e = simplifyVectorExpression("A+=(((D[B])*C)[B])", "2004"); 
    419     assert(e.expression == "A+=(B*C)"); 
    420214    assert(e.rank=="202"); 
    421215    assert(e.compounds[0]=="D[B,B]"); 
    422216    assert(e.mapping=="ACE"); 
     217    assert(e.expression== "A+=(C*B)"); 
    423218} 
    424219 
     
    460255       } else { 
    461256           assert(sym!="$" && this_.rank[sym[0]-'A']>'0', "Rank error " ~ sym); 
    462            // TODO: We want this to be a new terminal. 
     257           // Note: Later, we'll want this to be a new terminal. 
    463258           return sym ~ createMultiSlice(this_.slicing);            
    464259       } 
     
    494289            //   it might contain a dollar, which we need to replace.  
    495290            // * If the existing dimension is a slice, the two slices will combine. 
    496             //  
     291            // 
    497292            // The items inside the slice are top-level, ie have no slice or dollar. 
    498293            char [] a = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank,"$",[]), slices[$-1][0])); 
     
    625420struct ScalarFold 
    626421{ 
    627     char [] expr; 
     422    char [] expr;       // vector or matrix expression; empty for a pure scalar expression 
    628423    char [] multiplier; // scalar multiply of the entire expression 
    629424} 
     
    662457        } else { 
    663458            ScalarFold f = doVisit(this_, expr); 
    664             return ScalarFold(op ~ wrapInParens(combineMul(f.expr, f.multiplier)),""); 
     459            assert(f.expr==""); 
     460            return ScalarFold("", op ~ wrapInParens(f.multiplier)); 
    665461        } 
    666462    } 
    667463    ReturnType onVisitPostfix(This this_, char [] op, char [] expr) { 
    668464        ScalarFold f = doVisit(this_, expr); 
    669         return ScalarFold(wrapInParens(combineMul(f.expr, f.multiplier))~ op,""); 
     465        assert(f.expr==""); 
     466        return ScalarFold("", wrapInParens(f.multiplier)~ op); 
    670467    } 
    671468    ReturnType onVisitIndex(This this_, char [] base, char [][2][] slices) { 
    672469        // Base is always a single symbol. 
     470        assert(base.length==1); 
    673471        ScalarFold left = doVisit(this_, base); 
    674         // BUG: This whole thing could be a scalar. 
     472        // Determine the rank of this expression 
     473        int r = this_.rank[base[0]-'A']-'0'; 
     474        for (int i=0; i<slices.length;++i) { 
     475            if (slices[i][1].length==0) --r; 
     476        } 
     477        if (r==0) { // the whole thing is a scalar 
     478            return ScalarFold("",combineMul(left.expr ~ createMultiSlice(slices), left.multiplier)); 
     479        } else 
    675480        return ScalarFold(" {" ~ left.expr ~ createMultiSlice(slices)~ "} ", left.multiplier); 
    676481    } 
     
    688493        if (first.expr=="" && second.expr=="") { // both are 100% scalars -- it remains a scalar. 
    689494            return ScalarFold("", 
    690             wrapInParens(combineMulWithCompound(first.expr, first.multiplier)) ~ op ~ 
    691             wrapInParens(combineMulWithCompound(second.expr, second.multiplier))); 
     495            wrapInParens(combineMul(first.expr, first.multiplier)) ~ op ~ 
     496            wrapInParens(combineMul(second.expr, second.multiplier))); 
    692497        } 
    693498        return ScalarFold(wrapInParens(combineMulWithCompound(first.expr, first.multiplier)) ~ op ~ 
     
    717522    if (f.multiplier=="") return f.expr; 
    718523    else if (f.expr=="") return " {" ~ f.multiplier ~ "} "; 
    719     else return " {" ~ f.multiplier ~ "} *" ~ wrapInParens(f.expr); 
     524    else return combineMulWithCompound(f.expr, f.multiplier); 
    720525} 
    721526