Show
Ignore:
Timestamp:
04/30/08 16:05:32 (5 months ago)
Author:
Don Clugston
Message:

Added prod(). Use .ptr to get raw data, so it works with Bill Baxter's ArrayView?.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/blade/BladeSimplify.d

    r175 r187  
    1414*      be moved to every vector inside A. 
    1515*    - Use associativity of *: A*(B*C[]) == (A*B)*C[] (Not strictly true for 
    16 *      floating point; results may differ by 1ulp,  
     16*      floating point; results may differ by 1ulp, 
    1717*       eg (1.3L*3.1L)*4.7L < 1.3L*(3.1L*4.7L) 
    1818*      Note that floating point addition is not associative at all). 
    1919*    - Remove unary minus where possible, eg A-(-B) => A+B, abs(-A) => abs(A). 
    20 *    - Use associativity of * in intrinsics:  
     20*    - Use associativity of * in intrinsics: 
    2121*         sum(A*V) => A*sum(V), abs(A*B) => abs(A)*abs(B) 
    22 * (D) Expression standardisation  
     22* (D) Expression standardisation 
    2323*    - Move multiplies to left: Convert A[]*B into B*A[] (assumes * is commutative, 
    2424*      not valid for quaternions). 
     
    5454{ 
    5555    return str=="dot" || str=="sum" || str=="max" || str=="min" 
    56            || str=="abs" || str=="sqrt"
     56           || str=="abs" || str=="sqrt" || str=="prod"
    5757} 
    5858 
     
    6868    } 
    6969    // Check for undefined symbols 
    70     if (err.length > 0)  
     70    if (err.length > 0) 
    7171        return RevisedExpression(tree.expression, "", tree.symbolTable, [""], "","", "Undefined symbols:" ~ err); 
    7272    else { 
     
    119119        } else e~=c; 
    120120    } 
    121     return e;     
     121    return e; 
    122122} 
    123123 
     
    171171            } 
    172172            --k; 
    173             char [] newexpr = expr[i+1..k]; // strip off the {}             
     173            char [] newexpr = expr[i+1..k]; // strip off the {} 
    174174            int newi = k; 
    175175            if (i>0 && k<expr.length-1 && expr[i-1]=='(' && expr[k+1]==')') { 
     
    184184                ++next; 
    185185                comp ~= expr[i+1..k]; // strip off the {} 
    186                 if (expr[k-1]==']') {                 
     186                if (expr[k-1]==']') { 
    187187                    // it's a vector/matrix of some kind, with rank reduced 
    188188                    // by indices. Can't just use exprRank, because the [] 
     
    192192                    // it's a scalar expression. Note that it could involve 
    193193                    // a vector expression. 
    194                     r~='0';  
    195                 }                 
     194                    r~='0'; 
     195                } 
    196196            } else e ~= cast(char)('A'+z+rank.length); 
    197197            i = newi; 
     
    202202    } 
    203203    // Create a mapping from old to new variable names 
    204          
     204 
    205205    char [] old_ranks = ""; 
    206206    char [] mapping=""; 
     
    235235} 
    236236 
    237 unittest {     
     237unittest { 
    238238    RevisedExpression e = simplifyVectorExpression("A+=(((D[B])*C)[B])", "2004",[]); 
    239239    assert(e.rank=="202"); 
     
    281281           assert(sym!="$" && this_.rank[sym[0]-'A']>'0', "Rank error " ~ sym); 
    282282           // Note: Later, we'll want this to be a new terminal. 
    283            return sym ~ createMultiSlice(this_.slicing);            
     283           return sym ~ createMultiSlice(this_.slicing); 
    284284       } 
    285285    } 
     
    301301        return wrapInParens(doVisit(this_, expr)) ~ op; 
    302302    } 
    303     // Includes multi-dimensional slicing and indexing.     
     303    // Includes multi-dimensional slicing and indexing. 
    304304    ReturnType onVisitIndex(This this_, char [] base, char [][2][] slices) { 
    305305        if (slices.length==0) { // []  -- has no effect. 
     
    311311            // with the earliest existing dimension. 
    312312            // * If the existing dimension is an index, 
    313             //   it might contain a dollar, which we need to replace.  
     313            //   it might contain a dollar, which we need to replace. 
    314314            // * If the existing dimension is a slice, the two slices will combine. 
    315315            // 
     
    331331                newslice ~= [a ~ "+" ~ c, ""]; 
    332332            } 
    333             if (slices.length>1) {                 
     333            if (slices.length>1) { 
    334334                // append other slices, if any. 
    335335                return doVisit(IndexFoldingVisitor(this_.rank, "$", slices[0..$-1] ~ newslice ~ this_.slicing[1..$]), base); 
     
    360360                assert(lrank>0 && rrank>0 && lrank<=2 && rrank<=2, "BLADE ICE: Tensor*tensor is unsupported"); 
    361361                bool isDotProduct = false; // was it reduced to a dot product? 
    362                  
     362 
    363363                // In the case of chained matrix multiplies, we can end up with an empty slice. 
    364364                if (this_.slicing.length>0 && this_.slicing[$-1][0]=="") { 
     
    368368                    // First dimension applies to rows of the left operand 
    369369                    // If it's a slice, it will be a strided slice -- unless 
    370                     // it comes from another matrix multiply, in which case the                     
     370                    // it comes from another matrix multiply, in which case the 
    371371                    // stride will drop out. (A[x]*B is strided). 
    372372                    char [][2][] newslice=[]; 
     
    390390                        second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right)); 
    391391                    } 
    392                 } else if (lrank==1 && rrank==2) {                     
     392                } else if (lrank==1 && rrank==2) { 
    393393                    // vector * matrix, Matrix uses all the slicing 
    394394                    second = wrapInParens(doVisit(this_, right)); 
     
    403403                } 
    404404            } 
    405         } else { 
    406             // in DMD1.024, nasty compiler bug if you save the first & second results into a local variable 
     405        } else { // not a multiplication 
    407406            return wrapInParens(doVisit(this_, left)) ~ op ~ wrapInParens(doVisit(this_, right)); 
    408407        } 
     
    416415} 
    417416 
    418 unittest {    
     417unittest { 
    419418    assert(foldIndices("((A[C..D])+B)[($-E)]", "21000")=="(A[C+((D-C)-E)])+(B[($-E)])"); 
    420419    assert(foldIndices("(A[C])[D]", "3100")=="A[C,D]"); 
     
    427426    assert(foldIndices("A[,B..$,C]", "300")=="A[,B..$,C]"); 
    428427    // Multidimensional slicing 
    429     assert(foldIndices("(C*((A*B)[C]))[D]", "2200")=="C*dot((A[C,]),(B[D]))");     
     428    assert(foldIndices("(C*((A*B)[C]))[D]", "2200")=="C*dot((A[C,]),(B[D]))"); 
    430429    assert(foldIndices("(A*B)[C..D,D]", "2200")=="(A[C..D,])*(B[D])"); 
    431430    assert(foldIndices("(A*B)[C..D]", "2200")=="(A[C..D,])*B"); 
     
    433432    assert(foldIndices("(A*B)[C..D]", "1200")=="A*(B[C..D])"); 
    434433    assert(foldIndices("(A*B)[C]", "120")=="dot(A,(B[C]))"); 
    435      
     434 
    436435    assert(foldIndices("((A*B)*C)[D]", "2220")=="((A[D,])*B)*C"); 
    437436    assert(foldIndices("((A+B)*C)[D]", "2220")=="((A[D,])+(B[D,]))*C"); 
    438437    assert(foldIndices("((D*A)*B)[C]", "2100")=="dot((D*(A[C,])),B)"); 
    439     assert(foldIndices("(((A*B)*C)[D..E])[D]", "12200")=="dot((A*B),(C[D+D]))");  
     438    assert(foldIndices("(((A*B)*C)[D..E])[D]", "12200")=="dot((A*B),(C[D+D]))"); 
    440439    assert(foldIndices("A+=(((D[B])*C)[B])", "2004")=="A+=((D[B,B])*C)"); 
    441440    assert(foldIndices("dot(A,(A*dot(A,A)))","1")=="dot(A,(A*dot(A,A)))"); 
     
    466465            ScalarFold right = doVisit(this_, args[1]); 
    467466            return ScalarFold("", combineMul(combineMul(left.multiplier, right.multiplier), "{" ~ func ~ "(" ~ wrapInParens(left.expr) ~ "," ~ wrapInParens(right.expr) ~ ")}")); 
    468         case "sum":  
     467        case "sum": 
     468        case "prod": 
    469469            //  sum(A*V) = A*sum(V) is a scalar. 
     470            //  prod(A*V) = A*prod(V) is a scalar. 
    470471            return ScalarFold("", combineMul(left.multiplier, "{" ~ func ~ "(" ~ wrapInParens(left.expr) ~ ")}")); 
    471472        case "abs": 
     
    483484        case "max": 
    484485        case "min": // max(A*B) can't be simplified unless we know that they are not negative. 
    485             return ScalarFold("", "{" ~ func ~ "(" ~ combineMulWithCompound(left.expr, left.multiplier) ~ ")}");  
    486 //            return ScalarFold("", "{" ~ func ~ "(@>"  ~ left.expr ~ "@" ~ left.multiplier ~ "<@)}");  
     486            return ScalarFold("", "{" ~ func ~ "(" ~ combineMulWithCompound(left.expr, left.multiplier) ~ ")}"); 
     487//            return ScalarFold("", "{" ~ func ~ "(@>"  ~ left.expr ~ "@" ~ left.multiplier ~ "<@)}"); 
    487488        default: 
    488489            assert(0, "BLADE: Unsupported function"); 
    489490            return ScalarFold("",""); 
    490491        } 
    491     }     
     492    } 
    492493    ReturnType onVisitPrefix(This this_, char [] op, char [] expr) { 
    493494        if (op=="-") { 
     
    498499            else return ScalarFold(left.expr, "-"); 
    499500        } else if (op=="+") { // just ignore unary plus 
    500             return doVisit(this_, expr);         
     501            return doVisit(this_, expr); 
    501502        } else { 
    502503            ScalarFold f = doVisit(this_, expr); 
     
    530531            assert(first.multiplier=="" && second.expr=="", "BLADE ICE"); 
    531532            assert(second.multiplier!="-", "BLADE ICE"); // this would be a*=-b, where b is a vector 
    532             if (second.multiplier.length>1)  return ScalarFold(wrapInParens(first.expr) ~ op ~ "{" ~ wrapInParens(second.multiplier) ~ "}","");  
    533             else return ScalarFold(wrapInParens(first.expr) ~ op ~ wrapInParens(second.multiplier),"");  
     533            if (second.multiplier.length>1)  return ScalarFold(wrapInParens(first.expr) ~ op ~ "{" ~ wrapInParens(second.multiplier) ~ "}",""); 
     534            else return ScalarFold(wrapInParens(first.expr) ~ op ~ wrapInParens(second.multiplier),""); 
    534535        } 
    535536        if (op=="*") { 
     
    588589    assert(left.length>0); 
    589590    if (right.length==0) return left; 
    590     if (right=="-") return "-" ~ wrapInParens(left);     
     591    if (right=="-") return "-" ~ wrapInParens(left); 
    591592    if (right.length==1) return wrapInParens(left) ~ "*" ~ right; 
    592     return wrapInParens(left) ~ "*{" ~ wrapInParens(right) ~ "}";     
     593    return wrapInParens(left) ~ "*{" ~ wrapInParens(right) ~ "}"; 
    593594} 
    594595