Merge pull request #10650 from pbackus/sumtype-proc-api

Procedural API for SumType
This commit is contained in:
Paul Backus 2025-03-03 23:19:20 -05:00 committed by GitHub
commit 3f990a7e25
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 465 additions and 15 deletions

View file

@ -0,0 +1,39 @@
New procedural API for `std.sumtype`
`std.sumtype` has three new convenience functions for querying and retrieving
the value of a `SumType` object.
* `has!T` returns `true` if the `SumType` object has a value of type `T`.
* `get!T` returns the value if its type is `T`, or asserts if it is not.
* `tryGet!T` returns the value if its type is `T`, or throws an exception if it
is not.
These functions make it easier to write code using `SumType` in a procedural
style, as opposed to the functional style encouraged by `match`.
Example:
---
import std.sumtype;
import std.stdio;
SumType!(string, double) example = "hello";
if (example.has!string)
{
writeln("string: ", example.get!string);
}
else if (example.has!double)
{
writeln("double: ", example.get!double);
}
try
{
writeln("double: ", example.tryGet!double);
}
catch (MatchException e)
{
writeln("Couldn't get a double.");
}
---

View file

@ -323,7 +323,7 @@ private:
@trusted
// Explicit return type omitted
// Workaround for https://github.com/dlang/dmd/issues/20549
ref get(size_t tid)() inout
ref getByIndex(size_t tid)() inout
if (tid < Types.length)
{
assert(tag == tid,
@ -1154,7 +1154,7 @@ version (D_BetterC) {} else
alias MySum = SumType!(ubyte, void*[2]);
MySum x = [null, cast(void*) 0x12345678];
void** p = &x.get!1[1];
void** p = &x.getByIndex!1[1];
x = ubyte(123);
assert(*p != cast(void*) 0x12345678);
@ -1186,8 +1186,8 @@ version (D_BetterC) {} else
catch (Exception e) {}
assert(
(x.tag == 0 && x.get!0.value == 123) ||
(x.tag == 1 && x.get!1.value == 456)
(x.tag == 0 && x.getByIndex!0.value == 123) ||
(x.tag == 1 && x.getByIndex!1.value == 456)
);
}
@ -1246,8 +1246,8 @@ version (D_BetterC) {} else
SumType!(S[1]) x = [S(0)];
SumType!(S[1]) y = x;
auto xval = x.get!0[0].n;
auto yval = y.get!0[0].n;
auto xval = x.getByIndex!0[0].n;
auto yval = y.getByIndex!0[0].n;
assert(xval != yval);
}
@ -1332,8 +1332,8 @@ version (D_BetterC) {} else
SumType!S y;
y = x;
auto xval = x.get!0.n;
auto yval = y.get!0.n;
auto xval = x.getByIndex!0.n;
auto yval = y.getByIndex!0.n;
assert(xval != yval);
}
@ -1407,8 +1407,8 @@ version (D_BetterC) {} else
SumType!S x = S();
SumType!S y = x;
auto xval = x.get!0.n;
auto yval = y.get!0.n;
auto xval = x.getByIndex!0.n;
auto yval = y.getByIndex!0.n;
assert(xval != yval);
}
@ -1902,10 +1902,10 @@ private template matchImpl(Flag!"exhaustive" exhaustive, handlers...)
* argument's tag, so there's no need for TagTuple.
*/
enum handlerArgs(size_t caseId) =
"args[0].get!(" ~ toCtString!caseId ~ ")()";
"args[0].getByIndex!(" ~ toCtString!caseId ~ ")()";
alias valueTypes(size_t caseId) =
typeof(args[0].get!(caseId)());
typeof(args[0].getByIndex!(caseId)());
enum numCases = SumTypes[0].Types.length;
}
@ -1931,7 +1931,7 @@ private template matchImpl(Flag!"exhaustive" exhaustive, handlers...)
template getType(size_t i)
{
alias getType = typeof(args[i].get!(tags[i])());
alias getType = typeof(args[i].getByIndex!(tags[i])());
}
alias valueTypes = Map!(getType, Iota!(tags.length));
@ -2156,7 +2156,7 @@ private template handlerArgs(size_t caseId, typeCounts...)
{
handlerArgs = AliasSeq!(
handlerArgs,
"args[" ~ toCtString!i ~ "].get!(" ~ toCtString!(tags[i]) ~ ")(), "
"args[" ~ toCtString!i ~ "].getByIndex!(" ~ toCtString!(tags[i]) ~ ")(), "
);
}
}
@ -2420,7 +2420,7 @@ version (D_Exceptions)
(ref double d) { d *= 2; }
);
assert(value.get!1.isClose(6.28));
assert(value.getByIndex!1.isClose(6.28));
}
// Unreachable handlers
@ -2642,6 +2642,417 @@ version (D_Exceptions)
}));
}
/**
* Checks whether a `SumType` contains a value of a given type.
*
* The types must match exactly, without implicit conversions.
*
* Params:
* T = the type to check for.
*/
template has(T)
{
/**
* The actual `has` function.
*
* Params:
* self = the `SumType` to check.
*
* Returns: true if `self` contains a `T`, otherwise false.
*/
bool has(Self)(auto ref Self self)
if (isSumType!Self)
{
return self.match!checkType;
}
// Helper to avoid redundant template instantiations
private bool checkType(Value)(ref Value value)
{
return is(Value == T);
}
}
/// Basic usage
@safe unittest
{
SumType!(string, double) example = "hello";
assert( example.has!string);
assert(!example.has!double);
// If T isn't part of the SumType, has!T will always return false.
assert(!example.has!int);
}
/// With type qualifiers
@safe unittest
{
alias Example = SumType!(string, double);
Example m = "mutable";
const Example c = "const";
immutable Example i = "immutable";
assert( m.has!string);
assert(!m.has!(const(string)));
assert(!m.has!(immutable(string)));
assert(!c.has!string);
assert( c.has!(const(string)));
assert(!c.has!(immutable(string)));
assert(!i.has!string);
assert(!i.has!(const(string)));
assert( i.has!(immutable(string)));
}
/// As a predicate
version (D_BetterC) {} else
@safe unittest
{
import std.algorithm.iteration : filter;
import std.algorithm.comparison : equal;
alias Example = SumType!(string, double);
auto arr = [
Example("foo"),
Example(0),
Example("bar"),
Example(1),
Example(2),
Example("baz")
];
auto strings = arr.filter!(has!string);
auto nums = arr.filter!(has!double);
assert(strings.equal([Example("foo"), Example("bar"), Example("baz")]));
assert(nums.equal([Example(0), Example(1), Example(2)]));
}
// Non-copyable types
@safe unittest
{
static struct NoCopy
{
@disable this(this);
}
SumType!NoCopy x;
assert(x.has!NoCopy);
}
/**
* Accesses a `SumType`'s value.
*
* The value must be of the specified type. Use [has] to check.
*
* Params:
* T = the type of the value being accessed.
*/
template get(T)
{
/**
* The actual `get` function.
*
* Params:
* self = the `SumType` whose value is being accessed.
*
* Returns: the `SumType`'s value.
*/
auto ref T get(Self)(auto ref Self self)
if (isSumType!Self)
{
import std.typecons : No;
static if (__traits(isRef, self))
return self.match!(getLvalue!(No.try_, T));
else
return self.match!(getRvalue!(No.try_, T));
}
}
/// Basic usage
@safe unittest
{
SumType!(string, double) example1 = "hello";
SumType!(string, double) example2 = 3.14;
assert(example1.get!string == "hello");
assert(example2.get!double == 3.14);
}
/// With type qualifiers
@safe unittest
{
alias Example = SumType!(string, double);
Example m = "mutable";
const(Example) c = "const";
immutable(Example) i = "immutable";
assert(m.get!string == "mutable");
assert(c.get!(const(string)) == "const");
assert(i.get!(immutable(string)) == "immutable");
}
/// As a predicate
version (D_BetterC) {} else
@safe unittest
{
import std.algorithm.iteration : map;
import std.algorithm.comparison : equal;
alias Example = SumType!(string, double);
auto arr = [Example(0), Example(1), Example(2)];
auto values = arr.map!(get!double);
assert(values.equal([0, 1, 2]));
}
// Non-copyable types
@safe unittest
{
static struct NoCopy
{
@disable this(this);
}
SumType!NoCopy lvalue;
auto rvalue() => SumType!NoCopy();
assert(lvalue.get!NoCopy == NoCopy());
assert(rvalue.get!NoCopy == NoCopy());
}
// Immovable rvalues
@safe unittest
{
auto rvalue() => const(SumType!string)("hello");
assert(rvalue.get!(const(string)) == "hello");
}
// Nontrivial rvalues at compile time
@safe unittest
{
static struct ElaborateCopy
{
this(this) {}
}
enum rvalue = SumType!ElaborateCopy();
enum ctResult = rvalue.get!ElaborateCopy;
assert(ctResult == ElaborateCopy());
}
/**
* Attempt to access a `SumType`'s value.
*
* If the `SumType` does not contain a value of the specified type, an
* exception is thrown.
*
* Params:
* T = the type of the value being accessed.
*/
version (D_Exceptions)
template tryGet(T)
{
/**
* The actual `tryGet` function.
*
* Params:
* self = the `SumType` whose value is being accessed.
*
* Throws: `MatchException` if the value does not have the expected type.
*
* Returns: the `SumType`'s value.
*/
auto ref T tryGet(Self)(auto ref Self self)
if (isSumType!Self)
{
import std.typecons : Yes;
static if (__traits(isRef, self))
return self.match!(getLvalue!(Yes.try_, T));
else
return self.match!(getRvalue!(Yes.try_, T));
}
}
/// Basic usage
version (D_Exceptions)
@safe unittest
{
SumType!(string, double) example = "hello";
assert(example.tryGet!string == "hello");
double result = double.nan;
try
result = example.tryGet!double;
catch (MatchException e)
result = 0;
// Exception was thrown
assert(result == 0);
}
/// With type qualifiers
version (D_Exceptions)
@safe unittest
{
import std.exception : assertThrown;
const(SumType!(string, double)) example = "const";
// Qualifier mismatch; throws exception
assertThrown!MatchException(example.tryGet!string);
// Qualifier matches; no exception
assert(example.tryGet!(const(string)) == "const");
}
/// As a predicate
version (D_BetterC) {} else
@safe unittest
{
import std.algorithm.iteration : map, sum;
import std.functional : pipe;
import std.exception : assertThrown;
alias Example = SumType!(string, double);
auto arr1 = [Example(0), Example(1), Example(2)];
auto arr2 = [Example("foo"), Example("bar"), Example("baz")];
alias trySum = pipe!(map!(tryGet!double), sum);
assert(trySum(arr1) == 0 + 1 + 2);
assertThrown!MatchException(trySum(arr2));
}
// Throws if requested type is impossible
version (D_Exceptions)
@safe unittest
{
import std.exception : assertThrown;
SumType!int x;
assertThrown!MatchException(x.tryGet!string);
}
// Non-copyable types
version (D_Exceptions)
@safe unittest
{
static struct NoCopy
{
@disable this(this);
}
SumType!NoCopy lvalue;
auto rvalue() => SumType!NoCopy();
assert(lvalue.tryGet!NoCopy == NoCopy());
assert(rvalue.tryGet!NoCopy == NoCopy());
}
// Immovable rvalues
version (D_Exceptions)
@safe unittest
{
auto rvalue() => const(SumType!string)("hello");
assert(rvalue.tryGet!(const(string)) == "hello");
}
// Nontrivial rvalues at compile time
version (D_Exceptions)
@safe unittest
{
static struct ElaborateCopy
{
this(this) {}
}
enum rvalue = SumType!ElaborateCopy();
enum ctResult = rvalue.tryGet!ElaborateCopy;
assert(ctResult == ElaborateCopy());
}
private template failedGetMessage(Expected, Actual)
{
static if (Expected.stringof == Actual.stringof)
{
enum expectedStr = __traits(fullyQualifiedName, Expected);
enum actualStr = __traits(fullyQualifiedName, Actual);
}
else
{
enum expectedStr = Expected.stringof;
enum actualStr = Actual.stringof;
}
enum failedGetMessage =
"Tried to get `" ~ expectedStr ~ "`" ~
" but found `" ~ actualStr ~ "`";
}
private template getLvalue(Flag!"try_" try_, T)
{
ref T getLvalue(Value)(ref Value value)
{
static if (is(Value == T))
{
return value;
}
else
{
static if (try_)
throw new MatchException(failedGetMessage!(T, Value));
else
assert(false, failedGetMessage!(T, Value));
}
}
}
private template getRvalue(Flag!"try_" try_, T)
{
T getRvalue(Value)(ref Value value)
{
static if (is(Value == T))
{
import core.lifetime : move;
// Move if possible; otherwise fall back to copy
static if (is(typeof(move(value))))
{
static if (isCopyable!Value)
// Workaround for https://issues.dlang.org/show_bug.cgi?id=21542
return __ctfe ? value : move(value);
else
return move(value);
}
else
return value;
}
else
{
static if (try_)
throw new MatchException(failedGetMessage!(T, Value));
else
assert(false, failedGetMessage!(T, Value));
}
}
}
private void destroyIfOwner(T)(ref T value)
{
static if (hasElaborateDestructor!T)