Add std.algorithm.iteration : cumulativeSum

This commit is contained in:
e-y-e 2016-10-28 15:40:38 +01:00
parent 027839fd79
commit aa74a62937
2 changed files with 308 additions and 0 deletions

View file

@ -18,6 +18,8 @@ $(T2 cumulativeFold,
$(D cumulativeFold!((a, b) => a + b)([1, 2, 3, 4])) returns a
lazily-evaluated range containing the successive reduced values `1`,
`3`, `6`, `10`.)
$(T2 cumulativeSum,
Same as $(D cumulativeFold), but specialized for accurate summation.)
$(T2 each,
$(D each!writeln([1, 2, 3])) eagerly prints the numbers $(D 1), $(D 2)
and $(D 3) on their own lines.)
@ -3483,6 +3485,311 @@ The number of seeds must be correspondingly increased.
}
}
/++
Performs a summation of the given $(REF_ALTTEXT input range, isInputRange,
std, range, primitives) r, and provides the intermediate results of the
summation as an $(REF_ALTTEXT input range, isInputRange, std, range,
primitives). `cumulativeSum` is conceptually equivalent to
`cumulativeFold!((a, b) => a + b)`, but for floating point summations the
$(HTTP en.wikipedia.org/wiki/Kahan_summation, Kahan summation) algorithm is
used to reduce accuracy loss from cancellation errors.
When called without a seed, the seed type is deduced from the $(REF_ALTTEXT
element type, ElementType, std,range,primitives) of r and the seed value is 0.
If the $(REF_ALTTEXT element type, ElementType, std, range, primitives) of r is
a $(REF_ALTTEXT floating point type, isFloatingPoint, std, traits), then the
seed type will be deduced to be the most precise type available from either
double or real.
Params:
r = any $(REF_ALTTEXT input range, isInputRange, std, range, primitives)
s = a seed value that gives the initial value of the summation
Returns:
an $(REF_ALTTEXT input range, isInputRange, std, range, primitives)
containing the intermediate results of the summation of r.
See_Also:
$(HTTP en.wikipedia.org/wiki/Prefix_sum, Prefix Sum)
$(HREF cumulativeFold) provides the intermediate results of generic
reduction operations on ranges.
$(HREF sum) performs a summation of a range without providing intermediate
results.
+/
auto cumulativeSum(Range)(Range r)
if (isInputRange!Range && __traits(compiles, r.front + r.front))
{
static if (isFloatingPoint!(ElementType!Range))
{
// Most precise seed type available from either double or real.
alias Seed = typeof(0.0 + r.front);
}
else
{
// Deduce seed type from the result of a single addition.
alias Seed = typeof(r.front + r.front);
}
return r.cumulativeSum(Seed(0));
}
/// Ditto
auto cumulativeSum(Range, Seed)(Range r, Seed s)
if (isInputRange!Range && __traits(compiles, r.front + r.front))
{
static if (isFloatingPoint!Seed)
{
static struct Result
{
this(Range r, Seed s)
{
_r = r;
if (_r.empty) return;
_s = s;
sumFront;
}
@property
auto front()
in
{
assert(!empty,
"Attempting to fetch the front of an empty cumulativeSum");
}
body
{
return _s;
}
void popFront()
in
{
assert(!empty,
"Attempting to popFront an empty cumulativeSum");
}
body
{
_r.popFront;
if (_r.empty) return;
sumFront;
}
static if (isInfinite!Range)
{
enum empty = false;
}
else
{
@property
bool empty()
{
return _r.empty;
}
}
static if (isForwardRange!Range)
{
@property
auto save()
{
auto result = this;
result._r = _r.save;
return result;
}
}
static if (hasLength!Range)
{
@property
size_t length()
{
return _r.length;
}
}
private:
Range _r;
Seed _s;
Seed _c = 0;
void sumFront()
{
// One iteration of Kahan summation.
immutable y = _r.front - _c;
immutable t = _s + y;
_c = (t - _s) - y;
_s = t;
}
}
return Result(r, s);
}
else
{
// Default to naive summation for integral values.
return r.cumulativeFold!((a, b) => a + b)(s);
}
}
///
@safe pure nothrow
unittest
{
import std.algorithm.comparison : equal;
import std.range : iota, repeat;
// Partial sum of integral values:
assert(cumulativeSum([1, 2, 3, 4, 5]).equal([1, 3, 6, 10, 15]));
// Using ranges and UFCS:
assert(iota(1, 6).cumulativeSum.equal([1, 3, 6, 10, 15]));
// With seed value:
assert(iota(1, 6).cumulativeSum(-15).equal([-14, -12, -9, -5, 0]));
// Partial sum of floating point values:
assert(cumulativeSum([1.0, 2.0, 3.0, 4.0, 5.0])
.equal([1.0, 3.0, 6.0, 10.0, 15.0]));
// With seed value:
assert(cumulativeSum([1.0, 2.0, 3.0, 4.0, 5.0], -15.0)
.equal([-14.0, -12.0, -9.0, -5.0, 0.0]));
// Partial sum with integral promotion:
assert(cumulativeSum([false, true, true, false, true])
.equal([0, 1, 2, 2, 3]));
// The result may overflow:
assert(uint.max.repeat(3).cumulativeSum
.equal([4294967295U, 4294967294U, 4294967293U]));
// But a seed can be used to change the sumation primitive:
assert(uint.max.repeat(3).cumulativeSum(ulong.init)
.equal([4294967295UL, 8589934590UL, 12884901885UL]));
}
/++
$(D cumulativeSum) uses Kahan summation to give more accurate results than
naive summation for ranges of floating point values.
+/
@safe pure nothrow
unittest
{
import std.math : approxEqual;
// Despite summing 'large' and 'small' numbers the loss of significance is
// a non-issue.
assert(cumulativeSum([10000, 3.14159, 2.71828, 1.41421, 1.61803, -10000])
.approxEqual([10000, 10003.1, 10005.9, 10007.3, 10008.9, 8.89211]));
// Another example with a 'large' seed value.
assert(cumulativeSum([6.28318, 1.73205, 3.33333, 2.23606, -10000], 10000.0)
.approxEqual([10006.3, 10008.0, 10011.3, 10013.6, 13.5846]));
// A more extreme example.
assert(cumulativeSum([71850, 1.594e-11, 7.91182e-11, 2.36169e-11, -71850])
.approxEqual([71850, 71850, 71850, 71850, 1.18675e-10]));
}
@safe pure nothrow
unittest
{
import std.range.primitives : ElementType;
import std.algorithm.comparison : equal;
// Integral types:
static assert(is(ElementType!(typeof(cumulativeSum([cast(byte)1]))) == int));
static assert(is(ElementType!(typeof(cumulativeSum([cast(ubyte)1]))) == int));
static assert(is(ElementType!(typeof(cumulativeSum([1, 2, 3, 4]))) == int));
static assert(is(ElementType!(typeof(cumulativeSum([1U, 2U, 3U, 4U]))) == uint));
static assert(is(ElementType!(typeof(cumulativeSum([1L, 2L, 3L, 4L]))) == long));
static assert(is(ElementType!(typeof(cumulativeSum([1UL, 2UL, 3UL, 4UL]))) == ulong));
int[] empty;
assert(cumulativeSum(empty).empty);
assert(cumulativeSum([42]).equal([42]));
assert(cumulativeSum([42, 43]).equal([42, 42 + 43]));
assert(cumulativeSum([42, 43, 44]).equal([42, 42 + 43, 42 + 43 + 44]));
assert(cumulativeSum([42, 43, 44, 45])
.equal([42, 42 + 43, 42 + 43 + 44, 42 + 43 + 44 + 45]));
}
@safe pure nothrow
unittest
{
import std.range.primitives : ElementType;
import std.algorithm.comparison : equal;
// Floating point types:
static assert(is(ElementType!(typeof(cumulativeSum([1.0, 2.0, 3.0, 4.0]))) == double));
static assert(is(ElementType!(typeof(cumulativeSum([1F, 2F, 3F, 4F]))) == double));
static assert(is(ElementType!(typeof(cumulativeSum([1.0L, 2.0L, 3.0L, 4.0L]))) == real));
const(float[]) a = [1F, 2F, 3F, 4F];
static assert(is(ElementType!(typeof(cumulativeSum(a))) == double));
const(float)[] b = [1F, 2F, 3F, 4F];
static assert(is(ElementType!(typeof(cumulativeSum(b))) == double));
double[] empty;
assert(cumulativeSum(empty).empty);
assert(cumulativeSum([42.0]).equal([42]));
assert(cumulativeSum([42.0, 43.0]).equal([42, 42 + 43]));
assert(cumulativeSum([42.0, 43.0, 44.0])
.equal([42, 42 + 43, 42 + 43 + 44]));
assert(cumulativeSum([42.0, 43.0, 44.0, 45.5])
.equal([42, 42 + 43, 42 + 43 + 44, 42 + 43 + 44 + 45.5]));
}
@safe @nogc pure nothrow
unittest
{
import std.algorithm.comparison : equal;
import std.range : iota, repeat;
foreach (n; iota(50))
{
assert(repeat(1, n).cumulativeSum(-1.0).equal(iota(n)));
}
}
@safe pure nothrow
unittest
{
import std.algorithm.comparison : equal;
import std.internal.test.dummyrange : AllDummyRanges, propagatesLength,
propagatesRangeType, RangeType;
import std.range : chunks;
import std.algorithm.iteration : map, joiner;
foreach (DummyType; AllDummyRanges)
{
DummyType d;
// Test floating point values as integral values are handled by
// cumulativeFold.
auto f = d.map!(n => cast(double)n);
static if (DummyType.rt == RangeType.Forward)
{
assert(f.chunks(1).map!cumulativeSum.joiner.equal(f));
}
auto s = f.cumulativeSum;
static assert(propagatesLength!(typeof(s), DummyType));
static if (DummyType.rt <= RangeType.Forward)
{
static assert(propagatesRangeType!(typeof(s), DummyType));
}
assert(s.equal([1, 3, 6, 10, 15, 21, 28, 36, 45, 55]));
}
}
// splitter
/**
Lazily splits a range using an element as a separator. This can be used with