Changeset 133

Show
Ignore:
Timestamp:
11/13/07 04:47:28 (10 months ago)
Author:
Don Clugston
Message:

Factored out the expression simplification code. Duplicate symbols now get removed. SSE/SSE2 now uses simplified expressions (not tested, but works for the existing cases at least).

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/blade/Blade.d

    r132 r133  
    4848private import blade.BladeUtil : wrapInQuotes, startsWith; 
    4949private import blade.BladeRank; 
     50private import blade.BladeSimplify : simplifySyntaxTree, RevisedExpression; 
    5051public import blade.CodegenX86 : generateCodeForAsmX87, generateCodeForSSE, MAX_X87_VECTORS, MAX_SSE_VECTORS; 
    5152 
    5253public: 
     54 
     55// FOR MIXIN: Generate code to evaluate the given vector expression. 
     56char [] vectorize(char [] expr) 
     57{ 
     58    return "mixin(makeVectorCode(" ~ syntaxtreeof(expr) ~ "));"; 
     59}         
     60 
     61// Simplify the expression, categorise it,  
     62// and dispatch to the appropriate code generator. 
     63char [] makeVectorCode(AbstractSyntaxTree tree) 
     64{ 
     65    RevisedExpression revised = simplifySyntaxTree(tree); 
     66     
     67    VecExpressionType exprType = categorizeExpression(tree); 
     68    if (exprType == VecExpressionType.SSE2Expression || exprType == VecExpressionType.SSE1Expression) { 
     69        return invokeSSE((exprType == VecExpressionType.SSE2Expression), tree, revised); 
     70    } else if (exprType == VecExpressionType.X87Expression) { 
     71        return invokeX87(tree); 
     72    } else { 
     73//        return "pragma(msg, " ~ wrapInQuotes(DCodeGenerator(tree)) ~ ");" ~ DCodeGenerator(tree); 
     74        return DCodeGenerator(tree); 
     75    }     
     76} 
     77 
    5378// These functions have the complete expression encoded in the template type. 
    5479// One of these functions is instantiated for each expression. 
     
    77102    mixin(generateCodeForAsmX87(typelist, ranklist, expr)); 
    78103} 
     104 
     105 
     106private: 
    79107 
    80108// ------------------------------------ 
     
    171199 
    172200/// Generate code which will call the SSE/SSE2 code generation function 
    173 char [] invokeSSE(bool SSE2, AbstractSyntaxTree tree
     201char [] invokeSSE(bool SSE2, AbstractSyntaxTree tree, RevisedExpression revised
    174202{ 
    175203    char [] result = assertAllVectorLengthsEqual(tree); 
    176204    result ~= assertAllVectorsAlign128(tree); 
    177205     
    178     result ~= "SSEVECGEN!(" ~ (SSE2?"2":"1") ~ "," ~ wrapInQuotes(tree.expression);     
     206    char [] scalartype = SSE2? ",double" : ",float"; 
     207    char [] vectortype = SSE2? ",double*" : ",float*"; 
     208     
     209    result ~= "SSEVECGEN!(" ~ (SSE2?"2":"1") ~ "," ~ wrapInQuotes(revised.expr);     
    179210    // For SSE2, everything must be implicitly convertible to double. 
    180     for (int i=0; i<tree.symbolTable.length;++i) { 
    181         if (SSE2) { 
    182             if (tree.symbolTable[i].rank==0) result ~= ",double"; 
    183             else result ~= ",double*"; 
    184         } else { 
    185             if (tree.symbolTable[i].rank==0) result ~= ",float"; 
    186             else result ~= ",float*"; 
    187         } 
     211    char [] vals; 
     212    for (int i=0; i<tree.symbolTable.length;++i) { 
     213        if (revised.used[i]=='-') continue; // ignore unused symbols 
     214        if (tree.symbolTable[i].rank==0) result ~= scalartype; 
     215        else result ~= vectortype;         
     216        vals ~= "," ~ tree.symbolTable[i].value; 
     217        // for vectors, we only need the pointer, not the length 
     218        if (tree.symbolTable[i].rank==1) vals ~= ".ptr"; 
     219    } 
     220    // Now deal with all of the compound expressions 
     221    for (int i=0; i<revised.rank.length;++i) { 
     222        if (revised.rank[i]==0) result ~= scalartype; 
     223        else result ~= vectortype; 
     224        char [] s = ""; 
     225        foreach(c; revised.compounds[i]) { 
     226            if (c>='A' && c<='Z') s ~= tree.symbolTable[c-'A'].value; 
     227            else s ~= c; 
     228        } 
     229        if (revised.rank[i]==1) vals ~= ",(" ~ s ~ ").ptr"; 
     230        else vals ~= "," ~ s; 
    188231    } 
    189232    result ~= ")("; 
    190233    int firstVector = findVectorForLength(tree); 
    191234    result ~= tree.symbolTable[firstVector].value ~ ".length"; 
    192  
    193     for (int i=0; i<tree.symbolTable.length;++i) { 
    194         result ~= "," ~ tree.symbolTable[i].value; 
    195         // for vectors, we only need the pointer, not the length 
    196         if (tree.symbolTable[i].rank==1) result ~= ".ptr"; 
    197     } 
     235    result ~= vals; 
     236 
    198237    return result ~ ");"; 
    199238} 
     
    275314} 
    276315 
    277  
    278 // Simplify a vector expression 
    279 //  - Use slicing distributive law: A[B..C] for expressions A,B,C 
    280 //     where B and C are both rank 0, and A is rank 1, the slice can 
    281 //     be moved to every vector inside A. 
    282 //  - Convert A[]*B into B*A[] (assumes * is commutative, 
    283 //      not valid for quaternions). 
    284 //  - Use associativity of *: A*(B*C[]) == (A*B)*C[] 
    285 //  - Use * distributive law over + and -. (Not strictly correct). 
    286 // Convert -A*B into +(-A)*B whenever possible. 
    287 // Combine all scalars into a single scalar. 
    288  
    289 // Categorise the expression, and dispatch to the appropriate code generator. 
    290 char [] makeVectorCode(AbstractSyntaxTree tree) 
    291 { 
    292     int [] ranks=[]; 
    293     for (int i=0; i<tree.symbolTable.length; ++i) {ranks~=tree.symbolTable[i].rank; }     
    294     RevisedExpression e = simplifyVectorExpression(tree.expression, ranks); 
    295      
    296     VecExpressionType exprType = categorizeExpression(tree); 
    297     if (exprType == VecExpressionType.SSE2Expression || exprType == VecExpressionType.SSE1Expression) { 
    298         return invokeSSE((exprType == VecExpressionType.SSE2Expression), tree); 
    299     } else if (exprType == VecExpressionType.X87Expression) { 
    300         return invokeX87(tree); 
    301     } else { 
    302 //        return "pragma(msg, " ~ wrapInQuotes(DCodeGenerator(tree)) ~ ");" ~ DCodeGenerator(tree); 
    303         return DCodeGenerator(tree); 
    304     }     
    305 } 
    306  
    307 char [] vectorize(char [] expr) 
    308 { 
    309     return "mixin(makeVectorCode(" ~ syntaxtreeof(expr) ~ "));"; 
    310 }         
  • trunk/blade/BladeDemo.d

    r127 r133  
    3333 
    3434    mixin(vectorize(" a   += d*2.01-z")); 
    35     mixin(vectorize(" a   += r*2.01")); 
    36     mixin(vectorize(" q   += q*2.01")); 
     35 //   mixin(vectorize(" a   += r*2.01")); 
     36 //   mixin(vectorize(" q   += q*2.01")); 
    3737     
    3838    writefln("a=", a); 
  • trunk/blade/BladeRank.d

    r132 r133  
    325325} 
    326326 
    327 char [] subexprSimplify(char [] expr, int [] rank, char [] mulexpr, char [] indexexpr) 
    328 { 
    329     if (expr.length==1) { 
    330         int r = subexprRank(expr, rank); 
    331         char [] e = expr; 
    332         if(indexexpr.length>0) { 
    333             assert(r>0, "BLADE BUG: MISMATCHED INDEX " ~ expr ~ " " ~ indexexpr ~ " " ~ mulexpr); 
    334             e = " {" ~ expr ~ indexexpr ~ "} "; 
    335         } 
    336         if (mulexpr.length>1) { 
    337             // in this case, it's worth creating a new variable 
    338             return "( {" ~ mulexpr ~ "} *" ~ e ~ ")"; 
    339         } 
    340         if (mulexpr.length>0) return "(" ~ mulexpr ~ "*" ~ e ~ ")"; 
    341         return e; 
    342     } 
    343     // strip off the parentheses before simplifying 
    344     return exprSimplify(expr[1..$-1], rank, mulexpr, indexexpr); 
    345 } 
    346  
    347 /** 
    348  * Rewrite the expression, taking advantage of distributivity of [] and 
    349  * associativity of *. Additionally, we group all scalars together, whenever 
    350  * possible. 
    351  * 
    352  * This process creates compound scalars and vectors, delineated by " {" and "} ". 
    353  * They will be removed in a subsequent step. 
    354  */ 
    355 char [] exprSimplify(char [] expr, int [] rank, char [] mulexpr, char [] indexexpr) 
    356 {            
    357     // Deal with ++ and --. Only for scalars 
    358     if (expr.length>2 && (expr[0..2]=="++" || expr[0..2]=="--")) { 
    359         return expr[0..2] ~ subexprSimplify(expr[2..$], rank, mulexpr, indexexpr); 
    360     } 
    361     if (expr.length>2 && (expr[$-2..$]=="++" || expr[$-2..$]=="--")) { 
    362         return subexprSimplify(expr[0..$-2], rank, mulexpr, indexexpr) ~ expr[$-2..$]; 
    363     } 
    364     // Deal with unary operators 
    365     if (expr[0]=='-') { 
    366         // Fold unary minus into a multiply, if possible. 
    367         if (mulexpr.length>0) return subexprSimplify(expr[1..$], rank, "-" ~ mulexpr, indexexpr); 
    368         return "-" ~ subexprSimplify(expr[1..$], rank, mulexpr, indexexpr); 
    369     } 
    370     // Just remove unary plus. 
    371     if (expr[0]=='+') return subexprSimplify(expr[1..$], rank, mulexpr, indexexpr); 
    372     
    373     int x = exprLength(expr); 
    374     int y = x+1; 
    375     assert(y < expr.length, expr); 
    376     // Deal with shifts, op=, and NCEG operators 
    377     while (expr[y+1]=='<' || expr[y+1]=='>' || expr[y+1]=='=') ++y;     
    378  
    379     char [] op = expr[x+1..y+1];     
    380     char [] left = expr[0..x+1]; 
    381     char [] right = expr[y+1..$]; 
    382     if (expr[x+1]=='[') right = expr[y+1..$-1]; // drop off the ']'. 
    383     int lrank = subexprRank(left, rank); 
    384     if (op=="[") { 
    385         // accumulate indexing and slicing operations 
    386         return subexprSimplify(left, rank, mulexpr, "[" ~ right ~ "]" ~ indexexpr); 
    387     } 
    388     int rrank = subexprRank(right, rank); 
    389     // Fold scalars together 
    390     if (op=="*") { 
    391         if (lrank==0) { 
    392             char [] m = left; 
    393             if (mulexpr.length>0) m = "(" ~ left ~ "*" ~ mulexpr ~ ")"; 
    394             if (right.length > 1 && hasScalarMultiply(right[1..$-1], rank)) { 
    395                 // opportunity for scalar folding 
    396                 return subexprSimplify(right[1..$-1], rank, m, indexexpr); 
    397             } else { 
    398                 if (m.length>1) m = " {" ~ m ~ "} "; 
    399                 return "(" ~ m ~ "*" ~ subexprSimplify(right, rank, "", indexexpr) ~ ")";                 
    400             } 
    401              
    402         } else if (rrank==0) { 
    403             char [] m = right; 
    404             if (mulexpr.length>0) m = "(" ~ mulexpr ~ "*" ~ right ~ ")"; 
    405             if (left.length> 1 && hasScalarMultiply(left[1..$-1], rank)) { 
    406                 return subexprSimplify(left, rank, m, indexexpr); 
    407             } else {                 
    408                 if (m.length>1) m = " {" ~ m ~ "} "; 
    409                 return "(" ~ m ~ "*" ~ subexprSimplify(left, rank, "", indexexpr) ~ ")";                 
    410             } 
    411         } // else it's matrix * matrix 
    412     } 
    413     if (op=="*=") { 
    414         if (rrank==0) { 
    415             char [] m = right; //subexprSimplify(right, rank, "", ""); 
    416             if (mulexpr.length>0) m ~= "*" ~ mulexpr; 
    417             if (m.length>1)  m= " {" ~ m ~ "} "; 
    418             return "(" ~ subexprSimplify(left, rank, "", indexexpr)~ "*=" ~ m ~ ")"; 
    419         } 
    420     } 
    421     return "(" ~ subexprSimplify(left, rank, mulexpr, indexexpr) ~ op ~ subexprSimplify(right, rank, mulexpr, indexexpr) ~ ")"; 
    422 } 
    423  
    424 struct RevisedExpression { 
    425     char [] expr; // the revised expression using original variable names 
    426     char [][] compounds; // the compound variables 
    427     int [] rank; // rank of all of the compound variables 
    428     char [] used; // which of the original variables are used. 
    429 } 
    430  
    431 // revised expression, with unused symbols collapsed. 
    432 // (so, for example, B+=D*F becomes A+=B*C). 
    433 char [] removedUnusedSymbolsFromExpression(RevisedExpression e) 
    434 { 
    435     char [] m = ""; 
    436     char knt = 'A'; 
    437     for (int i=0; i<e.used.length; ++i) { 
    438         if (e.used[i]!='-') { 
    439             m~=knt; 
    440             ++knt; 
    441         } else m~="@"; 
    442     } 
    443     for (int i=0; i<e.compounds.length;++i) { 
    444         m~=knt; ++knt; 
    445     } 
    446     
    447     char [] f = ""; 
    448     for (int i=0; i<e.expr.length; ++i) { 
    449         char c = e.expr[i]; 
    450         if (c>='A' && c<='Z') { 
    451             f ~= m[c-'A']; 
    452         } else f ~= c; 
    453     } 
    454     return f; 
    455 } 
    456  
    457 int indexRank(char [] s) 
    458 { 
    459    int r=0; 
    460    int numbrack=0; 
    461    for(int i=1; i<s.length; ++i) { 
    462         if (s[i]==']') numbrack--; 
    463         if (s[i]=='[') { 
    464             if (numbrack==0) ++r; 
    465             numbrack++; 
    466         } 
    467         if (numbrack==0 && s[i]=='.' && s[i-1]=='.') { 
    468             // if it's a slice, it does not increase the rank 
    469              r--; 
    470         } 
    471    } 
    472    return r; 
    473 } 
    474  
    475  
    476 RevisedExpression simplifyVectorExpression(char [] expr, int [] rank) 
    477 { 
    478     char [] s = exprSimplify(expr, rank, "", ""); 
    479     if (s.length>1) s = s[1..$-1]; // strip off () 
    480     char [][] comp; 
    481     char [] used=""; 
    482     for (int i=0; i<rank.length; ++i) used~="-"; 
    483     int [] r; 
    484     char next = cast(char)('A' + rank.length); 
    485     char [] e = ""; 
    486     for (int i=0; i<s.length; ++i) { 
    487         if (s[i]==' ') { 
    488             int k; 
    489             for (k=i+1; s[k]!=' '; ++k) {} 
    490             comp ~= s[i+2..k-1]; 
    491             if (s[k-2]==']') {                 
    492                 // it's a vector/matrix of some kind, with rank reduced 
    493                 // by indices. Can't just use exprRank, because the [] 
    494                 // aren't wrapped by (). 
    495                 r ~= rank[s[i+2]-'A'] - indexRank(s[i+2..k-1]); 
    496             } else { 
    497                 // it's a scalar expression. Note that it could involve 
    498                 // a vector expression. 
    499                 r~=0;  
    500             } 
    501             e ~= next; 
    502             ++next; 
    503             i = k; 
    504         } else { 
    505             e ~= s[i]; 
    506             if (s[i]>='A' && s[i]<='Z') used[s[i]-'A']=s[i]; 
    507         } 
    508     } 
    509     return RevisedExpression(e, comp, r, used); 
    510 } 
    511  
    512 unittest { 
    513     
    514     assert(exprSimplify("A+=(((D[E])*B)[E])", [1,0,3,3,0,0],"","")=="(A+=(B* {D[E][E]} ))"); 
    515     assert(exprSimplify("A+=(B*((C[B])[B..E]))", [1,0,3,3,0,0],"","")=="(A+=(B* {C[B][B..E]} ))"); 
    516     assert(exprSimplify("A*=(B*C)", [1,0,0],"","")== "(A*= {(B*C)} )"); 
    517     assert(exprSimplify("A=((B*C)-D)", [1,1,0,1],"","")=="(A=((C*B)-D))"); 
    518  
    519     RevisedExpression e = simplifyVectorExpression("A+=(((D[B])*C)[B])", [2,0,0,4]); 
    520     assert(e.expr == "A+=(C*E)"); 
    521     assert(e.rank[0]==2); 
    522     assert(e.compounds[0]=="D[B][B]"); 
    523     assert(e.used=="A-C-");  
    524     assert(removedUnusedSymbolsFromExpression(e)=="A+=(B*C)"); 
    525 }