Changeset 87

Show
Ignore:
Timestamp:
03/15/07 03:19:47 (2 years ago)
Author:
Don Clugston
Message:

BLADE: Support real vectors. (vec[]*=scalar is still not working).

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/mathextra/Blade.d

    r86 r87  
    11/** 
    2 * BLADE 0.2 -- Basic Linear Algebra D Expressions 
     2* BLADE 0.2Alpha -- Basic Linear Algebra D Expressions 
    33* 
    44* Generate near-optimal x87 asm code for BLAS1 basic vector operations at compile time. 
     5* 32, 64 and 80 bit vectors are all supported. 
    56* 
    67* FEATURES: 
    78*  - Supports any mix of vector addition, subtraction, dot product, and multiplication 
    8 *    by a real scalar. 
    9 *  - Supports mixed-length operations (eg, double[] + float[]). 
     9*    by a scalar. 
     10*  - Supports mixed-length operations (eg, real[] + double[] + float[]). 
    1011*  - When static arrays are used, mismatches in array length are detected at compile time. 
    1112* 
     
    1617* 
    1718* BUGS: 
    18 *  - 80-bit vectors don't yet work (vector variables needs to add 10 each iteration) 
     19*  - Not well tested. 
     20*  - Not optimal for 80-bit vectors (can save an instruction by changing float + real into real+float). 
     21*  - Not optimal for the case of multiple real vectors (they could share a counter). 
     22*  - Doesn't support multiply by scalar real (need to keep track of the FP stack to do this). 
    1923*  - Doesn't yet support all combinations (eg vector*=scalar). 
    2024*  - Doesn't take advantage of length being known at compile time (loop unrolling 
    2125*     is possible). 
     26*  - Doesn't use EBP register -- this would allow an extra vector in expressions. 
    2227*  - Doesn't support imaginary and complex numbers. 
    2328*  - Need a D and an SSE2 version. 
     29* 
     30* THEORY: 
     31* Expression templates are used to create an expression string of the form "(a+b*c)+d" 
     32* and a tuple, the entries of which correspond to a, b, c, d, ... 
     33* This string is converted to postfix. The postfix string is converted to 
     34* a string containing x87 asm, which is then mixed into a function which accepts the tuple. 
     35* 
    2436*/ 
    2537module Blade; 
     
    156168        performOperation!(expr.ops, "-=", len==0? expr.len : len, expr.ValueTuple, X[])(expr.values, values); 
    157169    } 
    158 //    void opMulAssign(A)(A w) { // We don't want to generate this code unless it's actually used. 
    159 //        static assert(is(A: double)); 
    160 //        performOperation!("a", "*=", double, X[])(w, values); 
    161 //    } 
     170    void opMulAssign(A)(A w) { // We don't want to generate this code unless it's actually used. 
     171        static assert(is(A: double)); 
     172        performOperation!("a", "*=", knownlength, double, X[])(w, values); 
     173    } 
    162174    VectorExpr!("a*b", knownlength, X[], double) opMul(double w) { 
    163175        return VectorExpr!("a*b", knownlength, X[], double)(values, w); 
     
    193205real performOperation(char [] operations, char [] finaloperation, int knownlength, X...)(X expr) 
    194206{ 
    195      pragma(msg, operations); 
     207     pragma(msg, finaloperation ~ operations); 
    196208     const char [] post = makePostfix(operations); 
    197209     const char [] tupstr = vectorTupleToString!(X); 
     
    313325    if (vartype=='D') return "8"; 
    314326    else if (vartype=='F') return "4"; 
    315     else assert(0)
    316 
    317  
    318  
    319 char [] vectorRegister(char [] typelist, char var) 
     327    else if (vartype=='R') return REALSIZE
     328    assert(0); 
     329
     330 
     331int vectorNum(char [] typelist, char var) 
    320332{ 
    321333    int numVecs=0; 
     
    323335        if (typelist[i]=='R' || typelist[i]=='D' || typelist[i]=='F') ++numVecs; 
    324336    } 
    325     assert(numVecs<4, "BLADE: Too many vectors!"); 
    326 //    if (numVecs==4) return "EBP"; 
    327     return "E" ~ cast(char)('A'+numVecs) ~ "X"; 
     337    return numVecs; 
     338
     339 
     340    // First, use the registers that don't need to be preserved. 
     341const char [][5] VectorRegisters = ["EAX", "ECX", "EDX", "EBX", "EDI"]; 
     342 
     343char [] vectorRegister(int vecnum) 
     344
     345    return VectorRegisters[vecnum]; 
     346
     347 
     348char [] pushRegisters(int numVectors) 
     349
     350    char [] result = "  push ESI;"; 
     351    for (int i=3; i<numVectors; ++i) result~= " push " ~ VectorRegisters[i] ~ ";"; 
     352    return result ~ "\n"; 
     353
     354 
     355char [] popRegisters(int numVectors) 
     356
     357    char [] result = "  "; 
     358    for (int i=numVectors-1; i>=3; --i) result~= "pop " ~ VectorRegisters[i] ~ "; "; 
     359    return result ~ "pop ESI;\n"; 
    328360} 
    329361 
    330362char [] indexedVector(char [] typelist, char var) 
    331363{ 
     364    if (typelist[var-'a']=='R') return " real ptr [" ~ vectorRegister(vectorNum(typelist, var)) ~ "]"; 
    332365    return operandSize(typelist[var-'a']) ~ "[" ~ 
    333             vectorRegister(typelist, var) ~ " + " ~ vectorSize(typelist[var-'a']) ~ "*ESI]"; 
    334 
     366            vectorRegister(vectorNum(typelist, var)) ~ " + " ~ vectorSize(typelist[var-'a']) ~ "*ESI]"; 
     367
     368 
     369static if (real.sizeof==10) const char [] REALSIZE="10"; 
     370else static if (real.sizeof==12) const char [] REALSIZE="12"; 
     371else static if (real.sizeof==16) const char [] REALSIZE="16"; 
    335372 
    336373// Generate asm code which is optimal for x87 CPUs without SSE2 
     
    342379{ 
    343380    char [] result=""; 
     381    char [] incrementRealVectors=""; 
    344382 
    345383    // Create local variables for everything (avoid bug #1028) 
     
    352390        } else { 
    353391            result~= "  auto vec" ~ itoa(vecnum) ~ " = expr[" ~itoa(i) ~"].ptr;\n"; 
     392            if (typelist[i]=='R') { 
     393                incrementRealVectors ~= "  add " ~ vectorRegister(vecnum) ~ ", " ~ REALSIZE ~ ";\n"; 
     394            } 
    354395            if (firstvec) { 
    355396                result~= "  int veclength = expr[" ~itoa(i) ~"].length;\n"; 
     
    359400        } 
    360401    } 
     402    assert(vecnum-1 < VectorRegisters.length, "Too many vectors!"); 
    361403 
    362404    bool isDotProduct = (operations[$-1]=='.'); 
    363405 
    364     result~= \n"asm {"\n 
    365         "  push EBP; push EBX; push ESI; push EDI;"\n 
     406    result~= \n"asm {"\n ~ pushRegisters(vecnum) ~ 
    366407        "  mov ESI, veclength;"\n; // ESI will be the counter 
    367408 
    368409        // Load all the vector pointers into registers. 
    369410    int numvecs=0; 
    370     int typelength = typelist.length; 
    371     if (!isDotProduct) { 
    372          --typelength; 
    373         result ~= "  lea EDI, [" ~ vectorSize_LEA(typelist[typelength]) ~ "];"\n 
    374             ~ "  add EDI, vec" ~ itoa(vecnum-1) ~";"\n; 
    375      } 
    376     for (int i=0; i<typelength; ++i) { 
     411    for (int i=0; i<typelist.length; ++i) { 
    377412      if (isVector(typelist, i+'a')) { 
    378         result ~= "  lea " ~ vectorRegister(typelist, i+'a') ~ ", [" 
    379           ~ vectorSize_LEA(typelist[i]) ~ "];"\n 
    380           ~ "  add " ~ vectorRegister(typelist, i+'a') ~ ", vec" ~ itoa(numvecs) ~ ";"\n; 
     413          if (isRealVector(typelist, i+'a')) { 
     414              result ~= "  mov " ~ vectorRegister(numvecs) ~ ", vec" ~ itoa(numvecs) ~ ";"\n; 
     415          } else  { 
     416            result ~= "  lea " ~ vectorRegister(numvecs) ~ ", [" 
     417              ~ vectorSize_LEA(typelist[i]) ~ "];"\n 
     418              ~ "  add " ~ vectorRegister(numvecs) ~ ", vec" ~ itoa(numvecs) ~ ";"\n; 
     419         } 
    381420        ++numvecs; 
    382421      } 
     
    391430    char [] mainbody = ""; 
    392431 
     432if (operations.length>1) { 
    393433    while(done<operations.length) { 
    394434      if (isInstruction(operations[done])) { 
     
    402442         if (isRealVector(typelist, operations[done])) { 
    403443             // 80-bit vectors must be loaded onto the FPU stack first 
    404             mainbody ~= "  fld "  ~ indexedVector(typelist, operations[done] ) ~ ";\n"; 
     444            mainbody ~= "  fld real ptr ["  ~ vectorRegister(vectorNum(typelist, operations[done])) ~ "];\n"; 
    405445            mainbody ~= "  " ~ opToX87(operations[done+1]) ~ "p ST(1), ST;\n"; 
    406446         } else { 
     
    414454      } 
    415455    } 
     456} 
    416457    if (!isDotProduct && finaloperation.length>1) { 
    417         char [] finalop = "fadd"; 
    418         if (finaloperation[0]=='-') finalop="fsubr"; 
    419         mainbody ~= "  " ~ finalop ~ " " ~ operandSize(typelist[typelength]) ~ " [EDI + " 
    420             ~ vectorSize_LEA(typelist[typelength]) ~ "];"\n; 
     458        if (finaloperation[0]=='*') mainbody ="  fmul ST, ST(2);"\n; 
     459        else { 
     460            char [] finalop = "fadd"; 
     461            if (finaloperation[0]=='-') finalop="fsubr"; 
     462             if (typelist[$-1]=='R') { 
     463                 // 80-bit vectors must be loaded onto the FPU stack first 
     464                mainbody ~= "  fld real ptr ["  ~ vectorRegister(numvecs-1) ~ " + " 
     465                ~ vectorSize_LEA(typelist[$-1]) ~ "];"\n; 
     466                mainbody ~= "  " ~ finalop ~ "p ST(1), ST;\n"; 
     467             } else { 
     468                mainbody ~= "  " ~ finalop ~ " " ~ operandSize(typelist[$-1]) ~ " [" ~ vectorRegister(numvecs-1) ~ " + " 
     469                ~ vectorSize_LEA(typelist[$-1]) ~ "];"\n; 
     470            } 
     471        } 
    421472    } 
    422473    result ~= \n ~  mainbody  ~ "  jmp short L2;\n" 
     
    425476    result ~= "  fxch ST(1), ST;\n"; // get previous result 
    426477    if (isDotProduct) result ~= "  faddp ST(2), ST;"\n; 
    427     else result ~= "  fstp " ~ operandSize(typelist[typelength]) ~ " [EDI + " ~ vectorSize_LEA(typelist[typelength]) ~ " - " ~ vectorSize(typelist[$-1]) ~ "];"\n; 
     478    else result ~= "  fstp " ~ operandSize(typelist[$-1]) ~ " [" ~ vectorRegister(numvecs-1) ~ " + " ~ vectorSize_LEA(typelist[$-1]) ~ " - " ~ vectorSize(typelist[$-1]) ~ "];"\n; 
    428479 
    429480    result ~= "L2: \n"; 
    430481 
    431     result~= "  inc ESI;\n  jnz L1;\n"; 
     482    result~= incrementRealVectors ~ "  inc ESI;\n  jnz L1;\n"; 
    432483    if (isDotProduct) result ~= "  faddp ST(1), ST;"\n; 
    433     else result ~= "  fstp " ~ operandSize(typelist[typelength]) ~ " [EDI + " ~ vectorSize_LEA(typelist[typelength]) ~ "];"\n; 
    434     result~= "L3:" \n 
    435         ~ "  pop EDI; pop ESI; pop EBX; pop EBP; "\n 
    436         ~ "}\r\n"; 
     484    else result ~= "  fstp " ~ operandSize(typelist[$-1]) ~ " [" ~ vectorRegister(numvecs-1) ~ " + " ~ vectorSize_LEA(typelist[$-1]) ~ "];"\n; 
     485    result~= "L3:" \n ~ popRegisters(vecnum) ~ "}\r\n"; 
    437486    return result; 
    438487} 
    439488 
     489// ------------------------------- 
     490//   PART 5 -- Example 
     491// ------------------------------- 
     492 
    440493void main() 
    441494{ 
    442     auto p = Vec([1.0, 2, 18]); 
    443     auto q = Vec([3.5, 1.1, 3.8]); 
     495    auto p = Vec([1.0L, 2, 18]); 
     496    auto q = Vec([3.5L, 1.1, 3.8]); 
    444497    auto r = Vec([17.0f, 28.1, 1]); 
    445     q += ((p+r)*18.0L*314.1L - (r-p))* 35; 
     498    q -= ((p+r)*18.0L*314.1L - (p-r))* 35; 
    446499    real d = dot(r, p+r+r); 
    447500    writefln(d); 
    448 
     501    p*=2; 
     502