Changeset 180

Show
Ignore:
Timestamp:
03/11/08 04:48:03 (9 months ago)
Author:
Don Clugston
Message:

Added most of the test cases from Alefeld's paper. This clearly shows that the remaining instances of poor performance occur when a successful cubic interpolation forces secant interpolation to be used from that point on.

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/mathextra/Bracket.d

    r179 r180  
    11/** Algorithms for finding roots and extrema of functions using bracketing. 
    22 * 
    3  * Uses an algorithm based on TOMS748, which uses inverse cubic interpolation, 
     3 * Uses an algorithm based on alefeld's TOMS748, which uses inverse cubic interpolation, 
    44 * reverting to parabolic or secant interpolation where necessary. 
    55 * This implementation improves worst-case performance by a factor of more than 
    66 * 100, by performing bisection in _binary space_ (instead of a simple bisection) 
    77 * when slow convergence is encountered. 
     8 * 
     9 *  The alefeld algorithm performs badly when: 
     10 *    - the range in binary space is so large that a naive bisection causes 
     11 *       catastrophic cancellation; or 
     12 *    - after a cubic fit, two points have the same y value, so that it then uses secant every time. 
    813 */ 
    914module Bracket; 
     
    8085unittest{ 
    8186    int numCalls; 
     87    int numProblems=0; 
    8288     
    8389    void testbinaryChop(real delegate(real) f, real x1, real x2) { 
     
    101107    void testToms(real delegate(real) f, real x1, real x2) { 
    102108        numCalls=0; 
     109        ++numProblems; 
    103110        auto result = toms_solve!(real)(f, x1, x2, f(x1), f(x2),(real a, real b){  
    104111         return b==nextUp(a) || a==nextUp(b); }, 300 ); 
    105112        printf("TOMS Num calls = %d\n",numCalls); 
    106 /*         
    107         printf("Plus=%La %Lg\n", result.xlo, f(result.xlo)); 
    108         printf("Minus=%La %Lg\n", result.xhi, f(result.xhi)); 
    109         if (f(result.xlo)!=0) { 
    110             assert(oppositeSigns(f(result.xlo), f(result.xhi))); 
     113         
     114        auto flo = f(result.xlo); 
     115        auto fhi = f(result.xhi); 
     116//        printf("Plus = %La %Lg\n", result.xlo, flo); 
     117//        printf("Minus= %La %Lg\n", result.xhi, fhi); 
     118        if (flo!=0) { 
     119            assert(oppositeSigns(flo, fhi)); 
    111120        } 
    112 */       
    113121    } 
    114122    void testBoth(real delegate(real) f, real x1, real x2) { 
     
    158166    // IE this is 3X faster in the worst case 
    159167    n=3; 
     168    numProblems=0; 
    160169    foreach(k; nvals) { 
    161170        n=k; 
    162171        testToms(&power, -1, 10); 
    163172    } 
    164     printf("Total power calls=%d", powercalls); 
     173     
     174    int powerProblems=numProblems; 
     175 
     176    // Tests from Alefeld paper 
     177         
     178    int [9] alefeldSums; 
     179     
     180    real alefeld0(real x){ 
     181        ++alefeldSums[0]; 
     182        ++numCalls; 
     183        real q =  sin(x) - x/2; 
     184        for (int i=1; i<20; ++i) q+=(2*i-5.0)*(2*i-5.0)/((x-i*i)*(x-i*i)*(x-i*i)); 
     185        return q; 
     186    } 
     187   real ale_a, ale_b; 
     188   real alefeld1(real x) { 
     189        ++numCalls; 
     190       ++alefeldSums[1]; 
     191       return ale_a*x + exp(ale_b * x); 
     192   } 
     193   real alefeld2(real x) { 
     194        ++numCalls; 
     195       ++alefeldSums[2]; 
     196       return pow(x, n) - ale_a; 
     197   } 
     198   real alefeld3(real x) { 
     199        ++numCalls; 
     200       ++alefeldSums[3]; 
     201       return (1+pow(1.0L-n, 2))*x - pow(1.0L-n*x, 2); 
     202   } 
     203   real alefeld4(real x) { 
     204        ++numCalls; 
     205       ++alefeldSums[4]; 
     206       return x*x - pow(1-x, n); 
     207   } 
     208    
     209   real alefeld5(real x) { 
     210        ++numCalls; 
     211       ++alefeldSums[5]; 
     212       return (1+pow(1.0L-n, 4))*x - pow(1.0L-n*x, 4); 
     213   } 
     214    
     215   real alefeld6(real x) { 
     216        ++numCalls; 
     217       ++alefeldSums[6]; 
     218       return exp(-n*x)*(x-1.0L) + pow(x, n); 
     219   } 
     220    
     221   real alefeld7(real x) { 
     222        ++numCalls; 
     223       ++alefeldSums[7]; 
     224       return (n*x-1)/((n-1)*x); 
     225   } 
     226   printf("\nALEFELD TESTS\n"); 
     227   numProblems=0; 
     228   testToms(&alefeld0, PI_2, PI); 
     229   for (n=1; n<=10; ++n) { 
     230    testToms(&alefeld0, n*n+1e-9L, (n+1)*(n+1)-1e-9L); 
     231   } 
     232printf("ALEFELD 1\n"); 
     233   ale_a = -40; ale_b = -1; 
     234   testToms(&alefeld1, -9, 31); 
     235   ale_a = -100; ale_b = -2; 
     236   testToms(&alefeld1, -9, 31); 
     237   ale_a = -200; ale_b = -3; 
     238   testToms(&alefeld1, -9, 31); 
     239   int [] nvals_3 = [1, 2, 5, 10, 15, 20]; 
     240   int [] nvals_5 = [1, 2, 4, 5, 8, 15, 20]; 
     241   int [] nvals_6 = [1, 5, 10, 15, 20]; 
     242   int [] nvals_7 = [2, 5, 15, 20]; 
     243printf("ALEFELD 2\n"); 
     244    
     245   for(int i=4; i<12; i+=2) { 
     246       n=i; 
     247   ale_a=0.2; 
     248       testToms(&alefeld2, 0, 5); 
     249   ale_a=1; 
     250       testToms(&alefeld2, 0.95, 4.05); 
     251       testToms(&alefeld2, 0, 1.5);        
     252   } 
     253printf("ALEFELD 3\n"); 
     254   foreach(i; nvals_3) { 
     255       n=i; 
     256   testToms(&alefeld3, 0, 1); 
     257
     258printf("ALEFELD 4\n"); 
     259   foreach(i; nvals_3) { 
     260       n=i; 
     261   testToms(&alefeld4, 0, 1); 
     262
     263printf("ALEFELD 5\n"); 
     264   foreach(i; nvals_5) { 
     265       n=i; 
     266   testToms(&alefeld5, 0, 1); 
     267
     268printf("ALEFELD 6\n"); 
     269   foreach(i; nvals_6) { 
     270       n=i; 
     271   testToms(&alefeld6, 0, 1); 
     272
     273printf("ALEFELD 7\n"); 
     274   foreach(i; nvals_7) { 
     275       n=i; 
     276       testToms(&alefeld7, 0.01L, 1); 
     277    }    
     278   printf("\nSUMMARY\n"); 
     279   int grandtotal=0; 
     280   foreach(calls; alefeldSums) { 
     281       grandtotal+=calls; 
     282       printf("%d ", calls); 
     283   } 
     284   grandtotal-=2*numProblems; 
     285   printf("\nALEFELD TOTAL = %d avg = %f\n", grandtotal, (1.0*grandtotal)/numProblems); 
     286   powercalls -= 2*powerProblems; 
     287 
     288   printf("POWER TOTAL = %d avg = %f ", powercalls, (1.0*powercalls)/powerProblems); 
     289     
    165290} 
    166291 
     
    168293/* 
    169294 
    170 (real x)( 
    171 real q =  sin(x) - x/2; 
    172 for (int i=1; i<20; ++i) q+=(2*i-5.0)*(2*i-5.0)/((x-i*i)*(x-i*i)*(x-i*i)); 
    173 return q; 
    174 } 
    175295range PI/2, PI 
    176296n*n + 1e-9 .. (n+1)*(n+1) - 1e-9. n=1(1)19. 
     
    320440/+   
    321441SolveResult!(T) toms_solve(T)(T delegate(T) f, T ax, T bx, T fax, T fbx, bool delegate(T,T) tol, uint max_iter) 
    322          SUBROUTINE RROOT(NPROB,NEPS,EPS,A,B,ROOT) 
    323          INTEGER NPROB,ITNUM,ISIGN,NEPS 
    324          DOUBLE PRECISION A,B,FA,FB,C,U,FU,MU,A0,B0,TOL,D,FD 
    325          DOUBLE PRECISION PROF,E,FE,EPS,ROOT 
    326          EXTERNAL ISIGN 
    327          PARAMETER (MU=0.5D0) 
    328442 
    329443// Initialization. set the number of iteration as 0. call subroutine 
     
    331445// dumb values for the variables "e" and "fe". 
    332446 
    333          itnum=0 
    334          call func(nprob,a,fa) 
    335          call func(nprob,b,fb) 
    336          e=1.0d5 
    337          fe=1.0d5 
     447         int itnum=0; 
     448         T fa = f(a); 
     449         T fb = f(b); 
     450         T e = fe = 1.0e5; 
    338451          
     452 
     453    for (;;) { 
    339454// iteration starts. the enclosing interval before executing the 
    340455// iteration is recorded as [a0, b0]. 
    341 // 
    342  10      a0=a 
    343          b0=b 
     456        a0=a; b0=b; 
    344457 
    345458// updates the number of iteration. 
    346  
    347          itnum=itnum+1 
    348 // calculates the termination criterion. stops the procedure if the 
    349 // criterion is satisfied. 
    350  
    351          if(dabs(fb) .le. dabs(fa)) then 
    352           call tole(b,tol,neps,eps) 
    353          else 
    354           call tole(a,tol,neps,eps) 
    355          endif 
    356          if((b-a).le.tol)goto 400 
     459         ++itnum; 
     460// Stops if termination criterion is satisfied. 
     461        if (tol(a, b)) return a; 
    357462          
    358463// for the first iteration, secant step is taken. 
    359464 
    360          if(itnum .eq. 1)then 
    361           c=a-(fa/(fb-fa))*(b-a) 
    362            
    363 // Call subroutine "brackt" to get a shrinked enclosing interval as 
    364 // well as to update the termination criterion. Stop the procedure 
    365 // if the criterion is satisfied or the exact solution is obtained. 
    366  
    367           call brackt(nprob,a,b,c,fa,fb,tol,neps,eps,d,fd) 
    368           if((fa.eq.0.0d0).or.((b-a).le.tol))goto 400 
    369           goto 10 
    370          endif 
     465        if(itnum==1) { 
     466          c = a-(fa/(fb-fa))*(b-a); 
     467          bracket(f, a,b,c, fa,fb, d,fd); 
     468          if((fa == 0)|| (tol(a, b))) return a; 
     469          continue; 
     470        } 
    371471          
    372472// starting with the second iteration, in the first two steps, either 
    373 // quadratic interpolation is used by calling the subroutine "newqua" 
    374 // or the cubic inverse interpolation is used by calling the subroutine 
    375 // "pzero". in the following, if "prof" is not equal to 0, then the 
    376 // four function values "fa", "fb", "fd", and "fe" are distinct, and 
    377 // hence "pzero" will be called. 
    378  
    379          prof=(fa-fb)*(fa-fd)*(fa-fe)*(fb-fd)*(fb-fe)*(fd-fe)           
    380          if((itnum .eq. 2) .or. (prof .eq. 0.0d0)) then 
    381           call newqua(a,b,d,fa,fb,fd,c,2) 
    382          else 
    383           call pzero(a,b,d,e,fa,fb,fd,fe,c) 
    384           if((c-a)*(c-b) .ge. 0.0d0)then 
    385            call newqua(a,b,d,fa,fb,fd,c,2) 
    386           endif 
    387          endif 
    388          e=d 
    389          fe=fd 
     473// quadratic interpolation or cubic inverse interpolation is used. 
     474// In the following, if "prof" is not equal to 0, then the 
     475// four function values "fa", "fb", "fd", and "fe" are distinct, and the 
     476// cubic interpolation is used. 
     477 
     478         T prof = (fa-fb)*(fa-fd)*(fa-fe)*(fb-fd)*(fb-fe)*(fd-fe); 
     479         if((itnum==2) || (prof==0)) { 
     480             c = newtonQuadratic(a,b,d,fa,fb,fd,2); 
     481         } else { 
     482             c = cubicInverseInterpolate(a,b,d,e,fa,fb,fd,fe); 
     483             if((c-a)*(c-b) >= 0) { 
     484                  c = newtonQuadratic(a,b,d,fa,fb,fd,2); 
     485             } 
     486         } 
     487         e=d; 
     488         fe=fd; 
     489         bracket(f, a,b,c,fa,fb,d,fd); 
     490         if((fa == 0) ||  tol(a, b)) return a; 
     491         prof = (fa-fb)*(fa-fd)*(fa-fe)*(fb-fd)*(fb-fe)*(fd-fe); 
     492         if (prof == 0) { 
     493             c = newtonQuadratic(a,b,d,fa,fb,fd,3); 
     494         } else { 
     495             c = cubicInverseInterpolate(a,b,d,e,fa,fb,fd,fe); 
     496             if((c-a)*(c-b) >= 0) { 
     497                  c = newtonQuadratic(a,b,d,fa,fb,fd,3); 
     498             } 
     499         } 
     500         bracket(f, a,b,c,fa,fb,d,fd); 
     501         if((fa == 0) ||  tol(a, b)) return a; 
     502         e=d; 
     503         fe=fd; 
     504 
     505// takes the double-size secant step. 
     506 
     507         if (fabs(fa) < fabs(fb)) { 
     508              u=a; 
     509              fu=fa; 
     510         } else{ 
     511              u=b; 
     512              fu=fb; 
     513         } 
     514         c=u-2.0d0*(fu/(fb-fa))*(b-a); 
     515         if(fabs(c-u) > (0.5d0*(b-a))) { 
     516            c=a+0.5d0*(b-a); 
     517         } 
     518         bracket(f, a,b,c,fa,fb,d,fd); 
     519         if((fa == 0) ||  tol(a, b)) return a; 
     520 
     521// determines whether an additional bisection step is needed. and takes 
     522// it if necessary. 
     523 
     524         if((b-a) < (0.5*(b0-a0))) continue; 
     525         e=d; 
     526         fe=fd; 
    390527 
    391528// call subroutine "brackt" to get a shrinked enclosing interval as 
     
    393530// if the criterion is satisfied or the exact solution is obtained. 
    394531 
    395          call brackt(nprob,a,b,c,fa,fb,tol,neps,eps,d,fd) 
    396          if((fa.eq.0.0d0).or.((b-a).le.tol))goto 400 
    397          prof=(fa-fb)*(fa-fd)*(fa-fe)*(fb-fd)*(fb-fe)*(fd-fe)           
    398          if(prof .eq. 0.0d0) then 
    399           call newqua(a,b,d,fa,fb,fd,c,3) 
    400          else 
    401           call pzero(a,b,d,e,fa,fb,fd,fe,c) 
    402           if((c-a)*(c-b) .ge. 0.0d0)then 
    403            call newqua(a,b,d,fa,fb,fd,c,3) 
    404           endif 
    405          endif 
    406  
    407 // call subroutine "brackt" to get a shrinked enclosing interval as 
    408 // well as to update the termination criterion. stop the procedure 
    409 // if the criterion is satisfied or the exact solution is obtained. 
    410  
    411          call brackt(nprob,a,b,c,fa,fb,tol,neps,eps,d,fd) 
    412          if((fa.eq.0.0d0).or.((b-a).le.tol))goto 400 
    413          e=d 
    414          fe=fd 
    415  
    416 // takes the double-size secant step. 
    417  
    418          if (dabs(fa) .lt. dabs(fb))then 
    419           u=a 
    420           fu=fa 
    421          else 
    422           u=b 
    423           fu=fb 
    424          endif 
    425          c=u-2.0d0*(fu/(fb-fa))*(b-a) 
    426          if(dabs(c-u) .gt. (0.5d0*(b-a)))then 
    427           c=a+0.5d0*(b-a) 
    428          endif 
    429  
    430 // call subroutine "brackt" to get a shrinked enclosing interval as 
    431 // well as to update the termination criterion. stop the procedure 
    432 // if the criterion is satisfied or the exact solution is obtained. 
    433  
    434          call brackt(nprob,a,b,c,fa,fb,tol,neps,eps,d,fd) 
    435          if((fa.eq.0.0d0).or.((b-a).le.tol))goto 400 
    436  
    437 // determines whether an additional bisection step is needed. and takes 
    438 // it if necessary. 
    439  
    440          if((b-a) .lt. (mu*(b0-a0)))then 
    441            goto 10 
    442          endif 
    443          e=d 
    444          fe=fd 
    445  
    446 // call subroutine "brackt" to get a shrinked enclosing interval as 
    447 // well as to update the termination criterion. stop the procedure 
    448 // if the criterion is satisfied or the exact solution is obtained. 
    449  
    450          call brackt(nprob,a,b,a+0.5d0*(b-a),fa,fb,tol,neps,eps,d,fd)          
    451          if((fa .eq. 0.0d0).or.((b-a).le.tol))goto 400 
    452          goto 10 
    453  
    454 // terminates the procedure and return the "root". 
    455  
    456  400     continue 
    457          root=a 
    458          return; 
     532         c = a + 0.5 * (b-a); 
     533         bracket(f, a,b,c,fa,fb,d,fd); 
     534         if((fa == 0) ||  tol(a, b)) return a; 
     535     } 
    459536} 
    460537+/ 
     
    487564{ 
    488565   if ((((a-b)==a)) || ((b-a)==b)) { 
     566       printf("x"); 
    489567       // Catastrophic cancellation 
    490568//       assert(a!=0 && b!=0); 
     
    500578        return c; 
    501579 } 
     580       printf("s"); 
    502581   T tol = T.epsilon * 5; 
    503582   T c = a - (fa / (fb - fa)) * (b - a); 
     
    550629      // Oops, failure, try a secant step: 
    551630      c = secant_interpolate(a, b, fa, fb); 
    552    } 
     631   } else printf("p"); 
    553632   return c; 
    554633} 
     
    612691             // Out of bounds step, fall back to quadratic interpolation:           
    613692             c = quadratic_interpolate!(T)(a, b, d, fa, fb, fd, 3); 
    614            } 
     693           }else printf("c"); 
    615694      } else { 
     695          printf("="); 
    616696         c = quadratic_interpolate!(T)(a, b, d, fa, fb, fd, 2); 
    617697      } 
     
    633713             // Out of bounds step, fall back to quadratic interpolation:           
    634714             c = quadratic_interpolate!(T)(a, b, d, fa, fb, fd, 3); 
    635            } 
     715           }else printf("c"); 
    636716      } else { 
     717          printf("="); 
    637718         c = quadratic_interpolate!(T)(a, b, d, fa, fb, fd, 2); 
    638719      } 
     
    661742//         c = a + (b - a) / 2; 
    662743       if ((a-b)==a || (b-a)==b) { 
     744      printf("b"); 
    663745        // DAC: Using ieeeMean here improves worst-case performance by a factor of ~300. 
    664746        // (from >32800 to ~110 for 80-bit reals). 
     
    673755            } 
    674756       } else { 
     757      printf("m"); 
    675758            c = a + (b - a) / 2; 
    676759       }        
     
    691774      // DAC: Also do a binary chop if we're not within a factor of 2 yet -- ie 
    692775      // if we don't yet know what the exponent is. 
    693             
    694       if( (a==0 || b==0 || (fabs(a)>0.5*fabs(b) && fabs(b)>0.5*fabs(a))) &&  (b - a) < 0.5 * (b0 - a0)) 
     776 
     777      real CLOSENESS = 0.5;            
     778      if( (a==0 || b==0 || (fabs(a)>=CLOSENESS*fabs(b) && fabs(b)>=CLOSENESS*fabs(a))) &&  (b - a) < 0.5 * (b0 - a0)) 
    695779         continue; 
     780      printf("B"); 
     781//      if( (a==0 || b==0 || (fabs(a)>0.5*fabs(b) && fabs(b)>0.5*fabs(a))) &&  (b - a) < 0.5 * (b0 - a0)) 
     782//         continue; 
    696783 
    697784      // There is a very nasty case where the range contains zero.