| 545 | | } else { |
|---|
| 546 | | // vector * matrix, or matrix * matrix |
|---|
| 547 | | assert(lrank<=2 && rrank<=2); |
|---|
| 548 | | // Interesting case -- indices must be distributed between both. |
|---|
| 549 | | assert(0, "Not yet implemented"); |
|---|
| | 547 | } else { |
|---|
| | 548 | assert(lrank<=2 && rrank<=2, "BLADE ICE: Tensor*tensor is unsupported"); |
|---|
| | 549 | bool isDotProduct = false; // was it reduced to a dot product? |
|---|
| | 550 | if (lrank==2) { |
|---|
| | 551 | // First dimension applies to rows of the left operand |
|---|
| | 552 | // If it's a slice, it will be a strided slice. |
|---|
| | 553 | char [][2][] newslice=[]; |
|---|
| | 554 | newslice ~= this_.slicing[0]; |
|---|
| | 555 | newslice ~= ["",""]; |
|---|
| | 556 | first = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,newslice), left)); |
|---|
| | 557 | } else { |
|---|
| | 558 | assert(this_.slicing.length==1, "BLADE ICE: Rank error"); |
|---|
| | 559 | first = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), left)); |
|---|
| | 560 | } |
|---|
| | 561 | if (lrank==2 && rrank==2) { |
|---|
| | 562 | // Matrix * matrix |
|---|
| | 563 | if (this_.slicing.length>1) { |
|---|
| | 564 | // Second dimension applies to the right operand. |
|---|
| | 565 | second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,this_.slicing[1..$]), right)); |
|---|
| | 566 | if (this_.slicing[0][1].length==0 && this_.slicing[1][1].length==0) { |
|---|
| | 567 | // It's indices in both cases -- so it's a dot product. |
|---|
| | 568 | isDotProduct = true; |
|---|
| | 569 | } |
|---|
| | 570 | } else { |
|---|
| | 571 | second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right)); |
|---|
| | 572 | } |
|---|
| | 573 | } else if (lrank==1 && rrank==2) { |
|---|
| | 574 | // vector * matrix, Matrix uses all the slicing |
|---|
| | 575 | second = wrapInParens(doVisit(this_, right)); |
|---|
| | 576 | if (this_.slicing[0][1].length==0) isDotProduct = true; |
|---|
| | 577 | } else if (lrank==2 && rrank==1) { |
|---|
| | 578 | // matrix * vector, vector is unsliced. |
|---|
| | 579 | second = wrapInParens(doVisit(IndexFoldingVisitor(this_.rank, this_.dollar,[]), right)); |
|---|
| | 580 | if (this_.slicing[0][1].length==0) isDotProduct = true; |
|---|
| | 581 | } |
|---|
| | 582 | if (isDotProduct) { |
|---|
| | 583 | return "d(" ~ first ~ "," ~ second ~ ")"; |
|---|
| | 584 | } |
|---|
| 576 | | } |
|---|
| | 610 | // Multidimensional slicing |
|---|
| | 611 | assert(foldIndices("(C*((A*B)[C]))[D]", "2200")=="C*(d((A[C,]),(B[D])))"); |
|---|
| | 612 | assert(foldIndices("(A*B)[C..D,D]", "2200")=="(A[C..D,])*(B[D])"); |
|---|
| | 613 | assert(foldIndices("(A*B)[C..D]", "2200")=="(A[C..D,])*B"); |
|---|
| | 614 | assert(foldIndices("(A*B)[C..D]", "2100")=="(A[C..D,])*B"); |
|---|
| | 615 | assert(foldIndices("(A*B)[C..D]", "1200")=="A*(B[C..D])"); |
|---|
| | 616 | assert(foldIndices("(A*B)[C]", "120")=="d(A,(B[C]))"); |
|---|
| | 617 | assert(foldIndices("((D*A)*B)[C]", "2100")=="d((D*(A[C,])),B)"); |
|---|
| | 618 | } |
|---|