Changeset 112

Show
Ignore:
Timestamp:
09/11/07 15:19:45 (1 year ago)
Author:
Don Clugston
Message:

Split codegeneration off completely.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/blade/Blade.d

    r111 r112  
    5050 
    5151module Blade; 
     52 
    5253public import SyntaxTree : mixin_SymbolTable, AbstractSyntaxTree, syntaxtreeof, AST, Symbol; 
    5354private import BladeUtil; 
    5455private import BladeRank; 
    55  
    56 private: 
    57 // ------------------------------------------------ 
    58 //   Convert infix string to postfix 
    59 // ------------------------------------------------ 
    60  
    61  
    62 /// Return the length of a sub-expression 
    63 int exprLength(char [] s) 
    64 
    65     if (s[0]>='A' && s[0]<='Z') 
    66         return 0; 
    67     int numParens = 0; 
    68     for (int i=0; i<s.length; ++i) { 
    69         if (s[i]=='(') { 
    70             numParens++; 
    71         } 
    72         if (s[i]==')') { 
    73             numParens--; 
    74         } 
    75         if (numParens == 0) { 
    76             return i; 
    77         } 
    78     } 
    79 
    80  
    81 /** Returns the (tensor) rank of the expression expr. 
    82  * 
    83  * Params: 
    84  * expr   Placeholder expression (A,B,... correspond to tuple[0],[1],...) 
    85  * rank   The rank of each tuple member A, B, C, ... 
    86  */ 
    87 int exprRank(char [] expr, int [] rank) 
    88 
    89     int x = exprLength(expr); 
    90      
    91     char [] op = expr[x+1..x+2];     
    92     char [] left = expr[0..x+1]; 
    93     char [] right = expr[x+2..$]; 
    94     int lrank = (left.length==1)?  rank[left[0]-'A'] : exprRank(left[1..$-1], rank); 
    95     int rrank = (right.length==1)?  rank[right[0]-'A'] : exprRank(right[1..$-1], rank); 
    96     if (op=="+" || op=="-" || op=="=") { 
    97         assert(lrank==rrank, "Rank error in expression"); 
    98         return lrank; 
    99     } 
    100     if (lrank==0) return rrank; 
    101     if (rrank==0) return lrank; 
    102     assert(0, "Unsupported operation"); 
    103     return 0; 
    104 
    105  
    106 unittest { 
    107     assert(exprRank("A+(B*C)", [1,1,0])==1); 
    108     assert(exprRank("A+(B*C)", [0,0,0])==0); 
    109     assert(exprRank("A+(B*C)", [2,0,2])==2); 
    110 
    111  
    112 /** Returns the resultant element type of the tensor expression expr. 
    113  *  
    114  * Note that since D doesn't have array operations, the expression is not 
    115  * normally valid D code. 
    116  * 
    117  * Params: 
    118  *  expr   Placeholder expression (A,B,... correspond to tuple[0],[1],...) 
    119  * T   Every type in the expression 
    120  */ 
    121 template exprElementType(char [] expr, T...) 
    122 
    123     const int x = exprLength(expr); 
    124      
    125     const char [] op = expr[x+1..x+2];     
    126     const char [] left = expr[0..x+1]; 
    127     const char [] right = expr[x+2..$]; 
    128     static if (left.length==1) 
    129         alias ElementType!(T[left[0]-'A']) LeftElemType; 
    130     else alias typeof(exprElementType!(left[1..$-1], T).ElemType) LeftElemType; 
    131     static if (right.length==1) 
    132         alias ElementType!(T[right[0]-'A']) RightElemType; 
    133     else alias exprElementType!(right[1..$-1], T).ElemType RightElemType; 
    134     static if (op=="+" || op=="-" || op=="=") { 
    135         alias typeof(LeftElemType + RightElemType) ElemType; 
    136     } else { // multiply 
    137         alias typeof(LeftElemType * RightElemType) ElemType; 
    138     } 
    139 
    140  
    141 unittest { 
    142 static assert(is(exprElementType!("A+(B*C)", float[], double[], double).ElemType == double)); 
    143 
    144  
    145 /// Converts an infix string into postfix. 
    146 /// Apply x87-specific optimisations during the conversion. 
    147 char [] makePostfixForX87(char [] operations, char [][] typelist, int[] ranklist) 
    148 
    149     if (operations.length==1) return operations; 
    150     int x = exprLength(operations); 
    151      
    152     char [] op = operations[x+1..x+2];     
    153     char [] first = operations[0..x+1]; 
    154     char [] second = operations[x+2..$]; 
    155     if (operations[x+2]=='=') { // +=, -=, *=, /= 
    156         // Convert "A+=B" into "A=A+B" 
    157         second = makePostfixForX87(operations[0..x+2] ~ operations[x+3..$], typelist, ranklist); 
    158         return second ~ first ~ "="; 
    159     } 
    160     char [] oprvs = op; 
    161     if (op=="-") oprvs="_";  // We use _ to mean reversed subtraction. 
    162      
    163     if (first[0]=='(') { 
    164         first = makePostfixForX87(first[1..first.length-1], typelist, ranklist); 
    165     }else assert(first.length<2, "Missing () in expression: " ~ first); 
    166     if (second[0]=='(') { 
    167         second = makePostfixForX87(second[1..second.length-1], typelist, ranklist); 
    168     }else assert(second.length<2, "Missing () in expression: " ~ second); 
    169     if (op=="=") { 
    170         return second ~ first ~ "="; 
    171     } 
    172  
    173     // x87 OPTIMISATION #1 
    174     // On x87, fmul has a long latency, so we want to delay using the 
    175     // result of a multiply. Since + is commutative, we can achieve this 
    176     // by calculating the value with the multiply, before the other one. 
    177     // We can also do the same thing with -, but we'll need to use fsubr 
    178     // instead of fsub. We use _ to mean reversed subtraction. 
    179     if (op=="+" || op=="-") { 
    180         if (second[second.length-1]=='*'&& first[first.length-1]!='*') { 
    181            return second ~ first ~ oprvs; 
    182         } 
    183         // x87 OPTIMISATION #2 
    184         // When an operation is performed between a real[] and a non-real[], 
    185         // we want to have the real[] being the one which is loaded first. 
    186         if (second.length==1 && typelist[second[0]-'A']=="real" && ranklist[second[0]-'A']==1) { 
    187                return second ~ first ~ oprvs; 
    188         } 
    189     } 
    190     return first ~ second ~ op; 
    191 
    192  
    193  
    194 unittest { 
    195 assert(makePostfixForX87("A=B", elementTupleToString!(double, double),[1,1])=="BA="); 
    196 assert(makePostfixForX87("(B*C)+A", elementTupleToString!(double, float, float),[1,1,1])=="BC*A+"); 
    197 assert(makePostfixForX87("(B*C)+A", elementTupleToString!(real, float, float),[1,1,1])=="ABC*+"); 
    198 assert(makePostfixForX87("A-(B*C)", elementTupleToString!(double, float, float),[1,0,0])=="BC*A_"); 
    199 assert(makePostfixForX87("(B*C)-A", elementTupleToString!(float, float, float),[1,0,0])=="BC*A-"); 
    200 assert(makePostfixForX87("(B*C)-A", elementTupleToString!(real, float, float),[1,0,0])=="ABC*_"); 
    201 assert(makePostfixForX87("C+=((B*C)-A)", elementTupleToString!(real, float, float),[1,0,1])=="CABC*_+C="); 
    202 assert(makePostfixForX87("C-=((B*C)-A)", elementTupleToString!(real, float, float),[1,0,1])=="CABC*_-C="); 
    203 assert(makePostfixForX87("C-=(B*A)", elementTupleToString!(real, float, float),[1,0,1]) =="BA*C_C="); 
    204 assert(makePostfixForX87("C-=(B*A)", elementTupleToString!(real, float, real),[1,0,1]) =="BA*C_C="); 
    205 assert(makePostfixForX87("((A*B)+(C*D))+(E*F)", elementTupleToString!(int, int, int),[0,0,0])=="EF*AB*CD*++"); 
    206  
    207 
    208  
    209  
    210 /// Converts an infix string into postfix. 
    211 /// Apply SSE/SSE2-specific optimisations during the conversion. 
    212 char [] makePostfixForSSE(char [] operations, int[] ranklist) 
    213 
    214     if (operations.length==1) return operations; 
    215     int x = exprLength(operations); 
    216      
    217     char [] op = operations[x+1..x+2];     
    218     char [] first = operations[0..x+1]; 
    219     char [] second = operations[x+2..$]; 
    220     if (operations[x+2]=='=') { // +=, -=, *=, /= 
    221         // Convert "A+=B" into "A=A+B" 
    222         second = makePostfixForSSE(operations[0..x+2] ~ operations[x+3..$], ranklist); 
    223         return second ~ first ~ "="; 
    224     } 
    225      
    226     if (first[0]=='(') { 
    227         first = makePostfixForSSE(first[1..first.length-1], ranklist); 
    228     }else assert(first.length<2, "Missing () in expression: " ~ first); 
    229     if (second[0]=='(') { 
    230         second = makePostfixForSSE(second[1..second.length-1], ranklist); 
    231     }else assert(second.length<2, "Missing () in expression: " ~ second); 
    232     if (op=="=") { 
    233         return second ~ first ~ "="; 
    234     } 
    235  
    236     // On x87, fp multiplies have a long latency, so we want to delay using the 
    237     // result of a multiply. Since + is commutative, we can achieve this 
    238     // by calculating the value with the multiply, before the other one. 
    239     if (op=="+") { 
    240         if (second[second.length-1]=='*'&& first[first.length-1]!='*') { 
    241            return second ~ first ~ op; 
    242         } 
    243     } 
    244     if (op=="*") { 
    245         // SSE OPTIMISATION #2 
    246         // When an operation is performed between a vector and a scalar 
    247         // we want to have the vector being the one which is loaded first. 
    248         if (first.length==1 && ranklist[first[0]-'A']==0) { 
    249                return second ~ first ~ op; 
    250         } 
    251     } 
    252  
    253     return first ~ second ~ op; 
    254 
    255  
    256 unittest { 
    257 assert(makePostfixForSSE("A=B", TupleRank!(double[], double[]))=="BA="); 
    258 assert(makePostfixForSSE("(A*B)+C", TupleRank!(double[], double, double[]))=="AB*C+"); 
    259 assert(makePostfixForSSE("A=(B*C)", TupleRank!(double[], double[], double))=="BC*A="); 
    260 
    261  
    262 // ------------------------------- 
    263 //   Mixins to generate x87 ASM code 
    264 // ------------------------------- 
    265  
    266 /// True if the character is an operation (everything else is an operand) 
    267 bool isInstruction(char op) 
    268 
    269     return (op=='+' || op=='*' || op=='-'|| op=='_' || op=='='); 
    270 
    271  
    272 /// Count the number of temporaries which occur in the postfix expression. 
    273 int countTemporaries(char [] postfix) 
    274 
    275 // A temporary occurs whenever we load two values without an operation performed on the 
    276 // first one. 
    277     int numTemps=0; 
    278     for (int i=1; i<postfix.length; ++i) { 
    279         if (!isInstruction(postfix[i-1]) && !isInstruction(postfix[i])) numTemps++; 
    280     } 
    281     return numTemps; 
    282 
    283  
    284  
    285 /// The maximum number of simultaneous temporary values in the postfix expression. 
    286 int maxActiveTemporaries(char [] postfix) 
    287 
    288     int maxTemps=0; 
    289     int numTemps=0; 
    290     for (int i=1; i<postfix.length; ++i) { 
    291         if (!isInstruction(postfix[i-1]) && !isInstruction(postfix[i])) numTemps++; 
    292         if (isInstruction(postfix[i-1]) && isInstruction(postfix[i])) numTemps--; 
    293         if (maxTemps<numTemps) maxTemps=numTemps; 
    294     } 
    295     return maxTemps; 
    296  
    297 
    298  
    299 unittest { 
    300     assert(countTemporaries("AB*BC*+DE*+")==3); 
    301     assert(maxActiveTemporaries("AB*BC*+DE*+")==2); 
    302 
    303  
    304 char [] operandSize(char [] typestr) 
    305 
    306     switch(typestr) { 
    307         case "real":   return "real ptr "; 
    308         case "double": return "double ptr "; 
    309         case "float":  return "float ptr "; 
    310         default: 
    311         assert(0, typestr); 
    312     } 
    313 
    314  
    315 char [][char] opToX87() { 
    316     return ['*':"fmul"[], '+': "fadd", '-': "fsub", '_': "fsubr"]; } 
    317  
    318 char [][char] opToSSE2() { 
    319     return ['*':"mulpd"[], '+': "addpd", '-': "subpd", '/': "divpd"]; } 
    320  
    321 char [][char] opToSSE() { 
    322     return ['*':"mulps"[], '+': "addps", '-': "subps", '/': "divps"]; } 
    323  
    324  
    325 static if (real.sizeof==10)      const char [] REALSIZE = "10"; 
    326 else static if (real.sizeof==12) const char [] REALSIZE = "12"; 
    327 else static if (real.sizeof==16) const char [] REALSIZE = "16"; 
    328  
    329 char [] vectorSize(char [] typestr) 
    330 
    331     switch (typestr) { 
    332         case "double": return "8"; 
    333         case "float": return "4"; 
    334         case "real": return REALSIZE; 
    335     } 
    336 
    337  
    338  
    339 // First, use the scratch registers (EAX, ECX, EDX). EAX is always used as 
    340 // an index register. If there are more than 2 vectors, use EBX, ESI, and EDI, 
    341 // which need to be pushed and popped. 
    342 // TODO: Finally, use the frame register EBP. 
    343 const char [][5] vectorRegister = ["ECX", "EDX", "EBX", "ESI", "EDI"]; 
    344  
    345  
    346 // Is this expression simple enough for the x87 code generator? 
    347 bool isX87AsmPossible(char [][] typelist, int [] ranklist, char [] operations) { 
    348   version (D_InlineAsm_X86) { 
    349         // Are there enough index registers? 
    350         if (countVectors(ranklist) > vectorRegister.length) return false; 
    351         // Does it contain any types we can't deal with? 
    352         foreach(r; ranklist) { 
    353             if (r>1) return false; 
    354         } 
    355         foreach(ch; typelist) { 
    356             // can only do float, double, and 80-bit vectors, and scalars. 
    357             if (ch!="real" && ch!="double" && ch!="float") return false; 
    358         } 
    359         // BUG: should also check if it will overflow the FPU stack 
    360         return true; 
    361   } else { 
    362       // Without an assembler, there's no chance! 
    363       return false; 
    364   } 
    365 
    366  
    367 // Is this expression simple enough for the SSE2 code generator? 
    368 bool isSSE2AsmPossible(char [][] typelist, char [] operations) 
    369 
    370   version (D_InlineAsm_X86) { 
    371         // Does it contain any types we can't deal with? 
    372         foreach(ch; typelist) { 
    373             // can only do double vectors and double scalars. 
    374             if (ch!="double[]" && ch!="double") return false; 
    375         } 
    376         return true; 
    377   } else { 
    378       // Without an assembler, there's no chance! 
    379       return false; 
    380   } 
    381 
    382  
    383 // Create code to push all used vector registors. 
    384 char [] pushRegisters(int numVectors) 
    385 
    386     char [] result = ""; 
    387     for (int i=2; i<numVectors; ++i) result~= " push " ~ vectorRegister[i] ~ ";"; 
    388     return result ~ "\n"; 
    389 
    390  
    391 // Create code to pop all used vector registors. 
    392 char [] popRegisters(int numVectors) 
    393 
    394     char [] result = ";  "; 
    395     for (int i=numVectors-1; i>=2; --i) result~= "pop " ~ vectorRegister[i] ~ "; "; 
    396     return result ~ \n; 
    397 
    398  
    399 // indexed by i. 
    400 char [] indexedVector(char [][] typelist, int [] ranklist, char var) 
    401 
    402     if (typelist[var-'A']=="real") return " real ptr [" ~ vectorRegister[vectorNum(ranklist, var)] ~ "]"; 
    403     return operandSize(typelist[var-'A']) ~ "[" ~ 
    404             vectorRegister[vectorNum(ranklist, var)] ~ " + " ~ vectorSize(typelist[var-'A']) ~ "*EAX]"; 
    405 
    406  
    407 // indexed by i-1 
    408 char [] indexedVectorPrev(char [][] typelist, int [] ranklist, char var) 
    409 
    410     char [] stride = " - " ~ vectorSize(typelist[var-'A']); 
    411     if (typelist[var-'A'] == "real") return " real ptr [" ~ vectorRegister[vectorNum(ranklist, var)] ~ stride ~ "]"; 
    412     return operandSize(typelist[var-'A']) ~ "[" ~ 
    413             vectorRegister[vectorNum(ranklist, var)] ~ " + " ~ vectorSize(typelist[var-'A']) ~ "*EAX" ~ stride ~ "]"; 
    414 
    415  
    416 char [] indexedSSEVector(int [] ranklist, char var) 
    417 
    418     return "[" ~ vectorRegister[vectorNum(ranklist, var)] ~ " + 8* EAX]"; 
    419 
    420  
    421 char [] indexedVectorWithStride(char [][] typelist, int [] ranklist, char var, int stride) 
    422 
    423     char [] stridestr = " - " ~ vectorSize(typelist[var-'A']) ~ "*" ~ itoa(stride); 
    424     if (typelist[var-'A'] == "real") return " real ptr [" ~ vectorRegister[vectorNum(ranklist, var)] ~ stridestr ~ "]"; 
    425     return operandSize(typelist[var-'A']) ~ "[" ~ 
    426             vectorRegister[vectorNum(ranklist, var)] ~ " + " ~ vectorSize(typelist[var-'A']) ~ "*EAX" ~ stridestr ~ "]"; 
    427 
    428  
    429 // Pop N values from the FPU stack 
    430 char [] discardFromStack(int n) 
    431 
    432     char [] result=""; 
    433     while (n>1) { 
    434         result~= "  fcompp ST(0), ST;"\n; // pop two values at once 
    435         n-=2; 
    436     } 
    437     if (n==1) result~= "  fstp ST(0), ST;"\n; 
    438     return result; 
    439 
    440  
    441  
    442 /** Generate asm code which is optimal for x87 CPUs without SSE2. 
    443  (Pentium, PMMX, PII, PIII). It is also optimal for recent x86 CPUs 
    444  where vector sizes are mixed. 
    445 The key optimisation rules are: 
    446  1. keep the loop overhead to one clock cycle if possible. 
    447  2. (FMUL latency) don't use the result of a multiply immediately 
    448  3. (FST latency) don't save a value to memory immediately after it's calculated. 
    449  4. (AGI stall) don't use the counter variable immediately after it's modified. 
    450 Techniques to address these are: 
    451  1. Use EAX as a counter and index variable, which begins negative and counts UP to zero. 
    452  2. The latency of fmul is avoided by swapping fadd/fsub with fmul whenever possible. 
    453  3. The latency of fstp is avoided by calculating a result in one iteration, 
    454      but not storing it to memory until the subsequent iteration. 
    455  4. (NOT YET IMPLEMENTED): first operation in the loop should be loading a scalar (for a multiply), 
    456     if possible, otherwise load an 80-bit vector, if possible. 
    457  
    458 The generated code is of the form: 
    459 ---- 
    460  load scalars onto FPU stack 
    461  load vector pointers into EAX, EBX, ... 
    462  calculate result[0] into ST(0) 
    463  goto L2 
    464 L1: 
    465  calculate result[i+1] into ST(0) 
    466  swap so that result[i] is in ST(0) 
    467 L2: 
    468  store result[i] 
    469  increment pointers, goto L1 if i<n-1 
    470  store result[n-1] 
    471  pop scalars off FPU stack 
    472 ---- 
    473  
    474 */ 
    475 char [] generateCodeForAsmX87(char [][] typelist, int [] ranklist, char [] infixOperations, char cumulatingOp=0) 
    476 
    477     char [] operations = makePostfixForX87(infixOperations, typelist, ranklist); 
    478     char [] result=""; 
    479     char [] incrementRealVectors=""; 
    480      
    481     result ~= "// Operation : " ~  operations ~ \n; 
    482  
    483     // Create local variables for pointers to vectors (avoid bug #1125) 
    484     int vecnum = 0; 
    485     for (int i=0; i< ranklist.length;++i) { 
    486         if (ranklist[i]==1){ 
    487             result~= "  auto vec" ~ itoa(i) ~ " = values[" ~itoa(i) ~"].ptr; // " ~ cast(char)('A'+i)~ \n; 
    488             if (typelist[i]=="real") { 
    489                 incrementRealVectors ~= "  add " ~ vectorRegister[vecnum] ~ ", " ~ REALSIZE ~ ";\n"; 
    490             } 
    491             ++vecnum; 
    492         } else result~= " alias values["~itoa(i)~"] val" ~ itoa(i) ~ "; // " ~ cast(char)('A'+i)~ \n; 
    493     } 
    494  
    495     result ~= "  int veclength = values[" ~itoa(findFirstVector(ranklist)) ~"].length;\n"; 
    496   
    497     int numScalarsOnStack=0; 
    498  
    499     result~= \n"asm {"\n ~ pushRegisters(vecnum); 
    500     // EAX will be the counter 
    501     result ~= "  mov EAX, veclength;"\n; 
    502  
    503     // Load all the vector pointers into registers, and push all the scalars onto the stack 
    504  
    505     int numvecs=0; 
    506     int numconsts=0; 
    507     for (int i=0; i<ranklist.length; ++i) { 
    508       if (ranklist[i]==1) { 
    509           if (typelist[i]=="real") { 
    510               result ~= "  mov " ~ vectorRegister[numvecs] ~ ", vec" ~ itoa(i) ~ ";"; 
    511           } else  { 
    512             result ~= "  lea " ~ vectorRegister[numvecs] 
    513               ~ ", [" ~ vectorSize(typelist[i]) ~ "*EAX];   " 
    514               ~ "  add " ~ vectorRegister[numvecs] ~ ", vec" ~ itoa(i) ~ ";"; 
    515          } 
    516          result ~= "  //" ~ cast(char)('A'+i) ~ \n;  
    517         ++numvecs; 
    518       } else if (typelist[i]=="real") { 
    519           result ~= "  fld real ptr values["~ itoa(i) ~"];"; 
    520           ++numconsts; 
    521           ++numScalarsOnStack; 
    522          result ~= "  //" ~ cast(char)('A'+i) ~ \n;  
    523       } 
    524     } 
    525     if (cumulatingOp=='+') { 
    526         result ~= "  fldz;"\n; // dot product 
    527     } else if (cumulatingOp=='*') { // trace 
    528         result ~= "fld1;"\n; 
    529     } 
    530     result ~= "  xor EAX, EAX; "\n 
    531         "  sub EAX, veclength; // counter=-length"\n 
    532         "  jz short L3; // test for length==0"\n; 
    533     int done=0; 
    534  
    535     // Construct the main body of the loop (the main body does not include 
    536     // the final storage instruction, because of the FST latency). 
    537     char [] mainbody = ""; 
    538     char [] firstbody = ""; 
    539     char [] storage = ""; 
    540  
    541     // We need to keep track of how many things are on the FPU stack. 
    542     // Every time something is pushed, the indices of our variables change! 
    543     int numOnStack = 0; // How much of the FP stack is being used? 
    544      
    545     if (operations.length>2 && operations[$-1]=='=') { 
    546         storage ~= "  fstp " ~ indexedVectorPrev(typelist, ranklist, operations[$-2] ) ~ ";  // " ~ operations[done..done+2] ~ \n; 
    547         operations=operations[0..$-2]; 
    548     } 
    549  
    550     while(done<operations.length) { 
    551         char [] next; 
    552       if (isInstruction(operations[done])) { 
    553             // Perform an arithmetic operation on the top two FPU stack items. 
    554             next = "  " ~ opToX87[operations[done]] ~ "p ST(1), ST;  //" ~ operations[done] ~ \n; 
    555             mainbody ~= next; firstbody ~= next; 
    556             ++done; 
    557             numOnStack--; 
    558       } else if (!isInstruction(operations[done+1])){ 
    559             // load a vector onto the FPU stack, to begin a new subexpression. 
    560             int u  = operations[done]-'A'; 
    561             next = "  fld "  ~ indexedVector(typelist, ranklist, operations[done] ) ~ ";  //" ~ operations[done] ~\n; 
    562             mainbody ~= next; firstbody ~= next; 
    563             ++done; 
    564             numOnStack++; 
    565       } else if (ranklist[operations[done]-'A']==1) { 
    566              // An operation will be performed between the stack top and a vector. 
    567              // If it's a float or double, we can combine the load+arithmetic op 
    568              // into a single instruction. 
    569              if (typelist[operations[done]-'A']=="real") { 
    570                  // 80-bit vectors must be loaded onto the FPU stack first 
    571                 next = "  fld real ptr ["  ~ vectorRegister[vectorNum(ranklist, operations[done])] ~ "]; //" ~ operations[done] ~ \n 
    572                     ~ "  " ~ opToX87[operations[done+1]] ~ "p ST(1), ST; //" ~ operations[done+1] ~\n; 
    573              } else { // floats and doubles can be used directly 
    574                 next = "  " ~ opToX87[operations[done+1]] ~ " " 
    575                   ~ indexedVector(typelist, ranklist, operations[done] ) ~ "; //" ~ operations[done..done+2] ~ \n; 
    576             } 
    577             mainbody ~= next; firstbody ~= next; 
    578             done+=2; 
    579       } else { // multiply by scalar. 
    580         if (typelist[operations[done]-'A']=="real") { 
    581              // Multiply by real scalar, which is already on the stack. Note that there's an extra item on the stack when we're in the body of the loop. 
    582             firstbody ~= "  fmul ST, ST(" ~ itoa(numOnStack + numScalarsOnStack - realScalarNum(typelist, ranklist, operations[done]-'A')-1) ~ "); // * " ~ operations[done] ~ \n; 
    583             mainbody ~= "  fmul ST, ST(" ~ itoa(1 + numOnStack + numScalarsOnStack - realScalarNum(typelist, ranklist, operations[done]-'A')-1) ~ "); // * " ~ operations[done] ~ \n; 
    584         } else { 
    585             // For scalar float or double values, we can multiply directly, saving one slot on the FP stack. 
    586             next = "  fmul " ~ operandSize(typelist[operations[done]-'A']) ~ "val" ~ itoa(operations[done]-'A') ~";\n"; 
    587             mainbody ~= next; firstbody ~= next; 
    588         } 
    589         done +=2; 
    590       }       
    591     } 
    592          
    593     result ~= \n  
    594         ~  firstbody 
    595         ~ "  jmp short L2;\n" 
    596         ~ "  align 4;\n"  
    597         ~ "L1:\n" ~ mainbody 
    598         ~ "  fxch ST(1), ST;\n"; // get previous result 
    599          
    600     if (cumulatingOp) result ~= "  " ~ opToX87[cumulatingOp] ~ "p ST(2), ST;"\n; 
    601     else result ~= storage; 
    602  
    603     result ~= "L2: \n"             
    604            ~  incrementRealVectors // Update the counters 
    605            ~ "  inc EAX;\n  jnz L1;\n"; 
    606  
    607     // Store the result from the final iteration 
    608     if (cumulatingOp) result ~= "  " ~ opToX87[cumulatingOp] ~ "p ST(1), ST;"\n; 
    609     else result ~= storage; 
    610  
    611     // Discard any scalars that are left on the stack 
    612     if (cumulatingOp!=0 && numScalarsOnStack>0) { 
    613         // Preserve the result of the dot product 
    614         result ~= "  fxch ST(" ~ itoa(numScalarsOnStack) ~ "), ST;"\n; 
    615     } 
    616     result ~= discardFromStack(numScalarsOnStack); 
    617  
    618     result~= "L3:" \n ~ popRegisters(vecnum) ~ "}\r\n"; 
    619     
    620     return result; 
    621 
    622  
    623 //----------------------------- 
    624  
    625 char [] XMM(int k) { return "XMM"~ itoa(k); } 
    626  
    627  
    628 // We don't need types for SSE2, everything is a double. 
    629  
    630 char [] generateCodeForSSE2(int [] ranklist, char [] infixOperations, char cumulatingOp=0) 
    631 
    632     char [] operations = makePostfixForSSE(infixOperations, ranklist); 
    633     char [] result=""; 
    634      
    635     result ~= "// Operation : " ~  operations ~ \n; 
    636  
    637     // Create local variables for pointers to vectors (avoid bug #1125) 
    638     // Bad code is also generated for loading scalars. 
    639     int vecnum = 0; 
    640     for (int i=0; i< ranklist.length;++i) { 
    641         if (ranklist[i]==1){ 
    642             result~= "  auto vec" ~ itoa(i) ~ " = values[" ~itoa(i) ~"].ptr; // " ~ cast(char)('A'+i)~ \n; 
    643             ++vecnum; 
    644         } else result~= " auto val" ~ itoa(i) ~ " = values["~itoa(i)~"]; // " ~ cast(char)('A'+i)~ \n; 
    645     } 
    646     result ~= "  int veclength = values[" ~itoa(findFirstVector(ranklist)) ~"].length;\n"; 
    647   
    648     int numScalarsOnStack=0; 
    649  
    650     result~= \n"asm {"\n ~ pushRegisters(vecnum); 
    651     // EAX will be the counter 
    652     result ~= "  mov EAX, veclength;"\n; 
    653     // Load all the vector pointers into registers 
    654  
    655     const char [] vectorsize = "8"; // size of a double 
    656     int numvecs=0; 
    657     int numconsts=0; 
    658     for (int i=0; i<ranklist.length; ++i) { 
    659       if (ranklist[i]==1) { 
    660         result ~= "  lea " ~ vectorRegister[numvecs] 
    661           ~ ", [" ~ vectorsize ~ "*EAX];   " 
    662           ~ "  add " ~ vectorRegister[numvecs] ~ ", vec" ~ itoa(i) ~ ";"; 
    663          result ~= "  //" ~ cast(char)('A'+i) ~ \n;  
    664         ++numvecs; 
    665       } else if (ranklist[i]==0) { 
    666 //          result ~= "  movsd " ~ XMM(numconsts) ~ ", double ptr values["~ itoa(i) ~"]; " 
    667           result ~= "  movsd " ~ XMM(numconsts) ~ ", double ptr val"~ itoa(i) ~"; " 
    668             "  shufpd " ~ XMM(numconsts) ~", " ~ XMM(numconsts) ~ ",0; //" ~ cast(char)('A'+i) ~ \n; 
    669           ++numconsts; 
    670           ++numScalarsOnStack; 
    671       } 
    672     } 
    673     result ~= "  xor EAX, EAX; "\n 
    674         "  sub EAX, veclength; // counter=-length"\n 
    675         "  jz short L3; // test for length==0"\n; 
    676     int done=0; 
    677  
    678     char [] mainbody = ""; 
    679  
    680     // The SSE implementation mimics the x87 version. Instead of keeping track of 
    681     // 'top of stack', we keep track of the next unused register. It's OK to keep 
    682     // reusing the same registers, because the CPUs which support SSE 
    683     // also have extensive support for register renaming. 
    684      
    685     int numOnStack = numScalarsOnStack; // How much of the FP stack is being used? 
    686     while(done<operations.length) { 
    687         char [] next; 
    688       if (isInstruction(operations[done])) { 
    689             // Perform an arithmetic operation on the top two items. 
    690             next = "  " ~ opToSSE2[operations[done]] ~ XMM(numOnStack-1) ~ ", " ~ XMM(numOnStack) ~ ";  //" ~ operations[done] ~ \n; 
    691             mainbody ~= next; 
    692             ++done; 
    693             numOnStack--; 
    694       } else if (!isInstruction(operations[done+1])){ 
    695             // load a vector onto the FPU stack, to begin a new subexpression. 
    696             int u  = operations[done]-'A'; 
    697             next = "  movapd " ~ XMM(numOnStack) ~ ", " ~ indexedSSEVector(ranklist, operations[done] ) ~ ";  // " ~ operations[done..done+1] ~ \n; 
    698             mainbody ~= next; 
    699             ++done; 
    700             numOnStack++; 
    701       } else if (ranklist[operations[done]-'A']==1) { 
    702              // An operation will be performed between the stack top and a vector. 
    703              // If it's a float or double, we can combine the load+arithmetic op 
    704              // into a single instruction. 
    705             if (operations[done+1]=='=') mainbody ~= "  movapd " ~ indexedSSEVector(ranklist, operations[$-2] ) ~ ",XMM" ~ itoa(numOnStack-1) ~";  // " ~ operations[$-2..$] ~ \n; 
    706             else mainbody ~= "  " ~ opToSSE2[operations[done+1]] ~ " " ~ XMM(numOnStack-1) ~ ", " 
    707               ~ indexedSSEVector(ranklist, operations[done] ) ~ "; // " ~ operations[done..done+2] ~ \n; 
    708             done+=2; 
    709       } else { // multiply by scalar. 
    710             next = "  " ~ opToSSE2[operations[done+1]] ~ " " ~ XMM(numOnStack-1) ~ ", " ~ XMM(scalarNum(ranklist, operations[done]-'A')) ~"; // " ~operations[done..done+2] ~ \n; 
    711             mainbody ~= next;        
    712         done +=2; 
    713       }       
    714     } 
    715          
    716     result ~= \n  
    717         ~ "  align 16;\n"  
    718         ~ "L1:\n" ~ mainbody; 
    719     result ~= "  add EAX,2;\n" ~ "  js L1;\n"; 
    720     result~= "L3:" \n ~ popRegisters(vecnum) ~ "}\r\n"; 
    721     
    722     return result; 
    723 
     56public import CodegenX86 : generateCodeForAsmX87, generateCodeForSSE2, isSSE2AsmPossible; 
    72457 
    72558//----------------------------- 
  • trunk/blade/BladeRank.d

    r111 r112  
    119119    return k; 
    120120} 
     121 
     122 
     123 
     124/// Return the length of a sub-expression 
     125int exprLength(char [] s) 
     126{ 
     127    if (s[0]>='A' && s[0]<='Z') 
     128        return 0; 
     129    int numParens = 0; 
     130    for (int i=0; i<s.length; ++i) { 
     131        if (s[i]=='(') { 
     132            numParens++; 
     133        } 
     134        if (s[i]==')') { 
     135            numParens--; 
     136        } 
     137        if (numParens == 0) { 
     138            return i; 
     139        } 
     140    } 
     141} 
     142 
     143/** Returns the (tensor) rank of the expression expr. 
     144 * 
     145 * Params: 
     146 * expr   Placeholder expression (A,B,... correspond to tuple[0],[1],...) 
     147 * rank   The rank of each tuple member A, B, C, ... 
     148 */ 
     149int exprRank(char [] expr, int [] rank) 
     150{ 
     151    int x = exprLength(expr); 
     152     
     153    char [] op = expr[x+1..x+2];     
     154    char [] left = expr[0..x+1]; 
     155    char [] right = expr[x+2..$]; 
     156    int lrank = (left.length==1)?  rank[left[0]-'A'] : exprRank(left[1..$-1], rank); 
     157    int rrank = (right.length==1)?  rank[right[0]-'A'] : exprRank(right[1..$-1], rank); 
     158    if (op=="+" || op=="-" || op=="=") { 
     159        assert(lrank==rrank, "Rank error in expression"); 
     160        return lrank; 
     161    } 
     162    if (lrank==0) return rrank; 
     163    if (rrank==0) return lrank; 
     164    assert(0, "Unsupported operation"); 
     165    return 0; 
     166} 
     167 
     168unittest { 
     169    assert(exprRank("A+(B*C)", [1,1,0])==1); 
     170    assert(exprRank("A+(B*C)", [0,0,0])==0); 
     171    assert(exprRank("A+(B*C)", [2,0,2])==2); 
     172} 
     173 
     174// Rank functions also using placeholder expressions 
     175 
     176/** Returns the resultant element type of the tensor expression expr. 
     177 *  
     178 * Note that since D doesn't have array operations, the expression is not 
     179 * normally valid D code. 
     180 * 
     181 * Params: 
     182 *  expr   Placeholder expression (A,B,... correspond to tuple[0],[1],...) 
     183 * T   Every type in the expression 
     184 */ 
     185template exprElementType(char [] expr, T...) 
     186{ 
     187    const int x = exprLength(expr); 
     188     
     189    const char [] op = expr[x+1..x+2];     
     190    const char [] left = expr[0..x+1]; 
     191    const char [] right = expr[x+2..$]; 
     192    static if (left.length==1) 
     193        alias ElementType!(T[left[0]-'A']) LeftElemType; 
     194    else alias typeof(exprElementType!(left[1..$-1], T).ElemType) LeftElemType; 
     195    static if (right.length==1) 
     196        alias ElementType!(T[right[0]-'A']) RightElemType; 
     197    else alias exprElementType!(right[1..$-1], T).ElemType RightElemType; 
     198    static if (op=="+" || op=="-" || op=="=") { 
     199        alias typeof(LeftElemType + RightElemType) ElemType; 
     200    } else { // multiply 
     201        alias typeof(LeftElemType * RightElemType) ElemType; 
     202    } 
     203} 
     204 
     205unittest { 
     206static assert(is(exprElementType!("A+(B*C)", float[], double[], double).ElemType == double)); 
     207}