root/trunk/cor.d

Revision 293, 27.9 kB (checked in by dsimcha, 1 year ago)

Better loop unrolling: Speed up ILP optimized functions ~10%.

Line 
1 /**Pearson, Spearman and Kendall correlations, covariance.
2  *
3  * Author:  David Simcha*/
4  /*
5  * License:
6  * Boost Software License - Version 1.0 - August 17th, 2003
7  *
8  * Permission is hereby granted, free of charge, to any person or organization
9  * obtaining a copy of the software and accompanying documentation covered by
10  * this license (the "Software") to use, reproduce, display, distribute,
11  * execute, and transmit the Software, and to prepare derivative works of the
12  * Software, and to permit third-parties to whom the Software is furnished to
13  * do so, all subject to the following:
14  *
15  * The copyright notices in the Software and this entire statement, including
16  * the above license grant, this restriction and the following disclaimer,
17  * must be included in all copies of the Software, in whole or in part, and
18  * all derivative works of the Software, unless such copies or derivative
19  * works are solely in the form of machine-executable object code generated by
20  * a source language processor.
21  *
22  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24  * FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
25  * SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
26  * FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
27  * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
28  * DEALINGS IN THE SOFTWARE.
29  */
30
31 module dstats.cor;
32
33 import std.conv, std.range, std.typecons, std.exception, std.math,
34     std.traits, std.typetuple;
35
36 import dstats.sort, dstats.base, dstats.alloc, dstats.regress : invert;
37
38 version(unittest) {
39     import std.stdio, dstats.random, std.algorithm : map, swap, reduce;
40
41     Random gen;
42
43     void main() {
44         gen.seed(unpredictableSeed);
45     }
46 }
47
48 /**Convenience function for calculating Pearson correlation.
49  * When the term correlation is used unqualified, it is
50  * usually referring to this quantity.  This is a parametric correlation
51  * metric and should not be used with extremely ill-behaved data.
52  * This function works with any pair of input ranges.
53  *
54  * Note:  The PearsonCor struct returned by this function is alias this'd to the
55  * correlation coefficient.  Therefore, the result from this function can
56  * be treated simply as a floating point number.
57  *
58  * References: Computing Higher-Order Moments Online.
59  * http://people.xiph.org/~tterribe/notes/homs.html
60  */
61 PearsonCor pearsonCor(T, U)(T input1, U input2)
62 if(doubleInput!(T) && doubleInput!(U)) {
63     PearsonCor corCalc;
64
65     static if(isRandomAccessRange!T && isRandomAccessRange!U &&
66         dstats.base.hasLength!T && dstats.base.hasLength!U) {
67
68         // ILP parallelization optimization.  Sharing a k between a bunch
69         // of implicit PearsonCor structs cuts down on the amount of divisions
70         // necessary.  Using nILP of them instead of one improves CPU pipeline
71         // performance by reducing data dependency.  When the stack is
72         // properly aligned, this can result in about 2x speedups compared
73         // to simply submitting everything to a single PearsonCor struct.
74         dstatsEnforce(input1.length == input2.length,
75             "Ranges must be same length for Pearson correlation.");
76
77         enum nILP = 8;
78         size_t i = 0;
79         if(input1.length > 2 * nILP) {
80
81             double _k = 0;
82             double[nILP] _mean1 = 0, _mean2 = 0, _var1 = 0, _var2 = 0, _cov = 0;
83
84             for(; i + nILP < input1.length; i += nILP) {
85                 immutable kMinus1 = _k;
86                 immutable kNeg1 = 1 / ++_k;
87
88                 foreach(j; StaticIota!nILP) {
89                     immutable double delta1 = input1[i + j] - _mean1[j];
90                     immutable double delta2 = input2[i + j] - _mean2[j];
91                     immutable delta1N = delta1 * kNeg1;
92                     immutable delta2N = delta2 * kNeg1;
93
94                     _mean1[j] += delta1N;
95                     _var1[j]  += kMinus1 * delta1N * delta1;
96                     _cov[j]   += kMinus1 * delta1N * delta2;
97                     _var2[j]  += kMinus1 * delta2N * delta2;
98                     _mean2[j] += delta2N;
99                 }
100             }
101
102             corCalc._k = _k;
103             corCalc._mean1 = _mean1[0];
104             corCalc._mean2 = _mean2[0];
105             corCalc._var1 = _var1[0];
106             corCalc._var2 = _var2[0];
107             corCalc._cov = _cov[0];
108
109             foreach(j; 1..nILP) {
110                 corCalc.put( PearsonCor(_k, _mean1[j], _mean2[j], _var1[j], _var2[j], _cov[j]));
111             }
112         }
113
114         // Handle remainder.
115         for(; i < input1.length; i++) {
116             corCalc.put(input1[i], input2[i]);
117         }
118
119     } else {
120         while(!input1.empty && !input2.empty) {
121             corCalc.put(input1.front, input2.front);
122             input1.popFront;
123             input2.popFront;
124         }
125
126         dstatsEnforce(input1.empty && input2.empty,
127             "Ranges must be same length for Pearson correlation.");
128     }
129
130     return corCalc;
131 }
132
133 unittest {
134     assert(approxEqual(pearsonCor([1,2,3,4,5][], [1,2,3,4,5][]).cor, 1));
135     assert(approxEqual(pearsonCor([1,2,3,4,5][], [10.0, 8.0, 6.0, 4.0, 2.0][]).cor, -1));
136     assert(approxEqual(pearsonCor([2, 4, 1, 6, 19][], [4, 5, 1, 3, 2][]).cor, -.2382314));
137
138         // Make sure everything works with lowest common denominator range type.
139     static struct Count {
140         uint num;
141         uint upTo;
142         @property size_t front() {
143             return num;
144         }
145         void popFront() {
146             num++;
147         }
148         @property bool empty() {
149             return num >= upTo;
150         }
151     }
152
153     Count a, b;
154     a.upTo = 100;
155     b.upTo = 100;
156     assert(approxEqual(pearsonCor(a, b).cor, 1));
157
158     PearsonCor cor1 = pearsonCor([1,2,4][], [2,3,5][]);
159     PearsonCor cor2 = pearsonCor([4,2,9][], [2,8,7][]);
160     PearsonCor combined = pearsonCor([1,2,4,4,2,9][], [2,3,5,2,8,7][]);
161
162     cor1.put(cor2);
163
164     foreach(ti, elem; cor1.tupleof) {
165         assert(approxEqual(elem, combined.tupleof[ti]));
166     }
167
168     assert(approxEqual(pearsonCor([1,2,3,4,5,6,7,8,9,10][],
169         [8,6,7,5,3,0,9,3,6,2][]).cor, -0.4190758));
170
171     foreach(iter; 0..1000) {
172         // Make sure results for the ILP-optimized and non-optimized versions
173         // agree.
174         auto foo = randArray!(rNorm)(uniform(5, 100), 0, 1);
175         auto bar = randArray!(rNorm)(foo.length, 0, 1);
176         auto res1 = pearsonCor(foo, bar);
177         PearsonCor res2;
178         foreach(i; 0..foo.length) {
179             res2.put(foo[i], bar[i]);
180         }
181
182         foreach(ti, elem; res1.tupleof) {
183             assert(approxEqual(elem, res2.tupleof[ti]));
184         }
185
186         PearsonCor resCornerCase;  // Test where one N is zero.
187         resCornerCase.put(res1);
188         PearsonCor dummy;
189         resCornerCase.put(dummy);
190         foreach(ti, elem; res1.tupleof) {
191             assert(isIdentical(resCornerCase.tupleof[ti], elem));
192         }
193     }
194 }
195
196 /**Allows computation of mean, stdev, variance, covariance, Pearson correlation online.
197  * Getters for stdev, var, cov, cor cost floating point division ops.  Getters
198  * for means cost a single branch to check for N == 0.  This struct uses O(1)
199  * space.
200  *
201  * PearsonCor.cor is alias this'd, so if this struct is used as a float, it will
202  * be converted to a simple correlation coefficient automatically.
203  *
204  * Bugs:  Alias this disabled due to compiler bugs.
205  *
206  * References: Computing Higher-Order Moments Online.
207  * http://people.xiph.org/~tterribe/notes/homs.html
208  */
209 struct PearsonCor {
210 package:
211     double _k = 0, _mean1 = 0, _mean2 = 0, _var1 = 0, _var2 = 0, _cov = 0;
212
213 public:
214     alias cor this;
215
216     ///
217     void put(double elem1, double elem2) nothrow @safe {
218         immutable kMinus1 = _k;
219         immutable kNeg1 = 1 / ++_k;
220         immutable delta1 = elem1 - _mean1;
221         immutable delta2 = elem2 - _mean2;
222         immutable delta1N = delta1 * kNeg1;
223         immutable delta2N = delta2 * kNeg1;
224
225         _mean1 += delta1N;
226         _var1  += kMinus1 * delta1N * delta1;
227         _cov   += kMinus1 * delta1N * delta2;
228         _var2  += kMinus1 * delta2N * delta2;
229         _mean2 += delta2N;
230     }
231
232     /// Combine two PearsonCor's.
233     void put(const ref typeof(this) rhs) nothrow @safe {
234         if(_k == 0) {
235             foreach(ti, elem; rhs.tupleof) {
236                 this.tupleof[ti] = elem;
237             }
238             return;
239         } else if(rhs._k == 0) {
240             return;
241         }
242
243         immutable totalN = _k + rhs._k;
244         immutable delta1 = rhs.mean1 - mean1;
245         immutable delta2 = rhs.mean2 - mean2;
246
247         _mean1 = _mean1 * (_k / totalN) + rhs._mean1 * (rhs._k / totalN);
248         _mean2 = _mean2 * (_k / totalN) + rhs._mean2 * (rhs._k / totalN);
249
250         _var1 = _var1 + rhs._var1 + (_k / totalN * rhs._k * delta1 * delta1 );
251         _var2 = _var2 + rhs._var2 + (_k / totalN * rhs._k * delta2 * delta2 );
252         _cov  =  _cov + rhs._cov  + (_k / totalN * rhs._k * delta1 * delta2 );
253         _k = totalN;
254     }
255
256     const pure nothrow @property @safe {
257
258         ///
259         double var1() {
260             return (_k < 2) ? double.nan : _var1 / (_k - 1);
261         }
262
263         ///
264         double var2() {
265             return (_k < 2) ? double.nan : _var2 / (_k - 1);
266         }
267
268         ///
269         double stdev1() {
270             return sqrt(var1);
271         }
272
273         ///
274         double stdev2() {
275             return sqrt(var2);
276         }
277
278         ///
279         double cor() {
280             return cov / stdev1 / stdev2;
281         }
282
283         ///
284         double cov() {
285             return (_k < 2) ? double.nan : _cov / (_k - 1);
286         }
287
288         ///
289         double mean1() {
290             return (_k == 0) ? double.nan : _mean1;
291         }
292
293         ///
294         double mean2() {
295             return (_k == 0) ? double.nan : _mean2;
296         }
297
298         ///
299         double N() {
300             return _k;
301         }
302
303     }
304 }
305
306 ///
307 double covariance(T, U)(T input1, U input2)
308 if(doubleInput!(T) && doubleInput!(U)) {
309     return pearsonCor(input1, input2).cov;
310 }
311
312 unittest {
313     assert(approxEqual(covariance([1,4,2,6,3].dup, [3,1,2,6,2].dup), 2.05));
314 }
315
316 /**Spearman's rank correlation.  Non-parametric.  This is essentially the
317  * Pearson correlation of the ranks of the data, with ties dealt with by
318  * averaging.*/
319 double spearmanCor(R, S)(R input1, S input2)
320 if(isInputRange!(R) && isInputRange!(S) &&
321 is(typeof(input1.front < input1.front) == bool) &&
322 is(typeof(input2.front < input2.front) == bool)) {
323
324     static if(dstats.base.hasLength!S && dstats.base.hasLength!R) {
325         if(input1.length < 2) {
326             return double.nan;
327         }
328     }
329
330     mixin(newFrame);
331
332     static double[] spearmanCorRank(T)(T someRange) {
333         static if(dstats.base.hasLength!(T) && isRandomAccessRange!(T)) {
334             double[] ret = newStack!(double)(someRange.length);
335             rank(someRange, ret);
336         } else {
337             auto iDup = tempdup(someRange);
338             double[] ret = newStack!(double)(iDup.length);
339             rankSort(iDup, ret);
340         }
341         return ret;
342     }
343
344     try {
345         auto ranks1 = spearmanCorRank(input1);
346         auto ranks2 = spearmanCorRank(input2);
347         dstatsEnforce(ranks1.length == ranks2.length,
348             "Ranges must be same length for Spearman correlation.");
349
350         return pearsonCor(ranks1, ranks2).cor;
351     } catch(SortException) {
352         return double.nan;
353     }
354 }
355
356 unittest {
357     //Test against a few known values.
358     assert(approxEqual(spearmanCor([1,2,3,4,5,6].dup, [3,1,2,5,4,6].dup), 0.77143));
359     assert(approxEqual(spearmanCor([3,1,2,5,4,6].dup, [1,2,3,4,5,6].dup ), 0.77143));
360     assert(approxEqual(spearmanCor([3,6,7,35,75].dup, [1,63,53,67,3].dup), 0.3));
361     assert(approxEqual(spearmanCor([1,63,53,67,3].dup, [3,6,7,35,75].dup), 0.3));
362     assert(approxEqual(spearmanCor([1.5,6.3,7.8,4.2,1.5].dup, [1,63,53,67,3].dup), .56429));
363     assert(approxEqual(spearmanCor([1,63,53,67,3].dup, [1.5,6.3,7.8,4.2,1.5].dup), .56429));
364     assert(approxEqual(spearmanCor([1.5,6.3,7.8,7.8,1.5].dup, [1,63,53,67,3].dup), .79057));
365     assert(approxEqual(spearmanCor([1,63,53,67,3].dup, [1.5,6.3,7.8,7.8,1.5].dup), .79057));
366     assert(approxEqual(spearmanCor([1.5,6.3,7.8,6.3,1.5].dup, [1,63,53,67,3].dup), .63246));
367     assert(approxEqual(spearmanCor([1,63,53,67,3].dup, [1.5,6.3,7.8,6.3,1.5].dup), .63246));
368     assert(approxEqual(spearmanCor([3,4,1,5,2,1,6,4].dup, [1,3,2,6,4,2,6,7].dup), .6829268));
369     assert(approxEqual(spearmanCor([1,3,2,6,4,2,6,7].dup, [3,4,1,5,2,1,6,4].dup), .6829268));
370     uint[] one = new uint[1000], two = new uint[1000];
371     foreach(i; 0..100) {  //Further sanity checks for things like commutativity.
372         size_t lowerBound = uniform(0, one.length);
373         size_t upperBound = uniform(0, one.length);
374         if(lowerBound > upperBound) swap(lowerBound, upperBound);
375         foreach(ref o; one) {
376             o = uniform(1, 10);  //Generate lots of ties.
377         }
378         foreach(ref o; two) {
379              o = uniform(1, 10);  //Generate lots of ties.
380         }
381         double sOne =
382              spearmanCor(one[lowerBound..upperBound], two[lowerBound..upperBound]);
383         double sTwo =
384              spearmanCor(two[lowerBound..upperBound], one[lowerBound..upperBound]);
385         foreach(ref o; one)
386             o*=-1;
387         double sThree =
388              -spearmanCor(one[lowerBound..upperBound], two[lowerBound..upperBound]);
389         double sFour =
390              -spearmanCor(two[lowerBound..upperBound], one[lowerBound..upperBound]);
391         foreach(ref o; two) o*=-1;
392         one[lowerBound..upperBound].reverse;
393         two[lowerBound..upperBound].reverse;
394         double sFive =
395              spearmanCor(one[lowerBound..upperBound], two[lowerBound..upperBound]);
396         assert(approxEqual(sOne, sTwo) || (isnan(sOne) && isnan(sTwo)));
397         assert(approxEqual(sTwo, sThree) || (isnan(sThree) && isnan(sTwo)));
398         assert(approxEqual(sThree, sFour) || (isnan(sThree) && isnan(sFour)));
399         assert(approxEqual(sFour, sFive) || (isnan(sFour) && isnan(sFive)));
400     }
401
402     // Test input ranges.
403     static struct Count {
404         uint num;
405         uint upTo;
406         @property size_t front() {
407             return num;
408         }
409         void popFront() {
410             num++;
411         }
412         @property bool empty() {
413             return num >= upTo;
414         }
415     }
416
417     Count a, b;
418     a.upTo = 100;
419     b.upTo = 100;
420     assert(approxEqual(spearmanCor(a, b), 1));
421 }
422
423 version(unittest) {
424     // Make sure when we call kendallCor, the large N version always executes.
425     private enum kendallSmallN = 1;
426 } else {
427     private enum kendallSmallN = 15;
428 }
429
430 /**Kendall's Tau-b, O(N log N) version.  This can be defined in terms of the
431  * bubble sort distance, or the number of swaps that would be needed in a
432  * bubble sort to sort input2 into the same order as input1.  It is
433  * a robust, non-parametric correlation metric.
434  *
435  * Since a copy of the inputs is made anyhow because they need to be sorted,
436  * this function can work with any input range.  However, the ranges must
437  * have the same length.
438  *
439  * References:
440  * A Computer Method for Calculating Kendall's Tau with Ungrouped Data,
441  * William R. Knight, Journal of the American Statistical Association, Vol.
442  * 61, No. 314, Part 1 (Jun., 1966), pp. 436-439
443  *
444  * The Variance of Tau When Both Rankings Contain Ties.  M.G. Kendall.
445  * Biometrika, Vol 34, No. 3/4 (Dec., 1947), pp. 297-298
446  */
447 double kendallCor(T, U)(T input1, U input2)
448 if(isInputRange!(T) && isInputRange!(U)) {
449     static if(isArray!(T) && isArray!(U)) {
450         dstatsEnforce(input1.length == input2.length,
451             "Ranges must be same length for Kendall correlation.");
452         if(input1.length <= kendallSmallN) {
453             return kendallCorSmallN(input1, input2);
454         }
455     }
456
457     auto i1d = tempdup(input1);
458     scope(exit) TempAlloc.free;
459     auto i2d = tempdup(input2);
460     scope(exit) TempAlloc.free;
461
462     dstatsEnforce(i1d.length == i2d.length,
463         "Ranges must be same length for Kendall correlation.");
464
465     if(i1d.length <= kendallSmallN) {
466         return kendallCorSmallN(i1d, i2d);
467     } else {
468         return kendallCorDestructive(i1d, i2d);
469     }
470 }
471
472 /**Kendall's Tau-b O(N log N), overwrites input arrays with undefined data but
473  * uses only O(log N) stack space for sorting, not O(N) space to duplicate
474  * input.  Only works on arrays.
475  */
476 double kendallCorDestructive(T, U)(T[] input1, U[] input2) {
477     dstatsEnforce(input1.length == input2.length,
478         "Ranges must be same length for Kendall correlation.");
479     try {
480         return kendallCorDestructiveLowLevel(input1, input2, false).tau;
481     } catch(SortException) {
482         return double.nan;
483     }
484 }
485
486 //bool compFun(T)(T lhs, T rhs) { return lhs < rhs; }
487 private enum compFun = "a < b";
488
489 // Guarantee that T.sizeof >= U.sizeof so we know we can recycle space.
490 auto kendallCorDestructiveLowLevel(T, U)(T[] input1, U[] input2, bool needTies)
491 if(T.sizeof < U.sizeof) {
492     return kendallCorDestructiveLowLevel(input2, input1, needTies);
493 }
494
495 struct KendallLowLevel {
496     double tau;
497     long s;
498
499     // Notation as in Kendall, 1947, Biometrika
500
501     ulong tieCorrectT1;  // sum{t(t - 1)(2t + 5)}
502     ulong tieCorrectT2;  // sum{t(t - 1)(t - 2)}
503     ulong tieCorrectT3;  // sum{t(t - 1)}
504
505     ulong tieCorrectU1;  // sum{u(u - 1)(2u + 5)}
506     ulong tieCorrectU2;  // sum{u(u - 1)(u - 2)}
507     ulong tieCorrectU3;  // sum{u(u - 1)}
508 }
509
510 // Used internally in dstats.tests.kendallCorTest.
511 KendallLowLevel kendallCorDestructiveLowLevel
512 (T, U)(T[] input1, U[] input2, bool needTies)
513 if(T.sizeof >= U.sizeof)
514 in {
515     assert(input1.length == input2.length);
516 } body {
517     static ulong getMs(V)(const V[] data) {  //Assumes data is sorted.
518         ulong Ms = 0, tieCount = 0;
519         foreach(i; 1..data.length) {
520             if(data[i] == data[i - 1]) {
521                 tieCount++;
522             } else if(tieCount) {
523                 Ms += (tieCount * (tieCount + 1)) / 2;
524                 tieCount = 0;
525             }
526         }
527         if(tieCount) {
528             Ms += (tieCount * (tieCount + 1)) / 2;
529         }
530         return Ms;
531     }
532
533     void computeTies(V)
534     (V[] arr, ref ulong tie1, ref ulong tie2, ref ulong tie3) {
535         if(!needTies) {
536             return;  // If only computing correlation, this is a waste of time.
537         }
538
539         ulong tieCount = 1;
540         foreach(i; 1..arr.length) {
541             if(arr[i] == arr[i - 1]) {
542                 tieCount++;
543             } else if(tieCount > 1) {
544                 tie1 += tieCount * (tieCount - 1) * (2 * tieCount + 5);
545                 tie2 += tieCount * (tieCount - 1) * (tieCount - 2);
546                 tie3 += tieCount * (tieCount - 1);
547                 tieCount = 1;
548             }
549         }
550
551         // Handle last run.
552          if(tieCount > 1) {
553             tie1 += tieCount * (tieCount - 1) * (2 * tieCount + 5);
554             tie2 += tieCount * (tieCount - 1) * (tieCount - 2);
555             tie3 += tieCount * (tieCount - 1);
556         }
557     }
558
559     ulong m1 = 0;
560     ulong nPair = (cast(ulong) input1.length *
561                   ( cast(ulong) input1.length - 1UL)) / 2UL;
562     KendallLowLevel ret;
563     ret.s = to!long(nPair);
564
565     qsort!(compFun)(input1, input2);
566
567     uint tieCount = 0;
568     foreach(i; 1..input1.length) {
569         if(input1[i] == input1[i - 1]) {
570             tieCount++;
571         } else if(tieCount > 0) {
572             qsort!(compFun)(input2[i - tieCount - 1..i]);
573             m1 += tieCount * (tieCount + 1UL) / 2UL;
574             ret.s += getMs(input2[i - tieCount - 1..i]);
575             tieCount = 0;
576         }
577     }
578     if(tieCount > 0) {
579         qsort!(compFun)(input2[input1.length - tieCount - 1..input1.length]);
580         m1 += tieCount * (tieCount + 1UL) / 2UL;
581         ret.s += getMs(input2[input1.length - tieCount - 1..input1.length]);
582     }
583
584     computeTies(input1, ret.tieCorrectT1, ret.tieCorrectT2, ret.tieCorrectT3);
585
586     // We've already guaranteed that T.sizeof >= U.sizeof and we own these
587     // arrays and will never use input1 again, so this is safe.
588     ulong swapCount = 0;
589     U[] input1Temp = (cast(U*) input1.ptr)[0..input2.length];
590     mergeSortTemp!(compFun)(input2, input1Temp, &swapCount);
591
592     immutable m2 = getMs(input2);
593     computeTies(input2, ret.tieCorrectU1, ret.tieCorrectU2, ret.tieCorrectU3);
594
595     ret.s -= (m1 + m2) + 2 * swapCount;
596     immutable double denominator1 = nPair - m1;
597     immutable double denominator2 = nPair - m2;
598     ret.tau = ret.s / sqrt(denominator1) / sqrt(denominator2);
599     return ret;
600 }
601
602 /* Kendall's Tau correlation, O(N^2) version.  This is faster than the
603  * more asymptotically efficient version for N <= about 15, and is also useful
604  * for testing.  Yes, the sorts for the large N impl fall back on insertion
605  * sorting for moderately small N, but due to additive constants and O(N) terms
606  * this algorithm is still faster for very small N.  (Besides, I can't
607  * delete it anyhow because I need it for testing.)
608  */
609 private double kendallCorSmallN(T, U)(const T[] input1, const U[] input2)
610 in {
611     assert(input1.length == input2.length);
612
613     // This function should never be used for any inputs even close to this
614     // large because it's a small-N optimization and a more efficient
615     // implementation exists in this module for large N, but when N gets this
616     // large it's not even correct due to overflow errors.
617     assert(input1.length < 1 << 15);
618 } body {
619     int m1 = 0, m2 = 0;
620     int s = 0;
621
622     foreach(i; 0..input2.length) {
623         foreach (j; i + 1..input2.length) {
624             if(input2[i] > input2[j]) {
625                 if (input1[i] > input1[j]) {
626                     s++;
627                 } else if(input1[i] < input1[j]) {
628                     s--;
629                 } else if(input1[i] == input1[j]) {
630                     m1++;
631                 } else {
632                     return double.nan;
633                 }
634             } else if(input2[i] < input2[j]) {
635                 if (input1[i] > input1[j]) {
636                     s--;
637                 } else if(input1[i] < input1[j]) {
638                     s++;
639                 } else if(input1[i] == input1[j]) {
640                     m1++;
641                 } else {
642                     return double.nan;
643                 }
644             } else if(input2[i] == input2[j]) {
645                 m2++;
646
647                 if(input1[i] < input1[j]) {
648                 } else if(input1[i] > input1[j]) {
649                 } else if(input1[i] == input1[j]) {
650                     m1++;
651                 } else {
652                     return double.nan;
653                 }
654
655             } else {
656                 return double.nan;
657             }
658         }
659     }
660
661     immutable nCombination = input2.length * (input2.length - 1) / 2;
662     immutable double denominator1 = nCombination - m1;
663     immutable double denominator2 = nCombination - m2;
664     return s / sqrt(denominator1) / sqrt(denominator2);
665 }
666
667
668 unittest {
669     //Test against known values.
670     assert(approxEqual(kendallCor([1,2,3,4,5].dup, [3,1,7,4,3].dup), 0.1054093));
671     assert(approxEqual(kendallCor([3,6,7,35,75].dup,[1,63,53,67,3].dup), 0.2));
672     assert(approxEqual(kendallCor([1.5,6.3,7.8,4.2,1.5].dup, [1,63,53,67,3].dup), .3162287));
673
674     static void doKendallTest(T)() {
675         T[] one = new T[1000], two = new T[1000];
676         // Test complex, fast implementation against straightforward,
677         // slow implementation.
678         foreach(i; 0..100) {
679             size_t lowerBound = uniform(0, 1000);
680             size_t upperBound = uniform(0, 1000);
681             if(lowerBound > upperBound) swap(lowerBound, upperBound);
682             foreach(ref o; one) {
683                 o = uniform(cast(T) -10, cast(T) 10);
684             }
685             foreach(ref o; two) {
686                  o = uniform(cast(T) -10, cast(T) 10);
687             }
688             double kOne =
689                  kendallCor(one[lowerBound..upperBound], two[lowerBound..upperBound]);
690             double kTwo =
691                  kendallCorSmallN(one[lowerBound..upperBound], two[lowerBound..upperBound]);
692             assert(approxEqual(kOne, kTwo) || (isNaN(kOne) && isNaN(kTwo)));
693         }
694     }
695
696     doKendallTest!int();
697     doKendallTest!float();
698     doKendallTest!double();
699
700     // Make sure everything works with lowest common denominator range type.
701     static struct Count {
702         uint num;
703         uint upTo;
704         @property size_t front() {
705             return num;
706         }
707         void popFront() {
708             num++;
709         }
710         @property bool empty() {
711             return num >= upTo;
712         }
713     }
714
715     Count a, b;
716     a.upTo = 100;
717     b.upTo = 100;
718     assert(approxEqual(kendallCor(a, b), 1));
719
720     // This test will fail if there are overflow bugs, especially in tie
721     // handling.
722     auto rng = chain(replicate(0, 100_000), replicate(1, 100_000));
723     assert(approxEqual(kendallCor(rng, rng), 1));
724 }
725
726 // Alias to old correlation function names, but don't document them.  These will
727 // eventually be deprecated.
728 alias PearsonCor Pcor;
729 alias pearsonCor pcor;
730 alias spearmanCor scor;
731 alias kendallCor kcor;
732 alias kendallCorDestructive kcorDestructive;
733
734 /**Computes the partial correlation between vec1, vec2 given
735  * conditions.  conditions can be either a tuple of ranges, a range of ranges,
736  * or (for a single condition) a single range.
737  *
738  * cor is the correlation metric to use.  It can be either pearsonCor,
739  * spearmanCor, kendallCor, or any custom correlation metric you can come up
740  * with.
741  *
742  * Examples:
743  * ---
744  * uint[] stock1Price = [8, 6, 7, 5, 3, 0, 9];
745  * uint[] stock2Price = [3, 1, 4, 1, 5, 9, 2];
746  * uint[] economicHealth = [2, 7, 1, 8, 2, 8, 1];
747  * uint[] consumerFear = [1, 2, 3, 4, 5, 6, 7];
748  *
749  * // See whether the prices of stock 1 and stock 2 are correlated even
750  * // after adjusting for the overall condition of the economy and consumer
751  * // fear.
752  * double partialCor =
753  *   partial!pearsonCor(stock1Price, stock2Price, economicHealth, consumerFear);
754  * ---
755  */
756 double partial(alias cor, T, U, V...)(T vec1, U vec2, V conditionsIn)
757 if(isInputRange!T && isInputRange!U && allSatisfy!(isInputRange, V)) {
758     mixin(newFrame);
759     static if(V.length == 1 && isInputRange!(ElementType!(V[0]))) {
760         // Range of ranges.
761         static if(isArray!(V[0])) {
762             alias conditionsIn[0] cond;
763         } else {
764             auto cond = tempdup(cond[0]);
765         }
766     } else {
767         alias conditionsIn cond;
768     }
769
770     auto corMatrix = newStack!(double[])(cond.length + 2);
771     foreach(i, ref elem; corMatrix) {
772         elem = newStack!double((cond.length + 2));
773         elem[] = 0;
774         elem[i] = 1;
775     }
776
777     corMatrix[0][1] = corMatrix[1][0] = cast(double) cor(vec1, vec2);
778     foreach(i, condition; cond) {
779         immutable conditionIndex = i + 2;
780         corMatrix[0][conditionIndex] = cast(double) cor(vec1, condition);
781         corMatrix[conditionIndex][0] =  corMatrix[0][conditionIndex];
782         corMatrix[1][conditionIndex] = cast(double) cor(vec2, condition);
783         corMatrix[conditionIndex][1] = corMatrix[1][conditionIndex];
784     }
785
786     foreach(i, condition1; cond) {
787         foreach(j, condition2; cond[i + 1..$]) {
788             immutable index1 = i + 2;
789             immutable index2 = index1 + j + 1;
790             corMatrix[index1][index2] = cast(double) cor(condition1, condition2);
791             corMatrix[index2][index1] = corMatrix[index1][index2];
792         }
793     }
794
795     auto invMatrix = newStack!(double[])(cond.length + 2);
796     foreach(i, ref elem; invMatrix) {
797         elem = newStack!double((cond.length + 2));
798         elem[] = 0;
799         elem[i] = 1;
800     }
801
802     invert(corMatrix, invMatrix);
803     return -invMatrix[0][1] / sqrt(invMatrix[0][0] * invMatrix[1][1]);
804 }
805
806 unittest {
807     // values from Matlab.
808     uint[] stock1Price = [8, 6, 7, 5, 3, 0, 9];
809     uint[] stock2Price = [3, 1, 4, 1, 5, 9, 2];
810     uint[] economicHealth = [2, 7, 1, 8, 2, 8, 1];
811     uint[] consumerFear = [1, 2, 3, 4, 5, 6, 7];
812     double partialCor =
813     partial!pearsonCor(stock1Price, stock2Price, [economicHealth, consumerFear][]);
814     assert(approxEqual(partialCor, -0.857818));
815
816     double spearmanPartial =
817     partial!spearmanCor(stock1Price, stock2Price, economicHealth, consumerFear);
818     assert(approxEqual(spearmanPartial, -0.7252));
819 }
820
821 // Verify that there are no TempAlloc memory leaks anywhere in the code covered
822 // by the unittest.  This should always be the last unittest of the module.
823 unittest {
824     auto TAState = TempAlloc.getState;
825     assert(TAState.used == 0);
826     assert(TAState.nblocks < 2);
827 }
Note: See TracBrowser for help on using the browser.