Changeset 156

Show
Ignore:
Timestamp:
12/05/07 14:22:09 (9 months ago)
Author:
Don Clugston
Message:

Fixed constant folding for dot (first attempt)

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/blade/BladeDemo.d

    r155 r156  
    5151 
    5252// Parses, and rank checks OK. Doesn't simplify yet, no codegen. 
    53 //    mixin(vectorize("dot(q,q*dot(q,q))")); // BUG: should simplify to: dot(q.q) * dot(q,q) 
     53 //   mixin(vectorize("dot(q,q*dot(q,q))")); // BUG: should simplify to: dot(q.q) * dot(q,q) 
    5454 
    5555    writefln("a=", a); 
  • trunk/blade/BladeRank.d

    r155 r156  
    280280} 
    281281 
     282 
    282283// Return true if the entire expression contains a multiplication by a scalar 
    283284bool hasScalarMultiply(char [] expr, char [] rank) 
  • trunk/blade/BladeSimplify.d

    r155 r156  
    150150} 
    151151 
     152char [] getCommonMultiplucation(char [] expr, char [] rank) 
     153{ 
     154    if (expr.length>1 && expr[0]=='(') expr = expr[1..$-1]; 
     155    else if (expr.length==1 && expr[0]>='A' && expr[0]<='Z'){ 
     156        return rank[expr[0]-'A']=='0'? expr : ""; 
     157    } else assert(0, "BLADE ICE: " ~ expr); 
     158    if (expr.length>2 && (expr[0..2]=="++" || expr[0..2]=="--" || expr[$-2..$]=="++" || expr[$-2..$]=="--")) { 
     159        return ""; 
     160    } 
     161    if (expr[0]=='+' || expr[0]=='-') return getCommonMultiplucation(expr[1..$], rank); 
     162    int x = exprLength(expr); 
     163    int y = x+1; 
     164    assert(y < expr.length, "BLADE ICE:" ~ expr); 
     165    // Deal with shifts, op=, and NCEG operators 
     166    while (expr[y+1]=='<' || expr[y+1]=='>' || expr[y+1]=='=') ++y; 
     167    char [] op = expr[x+1..y+1];     
     168    char [] left = expr[0..x+1]; 
     169    char [] right = expr[y+1..$]; 
     170    if (op=="[") { 
     171        // (A)[C] can still have a multiply by scalar, if A contains a 
     172        // multiply. 
     173        if (left.length==1) return ""; 
     174        return getCommonMultiplucation(left[1..$], rank); 
     175    } 
     176    if (op!="*") return "";     
     177    int leftrnk = subexprRank(left, rank); 
     178    int rightrnk = subexprRank(right, rank); 
     179    char [] leftMul = ""; 
     180    char [] rightMul = ""; 
     181    if (leftrnk == 0 && rightrnk == 0) return expr; 
     182    if (leftrnk == 0) leftMul = left; else leftMul = getCommonMultiplucation(left, rank); 
     183    if (rightrnk== 0) rightMul = right; else rightMul = getCommonMultiplucation(right, rank); 
     184    if (leftMul!="" && rightMul!="") return "(" ~ leftMul ~ "*" ~ rightMul ~ ")"; 
     185    if (leftMul!="") return leftMul; 
     186    return rightMul; 
     187} 
     188 
     189// Simplify the expression, assuming global scalar multiply has already been removed. 
     190char [] simplifyWithoutMul(char [] rawExpr, char [] rank) 
     191{ 
     192    char [] expr = rawExpr; 
     193    if (expr.length==1 && expr[0]>='A' && expr[0]<='Z') { 
     194        return rank[expr[0]-'A']=='0'? "":expr; 
     195    } else if (expr.length>1 && expr[0]=='(') expr = rawExpr[1..$-1]; 
     196 
     197    if (expr.length>2 && (expr[0..2]=="++" || expr[0..2]=="--" || expr[$-2..$]=="++" || expr[$-2..$]=="--")) { 
     198        return expr; 
     199    } 
     200    int x = exprLength(expr); 
     201    int y = x+1; 
     202    assert(y < expr.length, "BLADE ICE:" ~ expr); 
     203    // Deal with shifts, op=, and NCEG operators 
     204    while (expr[y+1]=='<' || expr[y+1]=='>' || expr[y+1]=='=') ++y; 
     205    char [] op = expr[x+1..y+1];     
     206    char [] left = expr[0..x+1]; 
     207    char [] right = expr[y+1..$]; 
     208/*     
     209    if (op=="[") { 
     210        // (A)[C] can still have a multiply by scalar, if A contains a 
     211        // multiply. 
     212        if (left.length==1) return left; 
     213        return simplifyWithoutMul(left[1..$], rank); 
     214    } 
     215*/ 
     216    if (op!="*") return subexprSimplify(rawExpr, rank, "", ""); 
     217     
     218    int leftrnk = subexprRank(left, rank); 
     219    int rightrnk = subexprRank(right, rank); 
     220    char [] leftMul = ""; 
     221    char [] rightMul = ""; 
     222    if (leftrnk == 0 && rightrnk == 0) return ""; 
     223    if (leftrnk == 0) leftMul = ""; else leftMul = simplifyWithoutMul(left, rank); 
     224    if (rightrnk== 0) rightMul = ""; else rightMul = simplifyWithoutMul(right, rank); 
     225    if (leftMul!="" && rightMul!="") return "(" ~ leftMul ~ "*" ~ rightMul ~ ")"; 
     226    if (leftMul!="") return leftMul; 
     227    return rightMul; 
     228} 
     229 
    152230/** 
    153231 * Rewrite the expression, taking advantage of distributivity of [] and 
     
    160238char [] exprSimplify(char [] expr, char [] rank, char [] mulexpr, char [] indexexpr) 
    161239{ 
    162     if (expr.length>3 && expr[0..2]=="d(") { // dot product     
     240    if (expr.length>3 && expr[0..2]=="d(") { // dot product 
     241        assert(indexexpr=="", "BLADE ICE"); 
    163242        int x = exprLength(expr[2..$-1]); 
    164243            char [] left = expr[2..x+3]; 
    165244            char [] right = expr[x+4..$-1]; 
     245            char [] leftmul = ""; 
     246            char [] rightmul = "";             
     247            leftmul = getCommonMultiplucation(left, rank); 
     248            rightmul = getCommonMultiplucation(right, rank); 
     249             
     250//            assert(0, leftmul~"#" ~ rightmul); 
    166251//            if (hasScalarMultiply(left, rank)) { 
    167 //                // pull the scalar mul out 
     252                // pull the scalar mul out 
    168253//            } 
    169254//            if (hasScalarMultiply(right, rank)) { 
    170255//            } // ditto for right. 
    171             char [] m = ""; 
    172             if (mulexpr.length>0) m ~= "*" ~ mulexpr; 
     256            char [] m = leftmul; 
     257            if (rightmul.length>0) m = m==""? rightmul : "(" ~ m ~ "*" ~rightmul~")"; 
     258            if (mulexpr.length>0) m = m=="" ? mulexpr : "(" ~ m ~ "*" ~mulexpr~")"; 
     259            if (m.length>1) m= "* {" ~ m ~ "}"; 
     260            else if (m.length==1) m= "*" ~ m; 
    173261            assert(indexexpr.length==0, "BLADE ICE: rank mismatch in dot product"); 
    174             return "d(" ~ subexprSimplify(left, rank, mulexpr,"") ~ "," ~ 
    175                           subexprSimplify(right, rank, mulexpr,"") ~ ")" ~ m; 
     262            //return "#" ~ subexprSimplify("A,B","01", "",""); 
     263//           assert(0, expr ~ "#" ~ left ~ "#" ~ right ~"#" ~ rank ~ "#");// ~ subexprSimplify(right, rank, mulexpr,"")~"#"); 
     264            return "d(" ~ simplifyWithoutMul(left, rank) ~ "," ~ 
     265                          simplifyWithoutMul(right, rank) ~ ")" ~ m; 
    176266    } 
    177267    // Deal with ++ and --. Only for scalars 
     
    313403 
    314404unittest { 
     405//    assert(0, exprSimplify("d(A,A*d(A,A))", "1", "","")); // == "d(A,A)*d(A,A)"); 
    315406    assert(exprSimplify("A+=(B*(C[D,D..$]))","1020","","")=="(A+=(B* {C[D,D..$]} ))"); 
    316407    assert(exprSimplify("A+=(((D[E])*B)[E])", "103300","","")=="(A+=(B* {D[E][E]} ))"); 
     
    320411    assert(exprSimplify("A=(((B*E)+(C*E))*D)", "11100","","")=="(A=(( {(D*E)} *B)+( {(D*E)} *C)))"); 
    321412    assert(exprSimplify("A=(D*((B*E)+(C*E)))", "11100","","")=="(A=(( {(D*E)} *B)+( {(D*E)} *C)))"); 
    322 //    assert(exprSimplify("d(A,A*d(A,A))", "1", "","") == "d(A,A)*d(A,A)"); 
     413    assert(exprSimplify("d((A*(B*C)),(B*A))","010","","")== "d(B,B)* {((A*C)*A)}"); 
    323414 
    324415    RevisedExpression e = simplifyVectorExpression("A+=(((D[B])*C)[B])", "2004");