root/trunk/etc/bigint/multiply.d

Revision 12, 8.2 kB (checked in by Arcane Jill, 5 years ago)

--

Line 
1 module etc.bigint.multiply;
2 import etc.bigint.lowlevel;
3
4 /*
5
6 Copyright (c) 2004, Arcane Jill
7
8 All rights reserved. Intellectual Property Me Arse!
9
10 Redistribution and use in source and binary forms, with or without modification, are permitted
11 provided that the following conditions are met:
12
13    * Redistributions of source code must retain the above copyright notice, the phrase
14      "Intellectual Property Me Arse!", this list of conditions, and the following disclaimer.
15    * Redistributions in binary form must reproduce the above copyright notice, the phrase
16      "Intellectual Property Me Arse!", this list of conditions and the following disclaimer
17      in the documentation and/or other materials provided with the distribution.
18    * The name Arcane Jill may not be used to endorse or promote products derived from this
19      software without specific prior written permission.
20
21 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS
22 OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY
23 AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER,
24 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
25 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED, AND ON ANY
27 THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
28 OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
29 OF SUCH DAMAGE.
30
31 */
32
33 version(DisableKaratsuba)
34 {
35 }
36 else
37 {
38     const int KARATSUBA_THRESHOLD = 10;
39 }
40
41 /*  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
42 d = x * y
43 */
44
45 void bigintMul(uint* d, uint D, uint* x, uint X, uint* y, uint Y)
46 in
47 {
48     assert(D >= X + Y);
49 }
50 body
51 {
52     // Zero the destination
53     bigintLLZero(d, D);
54
55     // Do the multiply
56     bigintMulInternal(d, x, X, y, Y);
57 }
58
59 private
60 {
61     void bigintMulInternal(uint* d, uint* x, uint X, uint* y, uint Y)
62     {
63         // Get things as simple as possible
64         X = bigintLLMinimize(x, X);
65         Y = bigintLLMinimize(y, Y);
66
67         // Get the biggest of the two parameters into x
68         if (X == Y)
69         {
70             version(DisableSquare)
71             {
72                 // if squaring algorithm is disabled, do nothing
73             }
74             else
75             {
76                 if (bigintLLEquals(x, y, X))
77                 {
78                     bigintSquareInternal(d, x, X);
79                     return;
80                 }
81             }
82         }
83         else if (X < Y)
84         {
85             uint* t = x;
86             x = y;
87             y = t;
88             uint tLen = X;
89             X = Y;
90             Y = tLen;
91         }
92
93         // Decide which algorithm to use...
94         version(DisableKaratsuba)
95         {
96             bigintMulClassic(d, x, X, y, Y);
97         }
98         else
99         {
100             if (X+Y >= KARATSUBA_THRESHOLD)
101             {
102                 bigintMulKaratsuba(d, x, X, y, Y);
103             }
104             else
105             {
106                 bigintMulClassic(d, x, X, y, Y);
107             }
108         }
109
110     }
111
112     // The classic multiply algorithm
113     void bigintMulClassic(uint* d, uint* x, uint X, uint* y, uint Y)
114     {
115         uint[] t;
116         t.length = Y + 1;
117         for (int i=0; i<X; ++i)
118         {
119             uint k = *x++;
120             if (k != 0)
121             {
122                 t[Y] = bigintLLMul(t, y, k, Y);
123                 bigintLLAdd(d, d, t, Y+1);
124             }
125             ++d;
126         }
127         t[] = 0;
128     }
129
130     version(DisableKaratsuba)
131     {
132         // Don't need to compile the algorithm if we're not going to use it
133     }
134     else
135     {
136         // The Karatsuba multiply algorithm
137         void bigintMulKaratsuba(uint* d, uint* x, uint X, uint* y, uint Y)
138         out
139         {
140             uint[] check;
141             check.length = X+Y;
142             bigintMulClassic(check, x, X, y, Y);
143             assert(bigintLLEquals(d, check, X+Y));
144             check[] = 0;
145         }
146         body
147         {
148             // Find the split point
149             uint L = X >>> 1;
150             if (L > Y) L = Y;
151
152             // Precalculate some numbers
153             uint T0 = L + L;
154             uint T1 = X + Y - L - L;
155             uint TX = X - L > L ? X - L + 1 : L + 1;
156             uint TY = Y - L > L ? Y - L + 1 : L + 1;
157             uint TXY = TX + TY;
158             uint X0 = L;
159             uint Y0 = L;
160             uint X1 = X - X0;
161             uint Y1 = Y - Y0;
162
163             // Make a scratch-buffer and make some pointers into it
164             uint[] t;
165             t.length = T0 + T1 + TX + TY + TXY;
166             uint* t0 = t;
167             uint* t1 = t0 + T0;
168             uint* tx = t1 + T1;
169             uint* ty = tx + TX;
170             uint* txy = ty + TY;
171
172             // Even more pointers
173             uint* x0 = x;
174             uint* x1 = x + L;
175             uint* y0 = y;
176             uint* y1 = y + L;
177
178             // Do the low part
179             bigintMulInternal(t0, x0, X0, y0, Y0);
180
181             // Do the high part
182             bigintMulInternal(t1, x1, X1, y0, Y1);
183
184             // Calculate the temporary results
185             tx[X1] = bigintLLAdd(tx, x1, X1, x0, X0);
186             TX = bigintLLMinimize(tx, TX);
187             if (Y0 > Y1)
188             {
189                 ty[Y0] = bigintLLAdd(ty, y0, Y0, y1, Y1);
190             }
191             else
192             {
193                 ty[Y1] = bigintLLAdd(ty, y1, Y1, y0, Y0);
194             }
195             TY = bigintLLMinimize(ty, TY);
196             bigintMulInternal(txy, tx, TX, ty, TY);
197             TXY = bigintLLMinimize(txy, TXY);
198
199             // Add all the bits together
200             d[0..L+L] = t0[0..L+L];
201             d[L+L..X+Y] = t1[0..X+Y-L-L];
202
203             bigintLLAdd(d+L, d+L, X+Y-L, txy, TXY);
204             bigintLLSub(d+L, d+L, X+Y-L, t0, T0);
205             bigintLLSub(d+L, d+L, X+Y-L, t1, T1);
206
207             // All done. That was easy!
208             t[] = 0;
209         }
210     }
211 }
212
213 /*  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
214 d = x * x
215 */
216
217 void bigintSquare(uint* d, uint D, uint* x, uint X)
218 in
219 {
220     assert(D >= X + X);
221 }
222 body
223 {
224     version(DisableSquare)
225     {
226         // If we can't use the square algorithm, do it the hard way
227         bigintMul(d, D, x, X, x, X);
228     }
229     else
230     {
231         // Zero the destination
232         bigintLLZero(d, D);
233
234         // Do the multiply
235         bigintSquareInternal(d, x, X);
236     }
237 }
238
239 version(DisableSquare)
240 {
241     // Don't compile the algorithm if we don't need it
242 }
243 else
244 {
245     private
246     {
247         void bigintSquareInternal(uint* d, uint* x, uint X)
248         {
249             // Get things as simple as possible
250             X = bigintLLMinimize(x, X);
251
252             // Decide which algorithm to use...
253             version(DisableKaratsuba)
254             {
255                 bigintSquareClassic(d, x, X);
256             }
257             else
258             {
259                 if (X+X >= KARATSUBA_THRESHOLD)
260                 {
261                     bigintSquareKaratsuba(d, x, X);
262                 }
263                 else
264                 {
265                     bigintSquareClassic(d, x, X);
266                 }
267             }
268         }
269
270         void bigintSquareClassic(uint* d, uint* x, uint X)
271         out
272         {
273             uint[] check;
274             check.length = X+X;
275             bigintMulClassic(check, x, X, x, X);
276             assert(bigintLLEquals(d, check, X+X));
277             check[] = 0;
278         }
279         body
280         {
281             uint[] t;
282             t.length = X + X;
283             for (int i=1; i<X; ++i)
284             {
285                 t[i] = bigintLLMul(t, x, x[i], i);
286                 bigintLLAdd(d+i, d+i, t, X+1);
287             }
288             bigintLLAdd(d,d,d,X+X);
289             for (int i=0; i<X; ++i)
290             {
291                 uint dHi, dLo;
292                 version(X86)
293                 {
294                     uint xi = x[i];
295                     asm
296                     {
297                         mov EAX,xi;
298                         mov EBX,EAX;
299                         mul EBX;
300                         mov dHi,EDX;
301                         mov dLo,EAX;
302                     }
303                 }
304                 else
305                 {
306                     ulong xi = x[i];
307                     ulong di = xi * xi;
308                     dHi = cast(uint) (di >> 32);
309                     dLo = cast(uint) di;
310                 }
311                 t[i+i] = dLo;
312                 t[i+i+1] = dHi;
313             }
314             bigintLLAdd(d,d,t,X+X);
315             t[] = 0;
316         }
317
318         void bigintSquareKaratsuba(uint* d, uint* x, uint X)
319         out
320         {
321             uint[] check;
322             check.length = X+X;
323             bigintSquareClassic(check, x, X);
324             assert(bigintLLEquals(d, check, X+X));
325             check[] = 0;
326         }
327         body
328         {
329             // Find the split point
330             uint L = X >>> 1;
331
332             // Precalculate some numbers
333             uint T0 = L + L;
334             uint T1 = X + X - L - L;
335             uint TX = X - L + 1;
336             uint TXX = TX + TX;
337             uint X0 = L;
338             uint X1 = X - X0;
339
340             // Make a scratch-buffer and make some pointers into it
341             uint[] t;
342             t.length = T0 + T1 + TX + TXX;
343             uint* t0 = t;
344             uint* t1 = t0 + T0;
345             uint* tx = t1 + T1;
346             uint* txx = tx + TX;
347
348             // Even more pointers
349             uint* x0 = x;
350             uint* x1 = x + L;
351
352             // Do the low part
353             bigintSquareInternal(t0, x0, X0);
354
355             // Do the high part
356             bigintSquareInternal(t1, x1, X1);
357
358             // Calculate the temporary results
359             tx[X1] = bigintLLAdd(tx, x1, X1, x0, X0);
360             TX = bigintLLMinimize(tx, TX);
361             bigintSquareInternal(txx, tx, TX);
362             TXX = bigintLLMinimize(txx, TXX);
363
364             // Add all the bits together
365             d[0..L+L] = t0[0..L+L];
366             d[L+L..X+X] = t1[0..X+X-L-L];
367
368             bigintLLAdd(d+L, d+L, X+X-L, txx, TXX);
369             bigintLLSub(d+L, d+L, X+X-L, t0, T0);
370             bigintLLSub(d+L, d+L, X+X-L, t1, T1);
371
372             // All done. That was easy!
373             t[] = 0;
374         }
375
376     }
377 }
Note: See TracBrowser for help on using the browser.