root/trunk/summary.d

Revision 293, 33.6 kB (checked in by dsimcha, 1 year ago)

Better loop unrolling: Speed up ILP optimized functions ~10%.

Line 
1 /**Summary statistics such as mean, median, sum, variance, skewness, kurtosis.
2  * Except for median and median absolute deviation, which cannot be calculated
3  * online, all summary statistics have both an input range interface and an
4  * output range interface.
5  *
6  * Notes: The put method on the structs defined in this module returns this by
7  *        ref.  The use case for returning this is to enable these structs
8  *        to be used with std.algorithm.reduce.  The rationale for returning
9  *        by ref is that the return value usually won't be used, and the
10  *        overhead of returning a large struct by value should be avoided.
11  *
12  * Bugs:  This whole module assumes that input will be doubles or types implicitly
13  *        convertible to double.  No allowances are made for user-defined numeric
14  *        types such as BigInts.  This is necessary for simplicity.  However,
15  *        if you have a function that converts your data to doubles, most of
16  *        these functions work with any input range, so you can simply map
17  *        this function onto your range.
18  *
19  * Author:  David Simcha
20  */
21 /*
22  * Copyright (C) 2008-2010 David Simcha
23  *
24  * License:
25  * Boost Software License - Version 1.0 - August 17th, 2003
26  *
27  * Permission is hereby granted, free of charge, to any person or organization
28  * obtaining a copy of the software and accompanying documentation covered by
29  * this license (the "Software") to use, reproduce, display, distribute,
30  * execute, and transmit the Software, and to prepare derivative works of the
31  * Software, and to permit third-parties to whom the Software is furnished to
32  * do so, all subject to the following:
33  *
34  * The copyright notices in the Software and this entire statement, including
35  * the above license grant, this restriction and the following disclaimer,
36  * must be included in all copies of the Software, in whole or in part, and
37  * all derivative works of the Software, unless such copies or derivative
38  * works are solely in the form of machine-executable object code generated by
39  * a source language processor.
40  *
41  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
42  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
43  * FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
44  * SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
45  * FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
46  * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
47  * DEALINGS IN THE SOFTWARE.
48  */
49
50
51 module dstats.summary;
52
53 import std.algorithm, std.functional, std.conv, std.range, std.array,
54     std.traits, std.math;
55
56 import dstats.sort, dstats.base, dstats.alloc;
57
58 version(unittest) {
59     import std.stdio, dstats.random;
60
61     void main() {
62     }
63 }
64
65 /**Finds median of an input range in O(N) time on average.  In the case of an
66  * even number of elements, the mean of the two middle elements is returned.
67  * This is a convenience founction designed specifically for numeric types,
68  * where the averaging of the two middle elements is desired.  A more general
69  * selection algorithm that can handle any type with a total ordering, as well
70  * as selecting any position in the ordering, can be found at
71  * dstats.sort.quickSelect() and dstats.sort.partitionK().
72  * Allocates memory, does not reorder input data.*/
73 double median(T)(T data)
74 if(doubleInput!(T)) {
75     // Allocate once on TempAlloc if possible, i.e. if we know the length.
76     // This can be done on TempAlloc.  Otherwise, have to use GC heap
77     // and appending.
78     auto dataDup = tempdup(data);
79     scope(exit) TempAlloc.free;
80     return medianPartition(dataDup);
81 }
82
83 /**Median finding as in median(), but will partition input data such that
84  * elements less than the median will have smaller indices than that of the
85  * median, and elements larger than the median will have larger indices than
86  * that of the median. Useful both for its partititioning and to avoid
87  * memory allocations.  Requires a random access range with swappable
88  * elements.*/
89 double medianPartition(T)(T data)
90 if(isRandomAccessRange!(T) &&
91    is(ElementType!(T) : double) &&
92    hasSwappableElements!(T) &&
93    dstats.base.hasLength!(T))
94 {
95     if(data.length == 0) {
96         return double.nan;
97     }
98     // Upper half of median in even length case is just the smallest element
99     // with an index larger than the lower median, after the array is
100     // partially sorted.
101     if(data.length == 1) {
102         return data[0];
103     } else if(data.length & 1) {  //Is odd.
104         return cast(double) partitionK(data, data.length / 2);
105     } else {
106         auto lower = partitionK(data, data.length / 2 - 1);
107         auto upper = ElementType!(T).max;
108
109         // Avoid requiring slicing to be supported.
110         foreach(i; data.length / 2..data.length) {
111             if(data[i] < upper) {
112                 upper = data[i];
113             }
114         }
115         return lower * 0.5 + upper * 0.5;
116     }
117 }
118
119 unittest {
120     float brainDeadMedian(float[] foo) {
121         qsort(foo);
122         if(foo.length & 1)
123             return foo[$ / 2];
124         return (foo[$ / 2] + foo[$ / 2 - 1]) / 2;
125     }
126
127     float[] test = new float[1000];
128     size_t upperBound, lowerBound;
129     foreach(testNum; 0..1000) {
130         foreach(ref e; test) {
131             e = uniform(0f, 1000f);
132         }
133         do {
134             upperBound = uniform(0u, test.length);
135             lowerBound = uniform(0u, test.length);
136         } while(lowerBound == upperBound);
137         if(lowerBound > upperBound) {
138             swap(lowerBound, upperBound);
139         }
140         auto quickRes = median(test[lowerBound..upperBound]);
141         auto accurateRes = brainDeadMedian(test[lowerBound..upperBound]);
142
143         // Off by some tiny fraction in even N case because of division.
144         // No idea why, but it's too small a rounding error to care about.
145         assert(approxEqual(quickRes, accurateRes));
146     }
147
148     // Make sure everything works with lowest common denominator range type.
149     static struct Count {
150         uint num;
151         uint upTo;
152         @property size_t front() {
153             return num;
154         }
155         void popFront() {
156             num++;
157         }
158         @property bool empty() {
159             return num >= upTo;
160         }
161     }
162
163     Count a;
164     a.upTo = 100;
165     assert(approxEqual(median(a), 49.5));
166 }
167
168 /**Plain old data holder struct for median, median absolute deviation.
169  * Alias this'd to the median absolute deviation member.
170  */
171 struct MedianAbsDev {
172     double median;
173     double medianAbsDev;
174
175     alias medianAbsDev this;
176 }
177
178 /**Calculates the median absolute deviation of a dataset.  This is the median
179  * of all absolute differences from the median of the dataset.
180  *
181  * Returns:  A MedianAbsDev struct that contains the median (since it is
182  * computed anyhow) and the median absolute deviation.
183  *
184  * Notes:  No bias correction is used in this implementation, since using
185  * one would require assumptions about the underlying distribution of the data.
186  */
187 MedianAbsDev medianAbsDev(T)(T data)
188 if(doubleInput!(T)) {
189     auto dataDup = tempdup(data);
190     immutable med = medianPartition(dataDup);
191     immutable len = dataDup.length;
192     TempAlloc.free;
193
194     double[] devs = newStack!double(len);
195
196     size_t i = 0;
197     foreach(elem; data) {
198         devs[i++] = abs(med - elem);
199     }
200     auto ret = medianPartition(devs);
201     TempAlloc.free;
202     return MedianAbsDev(med, ret);
203 }
204
205 unittest {
206     assert(approxEqual(medianAbsDev([7,1,8,2,8,1,9,2,8,4,5,9].dup).medianAbsDev, 2.5L));
207     assert(approxEqual(medianAbsDev([8,6,7,5,3,0,999].dup).medianAbsDev, 2.0L));
208 }
209
210 /**Computes the interquantile range of data at the given quantile value in O(N)
211  * time complexity.  For example, using a quantile value of either 0.25 or 0.75
212  * will give the interquartile range.  (This is the default since it is
213  * apparently the most common interquantile range in common usage.)
214  * Using a quantile value of 0.2 or 0.8 will give the interquntile range.
215  *
216  * If the quantile point falls between two indices, linear interpolation is
217  * used.
218  *
219  * This function is somewhat more efficient than simply finding the upper and
220  * lower quantile and subtracting them.
221  *
222  * Tip:  A quantile of 0 or 1 is handled as a special case and will compute the
223  *       plain old range of the data in a single pass.
224  */
225 double interquantileRange(R)(R data, double quantile = 0.25)
226 if(doubleInput!R) {
227     alias quantile q;  // Save typing.
228     dstatsEnforce(q >= 0 && q <= 1,
229         "Quantile must be between 0, 1 for interquantileRange.");
230
231     mixin(newFrame);
232     if(q > 0.5) {
233         q = 1.0 - q;
234     }
235
236     if(q == 0) {  // Special case:  Compute the plain old range.
237         double minElem = double.infinity;
238         double maxElem = -double.infinity;
239
240         foreach(elem; data) {
241             minElem = min(minElem, elem);
242             maxElem = max(maxElem, elem);
243         }
244
245         return maxElem - minElem;
246     }
247
248     // Common case.
249     auto duped = tempdup(data);
250     immutable double N = duped.length;
251     if(duped.length < 2) {
252         return double.nan;  // Can't do it.
253     }
254
255     immutable lowEnd = to!size_t((N - 1) * q);
256     immutable lowFract = (N - 1) * q - lowEnd;
257
258     partitionK(duped, lowEnd);
259     immutable lowQuantile1 = duped[lowEnd];
260     double minAbove = double.infinity;
261
262     foreach(elem; duped[lowEnd + 1..$]) {
263         minAbove = min(minAbove, elem);
264     }
265
266     immutable lowerQuantile =
267         lowFract * minAbove + (1 - lowFract) * lowQuantile1;
268
269     immutable highEnd = to!size_t((N - 1) * (1.0 - q) - lowEnd);
270     immutable highFract = (N - 1) * (1.0 - q) - lowEnd - highEnd;
271     duped = duped[lowEnd..$];
272     assert(highEnd < duped.length - 1);
273
274     partitionK(duped, highEnd);
275     immutable minAbove2 = reduce!min(double.infinity, duped[highEnd + 1..$]);
276     immutable upperQuantile = minAbove2 * highFract
277                             + duped[highEnd] * (1 - highFract);
278
279     return upperQuantile - lowerQuantile;
280 }
281
282 unittest {
283     // 0 3 5 6 7 8 9
284     assert(approxEqual(interquantileRange([1,2,3,4,5,6,7,8]), 3.5));
285     assert(approxEqual(interquantileRange([1,2,3,4,5,6,7,8,9]), 4));
286     assert(interquantileRange([1,9,2,4,3,6,8], 0) == 8);
287     assert(approxEqual(interquantileRange([8,6,7,5,3,0,9], 0.2), 4.4));
288 }
289
290 /**Output range to calculate the mean online.  Getter for mean costs a branch to
291  * check for N == 0.  This struct uses O(1) space and does *NOT* store the
292  * individual elements.
293  *
294  * Note:  This struct can implicitly convert to the value of the mean.
295  *
296  * Examples:
297  * ---
298  * Mean summ;
299  * summ.put(1);
300  * summ.put(2);
301  * summ.put(3);
302  * summ.put(4);
303  * summ.put(5);
304  * assert(summ.mean == 3);
305  * ---*/
306 struct Mean {
307 private:
308     double result = 0;
309     double k = 0;
310
311 public:
312     ///// Allow implicit casting to double, by returning the current mean.
313     alias mean this;
314
315     ///
316     void put(double element) pure nothrow @safe {
317         result += (element - result) / ++k;
318     }
319
320     /**Adds the contents of rhs to this instance.
321      *
322      * Examples:
323      * ---
324      * Mean mean1, mean2, combined;
325      * foreach(i; 0..5) {
326      *     mean1.put(i);
327      * }
328      *
329      * foreach(i; 5..10) {
330      *     mean2.put(i);
331      * }
332      *
333      * mean1.put(mean2);
334      *
335      * foreach(i; 0..10) {
336      *     combined.put(i);
337      * }
338      *
339      * assert(approxEqual(combined.mean, mean1.mean));
340      * ---
341      */
342      void put(typeof(this) rhs) pure nothrow @safe {
343          immutable totalN = k + rhs.k;
344          result = result * (k / totalN) + rhs.result * (rhs.k / totalN);
345          k = totalN;
346      }
347
348     const pure nothrow @property @safe {
349
350         ///
351         double sum() {
352             return result * k;
353         }
354
355         ///
356         double mean() {
357             return (k == 0) ? double.nan : result;
358         }
359
360         ///
361         double N() {
362             return k;
363         }
364
365         /**Simply returns this.  Useful in generic programming contexts.*/
366         Mean toMean() {
367             return this;
368         }
369     }
370
371     ///
372     string toString() const {
373         return to!(string)(mean);
374     }
375 }
376
377 /**Finds the arithmetic mean of any input range whose elements are implicitly
378  * convertible to double.*/
379 Mean mean(T)(T data)
380 if(doubleIterable!(T)) {
381
382     static if(isRandomAccessRange!T && dstats.base.hasLength!T) {
383         // This is optimized for maximum instruction level parallelism:
384         // The loop is unrolled such that there are 1 / (nILP)th the data
385         // dependencies of the naive algorithm.
386         enum nILP = 8;
387
388         Mean ret;
389         size_t i = 0;
390         if(data.length > 2 * nILP) {
391             double k = 0;
392             double[nILP] means = 0;
393             for(; i + nILP < data.length; i += nILP) {
394                 immutable kNeg1 = 1 / ++k;
395
396                 foreach(j; StaticIota!nILP) {
397                     means[j] += (data[i + j] - means[j]) * kNeg1;
398                 }
399             }
400
401             ret.k = k;
402             ret.result = means[0];
403             foreach(m; means[1..$]) {
404                 ret.put( Mean(m, k));
405             }
406         }
407
408         // Handle the remainder.
409         for(; i < data.length; i++) {
410             ret.put(data[i]);
411         }
412         return ret;
413
414     } else {
415         // Just submit everything to a single Mean struct and return it.
416         Mean meanCalc;
417
418         foreach(element; data) {
419             meanCalc.put(element);
420         }
421         return meanCalc;
422     }
423 }
424
425 ///
426 struct GeometricMean {
427 private:
428     Mean m;
429 public:
430     /////Allow implicit casting to double, by returning current geometric mean.
431     alias geoMean this;
432
433     ///
434     void put(double element) pure nothrow @safe {
435         m.put(log2(element));
436     }
437
438     /// Combine two GeometricMean's.
439     void put(typeof(this) rhs) pure nothrow @safe {
440         m.put(rhs.m);
441     }
442
443     const pure nothrow @property {
444         ///
445         double geoMean() {
446             return exp2(m.mean);
447         }
448
449         ///
450         double N() {
451             return m.k;
452         }
453     }
454
455     ///
456     string toString() const {
457         return to!(string)(geoMean);
458     }
459 }
460
461 ///
462 double geometricMean(T)(T data)
463 if(doubleIterable!(T)) {
464     // This is relatively seldom used and the log function is the bottleneck
465     // anyhow, not worth ILP optimizing.
466     GeometricMean m;
467     foreach(elem; data) {
468         m.put(elem);
469     }
470     return m.geoMean;
471 }
472
473 unittest {
474     string[] data = ["1", "2", "3", "4", "5"];
475     auto foo = map!(to!(uint))(data);
476
477     auto result = geometricMean(map!(to!(uint))(data));
478     assert(approxEqual(result, 2.60517));
479
480     Mean mean1, mean2, combined;
481     foreach(i; 0..5) {
482       mean1.put(i);
483     }
484
485     foreach(i; 5..10) {
486       mean2.put(i);
487     }
488
489     mean1.put(mean2);
490
491     foreach(i; 0..10) {
492       combined.put(i);
493     }
494
495     assert(approxEqual(combined.mean, mean1.mean),
496         text(combined.mean, "  ", mean1.mean));
497     assert(combined.N == mean1.N);
498 }
499
500
501 /**Finds the sum of an input range whose elements implicitly convert to double.
502  * User has option of making U a different type than T to prevent overflows
503  * on large array summing operations.  However, by default, return type is
504  * T (same as input type).*/
505 U sum(T, U = Unqual!(IterType!(T)))(T data)
506 if(doubleIterable!(T)) {
507
508     static if(isRandomAccessRange!T && dstats.base.hasLength!T) {
509         enum nILP = 8;
510         U[nILP] sum = 0;
511
512         size_t i = 0;
513         if(data.length > 2 * nILP) {
514
515             for(; i + nILP < data.length; i += nILP) {
516                 foreach(j; StaticIota!nILP) {
517                     sum[j] += data[i + j];
518                 }
519             }
520
521             foreach(j; 1..nILP) {
522                 sum[0] += sum[j];
523             }
524         }
525
526         for(; i < data.length; i++) {
527             sum[0] += data[i];
528         }
529
530         return sum[0];
531     } else {
532         U sum = 0;
533         foreach(elem; data) {
534             sum += elem;
535         }
536
537         return sum;
538     }
539 }
540
541 unittest {
542     assert(sum([1,2,3,4,5,6,7,8,9,10][]) == 55);
543     assert(sum(filter!"true"([1,2,3,4,5,6,7,8,9,10][])) == 55);
544     assert(sum(cast(int[]) [1,2,3,4,5])==15);
545     assert(approxEqual( sum(cast(int[]) [40.0, 40.1, 5.2]), 85.3));
546     assert(mean(cast(int[]) [1,2,3]).mean == 2);
547     assert(mean(cast(int[]) [1.0, 2.0, 3.0]).mean == 2.0);
548     assert(mean([1, 2, 5, 10, 17][]).mean == 7);
549     assert(mean([1, 2, 5, 10, 17][]).sum == 35);
550     assert(approxEqual(mean([8,6,7,5,3,0,9,3,6,2,4,3,6][]).mean, 4.769231));
551
552     // Test the OO struct a little, since we're using the new ILP algorithm.
553     Mean m;
554     m.put(1);
555     m.put(2);
556     m.put(5);
557     m.put(10);
558     m.put(17);
559     assert(m.mean == 7);
560
561     foreach(i; 0..100) {
562         // Monte carlo test the unrolled version.
563         auto foo = randArray!rNorm(uniform(5, 100), 0, 1);
564         auto res1 = mean(foo);
565         Mean res2;
566         foreach(elem; foo) {
567             res2.put(elem);
568         }
569
570         foreach(ti, elem; res1.tupleof) {
571             assert(approxEqual(elem, res2.tupleof[ti]));
572         }
573     }
574 }
575
576
577 /**Output range to compute mean, stdev, variance online.  Getter methods
578  * for stdev, var cost a few floating point ops.  Getter for mean costs
579  * a single branch to check for N == 0.  Relatively expensive floating point
580  * ops, if you only need mean, try Mean.  This struct uses O(1) space and
581  * does *NOT* store the individual elements.
582  *
583  * Note:  This struct can implicitly convert to a Mean struct.
584  *
585  * References: Computing Higher-Order Moments Online.
586  * http://people.xiph.org/~tterribe/notes/homs.html
587  *
588  * Examples:
589  * ---
590  * MeanSD summ;
591  * summ.put(1);
592  * summ.put(2);
593  * summ.put(3);
594  * summ.put(4);
595  * summ.put(5);
596  * assert(summ.mean == 3);
597  * assert(summ.stdev == sqrt(2.5));
598  * assert(summ.var == 2.5);
599  * ---*/
600 struct MeanSD {
601 private:
602     double _mean = 0;
603     double _var = 0;
604     double _k = 0;
605 public:
606     ///
607     void put(double element) pure nothrow @safe {
608         immutable kMinus1 = _k;
609         immutable delta = element - _mean;
610         immutable deltaN = delta / ++_k;
611
612         _mean += deltaN;
613         _var += kMinus1 * deltaN * delta;
614         return;
615     }
616
617     /// Combine two MeanSD's.
618     void put(typeof(this) rhs) pure nothrow @safe {
619         if(_k == 0) {
620             foreach(ti, elem; rhs.tupleof) {
621                 this.tupleof[ti] = elem;
622             }
623
624             return;
625         } else if(rhs._k == 0) {
626             return;
627         }
628
629         immutable totalN = _k + rhs._k;
630         immutable delta = rhs._mean - _mean;
631         _mean = _mean * (_k / totalN) + rhs._mean * (rhs._k / totalN);
632
633         _var = _var + rhs._var + (_k / totalN * rhs._k * delta * delta);
634         _k = totalN;
635     }
636
637     const pure nothrow @property @safe {
638
639         ///
640         double sum() {
641             return _k * _mean;
642         }
643
644         ///
645         double mean() {
646             return (_k == 0) ? double.nan : _mean;
647         }
648
649         ///
650         double stdev() {
651             return sqrt(var);
652         }
653
654         ///
655         double var() {
656             return (_k < 2) ? double.nan : _var / (_k - 1);
657         }
658
659         /**
660         Mean squared error.  In other words, a biased estimate of variance.
661         */
662         double mse() {
663             return (_k < 1) ? double.nan : _var / _k;
664         }
665
666         ///
667         double N() {
668             return _k;
669         }
670
671         /**Converts this struct to a Mean struct.  Also called when an
672          * implicit conversion via alias this takes place.
673          */
674         Mean toMean() {
675             return Mean(_mean, _k);
676         }
677
678         /**Simply returns this.  Useful in generic programming contexts.*/
679         MeanSD toMeanSD() const  {
680             return this;
681         }
682     }
683
684     alias Mean this;
685
686     ///
687     string toString() const {
688         return text("N = ", cast(ulong) _k, "\nMean = ", mean, "\nVariance = ",
689                var, "\nStdev = ", stdev);
690     }
691 }
692
693 /**Puts all elements of data into a MeanSD struct,
694  * then returns this struct.  This can be faster than doing this manually
695  * due to ILP optimizations.*/
696 MeanSD meanStdev(T)(T data)
697 if(doubleIterable!(T)) {
698
699     MeanSD ret;
700
701     static if(isRandomAccessRange!T && dstats.base.hasLength!T) {
702         // Optimize for instruction level parallelism.
703         enum nILP = 6;
704         double k = 0;
705         double[nILP] means = 0;
706         double[nILP] variances = 0;
707         size_t i = 0;
708
709         if(data.length > 2 * nILP) {
710             for(; i + nILP < data.length; i += nILP) {
711                 immutable kMinus1 = k;
712                 immutable kNeg1 = 1 / ++k;
713
714                 foreach(j; StaticIota!nILP) {
715                     immutable double delta = data[i + j] - means[j];
716                     immutable deltaN = delta * kNeg1;
717
718                     means[j] += deltaN;
719                     variances[j] += kMinus1 * deltaN * delta;
720                 }
721             }
722
723             ret._mean = means[0];
724             ret._var = variances[0];
725             ret._k = k;
726
727             foreach(j; 1..nILP) {
728                 ret.put( MeanSD(means[j], variances[j], k));
729             }
730         }
731
732         // Handle remainder.
733         for(; i < data.length; i++) {
734             ret.put(data[i]);
735         }
736     } else {
737         foreach(elem; data) {
738             ret.put(elem);
739         }
740     }
741     return ret;
742 }
743
744 /**Finds the variance of an input range with members implicitly convertible
745  * to doubles.*/
746 double variance(T)(T data)
747 if(doubleIterable!(T)) {
748     return meanStdev(data).var;
749 }
750
751 /**Calculate the standard deviation of an input range with members
752  * implicitly converitble to double.*/
753 double stdev(T)(T data)
754 if(doubleIterable!(T)) {
755     return meanStdev(data).stdev;
756 }
757
758 unittest {
759     auto res = meanStdev(cast(int[]) [3, 1, 4, 5]);
760     assert(approxEqual(res.stdev, 1.7078));
761     assert(approxEqual(res.mean, 3.25));
762     res = meanStdev(cast(double[]) [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
763     assert(approxEqual(res.stdev, 2.160247));
764     assert(approxEqual(res.mean, 4));
765     assert(approxEqual(res.sum, 28));
766
767     MeanSD mean1, mean2, combined;
768     foreach(i; 0..5) {
769       mean1.put(i);
770     }
771
772     foreach(i; 5..10) {
773       mean2.put(i);
774     }
775
776     mean1.put(mean2);
777
778     foreach(i; 0..10) {
779       combined.put(i);
780     }
781
782     assert(approxEqual(combined.mean, mean1.mean));
783     assert(approxEqual(combined.stdev, mean1.stdev));
784     assert(combined.N == mean1.N);
785     assert(approxEqual(combined.mean, 4.5));
786     assert(approxEqual(combined.stdev, 3.027650));
787
788     foreach(i; 0..100) {
789         // Monte carlo test the unrolled version.
790         auto foo = randArray!rNorm(uniform(5, 100), 0, 1);
791         auto res1 = meanStdev(foo);
792         MeanSD res2;
793         foreach(elem; foo) {
794             res2.put(elem);
795         }
796
797         foreach(ti, elem; res1.tupleof) {
798             assert(approxEqual(elem, res2.tupleof[ti]));
799         }
800
801         MeanSD resCornerCase;  // Test corner cases where one of the N's is 0.
802         resCornerCase.put(res1);
803         MeanSD dummy;
804         resCornerCase.put(dummy);
805         foreach(ti, elem; res1.tupleof) {
806             assert(elem == resCornerCase.tupleof[ti]);
807         }
808     }
809 }
810
811 /**Output range to compute mean, stdev, variance, skewness, kurtosis, min, and
812  * max online. Using this struct is relatively expensive, so if you just need
813  * mean and/or stdev, try MeanSD or Mean. Getter methods for stdev,
814  * var cost a few floating point ops.  Getter for mean costs a single branch to
815  * check for N == 0.  Getters for skewness and kurtosis cost a whole bunch of
816  * floating point ops.  This struct uses O(1) space and does *NOT* store the
817  * individual elements.
818  *
819  * Note:  This struct can implicitly convert to a MeanSD.
820  *
821  * References: Computing Higher-Order Moments Online.
822  * http://people.xiph.org/~tterribe/notes/homs.html
823  *
824  * Examples:
825  * ---
826  * Summary summ;
827  * summ.put(1);
828  * summ.put(2);
829  * summ.put(3);
830  * summ.put(4);
831  * summ.put(5);
832  * assert(summ.N == 5);
833  * assert(summ.mean == 3);
834  * assert(summ.stdev == sqrt(2.5));
835  * assert(summ.var == 2.5);
836  * assert(approxEqual(summ.kurtosis, -1.9120));
837  * assert(summ.min == 1);
838  * assert(summ.max == 5);
839  * assert(summ.sum == 15);
840  * ---*/
841 struct Summary {
842 private:
843     double _mean = 0;
844     double _m2 = 0;
845     double _m3 = 0;
846     double _m4 = 0;
847     double _k = 0;
848     double _min = double.infinity;
849     double _max = -double.infinity;
850 public:
851     ///
852     void put(double element) pure nothrow @safe {
853         immutable kMinus1 = _k;
854         immutable kNeg1 = 1.0 / ++_k;
855         _min = (element < _min) ? element : _min;
856         _max = (element > _max) ? element : _max;
857
858         immutable delta = element - _mean;
859         immutable deltaN = delta * kNeg1;
860         _mean += deltaN;
861
862         _m4 += kMinus1 * deltaN * (_k * _k - 3 * _k + 3) * deltaN * deltaN * delta +
863             6 * _m2 * deltaN * deltaN - 4 * deltaN * _m3;
864         _m3 += kMinus1 * deltaN * (_k - 2) * deltaN * delta - 3 * delta * _m2 * kNeg1;
865         _m2 += kMinus1 * deltaN * delta;
866     }
867
868     /// Combine two Summary's.
869     void put(typeof(this) rhs) pure nothrow @safe {
870         if(_k == 0) {
871             foreach(ti, elem; rhs.tupleof) {
872                 this.tupleof[ti] = elem;
873             }
874
875             return;
876         } else if(rhs._k == 0) {
877             return;
878         }
879
880         immutable totalN = _k + rhs._k;
881         immutable delta = rhs._mean - _mean;
882         immutable deltaN = delta / totalN;
883         _mean = _mean * (_k / totalN) + rhs._mean * (rhs._k / totalN);
884
885         _m4 = _m4 + rhs._m4 +
886             deltaN * _k * deltaN * rhs._k * deltaN * delta *
887             (_k * _k - _k * rhs._k + rhs._k * rhs._k) +
888             6 * deltaN * _k * deltaN * _k * rhs._m2 +
889             6 * deltaN * rhs._k * deltaN * rhs._k * _m2 +
890             4 * deltaN * _k * rhs._m3 -
891             4 * deltaN * rhs._k * _m3;
892
893         _m3 = _m3 + rhs._m3 + deltaN * _k * deltaN * rhs._k * (_k - rhs._k) +
894             3 * deltaN * _k * rhs._m2 -
895             3 * deltaN * rhs._k * _m2;
896
897         _m2 = _m2 + rhs._m2 + (_k / totalN * rhs._k * delta * delta);
898
899         _k = totalN;
900         _max = (_max > rhs._max) ? _max : rhs._max;
901         _min = (_min < rhs._min) ? _min : rhs._min;
902     }
903
904     const pure nothrow @property @safe {
905
906         ///
907         double sum() {
908             return _mean * _k;
909         }
910
911         ///
912         double mean() {
913             return (_k == 0) ? double.nan : _mean;
914         }
915
916         ///
917         double stdev() {
918             return sqrt(var);
919         }
920
921         ///
922         double var() {
923             return (_k < 2) ? double.nan : _m2 / (_k - 1);
924         }
925
926         /**
927         Mean squared error.  In other words, a biased estimate of variance.
928         */
929         double mse() {
930             return (_k < 1) ? double.nan : _m2 / _k;
931         }
932
933         ///
934         double skewness() {
935             immutable sqM2 = sqrt(_m2);
936             return _m3 / (sqM2 * sqM2 * sqM2) * sqrt(_k);
937         }
938
939         ///
940         double kurtosis() {
941             return _m4 / _m2 * _k  / _m2 - 3;
942         }
943
944         ///
945         double N() {
946             return _k;
947         }
948
949         ///
950         double min() {
951             return _min;
952         }
953
954         ///
955         double max() {
956             return _max;
957         }
958
959         /**Converts this struct to a MeanSD.  Called via alias this when an
960          * implicit conversion is attetmpted.
961          */
962         MeanSD toMeanSD() {
963             return MeanSD(_mean, _m2, _k);
964         }
965     }
966
967     alias toMeanSD this;
968
969     ///
970     string toString() const {
971         return text("N = ", roundTo!long(_k),
972                   "\nMean = ", mean,
973                   "\nVariance = ", var,
974                   "\nStdev = ", stdev,
975                   "\nSkewness = ", skewness,
976                   "\nKurtosis = ", kurtosis,
977                   "\nMin = ", _min,
978                   "\nMax = ", _max);
979     }
980 }
981
982 unittest {
983     // Everything else is tested indirectly through kurtosis, skewness.  Test
984     // put(typeof(this)).
985
986     Summary mean1, mean2, combined;
987     foreach(i; 0..5) {
988       mean1.put(i);
989     }
990
991     foreach(i; 5..10) {
992       mean2.put(i);
993     }
994
995     auto m1_2 = mean1;
996     auto m2_2 = mean2;
997     m1_2.put(m2_2);
998
999     mean1.put(mean2);
1000
1001     foreach(i; 0..10) {
1002       combined.put(i);
1003     }
1004
1005     foreach(ti, elem; mean1.tupleof) {
1006         assert(approxEqual(elem, combined.tupleof[ti]));
1007     }
1008
1009     Summary summCornerCase;  // Case where one N is zero.
1010     summCornerCase.put(mean1);
1011     Summary dummy;
1012     summCornerCase.put(dummy);
1013     foreach(ti, elem; summCornerCase.tupleof) {
1014         assert(elem == mean1.tupleof[ti]);
1015     }
1016 }
1017
1018 /**Excess kurtosis relative to normal distribution.  High kurtosis means that
1019  * the variance is due to infrequent, large deviations from the mean.  Low
1020  * kurtosis means that the variance is due to frequent, small deviations from
1021  * the mean.  The normal distribution is defined as having kurtosis of 0.
1022  * Input must be an input range with elements implicitly convertible to double.*/
1023 double kurtosis(T)(T data)
1024 if(doubleIterable!(T)) {
1025     // This is too infrequently used and has too much ILP within a single
1026     // iteration to be worth ILP optimizing.
1027     Summary kCalc;
1028     foreach(elem; data) {
1029         kCalc.put(elem);
1030     }
1031     return kCalc.kurtosis;
1032 }
1033
1034 unittest {
1035     // Values from Matlab.
1036     assert(approxEqual(kurtosis([1, 1, 1, 1, 10].dup), 0.25));
1037     assert(approxEqual(kurtosis([2.5, 3.5, 4.5, 5.5].dup), -1.36));
1038     assert(approxEqual(kurtosis([1,2,2,2,2,2,100].dup), 2.1657));
1039 }
1040
1041 /**Skewness is a measure of symmetry of a distribution.  Positive skewness
1042  * means that the right tail is longer/fatter than the left tail.  Negative
1043  * skewness means the left tail is longer/fatter than the right tail.  Zero
1044  * skewness indicates a symmetrical distribution.  Input must be an input
1045  * range with elements implicitly convertible to double.*/
1046 double skewness(T)(T data)
1047 if(doubleIterable!(T)) {
1048     // This is too infrequently used and has too much ILP within a single
1049     // iteration to be worth ILP optimizing.
1050     Summary sCalc;
1051     foreach(elem; data) {
1052         sCalc.put(elem);
1053     }
1054     return sCalc.skewness;
1055 }
1056
1057 unittest {
1058     // Values from Octave.
1059     assert(approxEqual(skewness([1,2,3,4,5].dup), 0));
1060     assert(approxEqual(skewness([3,1,4,1,5,9,2,6,5].dup), 0.5443));
1061     assert(approxEqual(skewness([2,7,1,8,2,8,1,8,2,8,4,5,9].dup), -0.0866));
1062
1063     // Test handling of ranges that are not arrays.
1064     string[] stringy = ["3", "1", "4", "1", "5", "9", "2", "6", "5"];
1065     auto intified = map!(to!(int))(stringy);
1066     assert(approxEqual(skewness(intified), 0.5443));
1067 }
1068
1069 /**Convenience function.  Puts all elements of data into a Summary struct,
1070  * and returns this struct.*/
1071 Summary summary(T)(T data)
1072 if(doubleIterable!(T)) {
1073     // This is too infrequently used and has too much ILP within a single
1074     // iteration to be worth ILP optimizing.
1075     Summary summ;
1076     foreach(elem; data) {
1077         summ.put(elem);
1078     }
1079     return summ;
1080 }
1081 // Just a convenience function for a well-tested struct.  No unittest really
1082 // necessary.  (Famous last words.)
1083
1084 ///
1085 struct ZScore(T) if(isForwardRange!(T) && is(ElementType!(T) : double)) {
1086 private:
1087     T range;
1088     double mean;
1089     double sdNeg1;
1090
1091     double z(double elem) {
1092         return (elem - mean) * sdNeg1;
1093     }
1094
1095 public:
1096     this(T range) {
1097         this.range = range;
1098         auto msd = meanStdev(range);
1099         this.mean = msd.mean;
1100         this.sdNeg1 = 1.0 / msd.stdev;
1101     }
1102
1103     this(T range, double mean, double sd) {
1104         this.range = range;
1105         this.mean = mean;
1106         this.sdNeg1 = 1.0 / sd;
1107     }
1108
1109     ///
1110     @property double front() {
1111         return z(range.front);
1112     }
1113
1114     ///
1115     void popFront() {
1116         range.popFront;
1117     }
1118
1119     ///
1120     @property bool empty() {
1121         return range.empty;
1122     }
1123
1124     static if(isForwardRange!(T)) {
1125         ///
1126         @property typeof(this) save() {
1127             auto ret = this;
1128             ret.range = range.save;
1129             return ret;
1130         }
1131     }
1132
1133     static if(isRandomAccessRange!(T)) {
1134         ///
1135         double opIndex(size_t index) {
1136             return z(range[index]);
1137         }
1138     }
1139
1140     static if(isBidirectionalRange!(T)) {
1141         ///
1142         @property double back() {
1143             return z(range.back);
1144         }
1145
1146         ///
1147         void popBack() {
1148             range.popBack;
1149         }
1150     }
1151
1152     static if(dstats.base.hasLength!(T)) {
1153         ///
1154         @property size_t length() {
1155             return range.length;
1156         }
1157     }
1158 }
1159
1160 /**Returns a range with whatever properties T has (forward range, random
1161  * access range, bidirectional range, hasLength, etc.),
1162  * of the z-scores of the underlying
1163  * range.  A z-score of an element in a range is defined as
1164  * (element - mean(range)) / stdev(range).
1165  *
1166  * Notes:
1167  *
1168  * If the data contained in the range is a sample of a larger population,
1169  * rather than an entire population, then technically, the results output
1170  * from the ZScore range are T statistics, not Z statistics.  This is because
1171  * the sample mean and standard deviation are only estimates of the population
1172  * parameters.  This does not affect the mechanics of using this range,
1173  * but it does affect the interpretation of its output.
1174  *
1175  * Accessing elements of this range is fairly expensive, as a
1176  * floating point multiply is involved.  Also, constructing this range is
1177  * costly, as the entire input range has to be iterated over to find the
1178  * mean and standard deviation.
1179  */
1180 ZScore!(T) zScore(T)(T range)
1181 if(isForwardRange!(T) && doubleInput!(T)) {
1182     return ZScore!(T)(range);
1183 }
1184
1185 /**Allows the construction of a ZScore range with precomputed mean and
1186  * stdev.
1187  */
1188 ZScore!(T) zScore(T)(T range, double mean, double sd)
1189 if(isForwardRange!(T) && doubleInput!(T)) {
1190     return ZScore!(T)(range, mean, sd);
1191 }
1192
1193 unittest {
1194     int[] arr = [1,2,3,4,5];
1195     auto m = mean(arr).mean;
1196     auto sd = stdev(arr);
1197     auto z = zScore(arr);
1198
1199     size_t pos = 0;
1200     foreach(elem; z) {
1201         assert(approxEqual(elem, (arr[pos++] - m) / sd));
1202     }
1203
1204     assert(z.length == 5);
1205     foreach(i; 0..z.length) {
1206         assert(approxEqual(z[i], (arr[i] - m) / sd));
1207     }
1208 }
1209
1210
1211
1212 // Verify that there are no TempAlloc memory leaks anywhere in the code covered
1213 // by the unittest.  This should always be the last unittest of the module.
1214 unittest {
1215     auto TAState = TempAlloc.getState;
1216     assert(TAState.used == 0);
1217     assert(TAState.nblocks < 2);
1218 }
Note: See TracBrowser for help on using the browser.