faster pairwise summation

This commit is contained in:
John Colvin 2016-03-10 18:29:31 +00:00
parent f69a7d9536
commit 678a511ff2

View file

@ -3597,7 +3597,7 @@ private struct SplitterResult(alias isTerminator, Range)
private Range _input; private Range _input;
private size_t _end = 0; private size_t _end = 0;
static if(!fullSlicing) static if (!fullSlicing)
private Range _next; private Range _next;
private void findTerminator() private void findTerminator()
@ -3619,7 +3619,7 @@ private struct SplitterResult(alias isTerminator, Range)
this(Range input) this(Range input)
{ {
_input = input; _input = input;
static if(!fullSlicing) static if (!fullSlicing)
_next = _input.save; _next = _input.save;
if (!_input.empty) if (!_input.empty)
@ -3869,7 +3869,7 @@ if (isSomeChar!C)
{ {
import std.algorithm.comparison : equal; import std.algorithm.comparison : equal;
import std.meta : AliasSeq; import std.meta : AliasSeq;
foreach(S; AliasSeq!(string, wstring, dstring)) foreach (S; AliasSeq!(string, wstring, dstring))
{ {
import std.conv : to; import std.conv : to;
S a = " a bcd ef gh "; S a = " a bcd ef gh ";
@ -4045,7 +4045,10 @@ if (isInputRange!R && !isInfinite!R && is(typeof(seed = seed + r.front)))
static if (isFloatingPoint!E) static if (isFloatingPoint!E)
{ {
static if (hasLength!R && hasSlicing!R) static if (hasLength!R && hasSlicing!R)
{
if (r.empty) return seed;
return seed + sumPairwise!E(r); return seed + sumPairwise!E(r);
}
else else
return sumKahan!E(seed, r); return sumKahan!E(seed, r);
} }
@ -4056,16 +4059,97 @@ if (isInputRange!R && !isInfinite!R && is(typeof(seed = seed + r.front)))
} }
// Pairwise summation http://en.wikipedia.org/wiki/Pairwise_summation // Pairwise summation http://en.wikipedia.org/wiki/Pairwise_summation
private auto sumPairwise(Result, R)(R r) private auto sumPairwise(F, R)(R data)
if (isInputRange!R && !isInfinite!R)
{ {
static assert (isFloatingPoint!Result); import core.bitop : bsf;
switch (r.length) // Works for r with at least length < 2^^(64 + log2(16)), in keeping with the use of size_t
// elsewhere in std.algorithm and std.range on 64 bit platforms. The 16 in log2(16) comes
// from the manual unrolling in sumPairWise16
F[64] store = void;
size_t idx = 0;
auto collapseStore(T)(T k)
{ {
case 0: return cast(Result) 0; auto lastToKeep = idx - cast(uint)bsf(k+1);
case 1: return cast(Result) r.front; while (idx > lastToKeep)
case 2: return cast(Result) r[0] + cast(Result) r[1]; {
default: return sumPairwise!Result(r[0 .. $ / 2]) + sumPairwise!Result(r[$ / 2 .. $]); store[idx - 1] += store[idx];
--idx;
}
} }
static if (hasLength!R)
{
foreach (k; 0 .. data.length / 16)
{
static if (isRandomAccessRange!R && hasSlicing!R)
{
store[idx] = sumPairwise16!F(data);
data = data[16 .. $];
}
else store[idx] = sumPairwiseN!(16, false, F)(data);
collapseStore(k);
++idx;
}
size_t i = 0;
foreach (el; data)
{
store[idx] = el;
collapseStore(i);
++idx;
++i;
}
}
else
{
size_t k = 0;
while (!data.empty)
{
store[idx] = sumPairwiseN!(16, true, F)(data);
collapseStore(k);
++idx;
++k;
}
}
F s = store[idx - 1];
foreach_reverse (j; 0 .. idx - 1)
s += store[j];
return s;
}
private auto sumPairwise16(F, R)(R r)
if (isRandomAccessRange!R)
{
return (((cast(F)r[ 0] + r[ 1]) + (cast(F)r[ 2] + r[ 3]))
+ ((cast(F)r[ 4] + r[ 5]) + (cast(F)r[ 6] + r[ 7])))
+ (((cast(F)r[ 8] + r[ 9]) + (cast(F)r[10] + r[11]))
+ ((cast(F)r[12] + r[13]) + (cast(F)r[14] + r[15])));
}
private auto sumPair(bool needEmptyChecks, F, R)(ref R r)
if (isForwardRange!R && !isRandomAccessRange!R)
{
static if (needEmptyChecks) if (r.empty) return F(0);
F s0 = r.front;
r.popFront();
static if (needEmptyChecks) if (r.empty) return s0;
s0 += r.front;
r.popFront();
return s0;
}
private auto sumPairwiseN(size_t N, bool needEmptyChecks, F, R)(ref R r)
if (isForwardRange!R && !isRandomAccessRange!R)
{
static assert(!(N & (N-1))); //isPow2
static if (N == 2) return sumPair!(needEmptyChecks, F)(r);
else return sumPairwiseN!(N/2, needEmptyChecks, F)(r)
+ sumPairwiseN!(N/2, needEmptyChecks, F)(r);
} }
// Kahan algo http://en.wikipedia.org/wiki/Kahan_summation_algorithm // Kahan algo http://en.wikipedia.org/wiki/Kahan_summation_algorithm
@ -4181,6 +4265,13 @@ unittest
assert(sb == (BigInt(ulong.max/2) * 10)); assert(sb == (BigInt(ulong.max/2) * 10));
} }
@safe pure nothrow @nogc unittest
{
import std.range;
foreach(n; iota(50))
assert(repeat(1.0, n).sum == n);
}
// uniq // uniq
/** /**
Lazily iterates unique consecutive elements of the given range (functionality Lazily iterates unique consecutive elements of the given range (functionality