root/trunk/pca.d

Revision 261, 18.7 kB (checked in by dsimcha, 1 year ago)

--

Line 
1 /**
2 This module contains a basic implementation of principal component analysis,
3 based on the NIPALS algorithm.  This is fast when you only need the first
4 few components (which is usually the case since PCA's main uses are
5 visualization and dimensionality reduction).  However, convergence slows
6 drastically after the first few components have been removed and most of
7 the matrix is just noise.
8
9 References:
10
11 en.wikipedia.org/wiki/Principal_component_analysis#Computing_principal_components_iteratively
12
13 Author:  David Simcha
14 */
15
16 /*
17  * License:
18  * Boost Software License - Version 1.0 - August 17th, 2003
19  *
20  * Permission is hereby granted, free of charge, to any person or organization
21  * obtaining a copy of the software and accompanying documentation covered by
22  * this license (the "Software") to use, reproduce, display, distribute,
23  * execute, and transmit the Software, and to prepare derivative works of the
24  * Software, and to permit third-parties to whom the Software is furnished to
25  * do so, all subject to the following:
26  *
27  * The copyright notices in the Software and this entire statement, including
28  * the above license grant, this restriction and the following disclaimer,
29  * must be included in all copies of the Software, in whole or in part, and
30  * all derivative works of the Software, unless such copies or derivative
31  * works are solely in the form of machine-executable object code generated by
32  * a source language processor.
33  *
34  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
35  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
36  * FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
37  * SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
38  * FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
39  * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
40  * DEALINGS IN THE SOFTWARE.
41  */
42 module dstats.pca;
43
44 import std.range, dstats.base, dstats.alloc, std.numeric, std.stdio, std.math,
45     std.algorithm, std.array, dstats.summary, dstats.random, std.conv,
46     std.exception, dstats.regress, std.traits;
47
48 /// Result holder
49 struct PrincipalComponent {
50     /// The projection of the data onto the first principal component.
51     double[] x;
52
53     /// The vector representing the first principal component loadings.
54     double[] rotation;
55 }
56
57 /**
58 Sets options for principal component analysis.  The default options are
59 also the values in PrinCompOptions.init.
60 */
61 struct PrinCompOptions {
62     ///  Center each column to zero mean.  Default value:  true.
63     bool zeroMean = true;
64
65     /**
66     Scale each column to unit variance.  Note that, if this option is set to
67     true, zeroMean is ignored and the mean of each column is set to zero even
68     if zeroMean is false.  Default value:  false.
69     */
70     bool unitVariance = false;
71
72     /**
73     Overwrite input matrix instead of copying.  Ignored if the matrix
74     passed in does not have assignable, lvalue elements and centering or
75     scaling is enabled.  Default value:  false.
76     */
77     bool destructive = false;
78
79     /**
80     Effectively transpose the matrix.  If enabled, treat each column as a
81     data points and each row as a dimension.  If disabled, do the opposite.
82     Note that, if this is enabled, each row will be scaled and centered,
83     not each column.  Default value:  false.
84     */
85     bool transpose = false;
86
87     /**
88     Relative error at which to stop the optimization procedure.  Default: 1e-4
89     */
90     double relError = 1.0e-4;
91
92     /**
93     Absolute error at which to stop the optimization procedure.  Default:  1e-5
94     */
95     double absError = 1.0e-5;
96
97     /**
98     Maximum iterations for the optimization procedure.  After this many
99     iterations, the algorithm gives up and calls teh solution "good enough"
100     no matter what.  For exploratory analyses, "good enough" solutions
101     can be had fast sometimes by making this value small.  Default:  uint.max
102     */
103     uint maxIter = uint.max;
104
105     private void doCenterScaleTransposed(R)(R data) {
106         foreach(row; data.save) {
107             immutable msd = meanStdev(row.save);
108
109             foreach(ref elem; row) {
110                 // Already checked for whether we're supposed to be normalizing.
111                 elem -= msd.mean;
112                 if(unitVariance) elem /= msd.stdev;
113             }
114         }
115     }
116
117     private void doCenterScale(R)(R data) {
118         if(!zeroMean && !unitVariance) return;
119         if(data.empty) {
120             return;
121         }
122
123         if(transpose) return doCenterScaleTransposed(data);
124
125         mixin(newFrame);
126         immutable rowLen = walkLength(data.front.save);
127
128         auto summs = newStack!MeanSD(rowLen);
129         summs[] = MeanSD.init;
130         foreach(row; data) {
131             size_t i = 0;
132             foreach(elem; row) {
133                 enforce(i < rowLen, "Matrix must be rectangular for PCA.");
134                 summs[i++].put(elem);
135             }
136
137             enforce(i == rowLen, "Matrix must be rectangular for PCA.");
138         }
139
140         foreach(row; data) {
141             size_t i = 0;
142             foreach(ref elem; row) {
143                 elem -= summs[i].mean;
144                 if(unitVariance) elem /= summs[i].stdev;
145                 i++;
146             }
147         }
148     }
149 }
150
151
152 /**
153 Uses expectation-maximization to compute the first principal component of mat.
154 Since there are a lot of options, they are controlled by a PrinCompOptions
155 struct.  (See above.  PrinCompOptions.init contains the default values.)
156 To have the results returned in a pre-allocated space, pass an explicit value
157 for buf.
158 */
159 PrincipalComponent firstComponent(Ror)(
160     Ror data,
161     PrinCompOptions opts = PrinCompOptions.init,
162     PrincipalComponent buf = PrincipalComponent.init
163 ) {
164     mixin(newFrame);
165
166     PrincipalComponent doNonDestructive() {
167         double[][] dataFixed;
168
169         if(opts.transpose) {
170             dataFixed = transposeDup(data);
171         } else {
172             dataFixed = tempdup(map!doubleTempdup(data));
173         }
174
175         opts.transpose = false;  // We already transposed if necessary.
176         opts.doCenterScale(dataFixed);
177         return firstComponentImpl(dataFixed, buf, opts);
178     }
179
180     static if(!hasLvalueElements!(ElementType!Ror) ||
181     !hasAssignableElements!(ElementType!Ror)) {
182         if(opts.zeroMean || opts.unitVariance) {
183             return doNonDestructive();
184         } else {
185             return firstComponentImpl(data, buf, opts);
186         }
187     } else {
188         if(!opts.destructive) {
189             return doNonDestructive;
190         }
191
192         opts.doCenterScale(data);
193         return firstComponentImpl(data, buf, opts);
194     }
195 }
196
197 private PrincipalComponent firstComponentImpl(Ror)(
198     Ror data,
199     PrincipalComponent buf,
200     PrinCompOptions opts
201 ) {
202     mixin(newFrame);
203
204     if(data.empty) return typeof(return).init;
205     size_t rowLen = walkLength(data.front.save);
206     size_t colLen = walkLength(data.save);
207
208     immutable transposed = opts.transpose;
209     if(transposed) swap(rowLen, colLen);
210
211     auto t = newStack!double(rowLen);
212     auto p = (buf.rotation.length >= rowLen) ?
213               buf.rotation[0..rowLen] : new double[rowLen];
214     p[] = 1;
215
216     bool approxEqualOrNotFinite(const double[] a, const double[] b) {
217         foreach(i; 0..a.length) {
218             if(!isFinite(a[i]) || !isFinite(b[i])) {
219                 return true;
220             } else if(!approxEqual(a[i], b[i], opts.relError, opts.absError)) {
221                 return false;
222             }
223         }
224
225         return true;
226     }
227
228     uint iter;
229     for(; iter < opts.maxIter; iter++) {
230         t[] = 0;
231
232         if(transposed) {
233             auto dps = newStack!double(colLen);
234             scope(exit) TempAlloc.free();
235             dps[] = 0;
236
237             size_t i = 0;
238             foreach(row; data.save) {
239                 scope(exit) i++;
240
241                 static if(is(typeof(row) : const(double)[])) {
242                     // Take advantage of array ops.
243                     dps[] += p[i] * row[];
244                 } else {
245                     size_t j = 0;
246                     foreach(elem; row) {
247                         scope(exit) j++;
248                         dps[j] += p[i] * elem;
249                     }
250                 }
251             }
252
253             i = 0;
254             foreach(row; data.save) {
255                 scope(exit) i++;
256                 t[i] += dotProduct(row, dps);
257             }
258
259         } else {
260             foreach(row; data.save) {
261                 immutable dp = dotProduct(p, row);
262                 static if( is(typeof(row) : const(double)[] )) {
263                     // Use array op optimization if possible.
264                     t[] += row[] * dp;
265                 } else {
266                     size_t i = 0;
267                     foreach(elem; row.save) {
268                         t[i++] += elem * dp;
269                     }
270                 }
271             }
272         }
273
274         immutable tMagnitude = magnitude(t);
275         t[] /= tMagnitude;
276
277         if(approxEqualOrNotFinite(t, p)) {
278             p[] = t[];
279             break;
280         }
281
282         p[] = t[];
283     }
284
285     auto x = (buf.x.length >= colLen) ?
286               buf.x[0..colLen] : new double[colLen];
287     size_t i = 0;
288
289     if(transposed) {
290         x[] = 0;
291
292         size_t rowIndex = 0;
293         foreach(row; data) {
294             scope(exit) rowIndex++;
295             size_t colIndex = 0;
296
297             foreach(elem; row) {
298                 scope(exit) colIndex++;
299                 x[colIndex] += p[rowIndex] * elem;
300             }
301         }
302
303     } else {
304         foreach(row; data) {
305             x[i++] = dotProduct(p, row);
306         }
307     }
308
309     return PrincipalComponent(x, p);
310 }
311
312 /// Used for removeComponent().
313 enum Transposed : bool {
314
315     ///
316     yes = true,
317
318     ///
319     no = false
320 }
321
322 /**
323 Remove the principal component specified by the given rotation vector from
324 data.  data must have assignable elements.  Transposed controls whether
325 rotation is considered a loading for the transposed matrix or the matrix
326 as-is.
327 */
328 void removeComponent(Ror, R)(
329     Ror data,
330     R rotation,
331     Transposed transposed = Transposed.no
332 ) {
333     double[2] regressBuf;
334
335     immutable rotMagNeg1 = 1.0 / magnitude(rotation.save);
336
337     if(transposed) {
338         mixin(newFrame);
339         auto dps = newStack!double(walkLength(data.front.save));
340         dps[] = 0;
341
342         auto r2 = rotation.save;
343         foreach(row; data.save) {
344             scope(exit) r2.popFront();
345
346             size_t j = 0;
347
348             foreach(elem; row) {
349                 scope(exit) j++;
350                 dps[j] += r2.front * elem;
351             }
352         }
353
354         dps[] *= rotMagNeg1;
355
356         r2 = rotation.save;
357         foreach(row; data.save) {
358             scope(exit) r2.popFront();
359
360             auto rs = row.save;
361             for(size_t j = 0; !rs.empty; rs.popFront, j++) {
362                 rs.front = rs.front - r2.front * dps[j];
363             }
364         }
365
366     } else {
367         foreach(row; data.save) {
368             immutable dotProd = dotProduct(rotation, row);
369             immutable coeff = dotProd * rotMagNeg1;
370
371             auto rs = row.save;
372             auto rots = rotation.save;
373             while(!rs.empty && !rots.empty) {
374                 scope(exit) {
375                     rs.popFront();
376                     rots.popFront();
377                 }
378
379                 rs.front = rs.front - rots.front * coeff;
380             }
381         }
382     }
383 }
384
385 /**
386 Computes the first N principal components of the matrix.  More efficient than
387 calling firstComponent and removeComponent repeatedly because copying and
388 transposing, if enabled, only happen once.
389 */
390 PrincipalComponent[] firstNComponents(Ror)(
391     Ror data,
392     uint n,
393     PrinCompOptions opts = PrinCompOptions.init,
394     PrincipalComponent[] buf = null
395 ) {
396
397     mixin(newFrame);
398
399     PrincipalComponent[] doNonDestructive() {
400         double[][] dataFixed;
401
402         if(opts.transpose) {
403             dataFixed = transposeDup(data);
404         } else {
405             dataFixed = tempdup(map!doubleTempdup(data));
406         }
407
408         opts.transpose = false;  // We already transposed if necessary.
409         opts.doCenterScale(dataFixed);
410         return firstNComponentsImpl(dataFixed, n, opts, buf);
411     }
412
413     static if(!hasLvalueElements!(ElementType!Ror) ||
414     !hasAssignableElements!(ElementType!Ror)) {
415         return doNonDestructive();
416     } else {
417         if(!opts.destructive) {
418             return doNonDestructive();
419         }
420
421         opts.doCenterScale(data);
422         return firstNComponentsImpl(data, n, opts, buf);
423     }
424 }
425
426 private PrincipalComponent[] firstNComponentsImpl(Ror)(Ror data, uint n,
427     PrinCompOptions opts, PrincipalComponent[] buf = null) {
428
429     opts.destructive = true;  // We already copied if necessary.
430     opts.unitVariance = false;  // Already did this.
431
432     buf.length = n;
433     foreach(comp; 0..n) {
434         if(comp != 0) {
435             removeComponent(data, buf[comp - 1].rotation,
436                 cast(Transposed) opts.transpose);
437         }
438
439         buf[comp] = firstComponent(data, opts, buf[comp]);
440     }
441
442     return buf;
443 }
444
445 private double magnitude(R)(R x) {
446     return sqrt(reduce!"a + b * b"(0.0, x));
447 }
448
449 // Convert the matrix to a double[][].
450 double[] doubleTempdup(R)(R range) {
451     return tempdup(map!(to!double)(range));
452 }
453
454 private double[][] transposeDup(Ror)(Ror data) {
455     if(data.empty) return null;
456
457     immutable rowLen = walkLength(data.front.save);
458     immutable colLen = walkLength(data.save);
459
460     auto ret = newStack!(double[])(rowLen);
461     foreach(ref elem; ret) elem = newStack!double(colLen);
462
463     size_t i = 0;
464     foreach(row; data) {
465         scope(exit) i++;
466         if(i == colLen) break;
467
468         size_t j = 0;
469         foreach(col; row) {
470             scope(exit) j++;
471             if(j == rowLen) break;
472             ret[j][i] = col;
473         }
474
475         dstatsEnforce(j == rowLen, "Matrices must be rectangular for PCA.");
476     }
477
478     dstatsEnforce(i == colLen, "Matrices must be rectangular for PCA.");
479     return ret;
480 }
481
482 version(unittest) {
483     // There are two equally valid answers for PCA that differ only by sign.
484     // This tests whether one of them matches the test value.
485     bool plusMinusAe(T, U)(T lhs, U rhs) {
486         return approxEqual(lhs, rhs) || approxEqual(lhs, map!"-a"(rhs));
487     }
488     void main() {}
489
490 }
491
492 unittest {
493     // Values from R's prcomp function.  Not testing the 4th component because
494     // it's mostly numerical fuzz.
495
496     static double[][] getMat() {
497         return [[3,6,2,4], [3,6,8,8], [6,7,5,3], [0,9,3,1]];
498     }
499
500     auto mat = getMat();
501     auto allComps = firstNComponents(mat, 3);
502
503     assert(plusMinusAe(allComps[0].x, [1.19, -5.11, -0.537, 4.45]));
504     assert(plusMinusAe(allComps[0].rotation, [-0.314, 0.269, -0.584, -0.698]));
505
506     assert(plusMinusAe(allComps[1].x, [0.805, -1.779, 2.882, -1.908]));
507     assert(plusMinusAe(allComps[1].rotation, [0.912, -0.180, -0.2498, -0.2713]));
508
509     assert(plusMinusAe(allComps[2].x, [2.277, -0.1055, -1.2867, -0.8849]));
510     assert(plusMinusAe(allComps[2].rotation, [-0.1578, -0.5162, -0.704, 0.461]));
511
512     auto comp1 = firstComponent(mat);
513     assert(plusMinusAe(comp1.x, allComps[0].x));
514     assert(plusMinusAe(comp1.rotation, allComps[0].rotation));
515
516     // Test transposed.
517     PrinCompOptions opts;
518     opts.transpose = true;
519     const(double)[][] m2 = mat;
520     auto allCompsT = firstNComponents(m2, 3, opts);
521
522     assert(plusMinusAe(allCompsT[0].x, [-3.2045, 6.3829695, -0.7227162, -2.455]));
523     assert(plusMinusAe(allCompsT[0].rotation, [0.3025, 0.05657, 0.25142, 0.91763]));
524
525     assert(plusMinusAe(allCompsT[1].x, [-3.46136, -0.6365, 1.75111, 2.3468]));
526     assert(plusMinusAe(allCompsT[1].rotation,
527         [-0.06269096,  0.88643747, -0.4498119, 0.08926183]));
528
529     assert(plusMinusAe(allCompsT[2].x,
530         [2.895362e-03,  3.201053e-01, -1.631345e+00,  1.308344e+00]));
531     assert(plusMinusAe(allCompsT[2].rotation,
532         [0.87140678, -0.14628160, -0.4409721, -0.15746595]));
533
534     auto comp1T = firstComponent(m2, opts);
535     assert(plusMinusAe(comp1T.x, allCompsT[0].x));
536     assert(plusMinusAe(comp1T.rotation, allCompsT[0].rotation));
537
538     // Test with scaling.
539     opts.unitVariance = true;
540     opts.transpose = false;
541     auto allCompsScale = firstNComponents(mat, 3, opts);
542     assert(plusMinusAe(allCompsScale[0].x,
543         [6.878307e-02, -1.791647e+00, -3.733826e-01,  2.096247e+00]));
544     assert(plusMinusAe(allCompsScale[0].rotation,
545         [-0.3903603,  0.5398265, -0.4767623, -0.5735014]));
546
547     assert(plusMinusAe(allCompsScale[1].x,
548         [6.804833e-01, -9.412491e-01,  9.231432e-01, -6.623774e-01]));
549     assert(plusMinusAe(allCompsScale[1].rotation,
550         [0.7355678, -0.2849885, -0.5068900, -0.3475401]));
551
552     assert(plusMinusAe(allCompsScale[2].x,
553         [9.618048e-01,  1.428492e-02, -8.120905e-01, -1.639992e-01]));
554     assert(plusMinusAe(allCompsScale[2].rotation,
555             [-0.4925027, -0.5721616, -0.5897120, 0.2869006]));
556
557     auto comp1S = firstComponent(m2, opts);
558     assert(plusMinusAe(comp1S.x, allCompsScale[0].x));
559     assert(plusMinusAe(comp1S.rotation, allCompsScale[0].rotation));
560
561     opts.transpose = true;
562     auto allTScale = firstNComponents(mat, 3, opts);
563
564     assert(plusMinusAe(allTScale[0].x,
565         [-1.419319e-01,  2.141908e+00, -8.368606e-01, -1.163116e+00]));
566     assert(plusMinusAe(allTScale[0].rotation,
567         [0.5361711, -0.2270814,  0.5685768,  0.5810981]));
568
569     assert(plusMinusAe(allTScale[1].x,
570         [-1.692899e+00,  4.929717e-01,  3.049089e-01,  8.950189e-01]));
571     assert(plusMinusAe(allTScale[1].rotation,
572         [0.3026505,  0.7906601, -0.3652524,  0.3871047]));
573
574     assert(plusMinusAe(allTScale[2].x,
575         [ 2.035977e-01,  2.705193e-02, -9.113051e-01,  6.806556e-01]));
576     assert(plusMinusAe(allTScale[2].rotation,
577             [0.7333168, -0.3396207, -0.4837054, -0.3360555]));
578
579     auto comp1ST = firstComponent(m2, opts);
580     assert(plusMinusAe(comp1ST.x, allTScale[0].x));
581     assert(plusMinusAe(comp1ST.rotation, allTScale[0].rotation));
582
583     void compAll(PrincipalComponent[] lhs, PrincipalComponent[] rhs) {
584         assert(lhs.length == rhs.length);
585         foreach(i, elem; lhs) {
586             assert(plusMinusAe(elem.x, rhs[i].x));
587             assert(plusMinusAe(elem.rotation, rhs[i].rotation));
588         }
589     }
590
591     opts.destructive = true;
592     auto allDestructive = firstNComponents(mat, 3, opts);
593     compAll(allTScale, allDestructive);
594     compAll([firstComponent(getMat(), opts)], allDestructive[0..1]);
595
596     mat = getMat();
597     opts.transpose = false;
598     allDestructive = firstNComponents(mat, 3, opts);
599     compAll(allDestructive, allCompsScale);
600     compAll([firstComponent(getMat(), opts)], allDestructive[0..1]);
601
602     mat = getMat();
603     opts.unitVariance = false;
604     allDestructive = firstNComponents(mat, 3, opts);
605     compAll(allDestructive, allComps);
606     compAll([firstComponent(getMat(), opts)], allDestructive[0..1]);
607
608     mat = getMat();
609     opts.transpose = true;
610     allDestructive = firstNComponents(mat, 3, opts);
611     compAll(allDestructive, allCompsT);
612     compAll([firstComponent(getMat(), opts)], allDestructive[0..1]);
613 }
Note: See TracBrowser for help on using the browser.