Changeset 160

Show
Ignore:
Timestamp:
12/13/07 05:23:00 (9 months ago)
Author:
Don Clugston
Message:

The index folding now supports vector*matrix and matrix*matrix multiplies, converting to dot product if possible.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/blade/BladeSimplify.d

    r159 r160  
    439439 
    440440// Combines all the indexing and slicing operations together (dimension reduction). 
     441// Multiplication of sliced matrices and/or vectors is dimensionally 
     442// reduced where possible (may even be converted into dot product). 
    441443// Returns the new expression. This eliminates all unnecessary slice operations. 
    442444// Furthermore, *any* value followed by '[' should be used as a new compound. 
     
    543545                first = wrapInParens(doVisit(this_, left)); 
    544546                second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right)); 
    545             } else { 
    546                 // vector * matrix, or matrix * matrix 
    547                 assert(lrank<=2 && rrank<=2); 
    548                 // Interesting case -- indices must be distributed between both. 
    549                 assert(0, "Not yet implemented"); 
     547            } else {                
     548                assert(lrank<=2 && rrank<=2, "BLADE ICE: Tensor*tensor is unsupported"); 
     549                bool isDotProduct = false; // was it reduced to a dot product? 
     550                if (lrank==2) { 
     551                    // First dimension applies to rows of the left operand 
     552                    // If it's a slice, it will be a strided slice. 
     553                    char [][2][] newslice=[]; 
     554                    newslice ~= this_.slicing[0]; 
     555                    newslice ~= ["",""]; 
     556                    first = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,newslice), left)); 
     557                } else { 
     558                    assert(this_.slicing.length==1, "BLADE ICE: Rank error"); 
     559                    first = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), left)); 
     560                } 
     561                if (lrank==2 && rrank==2) { 
     562                    // Matrix * matrix 
     563                    if (this_.slicing.length>1) { 
     564                        // Second dimension applies to the right operand. 
     565                        second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,this_.slicing[1..$]), right)); 
     566                        if (this_.slicing[0][1].length==0 && this_.slicing[1][1].length==0) { 
     567                            // It's indices in both cases -- so it's a dot product. 
     568                            isDotProduct = true; 
     569                        } 
     570                    } else { 
     571                        second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right)); 
     572                    } 
     573                } else if (lrank==1 && rrank==2) {                     
     574                    // vector * matrix, Matrix uses all the slicing 
     575                    second = wrapInParens(doVisit(this_, right)); 
     576                    if (this_.slicing[0][1].length==0)  isDotProduct = true; 
     577                } else if (lrank==2 && rrank==1) { 
     578                    // matrix * vector, vector is unsliced. 
     579                    second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right)); 
     580                    if (this_.slicing[0][1].length==0)  isDotProduct = true; 
     581                } 
     582                if (isDotProduct) { 
     583                  return "d(" ~ first ~ "," ~ second ~ ")"; 
     584                } 
    550585            } 
    551586        } else { 
     
    563598 
    564599unittest { 
    565    //BUG: assert(0, foldIndices("((A*B)[C])[D]", "2200")); 
    566600    assert(foldIndices("A+=(((D[B])*C)[B])", "2004")=="A+=((D[B,B])*C)"); 
    567601    assert(foldIndices("((A[C..D])+B)[($-E)]", "21000")=="(A[C+((D-C)-E)])+(B[($-E)])"); 
     
    574608    assert(foldIndices("((A[])[(B[C])..$])[]", "3100")=="A[(B[C])..$]"); 
    575609    assert(foldIndices("A[,B..$,C]", "300")=="A[,B..$,C]"); 
    576 
     610    // Multidimensional slicing 
     611    assert(foldIndices("(C*((A*B)[C]))[D]", "2200")=="C*(d((A[C,]),(B[D])))");     
     612    assert(foldIndices("(A*B)[C..D,D]", "2200")=="(A[C..D,])*(B[D])"); 
     613    assert(foldIndices("(A*B)[C..D]", "2200")=="(A[C..D,])*B"); 
     614    assert(foldIndices("(A*B)[C..D]", "2100")=="(A[C..D,])*B"); 
     615    assert(foldIndices("(A*B)[C..D]", "1200")=="A*(B[C..D])"); 
     616    assert(foldIndices("(A*B)[C]", "120")=="d(A,(B[C]))"); 
     617    assert(foldIndices("((D*A)*B)[C]", "2100")=="d((D*(A[C,])),B)"); 
     618