| 1 |
/**Basic information theory. Joint entropy, mutual information, conditional |
|---|
| 2 |
* mutual information. This module uses the base 2 definition of these |
|---|
| 3 |
* quantities, i.e, entropy, mutual info, etc. are output in bits. |
|---|
| 4 |
* |
|---|
| 5 |
* Author: David Simcha*/ |
|---|
| 6 |
/* |
|---|
| 7 |
* License: |
|---|
| 8 |
* Boost Software License - Version 1.0 - August 17th, 2003 |
|---|
| 9 |
* |
|---|
| 10 |
* Permission is hereby granted, free of charge, to any person or organization |
|---|
| 11 |
* obtaining a copy of the software and accompanying documentation covered by |
|---|
| 12 |
* this license (the "Software") to use, reproduce, display, distribute, |
|---|
| 13 |
* execute, and transmit the Software, and to prepare derivative works of the |
|---|
| 14 |
* Software, and to permit third-parties to whom the Software is furnished to |
|---|
| 15 |
* do so, all subject to the following: |
|---|
| 16 |
* |
|---|
| 17 |
* The copyright notices in the Software and this entire statement, including |
|---|
| 18 |
* the above license grant, this restriction and the following disclaimer, |
|---|
| 19 |
* must be included in all copies of the Software, in whole or in part, and |
|---|
| 20 |
* all derivative works of the Software, unless such copies or derivative |
|---|
| 21 |
* works are solely in the form of machine-executable object code generated by |
|---|
| 22 |
* a source language processor. |
|---|
| 23 |
* |
|---|
| 24 |
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|---|
| 25 |
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|---|
| 26 |
* FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT |
|---|
| 27 |
* SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE |
|---|
| 28 |
* FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, |
|---|
| 29 |
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
|---|
| 30 |
* DEALINGS IN THE SOFTWARE. |
|---|
| 31 |
*/ |
|---|
| 32 |
|
|---|
| 33 |
module dstats.infotheory; |
|---|
| 34 |
|
|---|
| 35 |
import std.traits, std.math, std.typetuple, std.functional, std.range, |
|---|
| 36 |
std.array, std.typecons, std.algorithm; |
|---|
| 37 |
|
|---|
| 38 |
import dstats.base, dstats.alloc; |
|---|
| 39 |
import dstats.summary : sum; |
|---|
| 40 |
import dstats.distrib : chiSquareCDFR; |
|---|
| 41 |
|
|---|
| 42 |
import dstats.tests : toContingencyScore, gTestContingency; |
|---|
| 43 |
|
|---|
| 44 |
version(unittest) { |
|---|
| 45 |
import std.stdio, std.bigint, dstats.tests : gTestObs; |
|---|
| 46 |
|
|---|
| 47 |
void main() {} |
|---|
| 48 |
} |
|---|
| 49 |
|
|---|
| 50 |
/**This function calculates the Shannon entropy of a forward range that is |
|---|
| 51 |
* treated as frequency counts of a set of discrete observations. |
|---|
| 52 |
* |
|---|
| 53 |
* Examples: |
|---|
| 54 |
* --- |
|---|
| 55 |
* double uniform3 = entropyCounts([4, 4, 4]); |
|---|
| 56 |
* assert(approxEqual(uniform3, log2(3))); |
|---|
| 57 |
* double uniform4 = entropyCounts([5, 5, 5, 5]); |
|---|
| 58 |
* assert(approxEqual(uniform4, 2)); |
|---|
| 59 |
* --- |
|---|
| 60 |
*/ |
|---|
| 61 |
double entropyCounts(T)(T data) |
|---|
| 62 |
if(isForwardRange!(T) && doubleInput!(T)) { |
|---|
| 63 |
auto save = data.save(); |
|---|
| 64 |
return entropyCounts(save, sum!(T, double)(data)); |
|---|
| 65 |
} |
|---|
| 66 |
|
|---|
| 67 |
double entropyCounts(T)(T data, double n) |
|---|
| 68 |
if(isIterable!(T)) { |
|---|
| 69 |
immutable double nNeg1 = 1.0 / n; |
|---|
| 70 |
double entropy = 0; |
|---|
| 71 |
foreach(value; data) { |
|---|
| 72 |
if(value == 0) |
|---|
| 73 |
continue; |
|---|
| 74 |
double pxi = cast(double) value * nNeg1; |
|---|
| 75 |
entropy -= pxi * log2(pxi); |
|---|
| 76 |
} |
|---|
| 77 |
return entropy; |
|---|
| 78 |
} |
|---|
| 79 |
|
|---|
| 80 |
unittest { |
|---|
| 81 |
double uniform3 = entropyCounts([4, 4, 4].dup); |
|---|
| 82 |
assert(approxEqual(uniform3, log2(3))); |
|---|
| 83 |
double uniform4 = entropyCounts([5, 5, 5, 5].dup); |
|---|
| 84 |
assert(approxEqual(uniform4, 2)); |
|---|
| 85 |
assert(entropyCounts([2,2].dup)==1); |
|---|
| 86 |
assert(entropyCounts([5.1,5.1,5.1,5.1].dup)==2); |
|---|
| 87 |
assert(approxEqual(entropyCounts([1,2,3,4,5].dup), 2.1492553971685)); |
|---|
| 88 |
} |
|---|
| 89 |
|
|---|
| 90 |
template FlattenType(T...) { |
|---|
| 91 |
alias FlattenTypeImpl!(T).ret FlattenType; |
|---|
| 92 |
} |
|---|
| 93 |
|
|---|
| 94 |
template FlattenTypeImpl(T...) { |
|---|
| 95 |
static if(T.length == 0) { |
|---|
| 96 |
alias TypeTuple!() ret; |
|---|
| 97 |
} else { |
|---|
| 98 |
T[0] j; |
|---|
| 99 |
static if(is(typeof(j._jointRanges))) { |
|---|
| 100 |
alias TypeTuple!(typeof(j._jointRanges), FlattenType!(T[1..$])) ret; |
|---|
| 101 |
} else { |
|---|
| 102 |
alias TypeTuple!(T[0], FlattenType!(T[1..$])) ret; |
|---|
| 103 |
} |
|---|
| 104 |
} |
|---|
| 105 |
} |
|---|
| 106 |
|
|---|
| 107 |
private Joint!(FlattenType!(T, U)) flattenImpl(T, U...)(T start, U rest) { |
|---|
| 108 |
static if(rest.length == 0) { |
|---|
| 109 |
return start; |
|---|
| 110 |
} else static if(is(typeof(rest[0]._jointRanges))) { |
|---|
| 111 |
return flattenImpl(jointImpl(start.tupleof, rest[0]._jointRanges), rest[1..$]); |
|---|
| 112 |
} else { |
|---|
| 113 |
return flattenImpl(jointImpl(start.tupleof, rest[0]), rest[1..$]); |
|---|
| 114 |
} |
|---|
| 115 |
} |
|---|
| 116 |
|
|---|
| 117 |
Joint!(FlattenType!(T)) flatten(T...)(T args) { |
|---|
| 118 |
static assert(args.length > 0); |
|---|
| 119 |
static if(is(typeof(args[0]._jointRanges))) { |
|---|
| 120 |
auto myTuple = args[0]; |
|---|
| 121 |
} else { |
|---|
| 122 |
auto myTuple = jointImpl(args[0]); |
|---|
| 123 |
} |
|---|
| 124 |
static if(args.length == 1) { |
|---|
| 125 |
return myTuple; |
|---|
| 126 |
} else { |
|---|
| 127 |
return flattenImpl(myTuple, args[1..$]); |
|---|
| 128 |
} |
|---|
| 129 |
} |
|---|
| 130 |
|
|---|
| 131 |
/**Bind a set of ranges together to represent a joint probability distribution. |
|---|
| 132 |
* |
|---|
| 133 |
* Examples: |
|---|
| 134 |
* --- |
|---|
| 135 |
* auto foo = [1,2,3,1,1]; |
|---|
| 136 |
* auto bar = [2,4,6,2,2]; |
|---|
| 137 |
* auto e = entropy(joint(foo, bar)); // Calculate joint entropy of foo, bar. |
|---|
| 138 |
* --- |
|---|
| 139 |
*/ |
|---|
| 140 |
Joint!(FlattenType!(T)) joint(T...)(T args) { |
|---|
| 141 |
return jointImpl(flatten(args).tupleof); |
|---|
| 142 |
} |
|---|
| 143 |
|
|---|
| 144 |
Joint!(T) jointImpl(T...)(T args) { |
|---|
| 145 |
return Joint!(T)(args); |
|---|
| 146 |
} |
|---|
| 147 |
|
|---|
| 148 |
/**Iterate over a set of ranges by value in lockstep and return an ObsEnt, |
|---|
| 149 |
* which is used internally by entropy functions on each iteration.*/ |
|---|
| 150 |
struct Joint(T...) { |
|---|
| 151 |
T _jointRanges; |
|---|
| 152 |
|
|---|
| 153 |
@property ObsEnt!(ElementsTuple!(T)) front() { |
|---|
| 154 |
alias ElementsTuple!(T) E; |
|---|
| 155 |
alias ObsEnt!(E) rt; |
|---|
| 156 |
rt ret; |
|---|
| 157 |
foreach(ti, elem; _jointRanges) { |
|---|
| 158 |
ret.tupleof[ti] = elem.front; |
|---|
| 159 |
} |
|---|
| 160 |
return ret; |
|---|
| 161 |
} |
|---|
| 162 |
|
|---|
| 163 |
void popFront() { |
|---|
| 164 |
foreach(ti, elem; _jointRanges) { |
|---|
| 165 |
_jointRanges[ti].popFront; |
|---|
| 166 |
} |
|---|
| 167 |
} |
|---|
| 168 |
|
|---|
| 169 |
@property bool empty() { |
|---|
| 170 |
foreach(elem; _jointRanges) { |
|---|
| 171 |
if(elem.empty) { |
|---|
| 172 |
return true; |
|---|
| 173 |
} |
|---|
| 174 |
} |
|---|
| 175 |
return false; |
|---|
| 176 |
} |
|---|
| 177 |
|
|---|
| 178 |
static if(T.length > 0 && allSatisfy!(dstats.base.hasLength, T)) { |
|---|
| 179 |
@property size_t length() { |
|---|
| 180 |
size_t ret = size_t.max; |
|---|
| 181 |
foreach(range; _jointRanges) { |
|---|
| 182 |
auto len = range.length; |
|---|
| 183 |
if(len < ret) { |
|---|
| 184 |
ret = len; |
|---|
| 185 |
} |
|---|
| 186 |
} |
|---|
| 187 |
return ret; |
|---|
| 188 |
} |
|---|
| 189 |
} |
|---|
| 190 |
} |
|---|
| 191 |
|
|---|
| 192 |
template ElementsTuple(T...) { |
|---|
| 193 |
static if(T.length == 1) { |
|---|
| 194 |
alias TypeTuple!(Unqual!(ElementType!(T[0]))) ElementsTuple; |
|---|
| 195 |
} else { |
|---|
| 196 |
alias TypeTuple!(Unqual!(ElementType!(T[0])), ElementsTuple!(T[1..$])) |
|---|
| 197 |
ElementsTuple; |
|---|
| 198 |
} |
|---|
| 199 |
} |
|---|
| 200 |
|
|---|
| 201 |
private template Comparable(T) { |
|---|
| 202 |
enum bool Comparable = is(typeof({ |
|---|
| 203 |
T a; |
|---|
| 204 |
T b; |
|---|
| 205 |
return a < b; })); |
|---|
| 206 |
} |
|---|
| 207 |
|
|---|
| 208 |
static assert(Comparable!ubyte); |
|---|
| 209 |
static assert(Comparable!ubyte); |
|---|
| 210 |
|
|---|
| 211 |
struct ObsEnt(T...) { |
|---|
| 212 |
T compRep; |
|---|
| 213 |
alias compRep this; |
|---|
| 214 |
|
|---|
| 215 |
static if(isReferenceType!(typeof(this))) { |
|---|
| 216 |
|
|---|
| 217 |
// Then there's indirection involved. We can't just do all our |
|---|
| 218 |
// comparison and hashing operations bitwise. |
|---|
| 219 |
hash_t toHash() { |
|---|
| 220 |
hash_t sum = 0; |
|---|
| 221 |
foreach(i, elem; this.tupleof) { |
|---|
| 222 |
sum *= 11; |
|---|
| 223 |
static if(is(elem : long) && elem.sizeof <= hash_t.sizeof) { |
|---|
| 224 |
sum += elem; |
|---|
| 225 |
} else static if(__traits(compiles, elem.toHash)) { |
|---|
| 226 |
sum += elem.toHash; |
|---|
| 227 |
} else { |
|---|
| 228 |
auto ti = typeid(typeof(elem)); |
|---|
| 229 |
sum += ti.getHash(&elem); |
|---|
| 230 |
} |
|---|
| 231 |
} |
|---|
| 232 |
return sum; |
|---|
| 233 |
} |
|---|
| 234 |
|
|---|
| 235 |
bool opEquals(const ref typeof(this) rhs) const { |
|---|
| 236 |
foreach(ti, elem; this.tupleof) { |
|---|
| 237 |
if(elem != rhs.tupleof[ti]) |
|---|
| 238 |
return false; |
|---|
| 239 |
} |
|---|
| 240 |
return true; |
|---|
| 241 |
} |
|---|
| 242 |
} |
|---|
| 243 |
// Else just use the default runtime functions for hash and equality. |
|---|
| 244 |
|
|---|
| 245 |
|
|---|
| 246 |
static if(allSatisfy!(Comparable, T)) { |
|---|
| 247 |
int opCmp(const ref typeof(this) rhs) const { |
|---|
| 248 |
foreach(ti, elem; this.tupleof) { |
|---|
| 249 |
if(rhs.tupleof[ti] < elem) { |
|---|
| 250 |
return -1; |
|---|
| 251 |
} else if(rhs.tupleof[ti] > elem) { |
|---|
| 252 |
return 1; |
|---|
| 253 |
} |
|---|
| 254 |
} |
|---|
| 255 |
return 0; |
|---|
| 256 |
} |
|---|
| 257 |
} |
|---|
| 258 |
} |
|---|
| 259 |
|
|---|
| 260 |
// Whether we can use StackTreeAA, or whether we have to use a regular AA for |
|---|
| 261 |
// entropy. |
|---|
| 262 |
private template NeedsHeap(T) { |
|---|
| 263 |
static if(!isReferenceType!(IterType!(T))) { |
|---|
| 264 |
enum bool NeedsHeap = false; |
|---|
| 265 |
} else static if(isArray!(T)) { |
|---|
| 266 |
enum bool NeedsHeap = false; |
|---|
| 267 |
} else static if(is(Joint!(typeof(T.init.tupleof))) |
|---|
| 268 |
&& is(T == Joint!(typeof(T.init.tupleof))) |
|---|
| 269 |
&& allSatisfy!(isArray, typeof(T.init.tupleof))) { |
|---|
| 270 |
enum bool NeedsHeap = false; |
|---|
| 271 |
} else { |
|---|
| 272 |
enum bool NeedsHeap = true; |
|---|
| 273 |
} |
|---|
| 274 |
} |
|---|
| 275 |
|
|---|
| 276 |
unittest { |
|---|
| 277 |
auto foo = filter!"a"(cast(uint[][]) [[1]]); |
|---|
| 278 |
auto bar = filter!("a")([1,2,3][]); |
|---|
| 279 |
static assert(NeedsHeap!(typeof(foo))); |
|---|
| 280 |
static assert(!NeedsHeap!(typeof(bar))); |
|---|
| 281 |
static assert(NeedsHeap!(Joint!(uint[], typeof(foo)))); |
|---|
| 282 |
static assert(!NeedsHeap!(Joint!(uint[], typeof(bar)))); |
|---|
| 283 |
static assert(!NeedsHeap!(Joint!(uint[], uint[]))); |
|---|
| 284 |
} |
|---|
| 285 |
|
|---|
| 286 |
/**Calculates the joint entropy of a set of observations. Each input range |
|---|
| 287 |
* represents a vector of observations. If only one range is given, this reduces |
|---|
| 288 |
* to the plain old entropy. Input range must have a length. |
|---|
| 289 |
* |
|---|
| 290 |
* Note: This function specializes if ElementType!(T) is a byte, ubyte, or |
|---|
| 291 |
* char, resulting in a much faster entropy calculation. When possible, try |
|---|
| 292 |
* to provide data in the form of a byte, ubyte, or char. |
|---|
| 293 |
* |
|---|
| 294 |
* Examples: |
|---|
| 295 |
* --- |
|---|
| 296 |
* int[] foo = [1, 1, 1, 2, 2, 2, 3, 3, 3]; |
|---|
| 297 |
* double entropyFoo = entropy(foo); // Plain old entropy of foo. |
|---|
| 298 |
* assert(approxEqual(entropyFoo, log2(3))); |
|---|
| 299 |
* int[] bar = [1, 2, 3, 1, 2, 3, 1, 2, 3]; |
|---|
| 300 |
* double HFooBar = entropy(joint(foo, bar)); // Joint entropy of foo and bar. |
|---|
| 301 |
* assert(approxEqual(HFooBar, log2(9))); |
|---|
| 302 |
* --- |
|---|
| 303 |
*/ |
|---|
| 304 |
double entropy(T)(T data) |
|---|
| 305 |
if(isIterable!(T)) { |
|---|
| 306 |
static if(!dstats.base.hasLength!(T)) { |
|---|
| 307 |
return entropyImpl!(uint, T)(data); |
|---|
| 308 |
} else { |
|---|
| 309 |
if(data.length <= ubyte.max) { |
|---|
| 310 |
return entropyImpl!(ubyte, T)(data); |
|---|
| 311 |
} else if(data.length <= ushort.max) { |
|---|
| 312 |
return entropyImpl!(ushort, T)(data); |
|---|
| 313 |
} else { |
|---|
| 314 |
return entropyImpl!(uint, T)(data); |
|---|
| 315 |
} |
|---|
| 316 |
} |
|---|
| 317 |
} |
|---|
| 318 |
|
|---|
| 319 |
private double entropyImpl(U, T)(T data) |
|---|
| 320 |
if((IterType!(T).sizeof > 1 || is(IterType!T == struct)) && !NeedsHeap!(T)) { |
|---|
| 321 |
// Generic version. |
|---|
| 322 |
mixin(newFrame); |
|---|
| 323 |
alias IterType!(T) E; |
|---|
| 324 |
|
|---|
| 325 |
static if(dstats.base.hasLength!T) { |
|---|
| 326 |
auto counts = StackHash!(E, U)(max(20, data.length / 20)); |
|---|
| 327 |
} else { |
|---|
| 328 |
auto counts = StackTreeAA!(E, U)(); |
|---|
| 329 |
} |
|---|
| 330 |
uint N; |
|---|
| 331 |
|
|---|
| 332 |
foreach(elem; data) { |
|---|
| 333 |
counts[elem]++; |
|---|
| 334 |
N++; |
|---|
| 335 |
} |
|---|
| 336 |
|
|---|
| 337 |
double ans = entropyCounts(counts.values, N); |
|---|
| 338 |
return ans; |
|---|
| 339 |
} |
|---|
| 340 |
|
|---|
| 341 |
private double entropyImpl(U, T)(T data) |
|---|
| 342 |
if(IterType!(T).sizeof > 1 && NeedsHeap!(T)) { // Generic version. |
|---|
| 343 |
alias IterType!(T) E; |
|---|
| 344 |
|
|---|
| 345 |
uint len = 0; |
|---|
| 346 |
U[E] counts; |
|---|
| 347 |
foreach(elem; data) { |
|---|
| 348 |
len++; |
|---|
| 349 |
counts[elem]++; |
|---|
| 350 |
} |
|---|
| 351 |
return entropyCounts(counts, len); |
|---|
| 352 |
} |
|---|
| 353 |
|
|---|
| 354 |
private double entropyImpl(U, T)(T data) // byte/char specialization |
|---|
| 355 |
if(IterType!(T).sizeof == 1 && !is(IterType!T == struct)) { |
|---|
| 356 |
alias IterType!(T) E; |
|---|
| 357 |
|
|---|
| 358 |
U[ubyte.max + 1] counts; |
|---|
| 359 |
|
|---|
| 360 |
uint min = ubyte.max, max = 0, len = 0; |
|---|
| 361 |
foreach(elem; data) { |
|---|
| 362 |
len++; |
|---|
| 363 |
static if(is(E == byte)) { |
|---|
| 364 |
// Keep adjacent elements adjacent. In real world use cases, |
|---|
| 365 |
// probably will have ranges like [-1, 1]. |
|---|
| 366 |
ubyte e = cast(ubyte) (cast(ubyte) (elem) + byte.max); |
|---|
| 367 |
} else { |
|---|
| 368 |
ubyte e = cast(ubyte) elem; |
|---|
| 369 |
} |
|---|
| 370 |
counts[e]++; |
|---|
| 371 |
if(e > max) { |
|---|
| 372 |
max = e; |
|---|
| 373 |
} |
|---|
| 374 |
if(e < min) { |
|---|
| 375 |
min = e; |
|---|
| 376 |
} |
|---|
| 377 |
} |
|---|
| 378 |
|
|---|
| 379 |
return entropyCounts(counts.ptr[min..max + 1], len); |
|---|
| 380 |
} |
|---|
| 381 |
|
|---|
| 382 |
unittest { |
|---|
| 383 |
{ // Generic version. |
|---|
| 384 |
int[] foo = [1, 1, 1, 2, 2, 2, 3, 3, 3]; |
|---|
| 385 |
double entropyFoo = entropy(foo); |
|---|
| 386 |
assert(approxEqual(entropyFoo, log2(3))); |
|---|
| 387 |
int[] bar = [1, 2, 3, 1, 2, 3, 1, 2, 3]; |
|---|
| 388 |
auto stuff = joint(foo, bar); |
|---|
| 389 |
double jointEntropyFooBar = entropy(joint(foo, bar)); |
|---|
| 390 |
assert(approxEqual(jointEntropyFooBar, log2(9))); |
|---|
| 391 |
} |
|---|
| 392 |
{ // byte specialization |
|---|
| 393 |
byte[] foo = [-1, -1, -1, 2, 2, 2, 3, 3, 3]; |
|---|
| 394 |
double entropyFoo = entropy(foo); |
|---|
| 395 |
assert(approxEqual(entropyFoo, log2(3))); |
|---|
| 396 |
string bar = "ACTGGCTA"; |
|---|
| 397 |
assert(entropy(bar) == 2); |
|---|
| 398 |
} |
|---|
| 399 |
{ // NeedsHeap version. |
|---|
| 400 |
string[] arr = ["1", "1", "1", "2", "2", "2", "3", "3", "3"]; |
|---|
| 401 |
auto m = map!("a")(arr); |
|---|
| 402 |
assert(approxEqual(entropy(m), log2(3))); |
|---|
| 403 |
} |
|---|
| 404 |
} |
|---|
| 405 |
|
|---|
| 406 |
/**Calculate the conditional entropy H(data | cond).*/ |
|---|
| 407 |
double condEntropy(T, U)(T data, U cond) |
|---|
| 408 |
if(isForwardRange!(T) && isForwardRange!(U)) { |
|---|
| 409 |
return entropy(joint(data, cond)) - entropy(cond); |
|---|
| 410 |
} |
|---|
| 411 |
|
|---|
| 412 |
unittest { |
|---|
| 413 |
// This shouldn't be easy to screw up. Just really basic. |
|---|
| 414 |
int[] foo = [1,2,2,1,1]; |
|---|
| 415 |
int[] bar = [1,2,3,1,2]; |
|---|
| 416 |
assert(approxEqual(entropy(foo) - condEntropy(foo, bar), |
|---|
| 417 |
mutualInfo(foo, bar))); |
|---|
| 418 |
} |
|---|
| 419 |
|
|---|
| 420 |
private double miContingency(double observed, double expected) { |
|---|
| 421 |
return (observed == 0) ? 0 : |
|---|
| 422 |
(observed * log2(observed / expected)); |
|---|
| 423 |
} |
|---|
| 424 |
|
|---|
| 425 |
|
|---|
| 426 |
/**Calculates the mutual information of two vectors of discrete observations. |
|---|
| 427 |
*/ |
|---|
| 428 |
double mutualInfo(T, U)(T x, U y) |
|---|
| 429 |
if(isInputRange!(T) && isInputRange!(U)) { |
|---|
| 430 |
uint xFreedom, yFreedom, n; |
|---|
| 431 |
typeof(return) ret; |
|---|
| 432 |
|
|---|
| 433 |
static if(!dstats.base.hasLength!T && !dstats.base.hasLength!U) { |
|---|
| 434 |
ret = toContingencyScore!(T, U, uint) |
|---|
| 435 |
(x, y, &miContingency, xFreedom, yFreedom, n); |
|---|
| 436 |
} else { |
|---|
| 437 |
immutable minLen = min(x.length, y.length); |
|---|
| 438 |
if(minLen <= ubyte.max) { |
|---|
| 439 |
ret = toContingencyScore!(T, U, ubyte) |
|---|
| 440 |
(x, y, &miContingency, xFreedom, yFreedom, n); |
|---|
| 441 |
} else if(minLen <= ushort.max) { |
|---|
| 442 |
ret = toContingencyScore!(T, U, ushort) |
|---|
| 443 |
(x, y, &miContingency, xFreedom, yFreedom, n); |
|---|
| 444 |
} else { |
|---|
| 445 |
ret = toContingencyScore!(T, U, uint) |
|---|
| 446 |
(x, y, &miContingency, xFreedom, yFreedom, n); |
|---|
| 447 |
} |
|---|
| 448 |
} |
|---|
| 449 |
|
|---|
| 450 |
return ret / n; |
|---|
| 451 |
} |
|---|
| 452 |
|
|---|
| 453 |
unittest { |
|---|
| 454 |
// Values from R, but converted from base e to base 2. |
|---|
| 455 |
assert(approxEqual(mutualInfo(bin([1,2,3,3,8].dup, 10), |
|---|
| 456 |
bin([8,6,7,5,3].dup, 10)), 1.921928)); |
|---|
| 457 |
assert(approxEqual(mutualInfo(bin([1,2,1,1,3,4,3,6].dup, 2), |
|---|
| 458 |
bin([2,7,9,6,3,1,7,40].dup, 2)), .2935645)); |
|---|
| 459 |
assert(approxEqual(mutualInfo(bin([1,2,1,1,3,4,3,6].dup, 4), |
|---|
| 460 |
bin([2,7,9,6,3,1,7,40].dup, 4)), .5435671)); |
|---|
| 461 |
|
|---|
| 462 |
} |
|---|
| 463 |
|
|---|
| 464 |
/** |
|---|
| 465 |
Calculates the mutual information of a contingency table representing a joint |
|---|
| 466 |
discrete probability distribution. Takes a set of finite forward ranges, |
|---|
| 467 |
one for each column in the contingency table. These can be expressed either as |
|---|
| 468 |
a tuple of ranges or a range of ranges. |
|---|
| 469 |
*/ |
|---|
| 470 |
double mutualInfoTable(T...)(T table) { |
|---|
| 471 |
// This function is really just included to give conceptual unity to |
|---|
| 472 |
// the infotheory module. |
|---|
| 473 |
return gTestContingency(table).mutualInfo; |
|---|
| 474 |
} |
|---|
| 475 |
|
|---|
| 476 |
/** |
|---|
| 477 |
Calculates the conditional mutual information I(x, y | z) from a set of |
|---|
| 478 |
observations. |
|---|
| 479 |
*/ |
|---|
| 480 |
double condMutualInfo(T, U, V)(T x, U y, V z) { |
|---|
| 481 |
auto ret = entropy(joint(x, z)) - entropy(joint(x, y, z)) - entropy(z) |
|---|
| 482 |
+ entropy(joint(y, z)); |
|---|
| 483 |
return max(ret, 0); |
|---|
| 484 |
} |
|---|
| 485 |
|
|---|
| 486 |
unittest { |
|---|
| 487 |
// Values from Matlab mi package by Hanchuan Peng. |
|---|
| 488 |
auto res = condMutualInfo([1,2,1,2,1,2,1,2].dup, [3,1,2,3,4,2,1,2].dup, |
|---|
| 489 |
[1,2,3,1,2,3,1,2].dup); |
|---|
| 490 |
assert(approxEqual(res, 0.4387)); |
|---|
| 491 |
res = condMutualInfo([1,2,3,1,2].dup, [2,1,3,2,1].dup, |
|---|
| 492 |
joint([1,1,1,2,2].dup, [2,2,2,1,1].dup)); |
|---|
| 493 |
assert(approxEqual(res, 1.3510)); |
|---|
| 494 |
} |
|---|
| 495 |
|
|---|
| 496 |
/**Calculates the entropy of any old input range of observations more quickly |
|---|
| 497 |
* than entropy(), provided that all equal values are adjacent. If the input |
|---|
| 498 |
* is sorted by more than one key, i.e. structs, the result will be the joint |
|---|
| 499 |
* entropy of all of the keys. The compFun alias will be used to compare |
|---|
| 500 |
* adjacent elements and determine how many instances of each value exist.*/ |
|---|
| 501 |
double entropySorted(alias compFun = "a == b", T)(T data) |
|---|
| 502 |
if(isInputRange!(T)) { |
|---|
| 503 |
alias ElementType!(T) E; |
|---|
| 504 |
alias binaryFun!(compFun) comp; |
|---|
| 505 |
immutable n = data.length; |
|---|
| 506 |
immutable nrNeg1 = 1.0L / n; |
|---|
| 507 |
|
|---|
| 508 |
double sum = 0.0; |
|---|
| 509 |
int nSame = 1; |
|---|
| 510 |
auto last = data.front; |
|---|
| 511 |
data.popFront; |
|---|
| 512 |
foreach(elem; data) { |
|---|
| 513 |
if(comp(elem, last)) { |
|---|
| 514 |
nSame++; |
|---|
| 515 |
} else { |
|---|
| 516 |
immutable p = nSame * nrNeg1; |
|---|
| 517 |
nSame = 1; |
|---|
| 518 |
sum -= p * log2(p); |
|---|
| 519 |
} |
|---|
| 520 |
last = elem; |
|---|
| 521 |
} |
|---|
| 522 |
// Handle last run. |
|---|
| 523 |
immutable p = nSame * nrNeg1; |
|---|
| 524 |
sum -= p * log2(p); |
|---|
| 525 |
|
|---|
| 526 |
return sum; |
|---|
| 527 |
} |
|---|
| 528 |
|
|---|
| 529 |
unittest { |
|---|
| 530 |
uint[] foo = [1U,2,3,1,3,2,6,3,1,6,3,2,2,1,3,5,2,1].dup; |
|---|
| 531 |
auto sorted = foo.dup; |
|---|
| 532 |
sort(sorted); |
|---|
| 533 |
assert(approxEqual(entropySorted(sorted), entropy(foo))); |
|---|
| 534 |
} |
|---|
| 535 |
|
|---|
| 536 |
/** |
|---|
| 537 |
Much faster implementations of information theory functions for the special |
|---|
| 538 |
but common case where all observations are integers on the range [0, nBin). |
|---|
| 539 |
This is the case, for example, when the observations have been previously |
|---|
| 540 |
binned using, for example, dstats.base.frqBin(). |
|---|
| 541 |
|
|---|
| 542 |
Note that, due to the optimizations used, joint() cannot be used with |
|---|
| 543 |
the member functions of this struct, except entropy(). |
|---|
| 544 |
|
|---|
| 545 |
For those looking for hard numbers, this seems to be on the order of 10x |
|---|
| 546 |
faster than the generic implementations according to my quick and dirty |
|---|
| 547 |
benchmarks. |
|---|
| 548 |
*/ |
|---|
| 549 |
struct DenseInfoTheory { |
|---|
| 550 |
private uint nBin; |
|---|
| 551 |
|
|---|
| 552 |
// Saves space and makes things cache efficient by using the smallest |
|---|
| 553 |
// integer width necessary for binning. |
|---|
| 554 |
double selectSize(alias fun, T...)(T args) { |
|---|
| 555 |
static if(allSatisfy!(dstats.base.hasLength, T)) { |
|---|
| 556 |
immutable len = args[0].length; |
|---|
| 557 |
|
|---|
| 558 |
if(len <= ubyte.max) { |
|---|
| 559 |
return fun!ubyte(args); |
|---|
| 560 |
} else if(len <= ushort.max) { |
|---|
| 561 |
return fun!ushort(args); |
|---|
| 562 |
} else { |
|---|
| 563 |
return fun!uint(args); |
|---|
| 564 |
} |
|---|
| 565 |
|
|---|
| 566 |
// For now, assume that noone is going to have more than |
|---|
| 567 |
// 4 billion observations. |
|---|
| 568 |
} else { |
|---|
| 569 |
return fun!uint(args); |
|---|
| 570 |
} |
|---|
| 571 |
} |
|---|
| 572 |
|
|---|
| 573 |
/** |
|---|
| 574 |
Constructs a DenseInfoTheory object for nBin bins. The values taken by |
|---|
| 575 |
each observation must then be on the interval [0, nBin). |
|---|
| 576 |
*/ |
|---|
| 577 |
this(uint nBin) { |
|---|
| 578 |
this.nBin = nBin; |
|---|
| 579 |
} |
|---|
| 580 |
|
|---|
| 581 |
/** |
|---|
| 582 |
Computes the entropy of a set of observations. Note that, for this |
|---|
| 583 |
function, the joint() function can be used to compute joint entropies |
|---|
| 584 |
as long as each individual range contains only integers on [0, nBin). |
|---|
| 585 |
*/ |
|---|
| 586 |
double entropy(R)(R range) if(isIterable!R) { |
|---|
| 587 |
return selectSize!entropyImpl(range); |
|---|
| 588 |
} |
|---|
| 589 |
|
|---|
| 590 |
private double entropyImpl(Uint, R)(R range) { |
|---|
| 591 |
mixin(newFrame); |
|---|
| 592 |
uint n = 0; |
|---|
| 593 |
|
|---|
| 594 |
static if(is(typeof(range._jointRanges))) { |
|---|
| 595 |
// Compute joint entropy. |
|---|
| 596 |
immutable nRanges = range._jointRanges.length; |
|---|
| 597 |
auto counts = newStack!Uint(nBin ^^ nRanges); |
|---|
| 598 |
counts[] = 0; |
|---|
| 599 |
|
|---|
| 600 |
Outer: |
|---|
| 601 |
while(true) { |
|---|
| 602 |
uint multiplier = 1; |
|---|
| 603 |
uint index = 0; |
|---|
| 604 |
|
|---|
| 605 |
foreach(ti, Unused; typeof(range._jointRanges)) { |
|---|
| 606 |
if(range._jointRanges[ti].empty) break Outer; |
|---|
| 607 |
immutable rFront = range._jointRanges[ti].front; |
|---|
| 608 |
assert(rFront < nBin); // Enforce is too costly here. |
|---|
| 609 |
|
|---|
| 610 |
index += multiplier * cast(uint) rFront; |
|---|
| 611 |
range._jointRanges[ti].popFront(); |
|---|
| 612 |
multiplier *= nBin; |
|---|
| 613 |
} |
|---|
| 614 |
|
|---|
| 615 |
counts[index]++; |
|---|
| 616 |
n++; |
|---|
| 617 |
} |
|---|
| 618 |
|
|---|
| 619 |
return entropyCounts(counts, n); |
|---|
| 620 |
} else { |
|---|
| 621 |
auto counts = newStack!Uint(nBin); |
|---|
| 622 |
|
|---|
| 623 |
counts[] = 0; |
|---|
| 624 |
foreach(elem; range) { |
|---|
| 625 |
counts[elem]++; |
|---|
| 626 |
n++; |
|---|
| 627 |
} |
|---|
| 628 |
|
|---|
| 629 |
return entropyCounts(counts, n); |
|---|
| 630 |
} |
|---|
| 631 |
} |
|---|
| 632 |
|
|---|
| 633 |
/// I(x; y) |
|---|
| 634 |
double mutualInfo(R1, R2)(R1 x, R2 y) |
|---|
| 635 |
if(isIterable!R1 && isIterable!R2) { |
|---|
| 636 |
return selectSize!mutualInfoImpl(x, y); |
|---|
| 637 |
} |
|---|
| 638 |
|
|---|
| 639 |
private double mutualInfoImpl(Uint, R1, R2)(R1 x, R2 y) { |
|---|
| 640 |
mixin(newFrame); |
|---|
| 641 |
auto joint = newStack!Uint(nBin * nBin); |
|---|
| 642 |
auto margx = newStack!Uint(nBin); |
|---|
| 643 |
auto margy = newStack!Uint(nBin); |
|---|
| 644 |
joint[] = 0; |
|---|
| 645 |
margx[] = 0; |
|---|
| 646 |
margy[] = 0; |
|---|
| 647 |
uint n; |
|---|
| 648 |
|
|---|
| 649 |
while(!x.empty && !y.empty) { |
|---|
| 650 |
immutable xFront = cast(uint) x.front; |
|---|
| 651 |
immutable yFront = cast(uint) y.front; |
|---|
| 652 |
assert(xFront < nBin); |
|---|
| 653 |
assert(yFront < nBin); |
|---|
| 654 |
|
|---|
| 655 |
joint[xFront * nBin + yFront]++; |
|---|
| 656 |
margx[xFront]++; |
|---|
| 657 |
margy[yFront]++; |
|---|
| 658 |
n++; |
|---|
| 659 |
x.popFront(); |
|---|
| 660 |
y.popFront(); |
|---|
| 661 |
} |
|---|
| 662 |
|
|---|
| 663 |
auto ret = entropyCounts(margx, n) + entropyCounts(margy, n) - |
|---|
| 664 |
entropyCounts(joint, n); |
|---|
| 665 |
return max(0, ret); |
|---|
| 666 |
} |
|---|
| 667 |
|
|---|
| 668 |
/** |
|---|
| 669 |
Calculates the P-value for I(X; Y) assuming x and y both have supports |
|---|
| 670 |
of [0, nBin). The P-value is calculated using a Chi-Square approximation. |
|---|
| 671 |
It is asymptotically correct, but is approximate for finite sample size. |
|---|
| 672 |
|
|---|
| 673 |
Parameters: |
|---|
| 674 |
mutualInfo: I(x; y), in bits |
|---|
| 675 |
n: The number of samples used to calculate I(x; y) |
|---|
| 676 |
*/ |
|---|
| 677 |
double mutualInfoPval(double mutualInfo, double n) { |
|---|
| 678 |
immutable df = (nBin - 1) ^^ 2; |
|---|
| 679 |
|
|---|
| 680 |
immutable testStat = mutualInfo * 2 * LN2 * n; |
|---|
| 681 |
return chiSquareCDFR(testStat, df); |
|---|
| 682 |
} |
|---|
| 683 |
|
|---|
| 684 |
/// H(X | Y) |
|---|
| 685 |
double condEntropy(R1, R2)(R1 x, R2 y) |
|---|
| 686 |
if(isIterable!R1 && isIterable!R2) { |
|---|
| 687 |
return selectSize!condEntropyImpl(x, y); |
|---|
| 688 |
} |
|---|
| 689 |
|
|---|
| 690 |
private double condEntropyImpl(Uint, R1, R2)(R1 x, R2 y) { |
|---|
| 691 |
mixin(newFrame); |
|---|
| 692 |
auto joint = newStack!Uint(nBin * nBin); |
|---|
| 693 |
auto margy = newStack!Uint(nBin); |
|---|
| 694 |
joint[] = 0; |
|---|
| 695 |
margy[] = 0; |
|---|
| 696 |
uint n; |
|---|
| 697 |
|
|---|
| 698 |
while(!x.empty && !y.empty) { |
|---|
| 699 |
immutable xFront = cast(uint) x.front; |
|---|
| 700 |
immutable yFront = cast(uint) y.front; |
|---|
| 701 |
assert(xFront < nBin); |
|---|
| 702 |
assert(yFront < nBin); |
|---|
| 703 |
|
|---|
| 704 |
joint[xFront * nBin + yFront]++; |
|---|
| 705 |
margy[yFront]++; |
|---|
| 706 |
n++; |
|---|
| 707 |
x.popFront(); |
|---|
| 708 |
y.popFront(); |
|---|
| 709 |
} |
|---|
| 710 |
|
|---|
| 711 |
auto ret = entropyCounts(joint, n) - entropyCounts(margy, n); |
|---|
| 712 |
return max(0, ret); |
|---|
| 713 |
} |
|---|
| 714 |
|
|---|
| 715 |
/// I(X; Y | Z) |
|---|
| 716 |
double condMutualInfo(R1, R2, R3)(R1 x, R2 y, R3 z) |
|---|
| 717 |
if(allSatisfy!(isIterable, R1, R2, R3)) { |
|---|
| 718 |
return selectSize!condMutualInfoImpl(x, y, z); |
|---|
| 719 |
} |
|---|
| 720 |
|
|---|
| 721 |
private double condMutualInfoImpl(Uint, R1, R2, R3)(R1 x, R2 y, R3 z) { |
|---|
| 722 |
mixin(newFrame); |
|---|
| 723 |
immutable nBinSq = nBin * nBin; |
|---|
| 724 |
auto jointxyz = newStack!Uint(nBin * nBin * nBin); |
|---|
| 725 |
auto jointxz = newStack!Uint(nBinSq); |
|---|
| 726 |
auto jointyz = newStack!Uint(nBinSq); |
|---|
| 727 |
auto margz = newStack!Uint(nBin); |
|---|
| 728 |
jointxyz[] = 0; |
|---|
| 729 |
jointxz[] = 0; |
|---|
| 730 |
jointyz[] = 0; |
|---|
| 731 |
margz[] = 0; |
|---|
| 732 |
uint n = 0; |
|---|
| 733 |
|
|---|
| 734 |
while(!x.empty && !y.empty && !z.empty) { |
|---|
| 735 |
immutable xFront = cast(uint) x.front; |
|---|
| 736 |
immutable yFront = cast(uint) y.front; |
|---|
| 737 |
immutable zFront = cast(uint) z.front; |
|---|
| 738 |
assert(xFront < nBin); |
|---|
| 739 |
assert(yFront < nBin); |
|---|
| 740 |
assert(zFront < nBin); |
|---|
| 741 |
|
|---|
| 742 |
jointxyz[xFront * nBinSq + yFront * nBin + zFront]++; |
|---|
| 743 |
jointxz[xFront * nBin + zFront]++; |
|---|
| 744 |
jointyz[yFront * nBin + zFront]++; |
|---|
| 745 |
margz[zFront]++; |
|---|
| 746 |
n++; |
|---|
| 747 |
|
|---|
| 748 |
x.popFront(); |
|---|
| 749 |
y.popFront(); |
|---|
| 750 |
z.popFront(); |
|---|
| 751 |
} |
|---|
| 752 |
|
|---|
| 753 |
auto ret = entropyCounts(jointxz, n) - entropyCounts(jointxyz, n) - |
|---|
| 754 |
entropyCounts(margz, n) + entropyCounts(jointyz, n); |
|---|
| 755 |
return max(0, ret); |
|---|
| 756 |
} |
|---|
| 757 |
} |
|---|
| 758 |
|
|---|
| 759 |
unittest { |
|---|
| 760 |
auto dense = DenseInfoTheory(3); |
|---|
| 761 |
auto a = [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]; |
|---|
| 762 |
auto b = [1, 2, 2, 2, 0, 0, 1, 1, 1, 1, 0, 0]; |
|---|
| 763 |
auto c = [1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 0, 0]; |
|---|
| 764 |
|
|---|
| 765 |
assert(entropy(a) == dense.entropy(a)); |
|---|
| 766 |
assert(entropy(b) == dense.entropy(b)); |
|---|
| 767 |
assert(entropy(c) == dense.entropy(c)); |
|---|
| 768 |
assert(entropy(joint(a, c)) == dense.entropy(joint(c, a))); |
|---|
| 769 |
assert(entropy(joint(a, b)) == dense.entropy(joint(a, b))); |
|---|
| 770 |
assert(entropy(joint(c, b)) == dense.entropy(joint(c, b))); |
|---|
| 771 |
|
|---|
| 772 |
assert(condEntropy(a, c) == dense.condEntropy(a, c)); |
|---|
| 773 |
assert(condEntropy(a, b) == dense.condEntropy(a, b)); |
|---|
| 774 |
assert(condEntropy(c, b) == dense.condEntropy(c, b)); |
|---|
| 775 |
|
|---|
| 776 |
alias approxEqual ae; |
|---|
| 777 |
assert(ae(mutualInfo(a, c), dense.mutualInfo(c, a))); |
|---|
| 778 |
assert(ae(mutualInfo(a, b), dense.mutualInfo(a, b))); |
|---|
| 779 |
assert(ae(mutualInfo(c, b), dense.mutualInfo(c, b))); |
|---|
| 780 |
|
|---|
| 781 |
assert(ae(condMutualInfo(a, b, c), dense.condMutualInfo(a, b, c))); |
|---|
| 782 |
assert(ae(condMutualInfo(a, c, b), dense.condMutualInfo(a, c, b))); |
|---|
| 783 |
assert(ae(condMutualInfo(b, c, a), dense.condMutualInfo(b, c, a))); |
|---|
| 784 |
|
|---|
| 785 |
// Test P-value stuff. |
|---|
| 786 |
immutable pDense = dense.mutualInfoPval(dense.mutualInfo(a, b), a.length); |
|---|
| 787 |
immutable pNotDense = gTestObs(a, b).p; |
|---|
| 788 |
assert(approxEqual(pDense, pNotDense)); |
|---|
| 789 |
} |
|---|
| 790 |
|
|---|
| 791 |
// Verify that there are no TempAlloc memory leaks anywhere in the code covered |
|---|
| 792 |
// by the unittest. This should always be the last unittest of the module. |
|---|
| 793 |
unittest { |
|---|
| 794 |
auto TAState = TempAlloc.getState; |
|---|
| 795 |
assert(TAState.used == 0); |
|---|
| 796 |
assert(TAState.nblocks < 2); |
|---|
| 797 |
} |
|---|