root/trunk/sort.d

Revision 249, 41.1 kB (checked in by dsimcha, 1 year ago)

Add ridge regression, a few small misc, cleanups.

Line 
1 /**A comprehensive sorting library for statistical functions.  Each function
2  * takes N arguments, which are arrays or array-like objects, sorts the first
3  * and sorts the rest in lockstep.  For merge and insertion sort, if the last
4  * argument is a ulong*, increments the dereference of this ulong* by the bubble
5  * sort distance between the first argument and the sorted version of the first
6  * argument.  This is useful for some statistical calculations.
7  *
8  * All sorting functions have the precondition that all parallel input arrays
9  * must have the same length.
10  *
11  * Notes:
12  *
13  * Comparison functions must be written such that compFun(x, x) == false.
14  * For example, "a < b" is good, "a <= b" is not.
15  *
16  * These functions are heavily optimized for sorting arrays of
17  * ints and floats (by far the most common case when doing statistical
18  * calculations).  In these cases, they can be several times faster than the
19  * equivalent functions in std.algorithm.  Since sorting is extremely important
20  * for non-parametric statistics, this results in important real-world
21  * performance gains.  However, it comes at a price in terms of generality:
22  *
23  * 1.  They assume that what they are sorting is cheap to copy via normal
24  *     assignment.
25  * 2.  They don't work at all with general ranges, only arrays and maybe
26  *     ranges very similar to arrays.
27  * 3.  All tuning and micro-optimization is done with ints and floats, not
28  *     classes, large structs, strings, etc.
29  *
30  * Examples:
31  * ---
32  * auto foo = [3, 1, 2, 4, 5].dup;
33  * auto bar = [8, 6, 7, 5, 3].dup;
34  * qsort(foo, bar);
35  * assert(foo == [1, 2, 3, 4, 5]);
36  * assert(bar == [6, 7, 8, 5, 3]);
37  * auto baz = [1.0, 0, -1, -2, -3].dup;
38  * mergeSort!("a > b")(bar, foo, baz);
39  * assert(bar == [8, 7, 6, 5, 3]);
40  * assert(foo == [3, 2, 1, 4, 5]);
41  * assert(baz == [-1.0, 0, 1, -2, -3]);
42  * ---
43  *
44  * Author:  David Simcha
45  */
46  /*
47  * License:
48  * Boost Software License - Version 1.0 - August 17th, 2003
49  *
50  * Permission is hereby granted, free of charge, to any person or organization
51  * obtaining a copy of the software and accompanying documentation covered by
52  * this license (the "Software") to use, reproduce, display, distribute,
53  * execute, and transmit the Software, and to prepare derivative works of the
54  * Software, and to permit third-parties to whom the Software is furnished to
55  * do so, all subject to the following:
56  *
57  * The copyright notices in the Software and this entire statement, including
58  * the above license grant, this restriction and the following disclaimer,
59  * must be included in all copies of the Software, in whole or in part, and
60  * all derivative works of the Software, unless such copies or derivative
61  * works are solely in the form of machine-executable object code generated by
62  * a source language processor.
63  *
64  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
65  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
66  * FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
67  * SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
68  * FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
69  * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
70  * DEALINGS IN THE SOFTWARE.
71  */
72
73 module dstats.sort;
74
75 import std.traits, std.algorithm, std.math, std.functional, std.math, std.typecons,
76        std.typetuple, std.range, std.array, std.traits, std.string : whitespace;
77
78 import dstats.alloc;
79
80 version(unittest) {
81     import std.stdio, std.random;
82
83     void main (){
84     }
85 }
86
87 class SortException : Exception {
88     this(string msg) {
89         super(msg);
90     }
91 }
92
93 /* CTFE function.  Used in isSimpleComparison.*/
94 /*private*/ string removeWhitespace(string input) pure nothrow {
95     string ret;
96     foreach(elem; input) {
97         bool shouldAppend = true;
98         foreach(whiteChar; whitespace) {
99             if(elem == whiteChar) {
100                 shouldAppend = false;
101                 break;
102             }
103         }
104
105         if(shouldAppend) {
106             ret ~= elem;
107         }
108     }
109     return ret;
110 }
111
112 /* Conservatively tests whether the comparison function is simple enough that
113  * we can get away with comparing floats as if they were ints.
114  */
115 /*private*/ template isSimpleComparison(alias comp) {
116     static if(!isSomeString!(typeof(comp))) {
117         enum bool isSimpleComparison = false;
118     } else {
119         enum bool isSimpleComparison =
120             removeWhitespace(comp) == "a<b" ||
121             removeWhitespace(comp) == "a>b";
122     }
123 }
124
125 /*private*/ bool intIsNaN(I)(I i) {
126     static if(is(I == int) || is(I == uint)) {
127         // IEEE 754 single precision float has a 23-bit significand stored in the
128         // lowest order bits, followed by an 8-bit exponent.  A NaN is when the
129         // exponent bits are all ones and the significand is nonzero.
130         enum uint significandMask = 0b111_1111_1111_1111_1111_1111UL;
131         enum uint exponentMask = 0b1111_1111UL << 23;
132     } else static if(is(I == long) || is(I == ulong)) {
133         // IEEE 754 double precision float has a 52-bit significand stored in the
134         // lowest order bits, followed by an 11-bit exponent.  A NaN is when the
135         // exponent bits are all ones and the significand is nonzero.
136         enum ulong significandMask =
137             0b1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111UL;
138         enum ulong exponentMask = 0b111_1111_1111UL << 52;
139     } else {
140         static assert(0);
141     }
142
143     return ((i & exponentMask) == exponentMask) && ((i & significandMask) != 0);
144 }
145
146 unittest {
147     // Test on randomly generated integers punned to floats.  We expect that
148     // about 1 in 256 will be NaNs.
149     foreach(i; 0..10_000) {
150         uint randInt = uniform(0U, uint.max);
151         assert(std.math.isNaN(*(cast(float*) &randInt)) == intIsNaN(randInt));
152     }
153
154     // Test on randomly generated integers punned to doubles.  We expect that
155     // about 1 in 2048 will be NaNs.
156     foreach(i; 0..1_000_000) {
157         ulong randInt = (cast(ulong) uniform(0U, uint.max) << 32) + uniform(0U, uint.max);
158         assert(std.math.isNaN(*(cast(double*) &randInt)) == intIsNaN(randInt));
159     }
160 }
161
162 /*private*/ T prepareForSorting(alias comp, T)(T arr)
163 if(!isFloatingPoint!(ElementType!T)) {
164     return arr;
165 }
166
167 /* Check for NaNs and throw an exception if they're present.*/
168  real[] prepareForSorting(alias comp, F)(F arr)
169 if(is(F == real[])) {
170     foreach(elem; arr) {
171         if(isNaN(elem)) {
172             throw new SortException("Can't sort NaNs.");
173         }
174     }
175
176     return arr;
177 }
178
179 /* Check for NaN and do some bit twiddling so that a float or double can be
180  * compared as an integer.  This results in approximately a 40% speedup
181  * compared to just sorting as floats.
182  */
183  auto prepareForSorting(alias comp, F)(F arr)
184 if(is(F == double[]) || is(F == float[])) {
185     static if(is(F == double[])) {
186         alias long Int;
187         enum signMask = 1UL << 63;
188     } else {
189         alias int Int;
190         enum signMask = 1U << 31;
191     }
192
193     Int[] intArr = cast(Int[]) arr;
194     foreach(i, ref elem; intArr) {
195         if(intIsNaN(elem)) {
196             // Roll back the bit twiddling in case someone catches the
197             // exception, so that they don't see corrupted values.
198             postProcess!comp(intArr[0..i]);
199
200             throw new SortException("Can't sort NaNs.");
201         }
202
203         static if(isSimpleComparison!comp) {
204             if(elem & signMask) {
205                 // Negative.
206                 elem ^= signMask;
207                 elem = ~elem;
208             }
209         }
210     }
211
212     static if(isSimpleComparison!comp) {
213         return intArr;
214     } else {
215         return arr;
216     }
217 }
218
219 /*private*/ void postProcess(alias comp, T)(T arr)
220 if(!isSimpleComparison!comp || (!is(T == double[]) && !is(T == float[]))) {}
221
222 /* Undo bit twiddling from prepareForSorting() to get back original
223  * floating point numbers.
224  */
225 /*private*/ void postProcess(alias comp, F)(F arr)
226 if((is(F == double[]) || is(F == float[])) && isSimpleComparison!comp) {
227     static if(is(F == double[])) {
228         alias long Int;
229         enum mask = 1UL << 63;
230     } else {
231         alias int Int;
232         enum mask = 1U << 31;
233     }
234
235     Int[] useMe = cast(Int[]) arr;
236     foreach(ref elem; useMe) {
237         if(elem & mask) {
238             elem = ~elem;
239             elem ^= mask;
240         }
241     }
242 }
243
244 version(unittest) {
245     static void testFloating(alias fun, F)() {
246         F[] testL = new F[1_000];
247         foreach(ref e; testL) {
248             e = uniform(-1_000_000, 1_000_000);
249         }
250         auto testL2 = testL.dup;
251
252         static if(__traits(isSame, fun, mergeSortTemp)) {
253             auto temp1 = testL.dup;
254             auto temp2 = testL.dup;
255         }
256
257         foreach(i; 0..200) {
258             randomShuffle(zip(testL, testL2));
259             uint len = uniform(0, 1_000);
260
261             static if(__traits(isSame, fun, mergeSortTemp)) {
262                 fun!"a > b"(testL[0..len], testL2[0..len], temp1[0..len], temp2[0..len]);
263             } else {
264                 fun!("a > b")(testL[0..len], testL2[0..len]);
265             }
266
267             assert(isSorted!("a > b")(testL[0..len]));
268             assert(testL == testL2, fun.stringof ~ '\t' ~ F.stringof);
269         }
270     }
271 }
272
273 void rotateLeft(T)(T input)
274 if(isRandomAccessRange!(T)) {
275     if(input.length < 2) return;
276     ElementType!(T) temp = input[0];
277     foreach(i; 1..input.length) {
278         input[i-1] = input[i];
279     }
280     input[$-1] = temp;
281 }
282
283 void rotateRight(T)(T input)
284 if(isRandomAccessRange!(T)) {
285     if(input.length < 2) return;
286     ElementType!(T) temp = input[$-1];
287     for(size_t i = input.length - 1; i > 0; i--) {
288         input[i] = input[i-1];
289     }
290     input[0] = temp;
291 }
292
293 /* Returns the index, NOT the value, of the median of the first, middle, last
294  * elements of data.*/
295 size_t medianOf3(alias compFun, T)(T[] data) {
296     alias binaryFun!(compFun) comp;
297     immutable size_t mid = data.length / 2;
298     immutable uint result = ((cast(uint) (comp(data[0], data[mid]))) << 2) |
299                             ((cast(uint) (comp(data[0], data[$ - 1]))) << 1) |
300                             (cast(uint) (comp(data[mid], data[$ - 1])));
301
302     assert(result != 2 && result != 5 && result < 8); // Cases 2, 5 can't happen.
303     switch(result) {
304         case 1:  // 001
305         case 6:  // 110
306             return data.length - 1;
307         case 3:  // 011
308         case 4:  // 100
309             return 0;
310         case 0:  // 000
311         case 7:  // 111
312             return mid;
313         default:
314             assert(0);
315     }
316     assert(0);
317 }
318
319 unittest {
320     assert(medianOf3!("a < b")([1,2,3,4,5]) == 2);
321     assert(medianOf3!("a < b")([1,2,5,4,3]) == 4);
322     assert(medianOf3!("a < b")([3,2,1,4,5]) == 0);
323     assert(medianOf3!("a < b")([5,2,3,4,1]) == 2);
324     assert(medianOf3!("a < b")([5,2,1,4,3]) == 4);
325     assert(medianOf3!("a < b")([3,2,5,4,1]) == 0);
326 }
327
328
329 /**Quick sort.  Unstable, O(N log N) time average, worst
330  * case, O(log N) space, small constant term in time complexity.
331  *
332  * In this implementation, the following steps are taken to avoid the
333  * O(N<sup>2</sup>) worst case of naive quick sorts:
334  *
335  * 1.  At each recursion, the median of the first, middle and last elements of
336  *     the array is used as the pivot.
337  *
338  * 2.  To handle the case of few unique elements, the "Fit Pivot" technique
339  *     previously decribed by Andrei Alexandrescu is used.  This allows
340  *     reasonable performance with few unique elements, with zero overhead
341  *     in other cases.
342  *
343  * 3.  After a much larger than expected amount of recursion has occured,
344  *     this function transitions to a heap sort.  This guarantees an O(N log N)
345  *     worst case.*/
346 T[0] qsort(alias compFun = "a < b", T...)(T data)
347 in {
348     assert(data.length > 0);
349     size_t len = data[0].length;
350     foreach(array; data[1..$]) {
351         assert(array.length == len);
352     }
353 } body {
354     if(data[0].length < 25) {
355         // Skip computing logarithm rather than waiting until qsortImpl to
356         // do this.
357         return insertionSort!compFun(data);
358     }
359
360     // Determines the transition point to a heap sort.
361     uint TTL = cast(uint) (log2(cast(real) data[0].length) * 2);
362
363     auto toSort = prepareForSorting!compFun(data[0]);
364
365     /* qsort() throws if an invalid comparison function is passed.  Even in
366      * this case, the data should be post-processed so the bit twiddling
367      * hacks for floats can be undone.
368      */
369     try {
370         qsortImpl!(compFun)(toSort, data[1..$], TTL);
371     } finally {
372         postProcess!compFun(data[0]);
373     }
374
375     return data[0];
376 }
377
378 //TTL = time to live, before transitioning to heap sort.
379 void qsortImpl(alias compFun, T...)(T data, uint TTL) {
380     alias binaryFun!(compFun) comp;
381     if(data[0].length < 25) {
382          insertionSortImpl!(compFun)(data);
383          return;
384     }
385     if(TTL == 0) {
386         heapSortImpl!(compFun)(data);
387         return;
388     }
389     TTL--;
390
391     {
392         immutable size_t med3 = medianOf3!(comp)(data[0]);
393         foreach(array; data) {
394             auto temp = array[med3];
395             array[med3] = array[$ - 1];
396             array[$ - 1] = temp;
397         }
398     }
399
400     T less, greater;
401     size_t lessI = size_t.max, greaterI = data[0].length - 1;
402
403     auto pivot = data[0][$ - 1];
404     if(comp(pivot, pivot)) {
405         throw new SortException
406             ("Comparison function must be such that compFun(x, x) == false.");
407     }
408
409     while(true) {
410         while(comp(data[0][++lessI], pivot)) {}
411         while(greaterI > 0 && comp(pivot, data[0][--greaterI])) {}
412
413         if(lessI < greaterI) {
414             foreach(array; data) {
415                 auto temp = array[lessI];
416                 array[lessI] = array[greaterI];
417                 array[greaterI] = temp;
418             }
419         } else break;
420     }
421
422     foreach(ti, array; data) {
423         auto temp = array[$ - 1];
424         array[$ - 1] = array[lessI];
425         array[lessI] = temp;
426         less[ti] = array[0..min(lessI, greaterI + 1)];
427         greater[ti] = array[lessI + 1..$];
428     }
429     // Allow tail recursion optimization for larger block.  This guarantees
430     // that, given a reasonable amount of stack space, no stack overflow will
431     // occur even in pathological cases.
432     if(greater[0].length > less[0].length) {
433         qsortImpl!(compFun)(less, TTL);
434         qsortImpl!(compFun)(greater, TTL);
435         return;
436     } else {
437         qsortImpl!(compFun)(greater, TTL);
438         qsortImpl!(compFun)(less, TTL);
439     }
440 }
441
442 unittest {
443     {  // Test integer.
444         uint[] test = new uint[1_000];
445         foreach(ref e; test) {
446             e = uniform(0, 100);
447         }
448         auto test2 = test.dup;
449         foreach(i; 0..1_000) {
450             randomShuffle(zip(test, test2));
451             uint len = uniform(0, 1_000);
452             qsort(test[0..len], test2[0..len]);
453             assert(isSorted(test[0..len]));
454             assert(test == test2);
455         }
456     }
457
458     testFloating!(qsort, float)();
459     testFloating!(qsort, double)();
460     testFloating!(qsort, real)();
461
462     auto nanArr = [double.nan, 1.0];
463     try {
464         qsort(nanArr);
465         assert(0);
466     } catch(SortException) {}
467 }
468
469 /* Keeps track of what array merge sort data is in.  This is a speed hack to
470  * copy back and forth less.*/
471 /*private*/ enum {
472     DATA,
473     TEMP
474 }
475
476 /**Merge sort.  O(N log N) time, O(N) space, small constant.  Stable sort.
477  * If last argument is a ulong* instead of an array-like type,
478  * the dereference of the ulong* will be incremented by the bubble sort
479  * distance between the input array and the sorted version.  This is useful
480  * in some statistics functions such as Kendall's tau.*/
481 T[0] mergeSort(alias compFun = "a < b", T...)(T data)
482 in {
483     assert(data.length > 0);
484     size_t len = data[0].length;
485     foreach(array; data[1..$]) {
486         static if(!is(typeof(array) == ulong*))
487             assert(array.length == len);
488     }
489 } body {
490     if(data[0].length < 65) {  //Avoid mem allocation.
491         return insertionSortImpl!(compFun)(data);
492     }
493     static if(is(T[$ - 1] == ulong*)) {
494         enum dl = data.length - 1;
495         alias data[$ - 1] swapCount;
496     } else {
497         enum dl = data.length;
498         alias TypeTuple!() swapCount; // Place holder.
499     }
500
501     auto keyArr = prepareForSorting!compFun(data[0]);
502     auto toSort = TypeTuple!(keyArr, data[1..dl]);
503
504     auto stateCache = TempAlloc.getState;
505     typeof(toSort) temp;
506     foreach(i, array; temp) {
507         temp[i] = newStack!(typeof(temp[i][0]))(data[i].length, stateCache);
508     }
509
510     uint res = mergeSortImpl!(compFun)(toSort, temp, swapCount);
511     if(res == TEMP) {
512         foreach(ti, array; temp) {
513             toSort[ti][0..$] = temp[ti][0..$];
514         }
515     }
516
517     foreach(array; temp) {
518         TempAlloc.free(stateCache);
519     }
520
521     postProcess!compFun(data[0]);
522     return data[0];
523 }
524
525 unittest {
526     uint[] test = new uint[1_000], stability = new uint[1_000];
527     uint[] temp1 = new uint[1_000], temp2 = new uint[1_000];
528     foreach(ref e; test) {
529         e = uniform(0, 100);  //Lots of ties.
530     }
531     foreach(i; 0..100) {
532         ulong mergeCount = 0, bubbleCount = 0;
533         foreach(j, ref e; stability) {
534             e = cast(uint) j;
535         }
536         randomShuffle(test);
537         uint len = uniform(0, 1_000);
538         // Testing bubble sort distance against bubble sort,
539         // since bubble sort distance computed by bubble sort
540         // is straightforward, unlikely to contain any subtle bugs.
541         bubbleSort(test[0..len].dup, &bubbleCount);
542         if(i & 1)  // Test both temp and non-temp branches.
543             mergeSort(test[0..len], stability[0..len], &mergeCount);
544         else
545             mergeSortTemp(test[0..len], stability[0..len], temp1[0..len],
546                           temp2[0..len], &mergeCount);
547         assert(bubbleCount == mergeCount);
548         assert(isSorted(test[0..len]));
549         foreach(j; 1..len) {
550             if(test[j - 1] == test[j]) {
551                 assert(stability[j - 1] < stability[j]);
552             }
553         }
554     }
555     // Test without swapCounts.
556     foreach(i; 0..1000) {
557         foreach(j, ref e; stability) {
558             e = cast(uint) j;
559         }
560         randomShuffle(test);
561         uint len = uniform(0, 1_000);
562         if(i & 1)  // Test both temp and non-temp branches.
563             mergeSort(test[0..len], stability[0..len]);
564         else
565             mergeSortTemp(test[0..len], stability[0..len], temp1[0..len],
566                           temp2[0..len]);
567         assert(isSorted(test[0..len]));
568         foreach(j; 1..len) {
569             if(test[j - 1] == test[j]) {
570                 assert(stability[j - 1] < stability[j]);
571             }
572         }
573     }
574
575     testFloating!(mergeSort, float)();
576     testFloating!(mergeSort, double)();
577     testFloating!(mergeSort, real)();
578
579     testFloating!(mergeSortTemp, float)();
580     testFloating!(mergeSortTemp, double)();
581     testFloating!(mergeSortTemp, real)();
582 }
583
584 /**Merge sort, allowing caller to provide a temp variable.  This allows
585  * recycling instead of repeated allocations.  If D is data, T is temp,
586  * and U is a ulong* for calculating bubble sort distance, this can be called
587  * as mergeSortTemp(D, D, D, T, T, T, U) or mergeSortTemp(D, D, D, T, T, T)
588  * where each D has a T of corresponding type.
589  *
590  * Examples:
591  * ---
592  * int[] foo = [3, 1, 2, 4, 5].dup;
593  * int[] temp = new uint[5];
594  * mergeSortTemp!("a < b")(foo, temp);
595  * assert(foo == [1, 2, 3, 4, 5]); // The contents of temp will be undefined.
596  * foo = [3, 1, 2, 4, 5].dup;
597  * real bar = [3.14L, 15.9, 26.5, 35.8, 97.9];
598  * real temp2 = new real[5];
599  * mergeSortTemp(foo, bar, temp, temp2);
600  * assert(foo == [1, 2, 3, 4, 5]);
601  * assert(bar == [15.9L, 26.5, 3.14, 35.8, 97.9]);
602  * // The contents of both temp and temp2 will be undefined.
603  * ---
604  */
605 T[0] mergeSortTemp(alias compFun = "a < b", T...)(T data)
606 in {
607     assert(data.length > 0);
608     size_t len = data[0].length;
609     foreach(array; data[1..$]) {
610         static if(!is(typeof(array) == ulong*))
611             assert(array.length == len);
612     }
613 } body {
614     static if(is(T[$ - 1] == ulong*)) {
615         enum dl = data.length - 1;
616     } else {
617         enum dl = data.length;
618     }
619
620     auto keyArr = prepareForSorting!compFun(data[0]);
621     auto keyTemp = cast(typeof(keyArr)) data[dl / 2];
622     auto toSort = TypeTuple!(
623         keyArr,
624         data[1..dl / 2],
625         keyTemp,
626         data[dl / 2 + 1..$]
627     );
628
629     uint res = mergeSortImpl!(compFun)(toSort);
630
631     if(res == TEMP) {
632         foreach(ti, array; toSort[0..$ / 2]) {
633             toSort[ti][0..$] = toSort[ti + dl / 2][0..$];
634         }
635     }
636
637     postProcess!compFun(data[0]);
638     return data[0];
639 }
640
641 /*private*/ uint mergeSortImpl(alias compFun = "a < b", T...)(T dataIn) {
642     static if(is(T[$ - 1] == ulong*)) {
643         alias dataIn[$ - 1] swapCount;
644         alias dataIn[0..dataIn.length / 2] data;
645         alias dataIn[dataIn.length / 2..$ - 1] temp;
646     } else {  // Make empty dummy tuple.
647         alias TypeTuple!() swapCount;
648         alias dataIn[0..dataIn.length / 2] data;
649         alias dataIn[dataIn.length / 2..$] temp;
650     }
651
652     if(data[0].length < 50) {
653         insertionSortImpl!(compFun)(data, swapCount);
654         return DATA;
655     }
656     size_t half = data[0].length / 2;
657     typeof(data) left, right, tempLeft, tempRight;
658     foreach(ti, array; data) {
659         left[ti] = array[0..half];
660         right[ti] = array[half..$];
661         tempLeft[ti] = temp[ti][0..half];
662         tempRight[ti] = temp[ti][half..$];
663     }
664
665     /* Implementation note:  The lloc, rloc stuff is a hack to avoid constantly
666      * copying data back and forth between the data and temp arrays.
667      * Instad of copying every time, I keep track of which array the last merge
668      * went into, and only copy at the end or if the two sides ended up in
669      * different arrays.*/
670     uint lloc = mergeSortImpl!(compFun)(left, tempLeft, swapCount);
671     uint rloc = mergeSortImpl!(compFun)(right, tempRight, swapCount);
672     if(lloc == DATA && rloc == TEMP) {
673         foreach(ti, array; tempLeft) {
674             array[] = left[ti][];
675         }
676         lloc = TEMP;
677     } else if(lloc == TEMP && rloc == DATA) {
678         foreach(ti, array; tempRight) {
679             array[] = right[ti][];
680         }
681     }
682     if(lloc == DATA) {
683         merge!(compFun)(left, right, temp, swapCount);
684         return TEMP;
685     } else {
686         merge!(compFun)(tempLeft, tempRight, data, swapCount);
687         return DATA;
688     }
689 }
690
691 /*private*/ void merge(alias compFun, T...)(T data) {
692     alias binaryFun!(compFun) comp;
693
694     static if(is(T[$ - 1] == ulong*)) {
695         enum dl = data.length - 1;  //Length after removing swapCount;
696         alias data[$ - 1] swapCount;
697     } else {
698         enum dl = data.length;
699     }
700
701     static assert(dl % 3 == 0);
702     alias data[0..dl / 3] left;
703     alias  data[dl / 3..dl * 2 / 3] right;
704     alias data[dl * 2 / 3..dl] result;
705     static assert(left.length == right.length && right.length == result.length);
706     size_t i = 0, l = 0, r = 0;
707     while(l < left[0].length && r < right[0].length) {
708         if(comp(right[0][r], left[0][l])) {
709
710             static if(is(T[$ - 1] == ulong*)) {
711                 *swapCount += left[0].length - l;
712             }
713
714             foreach(ti, array; result) {
715                 result[ti][i] = right[ti][r];
716             }
717             r++;
718         } else {
719             foreach(ti, array; result) {
720                 result[ti][i] = left[ti][l];
721             }
722             l++;
723         }
724         i++;
725     }
726     if(right[0].length > r) {
727         foreach(ti, array; result) {
728             result[ti][i..$] = right[ti][r..$];
729         }
730     } else {
731         foreach(ti, array; result) {
732             result[ti][i..$] = left[ti][l..$];
733         }
734     }
735 }
736
737 /**In-place merge sort, based on C++ STL's stable_sort().  O(N log<sup>2</sup> N)
738  * time complexity, O(1) space complexity, stable.  Much slower than plain
739  * old mergeSort(), so only use it if you really need the O(1) space.*/
740 T[0] mergeSortInPlace(alias compFun = "a < b", T...)(T data)
741 in {
742     assert(data.length > 0);
743     size_t len = data[0].length;
744     foreach(array; data[1..$]) {
745         assert(array.length == len);
746     }
747 } body {
748     auto toSort = prepareForSorting!compFun(data[0]);
749     mergeSortInPlaceImpl!compFun(toSort, data[1..$]);
750     postProcess!compFun(data[0]);
751     return data[0];
752 }
753
754 /*private*/ T[0] mergeSortInPlaceImpl(alias compFun, T...)(T data) {
755     if (data[0].length <= 100)
756         return insertionSortImpl!(compFun)(data);
757
758     T left, right;
759     foreach(ti, array; data) {
760         left[ti] = array[0..$ / 2];
761         right[ti] = array[$ / 2..$];
762     }
763
764     mergeSortInPlace!(compFun, T)(right);
765     mergeSortInPlace!(compFun, T)(left);
766     mergeInPlace!(compFun)(data, data[0].length / 2);
767     return data[0];
768 }
769
770 unittest {
771     uint[] test = new uint[1_000], stability = new uint[1_000];
772     foreach(ref e; test) {
773         e = uniform(0, 100);  //Lots of ties.
774     }
775     uint[] test2 = test.dup;
776     foreach(i; 0..1000) {
777         foreach(j, ref e; stability) {
778             e = cast(uint) j;
779         }
780         randomShuffle(zip(test, test2));
781         uint len = uniform(0, 1_000);
782         mergeSortInPlace(test[0..len], test2[0..len], stability[0..len]);
783         assert(isSorted(test[0..len]));
784         assert(test == test2);
785         foreach(j; 1..len) {
786             if(test[j - 1] == test[j]) {
787                 assert(stability[j - 1] < stability[j]);
788             }
789         }
790     }
791
792     testFloating!(mergeSortInPlace, float)();
793     testFloating!(mergeSortInPlace, double)();
794     testFloating!(mergeSortInPlace, real)();
795 }
796
797 // Loosely based on C++ STL's __merge_without_buffer().
798 /*private*/ void mergeInPlace(alias compFun = "a < b", T...)(T data, size_t middle) {
799     static size_t largestLess(alias compFun, T)(T[] data, T value) {
800         alias binaryFun!(compFun) comp;
801         size_t len = data.length, first, last = data.length, half, middle;
802
803         while (len > 0) {
804             half = len / 2;
805             middle = first + half;
806             if (comp(data[middle], value)) {
807                 first = middle + 1;
808                 len = len - half - 1;
809             } else
810                 len = half;
811         }
812         return first;
813     }
814
815     static size_t smallestGr(alias compFun, T)(T[] data, T value) {
816         alias binaryFun!(compFun) comp;
817         size_t len = data.length, first, last = data.length, half, middle;
818
819         while (len > 0) {
820             half = len / 2;
821             middle = first + half;
822             if (comp(value, data[middle]))
823                 len = half;
824             else {
825                 first = middle + 1;
826                 len = len - half - 1;
827             }
828         }
829         return first;
830     }
831
832
833     alias binaryFun!(compFun) comp;
834     if (data[0].length < 2 || middle == 0 || middle == data[0].length)
835         return;
836     if (data[0].length == 2) {
837         if(comp(data[0][1], data[0][0])) {
838             foreach(array; data) {
839                 auto temp = array[0];
840                 array[0] = array[1];
841                 array[1] = temp;
842             }
843         }
844         return;
845     }
846
847     size_t half1, half2, firstCut, secondCut;
848
849     if (middle > data[0].length - middle) {
850         half1 = middle / 2;
851         auto pivot = data[0][half1];
852         half2 = largestLess!(compFun)(data[0][middle..$], pivot);
853     } else {
854         half2 = (data[0].length - middle) / 2;
855         auto pivot = data[0][half2 + middle];
856         half1 = smallestGr!(compFun)(data[0][0..middle], pivot);
857     }
858
859     foreach(array; data) {
860         bringToFront(array[half1..middle], array[middle..middle + half2]);
861     }
862     size_t newMiddle = half1 + half2;
863
864     T left, right;
865     foreach(ti, array; data) {
866         left[ti] = array[0..newMiddle];
867         right[ti] = array[newMiddle..$];
868     }
869
870     mergeInPlace!(compFun, T)(left, half1);
871     mergeInPlace!(compFun, T)(right, half2 + middle - newMiddle);
872 }
873
874
875 /**Heap sort.  Unstable, O(N log N) time average and worst case, O(1) space,
876  * large constant term in time complexity.*/
877 T[0] heapSort(alias compFun = "a < b", T...)(T data)
878 in {
879     assert(data.length > 0);
880     size_t len = data[0].length;
881     foreach(array; data[1..$]) {
882         assert(array.length == len);
883     }
884 } body {
885     auto toSort = prepareForSorting!compFun(data[0]);
886     heapSortImpl!compFun(toSort, data[1..$]);
887     postProcess!compFun(data[0]);
888     return data[0];
889 }
890
891 /*private*/ T[0] heapSortImpl(alias compFun, T...)(T input) {
892     // Heap sort has such a huge constant that insertion sort's faster for N <
893     // 100 (for reals; even larger for smaller types).
894     if(input[0].length <= 100) {
895         return insertionSortImpl!(compFun)(input);
896     }
897
898     alias binaryFun!(compFun) comp;
899     if(input[0].length < 2) return input[0];
900     makeMultiHeap!(compFun)(input);
901     for(size_t end = input[0].length - 1; end > 0; end--) {
902         foreach(ti, ia; input) {
903             auto temp = ia[end];
904             ia[end] = ia[0];
905             ia[0] = temp;
906         }
907         multiSiftDown!(compFun)(input, 0, end);
908     }
909     return input[0];
910 }
911
912 unittest {
913     uint[] test = new uint[1_000];
914     foreach(ref e; test) {
915         e = uniform(0, 100_000);
916     }
917     auto test2 = test.dup;
918     foreach(i; 0..1_000) {
919         randomShuffle(zip(test, test2));
920         uint len = uniform(0, 1_000);
921         heapSort(test[0..len], test2[0..len]);
922         assert(isSorted(test[0..len]));
923         assert(test == test2);
924     }
925
926     testFloating!(heapSort, float)();
927     testFloating!(heapSort, double)();
928     testFloating!(heapSort, real)();
929 }
930
931 void makeMultiHeap(alias compFun = "a < b", T...)(T input) {
932     if(input[0].length < 2)
933         return;
934     alias binaryFun!(compFun) comp;
935     for(sizediff_t start = (input[0].length - 1) / 2; start >= 0; start--) {
936         multiSiftDown!(compFun)(input, start, input[0].length);
937     }
938 }
939
940 void multiSiftDown(alias compFun = "a < b", T...)
941      (T input, size_t root, size_t end) {
942     alias binaryFun!(compFun) comp;
943     alias input[0] a;
944     while(root * 2 + 1 < end) {
945         size_t child = root * 2 + 1;
946         if(child + 1 < end && comp(a[child], a[child + 1])) {
947             child++;
948         }
949         if(comp(a[root], a[child])) {
950             foreach(ia; input) {
951                 auto temp = ia[root];
952                 ia[root] = ia[child];
953                 ia[child] = temp;
954             }
955             root = child;
956         }
957         else return;
958     }
959 }
960
961 /**Insertion sort.  O(N<sup>2</sup>) time worst, average case, O(1) space, VERY
962  * small constant, which is why it's useful for sorting small subarrays in
963  * divide and conquer algorithms.  If last argument is a ulong*, increments
964  * the dereference of this argument by the bubble sort distance between the
965  * input array and the sorted version of the input.*/
966 T[0] insertionSort(alias compFun = "a < b", T...)(T data)
967 in {
968     assert(data.length > 0);
969     size_t len = data[0].length;
970     foreach(array; data[1..$]) {
971         static if(!is(typeof(array) == ulong*))
972             assert(array.length == len);
973     }
974 } body {
975     auto toSort = prepareForSorting!compFun(data[0]);
976     insertionSortImpl!compFun(toSort, data[1..$]);
977     postProcess!compFun(data[0]);
978     return data[0];
979 }
980
981 private template IndexType(T) {
982     alias typeof(T.init[0]) IndexType;
983 }
984
985 /*private*/ T[0] insertionSortImpl(alias compFun, T...)(T data) {
986     alias binaryFun!(compFun) comp;
987     static if(is(T[$ - 1] == ulong*)) {
988         enum dl = data.length - 1;
989         alias data[$ - 1] swapCount;
990     } else {
991         enum dl = data.length;
992     }
993
994     alias data[0] keyArray;
995     if(keyArray.length < 2) {
996         return keyArray;
997     }
998
999     // Yes, I measured this, caching this value is actually faster on DMD.
1000     immutable maxJ = keyArray.length - 1;
1001     for(size_t i = keyArray.length - 2; i != size_t.max; --i) {
1002         size_t j = i;
1003
1004         Tuple!(staticMap!(IndexType, typeof(data[0..dl]))) temp = void;
1005         foreach(ti, Type; typeof(data[0..dl])) {
1006             static if(hasElaborateAssign!Type) {
1007                 emplace(&(temp.field[ti]), data[ti][i]);
1008             } else {
1009                 temp.field[ti] = data[ti][i];
1010             }
1011         }
1012
1013         for(; j < maxJ && comp(keyArray[j + 1], temp.field[0]); ++j) {
1014             // It's faster to do all copying here than to call rotateLeft()
1015             // later, probably due to better ILP.
1016             foreach(array; data[0..dl]) {
1017                 array[j] = array[j + 1];
1018             }
1019         }
1020
1021         foreach(ti, Unused; typeof(temp.field)) {
1022             data[ti][j] = temp.field[ti];
1023         }
1024
1025         static if(is(typeof(swapCount))) {
1026             *swapCount += (j - i);  //Increment swapCount variable.
1027         }
1028     }
1029
1030     return keyArray;
1031 }
1032
1033 unittest {
1034     uint[] test = new uint[100], stability = new uint[100];
1035     foreach(ref e; test) {
1036         e = uniform(0, 100);  //Lots of ties.
1037     }
1038     foreach(i; 0..1_000) {
1039         ulong insertCount = 0, bubbleCount = 0;
1040         foreach(j, ref e; stability) {
1041             e = cast(uint) j;
1042         }
1043         randomShuffle(test);
1044         uint len = uniform(0, 100);
1045         // Testing bubble sort distance against bubble sort,
1046         // since bubble sort distance computed by bubble sort
1047         // is straightforward, unlikely to contain any subtle bugs.
1048         bubbleSort(test[0..len].dup, &bubbleCount);
1049         insertionSort(test[0..len], stability[0..len], &insertCount);
1050         assert(bubbleCount == insertCount);
1051         assert(isSorted(test[0..len]));
1052         foreach(j; 1..len) {
1053             if(test[j - 1] == test[j]) {
1054                 assert(stability[j - 1] < stability[j]);
1055             }
1056         }
1057     }
1058 }
1059
1060 // Kept around only because it's easy to implement, and therefore good for
1061 // testing more complex sort functions against.  Especially useful for bubble
1062 // sort distance, since it's straightforward with a bubble sort, and not with
1063 // a merge sort or insertion sort.
1064 version(unittest) {
1065     T[0] bubbleSort(alias compFun = "a < b", T...)(T data) {
1066         alias binaryFun!(compFun) comp;
1067         static if(is(T[$ - 1] == ulong*))
1068             enum dl = data.length - 1;
1069         else enum dl = data.length;
1070         if(data[0].length < 2)
1071             return data[0];
1072         bool swapExecuted;
1073         foreach(i; 0..data[0].length) {
1074             swapExecuted = false;
1075             foreach(j; 1..data[0].length) {
1076                 if(comp(data[0][j], data[0][j - 1])) {
1077                     swapExecuted = true;
1078                     static if(is(T[$ - 1] == ulong*))
1079                         (*(data[$-1]))++;
1080                     foreach(array; data[0..dl])
1081                         swap(array[j-1], array[j]);
1082                 }
1083             }
1084             if(!swapExecuted) return data[0];
1085         }
1086         return data[0];
1087     }
1088 }
1089
1090 unittest {
1091     //Sanity check for bubble sort distance.
1092     uint[] test = [4, 5, 3, 2, 1];
1093     ulong dist = 0;
1094     bubbleSort(test, &dist);
1095     assert(dist == 9);
1096     dist = 0;
1097     test = [6, 1, 2, 4, 5, 3];
1098     bubbleSort(test, &dist);
1099     assert(dist == 7);
1100 }
1101
1102 /**Returns the kth largest/smallest element (depending on compFun, 0-indexed)
1103  * in the input array in O(N) time.  Allocates memory, does not modify input
1104  * array.*/
1105 T quickSelect(alias compFun = "a < b", T)(T[] data, sizediff_t k) {
1106     auto dataDup = data.tempdup;
1107     scope(exit) TempAlloc.free;
1108     return partitionK!(compFun)(dataDup, k);
1109 }
1110
1111 /**Partitions the input data according to compFun, such that position k contains
1112  * the kth largest/smallest element according to compFun.  For all elements e
1113  * with indices < k, !compFun(data[k], e) is guaranteed to be true.  For all
1114  * elements e with indices > k, !compFun(e, data[k]) is guaranteed to be true.
1115  * For example, if compFun is "a < b", all elements with indices < k will be
1116  * <= data[k], and all elements with indices larger than k will be >= k.
1117  * Reorders any additional input arrays in lockstep.
1118  *
1119  * Examples:
1120  * ---
1121  * auto foo = [3, 1, 5, 4, 2].dup;
1122  * auto secondSmallest = partitionK(foo, 1);
1123  * assert(secondSmallest == 2);
1124  * foreach(elem; foo[0..1]) {
1125  *     assert(elem <= foo[1]);
1126  * }
1127  * foreach(elem; foo[2..$]) {
1128  *     assert(elem >= foo[1]);
1129  * }
1130  * ---
1131  *
1132  * Returns:  The kth element of the array.
1133  */
1134 ElementType!(T[0]) partitionK(alias compFun = "a < b", T...)(T data, ptrdiff_t k)
1135 in {
1136     assert(data.length > 0);
1137     size_t len = data[0].length;
1138     foreach(array; data[1..$]) {
1139         assert(array.length == len);
1140     }
1141 } body {
1142     // Don't use the float-to-int trick because it's actually slower here
1143     // because the main part of the algorithm is O(N), not O(N log N).
1144     return partitionKImpl!compFun(data, k);
1145 }
1146
1147 /*private*/ ElementType!(T[0]) partitionKImpl(alias compFun, T...)(T data, ptrdiff_t k) {
1148     alias binaryFun!(compFun) comp;
1149
1150     {
1151         immutable size_t med3 = medianOf3!(comp)(data[0]);
1152         foreach(array; data) {
1153             auto temp = array[med3];
1154             array[med3] = array[$ - 1];
1155             array[$ - 1] = temp;
1156         }
1157     }
1158
1159     ptrdiff_t lessI = -1, greaterI = data[0].length - 1;
1160     auto pivot = data[0][$ - 1];
1161     while(true) {
1162         while(comp(data[0][++lessI], pivot)) {}
1163         while(greaterI > 0 && comp(pivot, data[0][--greaterI])) {}
1164
1165         if(lessI < greaterI) {
1166             foreach(array; data) {
1167                 auto temp = array[lessI];
1168                 array[lessI] = array[greaterI];
1169                 array[greaterI] = temp;
1170             }
1171         } else break;
1172     }
1173     foreach(array; data) {
1174         auto temp = array[lessI];
1175         array[lessI] = array[$ - 1];
1176         array[$ - 1] = temp;
1177     }
1178
1179     if((greaterI < k && lessI >= k) || lessI == k) {
1180         return data[0][k];
1181     } else if(lessI < k) {
1182         foreach(ti, array; data) {
1183             data[ti] = array[lessI + 1..$];
1184         }
1185         return partitionK!(compFun, T)(data, k - lessI - 1);
1186     } else {
1187         foreach(ti, array; data) {
1188             data[ti] = array[0..min(greaterI + 1, lessI)];
1189         }
1190         return partitionK!(compFun, T)(data, k);
1191     }
1192 }
1193
1194 template ArrayElemType(T : T[]) {
1195     alias T ArrayElemType;
1196 }
1197
1198 unittest {
1199     enum n = 1000;
1200     uint[] test = new uint[n];
1201     uint[] test2 = new uint[n];
1202     uint[] lockstep = new uint[n];
1203     foreach(ref e; test) {
1204         e = uniform(0, 1000);
1205     }
1206     foreach(i; 0..1_000) {
1207         test2[] = test[];
1208         lockstep[] = test[];
1209         uint len = uniform(0, n - 1) + 1;
1210         qsort!("a > b")(test2[0..len]);
1211         int k = uniform(0, len);
1212         auto qsRes = partitionK!("a > b")(test[0..len], lockstep[0..len], k);
1213         assert(qsRes == test2[k]);
1214         foreach(elem; test[0..k]) {
1215             assert(elem >= test[k]);
1216         }
1217         foreach(elem; test[k + 1..len]) {
1218             assert(elem <= test[k]);
1219         }
1220         assert(test == lockstep);
1221     }
1222 }
1223
1224 /**Given a set of data points entered through the put function, this output range
1225  * maintains the invariant that the top N according to compFun will be
1226  * contained in the data structure.  Uses a heap internally, O(log N) insertion
1227  * time.  Good for finding the largest/smallest N elements of a very large
1228  * dataset that cannot be sorted quickly in its entirety, and may not even fit
1229  * in memory. If less than N datapoints have been entered, all are contained in
1230  * the structure.
1231  *
1232  * Examples:
1233  * ---
1234  * Random gen;
1235  * gen.seed(unpredictableSeed);
1236  * uint[] nums = seq(0U, 100U);
1237  * auto less = TopN!(uint, "a < b")(10);
1238  * auto more = TopN!(uint, "a > b")(10);
1239  * randomShuffle(nums, gen);
1240  * foreach(n; nums) {
1241  *     less.put(n);
1242  *     more.put(n);
1243  * }
1244  *  assert(less.getSorted == [0U, 1,2,3,4,5,6,7,8,9]);
1245  *  assert(more.getSorted == [99U, 98, 97, 96, 95, 94, 93, 92, 91, 90]);
1246  *  ---
1247  */
1248 struct TopN(T, alias compFun = "a > b") {
1249 private:
1250     alias binaryFun!(compFun) comp;
1251     uint n;
1252     uint nAdded;
1253
1254     T[] nodes;
1255 public:
1256     /** The variable ntop controls how many elements are retained.*/
1257     this(uint ntop) {
1258         n = ntop;
1259         nodes = new T[n];
1260     }
1261
1262     /** Insert an element into the topN struct.*/
1263     void put(T elem) {
1264         if(nAdded < n) {
1265             nodes[nAdded] = elem;
1266             if(nAdded == n - 1) {
1267                 makeMultiHeap!(comp)(nodes);
1268             }
1269             nAdded++;
1270         } else if(nAdded >= n) {
1271              if(comp(elem, nodes[0])) {
1272                 nodes[0] = elem;
1273                 multiSiftDown!(comp)(nodes, 0, nodes.length);
1274             }
1275         }
1276     }
1277
1278     /**Get the elements currently in the struct.  Returns a reference to
1279      * internal state, elements will be in an arbitrary order.  Cheap.*/
1280     T[] getElements() {
1281         return nodes[0..min(n, nAdded)];
1282     }
1283
1284     /**Returns the elements sorted by compFun.  The array returned is a
1285      * duplicate of the input array.  Not cheap.*/
1286     T[] getSorted() {
1287         return qsort!(comp)(nodes[0..min(n, nAdded)].dup);
1288     }
1289 }
1290
1291 unittest {
1292     alias TopN!(uint, "a < b") TopNLess;
1293     alias TopN!(uint, "a > b") TopNGreater;
1294     Random gen;
1295     gen.seed(unpredictableSeed);
1296     uint[] nums = new uint[100];
1297     foreach(i, ref n; nums) {
1298         n = cast(uint) i;
1299     }
1300     foreach(i; 0..100) {
1301         auto less = TopNLess(10);
1302         auto more = TopNGreater(10);
1303         randomShuffle(nums, gen);
1304         foreach(n; nums) {
1305             less.put(n);
1306             more.put(n);
1307         }
1308         assert(less.getSorted == [0U, 1,2,3,4,5,6,7,8,9]);
1309         assert(more.getSorted == [99U, 98, 97, 96, 95, 94, 93, 92, 91, 90]);
1310     }
1311     foreach(i; 0..100) {
1312         auto less = TopNLess(10);
1313         auto more = TopNGreater(10);
1314         randomShuffle(nums, gen);
1315         foreach(n; nums[0..5]) {
1316             less.put(n);
1317             more.put(n);
1318         }
1319         assert(less.getSorted == qsort!("a < b")(nums[0..5]));
1320         assert(more.getSorted == qsort!("a > b")(nums[0..5]));
1321     }
1322 }
1323
1324 // Verify that there are no TempAlloc memory leaks anywhere in the code covered
1325 // by the unittest.  This should always be the last unittest of the module.
1326 unittest {
1327     auto TAState = TempAlloc.getState;
1328     assert(TAState.used == 0);
1329     assert(TAState.nblocks < 2);
1330 }
Note: See TracBrowser for help on using the browser.