Fix Issue 18280 - std.algorithm.comparison.cmp for non-strings should call opCmp only once per item pair

split cmp into two overloads per @andralex

https://github.com/dlang/phobos/pull/6056#pullrequestreview-90687092

Minor adjustments, again

cmp should return auto and let opCmp drive

https://github.com/dlang/phobos/pull/6056#issuecomment-359665184

Fix Issue 18285 - std.algorithm.comparison.cmp for strings with custom predicate compares lengths wrong

Test std.algorithm.comparison.cmp when opCmp returns float

Promotions should not use cast

Optimize cmp's endgame

There are some redundant tests when the end of the ranges is reached. Eliminated that, and improved threeWayByPred.

Fix Issue 18286 - std.algorithm.comparison.cmp for string with custom predicate fails if distinct chars can compare equal

Fix Issue 18288 - std.algorithm.comparison.cmp for wide strings should be @safe

re-apply remove cast in promotions
This commit is contained in:
Nathan Sashihara 2018-01-22 17:30:46 -05:00
parent d11a3d0a5d
commit 7f59e5ad52

View file

@ -580,43 +580,64 @@ do
// cmp // cmp
/********************************** /**********************************
Performs three-way lexicographical comparison on two Performs a lexicographical comparison on two
$(REF_ALTTEXT input ranges, isInputRange, std,range,primitives) $(REF_ALTTEXT input ranges, isInputRange, std,range,primitives).
according to predicate `pred`. Iterating `r1` and `r2` in Iterating `r1` and `r2` in lockstep, `cmp` compares each element
lockstep, `cmp` compares each element `e1` of `r1` with the `e1` of `r1` with the corresponding element `e2` in `r2`. If one
corresponding element `e2` in `r2`. If one of the ranges has been of the ranges has been finished, `cmp` returns a negative value
finished, `cmp` returns a negative value if `r1` has fewer if `r1` has fewer elements than `r2`, a positive value if `r1`
elements than `r2`, a positive value if `r1` has more elements has more elements than `r2`, and `0` if the ranges have the same
than `r2`, and `0` if the ranges have the same number of number of elements.
elements.
If the ranges are strings, `cmp` performs UTF decoding If the ranges are strings, `cmp` performs UTF decoding
appropriately and compares the ranges one code point at a time. appropriately and compares the ranges one code point at a time.
A custom predicate may be specified, in which case `cmp` performs
a three-way lexicographical comparison using `pred`. Otherwise
the elements are compared using `opCmp`.
Params: Params:
pred = The predicate used for comparison. pred = Predicate used for comparison. Without a predicate
specified the ordering implied by `opCmp` is used.
r1 = The first range. r1 = The first range.
r2 = The second range. r2 = The second range.
Returns: Returns:
0 if both ranges compare equal. -1 if the first differing element of $(D `0` if the ranges compare equal. A negative value if `r1` is a prefix of `r2` or
r1) is less than the corresponding element of `r2` according to $(D the first differing element of `r1` is less than the corresponding element of `r2`
pred). 1 if the first differing element of `r2` is less than the according to `pred`. A positive value if `r2` is a prefix of `r1` or the first
corresponding element of `r1` according to `pred`. differing element of `r2` is less than the corresponding element of `r1`
according to `pred`.
Note:
An earlier version of the documentation incorrectly stated that `-1` is the
only negative value returned and `1` is the only positive value returned.
Whether that is true depends on the types being compared.
*/ */
int cmp(alias pred = "a < b", R1, R2)(R1 r1, R2 r2) auto cmp(R1, R2)(R1 r1, R2 r2)
if (isInputRange!R1 && isInputRange!R2) if (isInputRange!R1 && isInputRange!R2)
{ {
static if (!(isSomeString!R1 && isSomeString!R2)) static if (!(isSomeString!R1 && isSomeString!R2))
{ {
for (;; r1.popFront(), r2.popFront()) for (;; r1.popFront(), r2.popFront())
{ {
if (r1.empty) return -cast(int)!r2.empty; static if (is(typeof(r1.front.opCmp(r2.front)) R))
if (r2.empty) return !r1.empty; alias Result = R;
auto a = r1.front, b = r2.front; else
if (binaryFun!pred(a, b)) return -1; alias Result = int;
if (binaryFun!pred(b, a)) return 1; if (r2.empty) return Result(!r1.empty);
if (r1.empty) return Result(-1);
static if (is(typeof(r1.front.opCmp(r2.front))))
{
auto c = r1.front.opCmp(r2.front);
if (c != 0) return c;
}
else
{
auto a = r1.front, b = r2.front;
if (a < b) return -1;
if (b < a) return 1;
}
} }
} }
else else
@ -624,35 +645,28 @@ if (isInputRange!R1 && isInputRange!R2)
import core.stdc.string : memcmp; import core.stdc.string : memcmp;
import std.utf : decode; import std.utf : decode;
static if (is(typeof(pred) : string))
enum isLessThan = pred == "a < b";
else
enum isLessThan = false;
// For speed only // For speed only
static int threeWay(size_t a, size_t b) static int threeWay(size_t a, size_t b)
{ {
static if (size_t.sizeof == int.sizeof && isLessThan) static if (size_t.sizeof == int.sizeof)
return a - b; return a - b;
else else
return binaryFun!pred(b, a) ? 1 : binaryFun!pred(a, b) ? -1 : 0; // Faster than return b < a ? 1 : a < b ? -1 : 0;
return (a > b) - (a < b);
} }
// For speed only // For speed only
// @@@BUG@@@ overloading should be allowed for nested functions // @@@BUG@@@ overloading should be allowed for nested functions
static int threeWayInt(int a, int b) static int threeWayInt(int a, int b)
{ {
static if (isLessThan) return a - b;
return a - b;
else
return binaryFun!pred(b, a) ? 1 : binaryFun!pred(a, b) ? -1 : 0;
} }
static if (typeof(r1[0]).sizeof == typeof(r2[0]).sizeof && isLessThan) static if (typeof(r1[0]).sizeof == typeof(r2[0]).sizeof)
{ {
static if (typeof(r1[0]).sizeof == 1) static if (typeof(r1[0]).sizeof == 1)
{ {
immutable len = min(r1.length, r2.length); immutable len = min(r1.length, r2.length);
immutable result = __ctfe ? int result = __ctfe ?
{ {
foreach (i; 0 .. len) foreach (i; 0 .. len)
{ {
@ -663,17 +677,21 @@ if (isInputRange!R1 && isInputRange!R2)
}() }()
: () @trusted { return memcmp(r1.ptr, r2.ptr, len); }(); : () @trusted { return memcmp(r1.ptr, r2.ptr, len); }();
if (result) return result; if (result) return result;
return threeWay(r1.length, r2.length);
} }
else else
{ {
auto p1 = r1.ptr, p2 = r2.ptr, return () @trusted
pEnd = p1 + min(r1.length, r2.length);
for (; p1 != pEnd; ++p1, ++p2)
{ {
if (*p1 != *p2) return threeWayInt(cast(int) *p1, cast(int) *p2); auto p1 = r1.ptr, p2 = r2.ptr,
} pEnd = p1 + min(r1.length, r2.length);
for (; p1 != pEnd; ++p1, ++p2)
{
if (*p1 != *p2) return threeWayInt(int(*p1), int(*p2));
}
return threeWay(r1.length, r2.length);
}();
} }
return threeWay(r1.length, r2.length);
} }
else else
{ {
@ -683,14 +701,58 @@ if (isInputRange!R1 && isInputRange!R2)
if (i2 == r2.length) return threeWay(r1.length, i1); if (i2 == r2.length) return threeWay(r1.length, i1);
immutable c1 = decode(r1, i1), immutable c1 = decode(r1, i1),
c2 = decode(r2, i2); c2 = decode(r2, i2);
if (c1 != c2) return threeWayInt(cast(int) c1, cast(int) c2); if (c1 != c2) return threeWayInt(int(c1), int(c2));
}
}
}
}
/// ditto
int cmp(alias pred, R1, R2)(R1 r1, R2 r2)
if (isInputRange!R1 && isInputRange!R2)
{
static if (!(isSomeString!R1 && isSomeString!R2))
{
for (;; r1.popFront(), r2.popFront())
{
if (r2.empty) return !r1.empty;
if (r1.empty) return -1;
auto a = r1.front, b = r2.front;
if (binaryFun!pred(a, b)) return -1;
if (binaryFun!pred(b, a)) return 1;
}
}
else
{
import std.utf : decode;
// For speed only
static int threeWayCompareLength(size_t a, size_t b)
{
static if (size_t.sizeof == int.sizeof)
return a - b;
else
// Faster than return b < a ? 1 : a < b ? -1 : 0;
return (a > b) - (a < b);
}
for (size_t i1, i2;;)
{
if (i1 == r1.length) return threeWayCompareLength(i2, r2.length);
if (i2 == r2.length) return threeWayCompareLength(r1.length, i1);
immutable c1 = decode(r1, i1),
c2 = decode(r2, i2);
if (c1 != c2)
{
if (binaryFun!pred(c2, c1)) return 1;
if (binaryFun!pred(c1, c2)) return -1;
} }
} }
} }
} }
/// ///
@safe unittest pure @safe unittest
{ {
int result; int result;
@ -712,6 +774,8 @@ if (isInputRange!R1 && isInputRange!R2)
assert(result > 0); assert(result > 0);
result = cmp("aaa", "aaa"d); result = cmp("aaa", "aaa"d);
assert(result == 0); assert(result == 0);
result = cmp("aaa"d, "aaa"d);
assert(result == 0);
result = cmp(cast(int[])[], cast(int[])[]); result = cmp(cast(int[])[], cast(int[])[]);
assert(result == 0); assert(result == 0);
result = cmp([1, 2, 3], [1, 2, 3]); result = cmp([1, 2, 3], [1, 2, 3]);
@ -724,6 +788,106 @@ if (isInputRange!R1 && isInputRange!R2)
assert(result > 0); assert(result > 0);
} }
/// Example predicate that compares individual elements in reverse lexical order
pure @safe unittest
{
int result;
result = cmp!"a > b"("abc", "abc");
assert(result == 0);
result = cmp!"a > b"("", "");
assert(result == 0);
result = cmp!"a > b"("abc", "abcd");
assert(result < 0);
result = cmp!"a > b"("abcd", "abc");
assert(result > 0);
result = cmp!"a > b"("abc"d, "abd");
assert(result > 0);
result = cmp!"a > b"("bbc", "abc"w);
assert(result < 0);
result = cmp!"a > b"("aaa", "aaaa"d);
assert(result < 0);
result = cmp!"a > b"("aaaa", "aaa"d);
assert(result > 0);
result = cmp!"a > b"("aaa", "aaa"d);
assert(result == 0);
result = cmp("aaa"d, "aaa"d);
assert(result == 0);
result = cmp!"a > b"(cast(int[])[], cast(int[])[]);
assert(result == 0);
result = cmp!"a > b"([1, 2, 3], [1, 2, 3]);
assert(result == 0);
result = cmp!"a > b"([1, 3, 2], [1, 2, 3]);
assert(result < 0);
result = cmp!"a > b"([1, 2, 3], [1L, 2, 3, 4]);
assert(result < 0);
result = cmp!"a > b"([1L, 2, 3], [1, 2]);
assert(result > 0);
}
@nogc nothrow pure @safe unittest
{
// Issue 18286: cmp for string with custom predicate fails if distinct chars can compare equal
static bool ltCi(dchar a, dchar b)// less than, case insensitive
{
import std.ascii : toUpper;
return toUpper(a) < toUpper(b);
}
static assert(cmp!ltCi("apple2", "APPLE1") > 0);
static assert(cmp!ltCi("apple1", "APPLE2") < 0);
static assert(cmp!ltCi("apple", "APPLE1") < 0);
static assert(cmp!ltCi("APPLE", "apple1") < 0);
static assert(cmp!ltCi("apple", "APPLE") == 0);
}
@nogc nothrow @safe unittest
{
// Issue 18280: for non-string ranges check that opCmp is evaluated only once per pair.
static int ctr = 0;
struct S
{
int opCmp(ref const S rhs) const
{
++ctr;
return 0;
}
}
immutable S[4] a;
immutable S[4] b;
immutable result = cmp(a[], b[]);
assert(result == 0, "neither should compare greater than the other!");
assert(ctr == a.length, "opCmp should be called exactly once per pair of items!");
}
nothrow pure @safe unittest
{
// Test cmp when opCmp returns float.
struct F
{
float value;
float opCmp(const ref F rhs) const
{
return value - rhs.value;
}
}
auto result = cmp([F(1), F(2), F(3)], [F(1), F(2), F(3)]);
assert(result == 0);
assert(is(typeof(result) == float));
result = cmp([F(1), F(3), F(2)], [F(1), F(2), F(3)]);
assert(result > 0);
result = cmp([F(1), F(2), F(3)], [F(1), F(2), F(3), F(4)]);
assert(result < 0);
result = cmp([F(1), F(2), F(3)], [F(1), F(2)]);
assert(result > 0);
}
nothrow pure @safe unittest
{
// Parallelism (was broken by inferred return type "immutable int")
import std.parallelism : task;
auto t = task!cmp("foo", "bar");
}
// equal // equal
/** /**
Compares two ranges for equality, as defined by predicate `pred` Compares two ranges for equality, as defined by predicate `pred`