Generating a binary search with CTFE results in some speed gains compared to using a switch statement on both GDC and DMD

This commit is contained in:
Hackerpilot 2014-06-09 16:22:51 -07:00
parent b29a7d8994
commit 1fd6a7d932
1 changed files with 111 additions and 52 deletions

View File

@ -423,7 +423,36 @@ mixin template Lexer(Token, alias defaultTokenFunction,
return format("0x%016x", ulong.max >> ((8 - l) * 8));
}
private static string generateCaseStatements()
private static size_t calcSplitCount(size_t a, size_t b) pure nothrow
{
int i;
while (true)
{
i++;
a /= 2;
if (a < b)
break;
}
return i;
}
private static char[] getBeginningChars(string[] allTokens)
{
char[] beginningChars;
for (size_t i = 0; i < allTokens.length; i++)
{
if (allTokens[i].length == 0)
continue;
beginningChars ~= allTokens[i][0];
size_t j = i + 1;
while (j < allTokens.length && allTokens[i][0] == allTokens[j][0])
j++;
i = j - 1;
}
return beginningChars;
}
private static string generateStatements()
{
import std.algorithm;
import std.conv;
@ -432,44 +461,76 @@ mixin template Lexer(Token, alias defaultTokenFunction,
string[] pseudoTokens = stupidToArray(tokenHandlers.stride(2));
string[] allTokens = stupidToArray(sort(staticTokens ~ possibleDefaultTokens ~ pseudoTokens).uniq);
// Array consisting of a sorted list of the first characters of the
// tokens.
char[] beginningChars = getBeginningChars(allTokens);
size_t i = calcSplitCount(beginningChars.length, 8);
return generateStatementsStep(allTokens, pseudoTokens, beginningChars, i);
}
private static string generateStatementsStep(string[] allTokens,
string[] pseudoTokens, char[] chars, size_t i, string indent = "")
{
import std.string;
string code;
for (size_t i = 0; i < allTokens.length; i++)
if (i > 0)
{
if (allTokens[i].length == 0)
continue;
size_t j = i + 1;
while (j < allTokens.length && allTokens[i][0] == allTokens[j][0])
j++;
code ~= format("case 0x%02x:\n", cast(ubyte) allTokens[i][0]);
code ~= printCase(allTokens[i .. j], pseudoTokens);
i = j - 1;
size_t p = chars.length / 2;
code ~= indent ~ format("if (f < 0x%02x) // %s \n%s{\n", chars[p], chars[p], indent);
code ~= generateStatementsStep(allTokens, pseudoTokens, chars[0 .. p], i - 1, indent ~ " ");
code ~= indent ~ "}\n" ~ indent ~ "else\n" ~ indent ~ "{\n";
code ~= generateStatementsStep(allTokens, pseudoTokens, chars[p .. $], i - 1, indent ~ " ");
code ~= indent ~ "}\n";
}
else
{
code ~= indent ~ "switch (f)\n" ~ indent ~ "{\n";
foreach (char c; chars)
{
size_t begin;
size_t end;
for (size_t j = 0; j < allTokens.length; j++)
{
if (allTokens[j].length == 0 || allTokens[j][0] != c)
continue;
begin = j;
end = j + 1;
while (end < allTokens.length && allTokens[begin][0] == allTokens[end][0])
end++;
break;
}
code ~= format("%scase 0x%02x:\n", indent, c);
code ~= printCase(allTokens[begin .. end], pseudoTokens, indent ~ " ");
}
code ~= indent ~ "default: goto _defaultTokenFunction;\n";
code ~= indent ~ "}\n";
}
return code;
}
private static string printCase(string[] tokens, string[] pseudoTokens)
private static string printCase(string[] tokens, string[] pseudoTokens, string indent)
{
import std.algorithm;
string[] t = tokens;
string[] sortedTokens = stupidToArray(sort!"a.length > b.length"(t));
string[] sortedTokens = stupidToArray(sort!"a.length > b.length"(tokens));
import std.conv;
if (tokens.length == 1 && tokens[0].length == 1)
{
if (pseudoTokens.countUntil(tokens[0]) >= 0)
{
return " return "
return indent ~ "return "
~ tokenHandlers[tokenHandlers.countUntil(tokens[0]) + 1]
~ "();\n";
}
else if (staticTokens.countUntil(tokens[0]) >= 0)
{
return " range.popFront();\n"
~ " return Token(_tok!\"" ~ escape(tokens[0]) ~ "\", null, line, column, index);\n";
return indent ~ "range.popFront();\n"
~ indent ~ "return Token(_tok!\"" ~ escape(tokens[0]) ~ "\", null, line, column, index);\n";
}
else if (pseudoTokens.countUntil(tokens[0]) >= 0)
{
return " return "
return indent ~ "return "
~ tokenHandlers[tokenHandlers.countUntil(tokens[0]) + 1]
~ "();\n";
}
@ -481,22 +542,22 @@ mixin template Lexer(Token, alias defaultTokenFunction,
{
immutable mask = generateMask(cast (const ubyte[]) token);
if (token.length >= 8)
code ~= " if (frontBytes == " ~ mask ~ ")\n";
code ~= indent ~ "if (frontBytes == " ~ mask ~ ")\n";
else
code ~= " if ((frontBytes & " ~ generateByteMask(token.length) ~ ") == " ~ mask ~ ")\n";
code ~= " {\n";
code ~= indent ~ "if ((frontBytes & " ~ generateByteMask(token.length) ~ ") == " ~ mask ~ ")\n";
code ~= indent ~ "{\n";
if (pseudoTokens.countUntil(token) >= 0)
{
if (token.length <= 8)
{
code ~= " return "
code ~= indent ~ " return "
~ tokenHandlers[tokenHandlers.countUntil(token) + 1]
~ "();\n";
}
else
{
code ~= " if (range.peek(" ~ text(token.length - 1) ~ ") == \"" ~ escape(token) ~"\")\n";
code ~= " return "
code ~= indent ~ " if (range.peek(" ~ text(token.length - 1) ~ ") == \"" ~ escape(token) ~"\")\n";
code ~= indent ~ " return "
~ tokenHandlers[tokenHandlers.countUntil(token) + 1]
~ "();\n";
}
@ -505,12 +566,12 @@ mixin template Lexer(Token, alias defaultTokenFunction,
{
if (token.length <= 8)
{
code ~= " range.popFrontN(" ~ text(token.length) ~ ");\n";
code ~= " return Token(_tok!\"" ~ escape(token) ~ "\", null, line, column, index);\n";
code ~= indent ~ " range.popFrontN(" ~ text(token.length) ~ ");\n";
code ~= indent ~ " return Token(_tok!\"" ~ escape(token) ~ "\", null, line, column, index);\n";
}
else
{
code ~= " pragma(msg, \"long static tokens not supported\"); // " ~ escape(token) ~ "\n";
code ~= indent ~ " pragma(msg, \"long static tokens not supported\"); // " ~ escape(token) ~ "\n";
}
}
else
@ -518,29 +579,29 @@ mixin template Lexer(Token, alias defaultTokenFunction,
// possible default
if (token.length <= 8)
{
code ~= " if (tokenSeparatingFunction(" ~ text(token.length) ~ "))\n";
code ~= " {\n";
code ~= " range.popFrontN(" ~ text(token.length) ~ ");\n";
code ~= " return Token(_tok!\"" ~ escape(token) ~ "\", null, line, column, index);\n";
code ~= " }\n";
code ~= " else\n";
code ~= " goto default;\n";
code ~= indent ~ " if (tokenSeparatingFunction(" ~ text(token.length) ~ "))\n";
code ~= indent ~ " {\n";
code ~= indent ~ " range.popFrontN(" ~ text(token.length) ~ ");\n";
code ~= indent ~ " return Token(_tok!\"" ~ escape(token) ~ "\", null, line, column, index);\n";
code ~= indent ~ " }\n";
code ~= indent ~ " else\n";
code ~= indent ~ " goto _defaultTokenFunction;\n";
}
else
{
code ~= " if (range.peek(" ~ text(token.length - 1) ~ ") == \"" ~ escape(token) ~"\" && isSeparating(" ~ text(token.length) ~ "))\n";
code ~= " {\n";
code ~= " range.popFrontN(" ~ text(token.length) ~ ");\n";
code ~= " return Token(_tok!\"" ~ escape(token) ~ "\", null, line, column, index);\n";
code ~= " }\n";
code ~= " else\n";
code ~= " goto default;\n";
code ~= indent ~ " if (range.peek(" ~ text(token.length - 1) ~ ") == \"" ~ escape(token) ~"\" && isSeparating(" ~ text(token.length) ~ "))\n";
code ~= indent ~ " {\n";
code ~= indent ~ " range.popFrontN(" ~ text(token.length) ~ ");\n";
code ~= indent ~ " return Token(_tok!\"" ~ escape(token) ~ "\", null, line, column, index);\n";
code ~= indent ~ " }\n";
code ~= indent ~ " else\n";
code ~= indent ~ " goto _defaultTokenFunction;\n";
}
}
code ~= " }\n";
code ~= indent ~ "}\n";
}
code ~= " else\n";
code ~= " goto default;\n";
code ~= indent ~ "else\n";
code ~= indent ~ " goto _defaultTokenFunction;\n";
return code;
}
@ -597,7 +658,7 @@ mixin template Lexer(Token, alias defaultTokenFunction,
return retVal;
}
enum tokenSearch = generateCaseStatements();
enum tokenSearch = generateStatements();
static ulong getFront(const ubyte[] arr) pure nothrow @trusted
{
@ -615,13 +676,11 @@ mixin template Lexer(Token, alias defaultTokenFunction,
immutable size_t column = range.column;
immutable size_t line = range.line;
immutable ulong frontBytes = getFront(range.peek(7));
switch (frontBytes & 0x00000000_000000ff)
{
mixin(tokenSearch);
ubyte f = frontBytes & 0xff;
// pragma(msg, tokenSearch);
default:
return defaultTokenFunction();
}
mixin(tokenSearch);
_defaultTokenFunction:
return defaultTokenFunction();
}
/**
@ -746,10 +805,10 @@ struct LexerRange
/**
* Increments the range's line number and resets the column counter.
*/
void incrementLine() pure nothrow @safe
void incrementLine(size_t i = 1) pure nothrow @safe
{
column = 1;
line++;
line += i;
}
/**