topN: define behavior for nth >= r.length, improve speed

This commit is contained in:
Andrei Alexandrescu 2016-01-14 19:29:58 -05:00
parent ac96a3d3c0
commit f259c22987

View file

@ -2163,7 +2163,7 @@ $(D !less(e2, r[nth])). Effectively, it finds the nth smallest
$(BIGOH r.length) (if unstable) or $(BIGOH r.length * log(r.length)) $(BIGOH r.length) (if unstable) or $(BIGOH r.length * log(r.length))
(if stable) evaluations of $(D less) and $(D swap). (if stable) evaluations of $(D less) and $(D swap).
If $(D n >= r.length), the algorithm has no effect. If $(D n >= r.length), the algorithm has no effect and returns `r[0 .. $]`.
Params: Params:
less = The predicate to sort by. less = The predicate to sort by.
@ -2185,21 +2185,30 @@ auto topN(alias less = "a < b",
Range)(Range r, size_t nth) Range)(Range r, size_t nth)
if (isRandomAccessRange!(Range) && hasLength!Range) if (isRandomAccessRange!(Range) && hasLength!Range)
{ {
import std.algorithm : swap; // FIXME
import std.random : uniform;
static assert(ss == SwapStrategy.unstable, static assert(ss == SwapStrategy.unstable,
"Stable topN not yet implemented"); "Stable topN not yet implemented");
if (nth >= r.length) return r[0 .. $];
auto ret = r[0 .. nth]; auto ret = r[0 .. nth];
while (r.length > nth) for (;;)
{ {
assert(nth < r.length);
import std.algorithm.mutation : swap;
import std.algorithm.searching : minPos;
if (nth == 0) if (nth == 0)
{ {
// Special-case "min" // Special-case "min"
import std.algorithm.searching : minPos;
swap(r.front, r.minPos!less.front); swap(r.front, r.minPos!less.front);
break; break;
} }
if (nth + 1 == r.length)
{
// Special-case "max"
swap(r.back, r.minPos!((a, b) => binaryFun!less(b, a)).front);
break;
}
auto pivot = r.getPivot!less; auto pivot = r.getPivot!less;
assert(!binaryFun!less(r[pivot], r[pivot])); assert(!binaryFun!less(r[pivot], r[pivot]));
swap(r[pivot], r.back); swap(r[pivot], r.back);
@ -2231,6 +2240,8 @@ auto topN(alias less = "a < b",
@safe unittest @safe unittest
{ {
int[] v = [ 25, 7, 9, 2, 0, 5, 21 ]; int[] v = [ 25, 7, 9, 2, 0, 5, 21 ];
topN!"a < b"(v, 100);
assert(v == [ 25, 7, 9, 2, 0, 5, 21 ]);
auto n = 4; auto n = 4;
topN!"a < b"(v, n); topN!"a < b"(v, n);
assert(v[n] == 9); assert(v[n] == 9);