root/trunk/infotheory.d

Revision 242, 24.6 kB (checked in by dsimcha, 1 year ago)

Added p-value stuff for DenseInfoTheory/?

Line 
1 /**Basic information theory.  Joint entropy, mutual information, conditional
2  * mutual information.  This module uses the base 2 definition of these
3  * quantities, i.e, entropy, mutual info, etc. are output in bits.
4  *
5  * Author:  David Simcha*/
6  /*
7  * License:
8  * Boost Software License - Version 1.0 - August 17th, 2003
9  *
10  * Permission is hereby granted, free of charge, to any person or organization
11  * obtaining a copy of the software and accompanying documentation covered by
12  * this license (the "Software") to use, reproduce, display, distribute,
13  * execute, and transmit the Software, and to prepare derivative works of the
14  * Software, and to permit third-parties to whom the Software is furnished to
15  * do so, all subject to the following:
16  *
17  * The copyright notices in the Software and this entire statement, including
18  * the above license grant, this restriction and the following disclaimer,
19  * must be included in all copies of the Software, in whole or in part, and
20  * all derivative works of the Software, unless such copies or derivative
21  * works are solely in the form of machine-executable object code generated by
22  * a source language processor.
23  *
24  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26  * FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
27  * SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
28  * FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
29  * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
30  * DEALINGS IN THE SOFTWARE.
31  */
32
33 module dstats.infotheory;
34
35 import std.traits, std.math, std.typetuple, std.functional, std.range,
36        std.array, std.typecons, std.algorithm;
37
38 import dstats.base, dstats.alloc;
39 import dstats.summary : sum;
40 import dstats.distrib : chiSquareCDFR;
41
42 import dstats.tests : toContingencyScore, gTestContingency;
43
44 version(unittest) {
45     import std.stdio, std.bigint, dstats.tests : gTestObs;
46
47     void main() {}
48 }
49
50 /**This function calculates the Shannon entropy of a forward range that is
51  * treated as frequency counts of a set of discrete observations.
52  *
53  * Examples:
54  * ---
55  * double uniform3 = entropyCounts([4, 4, 4]);
56  * assert(approxEqual(uniform3, log2(3)));
57  * double uniform4 = entropyCounts([5, 5, 5, 5]);
58  * assert(approxEqual(uniform4, 2));
59  * ---
60  */
61 double entropyCounts(T)(T data)
62 if(isForwardRange!(T) && doubleInput!(T)) {
63     auto save = data.save();
64     return entropyCounts(save, sum!(T, double)(data));
65 }
66
67 double entropyCounts(T)(T data, double n)
68 if(isIterable!(T)) {
69     immutable double nNeg1 = 1.0 / n;
70     double entropy = 0;
71     foreach(value; data) {
72         if(value == 0)
73             continue;
74         double pxi = cast(double) value * nNeg1;
75         entropy -= pxi * log2(pxi);
76     }
77     return entropy;
78 }
79
80 unittest {
81     double uniform3 = entropyCounts([4, 4, 4].dup);
82     assert(approxEqual(uniform3, log2(3)));
83     double uniform4 = entropyCounts([5, 5, 5, 5].dup);
84     assert(approxEqual(uniform4, 2));
85     assert(entropyCounts([2,2].dup)==1);
86     assert(entropyCounts([5.1,5.1,5.1,5.1].dup)==2);
87     assert(approxEqual(entropyCounts([1,2,3,4,5].dup), 2.1492553971685));
88 }
89
90 template FlattenType(T...) {
91     alias FlattenTypeImpl!(T).ret FlattenType;
92 }
93
94 template FlattenTypeImpl(T...) {
95     static if(T.length == 0) {
96         alias TypeTuple!() ret;
97     } else {
98         T[0] j;
99         static if(is(typeof(j._jointRanges))) {
100             alias TypeTuple!(typeof(j._jointRanges), FlattenType!(T[1..$])) ret;
101         } else {
102             alias TypeTuple!(T[0], FlattenType!(T[1..$])) ret;
103         }
104     }
105 }
106
107 private Joint!(FlattenType!(T, U)) flattenImpl(T, U...)(T start, U rest) {
108     static if(rest.length == 0) {
109         return start;
110     } else static if(is(typeof(rest[0]._jointRanges))) {
111         return flattenImpl(jointImpl(start.tupleof, rest[0]._jointRanges), rest[1..$]);
112     } else {
113         return flattenImpl(jointImpl(start.tupleof, rest[0]), rest[1..$]);
114     }
115 }
116
117 Joint!(FlattenType!(T)) flatten(T...)(T args) {
118     static assert(args.length > 0);
119     static if(is(typeof(args[0]._jointRanges))) {
120         auto myTuple = args[0];
121     } else {
122         auto myTuple = jointImpl(args[0]);
123     }
124     static if(args.length == 1) {
125         return myTuple;
126     } else {
127         return flattenImpl(myTuple, args[1..$]);
128     }
129 }
130
131 /**Bind a set of ranges together to represent a joint probability distribution.
132  *
133  * Examples:
134  * ---
135  * auto foo = [1,2,3,1,1];
136  * auto bar = [2,4,6,2,2];
137  * auto e = entropy(joint(foo, bar));  // Calculate joint entropy of foo, bar.
138  * ---
139  */
140 Joint!(FlattenType!(T)) joint(T...)(T args) {
141     return jointImpl(flatten(args).tupleof);
142 }
143
144 Joint!(T) jointImpl(T...)(T args) {
145     return Joint!(T)(args);
146 }
147
148 /**Iterate over a set of ranges by value in lockstep and return an ObsEnt,
149  * which is used internally by entropy functions on each iteration.*/
150 struct Joint(T...) {
151     T _jointRanges;
152
153     @property ObsEnt!(ElementsTuple!(T)) front() {
154         alias ElementsTuple!(T) E;
155         alias ObsEnt!(E) rt;
156         rt ret;
157         foreach(ti, elem; _jointRanges) {
158             ret.tupleof[ti] = elem.front;
159         }
160         return ret;
161     }
162
163     void popFront() {
164         foreach(ti, elem; _jointRanges) {
165             _jointRanges[ti].popFront;
166         }
167     }
168
169     @property bool empty() {
170         foreach(elem; _jointRanges) {
171             if(elem.empty) {
172                 return true;
173             }
174         }
175         return false;
176     }
177
178     static if(T.length > 0 && allSatisfy!(dstats.base.hasLength, T)) {
179         @property size_t length() {
180             size_t ret = size_t.max;
181             foreach(range; _jointRanges) {
182                 auto len = range.length;
183                 if(len < ret) {
184                     ret = len;
185                 }
186             }
187             return ret;
188         }
189     }
190 }
191
192 template ElementsTuple(T...) {
193     static if(T.length == 1) {
194         alias TypeTuple!(Unqual!(ElementType!(T[0]))) ElementsTuple;
195     } else {
196         alias TypeTuple!(Unqual!(ElementType!(T[0])), ElementsTuple!(T[1..$]))
197             ElementsTuple;
198     }
199 }
200
201 private template Comparable(T) {
202     enum bool Comparable = is(typeof({
203         T a;
204         T b;
205         return a < b; }));
206 }
207
208 static assert(Comparable!ubyte);
209 static assert(Comparable!ubyte);
210
211 struct ObsEnt(T...) {
212     T compRep;
213     alias compRep this;
214
215     static if(isReferenceType!(typeof(this))) {
216
217         // Then there's indirection involved.  We can't just do all our
218         // comparison and hashing operations bitwise.
219         hash_t toHash() {
220             hash_t sum = 0;
221             foreach(i, elem; this.tupleof) {
222                 sum *= 11;
223                 static if(is(elem : long) && elem.sizeof <= hash_t.sizeof) {
224                     sum += elem;
225                 } else static if(__traits(compiles, elem.toHash)) {
226                     sum += elem.toHash;
227                 } else {
228                     auto ti = typeid(typeof(elem));
229                     sum += ti.getHash(&elem);
230                 }
231             }
232             return sum;
233         }
234
235         bool opEquals(const ref typeof(this) rhs) const {
236             foreach(ti, elem; this.tupleof) {
237                 if(elem != rhs.tupleof[ti])
238                     return false;
239             }
240             return true;
241         }
242     }
243     // Else just use the default runtime functions for hash and equality.
244
245
246     static if(allSatisfy!(Comparable, T)) {
247         int opCmp(const ref typeof(this) rhs) const {
248             foreach(ti, elem; this.tupleof) {
249                 if(rhs.tupleof[ti] < elem) {
250                     return -1;
251                 } else if(rhs.tupleof[ti] > elem) {
252                     return 1;
253                 }
254             }
255             return 0;
256         }
257     }
258 }
259
260 // Whether we can use StackTreeAA, or whether we have to use a regular AA for
261 // entropy.
262 private template NeedsHeap(T) {
263     static if(!isReferenceType!(IterType!(T))) {
264         enum bool NeedsHeap = false;
265     } else static if(isArray!(T)) {
266         enum bool NeedsHeap = false;
267     } else static if(is(Joint!(typeof(T.init.tupleof)))
268            && is(T == Joint!(typeof(T.init.tupleof)))
269            && allSatisfy!(isArray, typeof(T.init.tupleof))) {
270         enum bool NeedsHeap = false;
271     } else {
272         enum bool NeedsHeap = true;
273     }
274 }
275
276 unittest {
277     auto foo = filter!"a"(cast(uint[][]) [[1]]);
278     auto bar = filter!("a")([1,2,3][]);
279     static assert(NeedsHeap!(typeof(foo)));
280     static assert(!NeedsHeap!(typeof(bar)));
281     static assert(NeedsHeap!(Joint!(uint[], typeof(foo))));
282     static assert(!NeedsHeap!(Joint!(uint[], typeof(bar))));
283     static assert(!NeedsHeap!(Joint!(uint[], uint[])));
284 }
285
286 /**Calculates the joint entropy of a set of observations.  Each input range
287  * represents a vector of observations. If only one range is given, this reduces
288  * to the plain old entropy.  Input range must have a length.
289  *
290  * Note:  This function specializes if ElementType!(T) is a byte, ubyte, or
291  * char, resulting in a much faster entropy calculation.  When possible, try
292  * to provide data in the form of a byte, ubyte, or char.
293  *
294  * Examples:
295  * ---
296  * int[] foo = [1, 1, 1, 2, 2, 2, 3, 3, 3];
297  * double entropyFoo = entropy(foo);  // Plain old entropy of foo.
298  * assert(approxEqual(entropyFoo, log2(3)));
299  * int[] bar = [1, 2, 3, 1, 2, 3, 1, 2, 3];
300  * double HFooBar = entropy(joint(foo, bar));  // Joint entropy of foo and bar.
301  * assert(approxEqual(HFooBar, log2(9)));
302  * ---
303  */
304 double entropy(T)(T data)
305 if(isIterable!(T)) {
306     static if(!dstats.base.hasLength!(T)) {
307         return entropyImpl!(uint, T)(data);
308     } else {
309         if(data.length <= ubyte.max) {
310             return entropyImpl!(ubyte, T)(data);
311         } else if(data.length <= ushort.max) {
312             return entropyImpl!(ushort, T)(data);
313         } else {
314             return entropyImpl!(uint, T)(data);
315         }
316     }
317 }
318
319 private double entropyImpl(U, T)(T data)
320 if((IterType!(T).sizeof > 1 || is(IterType!T == struct)) && !NeedsHeap!(T)) {
321     // Generic version.
322     mixin(newFrame);
323     alias IterType!(T) E;
324
325     static if(dstats.base.hasLength!T) {
326         auto counts = StackHash!(E, U)(max(20, data.length / 20));
327     } else {
328         auto counts = StackTreeAA!(E, U)();
329     }
330     uint N;
331
332     foreach(elem; data)  {
333         counts[elem]++;
334         N++;
335     }
336
337     double ans = entropyCounts(counts.values, N);
338     return ans;
339 }
340
341 private double entropyImpl(U, T)(T data)
342 if(IterType!(T).sizeof > 1 && NeedsHeap!(T)) {  // Generic version.
343     alias IterType!(T) E;
344
345     uint len = 0;
346     U[E] counts;
347     foreach(elem; data) {
348         len++;
349         counts[elem]++;
350     }
351     return entropyCounts(counts, len);
352 }
353
354 private double entropyImpl(U, T)(T data)  // byte/char specialization
355 if(IterType!(T).sizeof == 1 && !is(IterType!T == struct)) {
356     alias IterType!(T) E;
357
358     U[ubyte.max + 1] counts;
359
360     uint min = ubyte.max, max = 0, len = 0;
361     foreach(elem; data)  {
362         len++;
363         static if(is(E == byte)) {
364             // Keep adjacent elements adjacent.  In real world use cases,
365             // probably will have ranges like [-1, 1].
366             ubyte e = cast(ubyte) (cast(ubyte) (elem) + byte.max);
367         } else {
368             ubyte e = cast(ubyte) elem;
369         }
370         counts[e]++;
371         if(e > max) {
372             max = e;
373         }
374         if(e < min) {
375             min = e;
376         }
377     }
378
379     return entropyCounts(counts.ptr[min..max + 1], len);
380 }
381
382 unittest {
383     { // Generic version.
384         int[] foo = [1, 1, 1, 2, 2, 2, 3, 3, 3];
385         double entropyFoo = entropy(foo);
386         assert(approxEqual(entropyFoo, log2(3)));
387         int[] bar = [1, 2, 3, 1, 2, 3, 1, 2, 3];
388         auto stuff = joint(foo, bar);
389         double jointEntropyFooBar = entropy(joint(foo, bar));
390         assert(approxEqual(jointEntropyFooBar, log2(9)));
391     }
392     { // byte specialization
393         byte[] foo = [-1, -1, -1, 2, 2, 2, 3, 3, 3];
394         double entropyFoo = entropy(foo);
395         assert(approxEqual(entropyFoo, log2(3)));
396         string bar = "ACTGGCTA";
397         assert(entropy(bar) == 2);
398     }
399     { // NeedsHeap version.
400         string[] arr = ["1", "1", "1", "2", "2", "2", "3", "3", "3"];
401         auto m = map!("a")(arr);
402         assert(approxEqual(entropy(m), log2(3)));
403     }
404 }
405
406 /**Calculate the conditional entropy H(data | cond).*/
407 double condEntropy(T, U)(T data, U cond)
408 if(isForwardRange!(T) && isForwardRange!(U)) {
409     return entropy(joint(data, cond)) - entropy(cond);
410 }
411
412 unittest {
413     // This shouldn't be easy to screw up.  Just really basic.
414     int[] foo = [1,2,2,1,1];
415     int[] bar = [1,2,3,1,2];
416     assert(approxEqual(entropy(foo) - condEntropy(foo, bar),
417            mutualInfo(foo, bar)));
418 }
419
420 private double miContingency(double observed, double expected) {
421     return (observed == 0) ? 0 :
422            (observed * log2(observed / expected));
423 }
424
425
426 /**Calculates the mutual information of two vectors of discrete observations.
427  */
428 double mutualInfo(T, U)(T x, U y)
429 if(isInputRange!(T) && isInputRange!(U)) {
430     uint xFreedom, yFreedom, n;
431     typeof(return) ret;
432
433     static if(!dstats.base.hasLength!T && !dstats.base.hasLength!U) {
434         ret = toContingencyScore!(T, U, uint)
435             (x, y, &miContingency, xFreedom, yFreedom, n);
436     } else {
437         immutable minLen = min(x.length, y.length);
438         if(minLen <= ubyte.max) {
439             ret = toContingencyScore!(T, U, ubyte)
440                 (x, y, &miContingency, xFreedom, yFreedom, n);
441         } else if(minLen <= ushort.max) {
442             ret = toContingencyScore!(T, U, ushort)
443                 (x, y, &miContingency, xFreedom, yFreedom, n);
444         } else {
445             ret = toContingencyScore!(T, U, uint)
446                 (x, y, &miContingency, xFreedom, yFreedom, n);
447         }
448     }
449
450     return ret / n;
451 }
452
453 unittest {
454     // Values from R, but converted from base e to base 2.
455     assert(approxEqual(mutualInfo(bin([1,2,3,3,8].dup, 10),
456            bin([8,6,7,5,3].dup, 10)), 1.921928));
457     assert(approxEqual(mutualInfo(bin([1,2,1,1,3,4,3,6].dup, 2),
458            bin([2,7,9,6,3,1,7,40].dup, 2)), .2935645));
459     assert(approxEqual(mutualInfo(bin([1,2,1,1,3,4,3,6].dup, 4),
460            bin([2,7,9,6,3,1,7,40].dup, 4)), .5435671));
461
462 }
463
464 /**
465 Calculates the mutual information of a contingency table representing a joint
466 discrete probability distribution.  Takes a set of finite forward ranges,
467 one for each column in the contingency table.  These can be expressed either as
468 a tuple of ranges or a range of ranges.
469 */
470 double mutualInfoTable(T...)(T table) {
471     // This function is really just included to give conceptual unity to
472     // the infotheory module.
473     return gTestContingency(table).mutualInfo;
474 }
475
476 /**
477 Calculates the conditional mutual information I(x, y | z) from a set of
478 observations.
479 */
480 double condMutualInfo(T, U, V)(T x, U y, V z) {
481     auto ret = entropy(joint(x, z)) - entropy(joint(x, y, z)) - entropy(z)
482         + entropy(joint(y, z));
483     return max(ret, 0);
484 }
485
486 unittest {
487     // Values from Matlab mi package by Hanchuan Peng.
488     auto res = condMutualInfo([1,2,1,2,1,2,1,2].dup, [3,1,2,3,4,2,1,2].dup,
489                               [1,2,3,1,2,3,1,2].dup);
490     assert(approxEqual(res, 0.4387));
491     res = condMutualInfo([1,2,3,1,2].dup, [2,1,3,2,1].dup,
492                          joint([1,1,1,2,2].dup, [2,2,2,1,1].dup));
493     assert(approxEqual(res, 1.3510));
494 }
495
496 /**Calculates the entropy of any old input range of observations more quickly
497  * than entropy(), provided that all equal values are adjacent.  If the input
498  * is sorted by more than one key, i.e. structs, the result will be the joint
499  * entropy of all of the keys.  The compFun alias will be used to compare
500  * adjacent elements and determine how many instances of each value exist.*/
501 double entropySorted(alias compFun = "a == b", T)(T data)
502 if(isInputRange!(T)) {
503     alias ElementType!(T) E;
504     alias binaryFun!(compFun) comp;
505     immutable n = data.length;
506     immutable nrNeg1 = 1.0L / n;
507
508     double sum = 0.0;
509     int nSame = 1;
510     auto last = data.front;
511     data.popFront;
512     foreach(elem; data) {
513         if(comp(elem, last)) {
514             nSame++;
515         } else {
516             immutable p = nSame * nrNeg1;
517             nSame = 1;
518             sum -= p * log2(p);
519         }
520         last = elem;
521     }
522     // Handle last run.
523     immutable p = nSame * nrNeg1;
524     sum -= p * log2(p);
525
526     return sum;
527 }
528
529 unittest {
530     uint[] foo = [1U,2,3,1,3,2,6,3,1,6,3,2,2,1,3,5,2,1].dup;
531     auto sorted = foo.dup;
532     sort(sorted);
533     assert(approxEqual(entropySorted(sorted), entropy(foo)));
534 }
535
536 /**
537 Much faster implementations of information theory functions for the special
538 but common case where all observations are integers on the range [0, nBin).
539 This is the case, for example, when the observations have been previously
540 binned using, for example, dstats.base.frqBin().
541
542 Note that, due to the optimizations used, joint() cannot be used with
543 the member functions of this struct, except entropy().
544
545 For those looking for hard numbers, this seems to be on the order of 10x
546 faster than the generic implementations according to my quick and dirty
547 benchmarks.
548 */
549 struct DenseInfoTheory {
550     private uint nBin;
551
552     // Saves space and makes things cache efficient by using the smallest
553     // integer width necessary for binning.
554     double selectSize(alias fun, T...)(T args) {
555         static if(allSatisfy!(dstats.base.hasLength, T)) {
556             immutable len = args[0].length;
557
558             if(len <= ubyte.max) {
559                 return fun!ubyte(args);
560             } else if(len <= ushort.max) {
561                 return fun!ushort(args);
562             } else {
563                 return fun!uint(args);
564             }
565
566             // For now, assume that noone is going to have more than
567             // 4 billion observations.
568         } else {
569             return fun!uint(args);
570         }
571     }
572
573     /**
574     Constructs a DenseInfoTheory object for nBin bins.  The values taken by
575     each observation must then be on the interval [0, nBin).
576     */
577     this(uint nBin) {
578         this.nBin = nBin;
579     }
580
581     /**
582     Computes the entropy of a set of observations.  Note that, for this
583     function, the joint() function can be used to compute joint entropies
584     as long as each individual range contains only integers on [0, nBin).
585     */
586     double entropy(R)(R range) if(isIterable!R) {
587         return selectSize!entropyImpl(range);
588     }
589
590     private double entropyImpl(Uint, R)(R range) {
591         mixin(newFrame);
592         uint n = 0;
593
594         static if(is(typeof(range._jointRanges))) {
595             // Compute joint entropy.
596             immutable nRanges = range._jointRanges.length;
597             auto counts = newStack!Uint(nBin ^^ nRanges);
598             counts[] = 0;
599
600             Outer:
601             while(true) {
602                 uint multiplier = 1;
603                 uint index = 0;
604
605                 foreach(ti, Unused; typeof(range._jointRanges)) {
606                     if(range._jointRanges[ti].empty) break Outer;
607                     immutable rFront = range._jointRanges[ti].front;
608                     assert(rFront < nBin);  // Enforce is too costly here.
609
610                     index += multiplier * cast(uint) rFront;
611                     range._jointRanges[ti].popFront();
612                     multiplier *= nBin;
613                 }
614
615                 counts[index]++;
616                 n++;
617             }
618
619             return entropyCounts(counts, n);
620         } else {
621             auto counts = newStack!Uint(nBin);
622
623             counts[] = 0;
624             foreach(elem; range) {
625                 counts[elem]++;
626                 n++;
627             }
628
629             return entropyCounts(counts, n);
630         }
631     }
632
633     /// I(x; y)
634     double mutualInfo(R1, R2)(R1 x, R2 y)
635     if(isIterable!R1 && isIterable!R2) {
636         return selectSize!mutualInfoImpl(x, y);
637     }
638
639     private double mutualInfoImpl(Uint, R1, R2)(R1 x, R2 y) {
640         mixin(newFrame);
641         auto joint = newStack!Uint(nBin * nBin);
642         auto margx = newStack!Uint(nBin);
643         auto margy = newStack!Uint(nBin);
644         joint[] = 0;
645         margx[] = 0;
646         margy[] = 0;
647         uint n;
648
649         while(!x.empty && !y.empty) {
650             immutable xFront = cast(uint) x.front;
651             immutable yFront = cast(uint) y.front;
652             assert(xFront < nBin);
653             assert(yFront < nBin);
654
655             joint[xFront * nBin + yFront]++;
656             margx[xFront]++;
657             margy[yFront]++;
658             n++;
659             x.popFront();
660             y.popFront();
661         }
662
663         auto ret = entropyCounts(margx, n) + entropyCounts(margy, n) -
664             entropyCounts(joint, n);
665         return max(0, ret);
666     }
667
668     /**
669     Calculates the P-value for I(X; Y) assuming x and y both have supports
670     of [0, nBin).  The P-value is calculated using a Chi-Square approximation.
671     It is asymptotically correct, but is approximate for finite sample size.
672
673     Parameters:
674     mutualInfo:  I(x; y), in bits
675     n:  The number of samples used to calculate I(x; y)
676     */
677     double mutualInfoPval(double mutualInfo, double n) {
678         immutable df = (nBin - 1) ^^ 2;
679
680         immutable testStat = mutualInfo * 2 * LN2 * n;
681         return chiSquareCDFR(testStat, df);
682     }
683
684     /// H(X | Y)
685     double condEntropy(R1, R2)(R1 x, R2 y)
686     if(isIterable!R1 && isIterable!R2) {
687         return selectSize!condEntropyImpl(x, y);
688     }
689
690     private double condEntropyImpl(Uint, R1, R2)(R1 x, R2 y) {
691         mixin(newFrame);
692         auto joint = newStack!Uint(nBin * nBin);
693         auto margy = newStack!Uint(nBin);
694         joint[] = 0;
695         margy[] = 0;
696         uint n;
697
698         while(!x.empty && !y.empty) {
699             immutable xFront = cast(uint) x.front;
700             immutable yFront = cast(uint) y.front;
701             assert(xFront < nBin);
702             assert(yFront < nBin);
703
704             joint[xFront * nBin + yFront]++;
705             margy[yFront]++;
706             n++;
707             x.popFront();
708             y.popFront();
709         }
710
711         auto ret = entropyCounts(joint, n) - entropyCounts(margy, n);
712         return max(0, ret);
713     }
714
715     /// I(X; Y | Z)
716     double condMutualInfo(R1, R2, R3)(R1 x, R2 y, R3 z)
717     if(allSatisfy!(isIterable, R1, R2, R3)) {
718         return selectSize!condMutualInfoImpl(x, y, z);
719     }
720
721     private double condMutualInfoImpl(Uint, R1, R2, R3)(R1 x, R2 y, R3 z) {
722         mixin(newFrame);
723         immutable nBinSq = nBin * nBin;
724         auto jointxyz = newStack!Uint(nBin * nBin * nBin);
725         auto jointxz = newStack!Uint(nBinSq);
726         auto jointyz = newStack!Uint(nBinSq);
727         auto margz = newStack!Uint(nBin);
728         jointxyz[] = 0;
729         jointxz[] = 0;
730         jointyz[] = 0;
731         margz[] = 0;
732         uint n = 0;
733
734         while(!x.empty && !y.empty && !z.empty) {
735             immutable xFront = cast(uint) x.front;
736             immutable yFront = cast(uint) y.front;
737             immutable zFront = cast(uint) z.front;
738             assert(xFront < nBin);
739             assert(yFront < nBin);
740             assert(zFront < nBin);
741
742             jointxyz[xFront * nBinSq + yFront * nBin + zFront]++;
743             jointxz[xFront * nBin + zFront]++;
744             jointyz[yFront * nBin + zFront]++;
745             margz[zFront]++;
746             n++;
747
748             x.popFront();
749             y.popFront();
750             z.popFront();
751         }
752
753         auto ret = entropyCounts(jointxz, n) - entropyCounts(jointxyz, n) -
754             entropyCounts(margz, n) + entropyCounts(jointyz, n);
755         return max(0, ret);
756     }
757 }
758
759 unittest {
760     auto dense = DenseInfoTheory(3);
761     auto a = [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2];
762     auto b = [1, 2, 2, 2, 0, 0, 1, 1, 1, 1, 0, 0];
763     auto c = [1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 0, 0];
764
765     assert(entropy(a) == dense.entropy(a));
766     assert(entropy(b) == dense.entropy(b));
767     assert(entropy(c) == dense.entropy(c));
768     assert(entropy(joint(a, c)) == dense.entropy(joint(c, a)));
769     assert(entropy(joint(a, b)) == dense.entropy(joint(a, b)));
770     assert(entropy(joint(c, b)) == dense.entropy(joint(c, b)));
771
772     assert(condEntropy(a, c) == dense.condEntropy(a, c));
773     assert(condEntropy(a, b) == dense.condEntropy(a, b));
774     assert(condEntropy(c, b) == dense.condEntropy(c, b));
775
776     alias approxEqual ae;
777     assert(ae(mutualInfo(a, c), dense.mutualInfo(c, a)));
778     assert(ae(mutualInfo(a, b), dense.mutualInfo(a, b)));
779     assert(ae(mutualInfo(c, b), dense.mutualInfo(c, b)));
780
781     assert(ae(condMutualInfo(a, b, c), dense.condMutualInfo(a, b, c)));
782     assert(ae(condMutualInfo(a, c, b), dense.condMutualInfo(a, c, b)));
783     assert(ae(condMutualInfo(b, c, a), dense.condMutualInfo(b, c, a)));
784
785     // Test P-value stuff.
786     immutable pDense = dense.mutualInfoPval(dense.mutualInfo(a, b), a.length);
787     immutable pNotDense = gTestObs(a, b).p;
788     assert(approxEqual(pDense, pNotDense));
789 }
790
791 // Verify that there are no TempAlloc memory leaks anywhere in the code covered
792 // by the unittest.  This should always be the last unittest of the module.
793 unittest {
794     auto TAState = TempAlloc.getState;
795     assert(TAState.used == 0);
796     assert(TAState.nblocks < 2);
797 }
Note: See TracBrowser for help on using the browser.