From cb31d2501e613e0212d8bc412908c7bb93f50363 Mon Sep 17 00:00:00 2001 From: BBasile Date: Mon, 2 Apr 2018 17:29:36 +0200 Subject: [PATCH] add IZ safeAccess util and refactor several && chains with it (#577) * add IZ safeAccess util and refactor several && chains with it * show how to make inference working --- src/dscanner/analysis/assert_without_msg.d | 8 +- src/dscanner/analysis/mismatched_args.d | 20 +- src/dscanner/analysis/run.d | 2 +- src/dscanner/analysis/static_if_else.d | 13 +- src/dscanner/analysis/unmodified.d | 8 +- src/dscanner/analysis/useless_initializer.d | 36 ++-- src/dscanner/analysis/vcall_in_ctor.d | 13 +- src/dscanner/imports.d | 2 +- src/dscanner/main.d | 2 +- src/dscanner/readers.d | 71 ------- src/dscanner/utils.d | 194 ++++++++++++++++++++ 11 files changed, 240 insertions(+), 129 deletions(-) delete mode 100644 src/dscanner/readers.d create mode 100644 src/dscanner/utils.d diff --git a/src/dscanner/analysis/assert_without_msg.d b/src/dscanner/analysis/assert_without_msg.d index 40c4863..476746d 100644 --- a/src/dscanner/analysis/assert_without_msg.d +++ b/src/dscanner/analysis/assert_without_msg.d @@ -5,6 +5,7 @@ module dscanner.analysis.assert_without_msg; import dscanner.analysis.base : BaseAnalyzer; +import dscanner.utils : safeAccess; import dsymbol.scope_ : Scope; import dparse.lexer; import dparse.ast; @@ -37,11 +38,10 @@ class AssertWithoutMessageCheck : BaseAnalyzer if (!isStdExceptionImported) return; - if (expr.unaryExpression !is null && - expr.unaryExpression.primaryExpression !is null && - expr.unaryExpression.primaryExpression.identifierOrTemplateInstance !is null) + if (const IdentifierOrTemplateInstance iot = safeAccess(expr) + .unaryExpression.primaryExpression.identifierOrTemplateInstance) { - auto ident = expr.unaryExpression.primaryExpression.identifierOrTemplateInstance.identifier; + auto ident = iot.identifier; if (ident.text == "enforce" && expr.arguments !is null && expr.arguments.argumentList !is null && expr.arguments.argumentList.items.length < 2) addErrorMessage(ident.line, ident.column, KEY, MESSAGE); diff --git a/src/dscanner/analysis/mismatched_args.d b/src/dscanner/analysis/mismatched_args.d index e3e1b37..718b34d 100644 --- a/src/dscanner/analysis/mismatched_args.d +++ b/src/dscanner/analysis/mismatched_args.d @@ -1,6 +1,7 @@ module dscanner.analysis.mismatched_args; import dscanner.analysis.base : BaseAnalyzer; +import dscanner.utils : safeAccess; import dsymbol.scope_; import dsymbol.symbol; import dparse.ast; @@ -126,16 +127,15 @@ final class ArgVisitor : ASTVisitor { import dsymbol.string_interning : internString; - if (unary.primaryExpression is null) - return; - if (unary.primaryExpression.identifierOrTemplateInstance is null) - return; - if (unary.primaryExpression.identifierOrTemplateInstance.identifier == tok!"") - return; - immutable t = unary.primaryExpression.identifierOrTemplateInstance.identifier; - lines ~= t.line; - columns ~= t.column; - args ~= internString(t.text); + if (auto iot = unary.safeAccess.primaryExpression.identifierOrTemplateInstance.unwrap) + { + if (iot.identifier == tok!"") + return; + immutable t = iot.identifier; + lines ~= t.line; + columns ~= t.column; + args ~= internString(t.text); + } } alias visit = ASTVisitor.visit; diff --git a/src/dscanner/analysis/run.d b/src/dscanner/analysis/run.d index 0442820..653d775 100644 --- a/src/dscanner/analysis/run.d +++ b/src/dscanner/analysis/run.d @@ -82,7 +82,7 @@ import dsymbol.conversion.first; import dsymbol.conversion.second; import dsymbol.modulecache : ModuleCache; -import dscanner.readers; +import dscanner.utils; bool first = true; diff --git a/src/dscanner/analysis/static_if_else.d b/src/dscanner/analysis/static_if_else.d index 2b29b49..f23199d 100644 --- a/src/dscanner/analysis/static_if_else.d +++ b/src/dscanner/analysis/static_if_else.d @@ -8,6 +8,7 @@ module dscanner.analysis.static_if_else; import dparse.ast; import dparse.lexer; import dscanner.analysis.base; +import dscanner.utils : safeAccess; /** * Checks for potentially mistaken static if / else if. @@ -47,17 +48,7 @@ class StaticIfElse : BaseAnalyzer const(IfStatement) getIfStatement(const ConditionalStatement cc) { - if (cc.falseStatement.statement) - { - if (cc.falseStatement.statement.statementNoCaseNoDefault) - { - if (cc.falseStatement.statement.statementNoCaseNoDefault.ifStatement) - { - return cc.falseStatement.statement.statementNoCaseNoDefault.ifStatement; - } - } - } - return null; + return safeAccess(cc).falseStatement.statement.statementNoCaseNoDefault.ifStatement; } enum KEY = "dscanner.suspicious.static_if_else"; diff --git a/src/dscanner/analysis/unmodified.d b/src/dscanner/analysis/unmodified.d index 01a1c1f..f93f09a 100644 --- a/src/dscanner/analysis/unmodified.d +++ b/src/dscanner/analysis/unmodified.d @@ -5,6 +5,7 @@ module dscanner.analysis.unmodified; import dscanner.analysis.base; +import dscanner.utils : safeAccess; import dsymbol.scope_ : Scope; import std.container; import dparse.ast; @@ -217,12 +218,9 @@ private: bool initializedFromNew(const Initializer initializer) { - if (initializer && initializer.nonVoidInitializer && - initializer.nonVoidInitializer.assignExpression && - cast(UnaryExpression) initializer.nonVoidInitializer.assignExpression) + if (const UnaryExpression ue = cast(UnaryExpression) safeAccess(initializer) + .nonVoidInitializer.assignExpression) { - const UnaryExpression ue = - cast(UnaryExpression) initializer.nonVoidInitializer.assignExpression; return ue.newExpression !is null; } return false; diff --git a/src/dscanner/analysis/useless_initializer.d b/src/dscanner/analysis/useless_initializer.d index 648d995..ccc9077 100644 --- a/src/dscanner/analysis/useless_initializer.d +++ b/src/dscanner/analysis/useless_initializer.d @@ -5,6 +5,7 @@ module dscanner.analysis.useless_initializer; import dscanner.analysis.base; +import dscanner.utils : safeAccess; import containers.dynamicarray; import containers.hashmap; import dparse.ast; @@ -147,7 +148,7 @@ public: !declarator.initializer.nonVoidInitializer || declarator.comment !is null) { - continue; + continue; } version(unittest) @@ -171,15 +172,14 @@ public: bool isStr, isSzInt; Token customType; - if (decl.type.type2.typeIdentifierPart && - decl.type.type2.typeIdentifierPart.typeIdentifierPart is null) + if (const TypeIdentifierPart tip = safeAccess(decl).type.type2.typeIdentifierPart) { - const IdentifierOrTemplateInstance idt = - decl.type.type2.typeIdentifierPart.identifierOrTemplateInstance; - - customType = idt.identifier; - isStr = customType.text.among("string", "wstring", "dstring") != 0; - isSzInt = customType.text.among("size_t", "ptrdiff_t") != 0; + if (!tip.typeIdentifierPart) + { + customType = tip.identifierOrTemplateInstance.identifier; + isStr = customType.text.among("string", "wstring", "dstring") != 0; + isSzInt = customType.text.among("size_t", "ptrdiff_t") != 0; + } } // --- 'BasicType/Symbol AssignExpression' ---// @@ -230,16 +230,18 @@ public: } } - // Symbol s = Symbol.init - else if (ue && customType != tok!"" && ue.unaryExpression && ue.unaryExpression.primaryExpression && - ue.unaryExpression.primaryExpression.identifierOrTemplateInstance && - ue.unaryExpression.primaryExpression.identifierOrTemplateInstance.identifier == customType && - ue.identifierOrTemplateInstance && ue.identifierOrTemplateInstance.identifier.text == "init") + else if (const IdentifierOrTemplateInstance iot = safeAccess(ue) + .unaryExpression.primaryExpression.identifierOrTemplateInstance) { - if (customType.text in _structCanBeInit) + // Symbol s = Symbol.init + if (ue && customType != tok!"" && iot.identifier == customType && + ue.identifierOrTemplateInstance && ue.identifierOrTemplateInstance.identifier.text == "init") { - if (!_structCanBeInit[customType.text]) - mixin(warn); + if (customType.text in _structCanBeInit) + { + if (!_structCanBeInit[customType.text]) + mixin(warn); + } } } diff --git a/src/dscanner/analysis/vcall_in_ctor.d b/src/dscanner/analysis/vcall_in_ctor.d index 753672f..efd1a44 100644 --- a/src/dscanner/analysis/vcall_in_ctor.d +++ b/src/dscanner/analysis/vcall_in_ctor.d @@ -5,6 +5,7 @@ module dscanner.analysis.vcall_in_ctor; import dscanner.analysis.base; +import dscanner.utils; import dparse.ast, dparse.lexer; import std.algorithm: among; import std.algorithm.iteration : filter; @@ -220,16 +221,12 @@ public: override void visit(const(UnaryExpression) exp) { + if (isInCtor) // get function identifier for a call, only for this member (so no ident chain) - if (isInCtor && exp.functionCallExpression && - exp.functionCallExpression.unaryExpression && - exp.functionCallExpression.unaryExpression.primaryExpression && - exp.functionCallExpression.unaryExpression.primaryExpression - .identifierOrTemplateInstance) + if (const IdentifierOrTemplateInstance iot = safeAccess(exp) + .functionCallExpression.unaryExpression.primaryExpression.identifierOrTemplateInstance) { - const Token t = exp.functionCallExpression.unaryExpression - .primaryExpression.identifierOrTemplateInstance.identifier; - + const Token t = iot.identifier; if (t != tok!"") { _ctorCalls[$-1] ~= t; diff --git a/src/dscanner/imports.d b/src/dscanner/imports.d index a46b660..b2b6fcc 100644 --- a/src/dscanner/imports.d +++ b/src/dscanner/imports.d @@ -12,7 +12,7 @@ import dparse.rollback_allocator; import std.stdio; import std.container.rbtree; import std.functional : toDelegate; -import dscanner.readers; +import dscanner.utils; /** * AST visitor that collects modules imported to an R-B tree. diff --git a/src/dscanner/main.d b/src/dscanner/main.d index d635176..8ab02cc 100644 --- a/src/dscanner/main.d +++ b/src/dscanner/main.d @@ -31,7 +31,7 @@ import dscanner.symbol_finder; import dscanner.analysis.run; import dscanner.analysis.config; import dscanner.dscanner_version; -import dscanner.readers; +import dscanner.utils; import inifiled; diff --git a/src/dscanner/readers.d b/src/dscanner/readers.d deleted file mode 100644 index 931ed17..0000000 --- a/src/dscanner/readers.d +++ /dev/null @@ -1,71 +0,0 @@ -module dscanner.readers; - -import std.array : appender, uninitializedArray; -import std.stdio : stdin, stderr, File; -import std.conv : to; -import std.file : exists; - -ubyte[] readStdin() -{ - auto sourceCode = appender!(ubyte[])(); - ubyte[4096] buf; - while (true) - { - auto b = stdin.rawRead(buf); - if (b.length == 0) - break; - sourceCode.put(b); - } - return sourceCode.data; -} - -ubyte[] readFile(string fileName) -{ - if (fileName == "stdin") - return readStdin(); - if (!exists(fileName)) - { - stderr.writefln("%s does not exist", fileName); - return []; - } - File f = File(fileName); - if (f.size == 0) - return []; - ubyte[] sourceCode = uninitializedArray!(ubyte[])(to!size_t(f.size)); - f.rawRead(sourceCode); - return sourceCode; -} - -string[] expandArgs(string[] args) -{ - import std.file : isFile, FileException, dirEntries, SpanMode; - import std.algorithm.iteration : map; - import std.algorithm.searching : endsWith; - - // isFile can throw if it's a broken symlink. - bool isFileSafe(T)(T a) - { - try - return isFile(a); - catch (FileException) - return false; - } - - string[] rVal; - if (args.length == 1) - args ~= "."; - foreach (arg; args[1 .. $]) - { - if (arg == "stdin" || isFileSafe(arg)) - rVal ~= arg; - else - foreach (item; dirEntries(arg, SpanMode.breadth).map!(a => a.name)) - { - if (isFileSafe(item) && (item.endsWith(`.d`) || item.endsWith(`.di`))) - rVal ~= item; - else - continue; - } - } - return rVal; -} diff --git a/src/dscanner/utils.d b/src/dscanner/utils.d new file mode 100644 index 0000000..7cbcbcc --- /dev/null +++ b/src/dscanner/utils.d @@ -0,0 +1,194 @@ +module dscanner.utils; + +import std.array : appender, uninitializedArray; +import std.stdio : stdin, stderr, File; +import std.conv : to; +import std.file : exists; + +ubyte[] readStdin() +{ + auto sourceCode = appender!(ubyte[])(); + ubyte[4096] buf; + while (true) + { + auto b = stdin.rawRead(buf); + if (b.length == 0) + break; + sourceCode.put(b); + } + return sourceCode.data; +} + +ubyte[] readFile(string fileName) +{ + if (fileName == "stdin") + return readStdin(); + if (!exists(fileName)) + { + stderr.writefln("%s does not exist", fileName); + return []; + } + File f = File(fileName); + if (f.size == 0) + return []; + ubyte[] sourceCode = uninitializedArray!(ubyte[])(to!size_t(f.size)); + f.rawRead(sourceCode); + return sourceCode; +} + +string[] expandArgs(string[] args) +{ + import std.file : isFile, FileException, dirEntries, SpanMode; + import std.algorithm.iteration : map; + import std.algorithm.searching : endsWith; + + // isFile can throw if it's a broken symlink. + bool isFileSafe(T)(T a) + { + try + return isFile(a); + catch (FileException) + return false; + } + + string[] rVal; + if (args.length == 1) + args ~= "."; + foreach (arg; args[1 .. $]) + { + if (arg == "stdin" || isFileSafe(arg)) + rVal ~= arg; + else + foreach (item; dirEntries(arg, SpanMode.breadth).map!(a => a.name)) + { + if (isFileSafe(item) && (item.endsWith(`.d`) || item.endsWith(`.di`))) + rVal ~= item; + else + continue; + } + } + return rVal; +} + +/** + * Allows to build access chains of class members as done with the $(D ?.) operator + * in other languages. In the chain, any $(D null) member that is a class instance + * or that returns one, has for effect to shortcut the complete evaluation. + * + * This function is copied from https://github.com/BBasile/iz to avoid a new submodule. + * Any change made to this copy should also be applied to the origin. + * + * Params: + * M = The class type of the chain entry point. + * + * Bugs: + * Assigning a member only works with $(D unwrap). + * + */ +struct SafeAccess(M) +if (is(M == class)) +{ + M m; + + @disable this(); + + /** + * Instantiate. + * + * Params: + * m = An instance of the entry point type. It is usually only + * $(D null) when the constructor is used internally, to build + * the chain. + */ + this(M m) + { + this.m = m; + } + + alias m this; + /// Unprotect the class instance. + alias unwrap = m; + + /// Handles safe access. + auto ref opDispatch(string member, A...)(auto ref A a) + { + import std.traits : ReturnType; + alias T = typeof(__traits(getMember, m, member)); + static if (is(T == class)) + { + return (!m || !__traits(getMember, m, member)) + ? SafeAccess!T(null) + : SafeAccess!T(__traits(getMember, m, member)); + } + else + { + import std.traits : ReturnType, Parameters, isFunction; + static if (isFunction!T) + { + // otherwise there's a missing return statement. + alias R = ReturnType!T; + static if (!is(R == void) && + !(is(R == class) && Parameters!T.length == 0)) + pragma(msg, __FILE__ ~ "(" ~ __LINE__.stringof ~ "): error, " ~ + "only `void function`s or `class` getters can be called without unwrap"); + + static if (is(R == class)) + { + return (m is null) + ? SafeAccess!R(null) + : SafeAccess!R(__traits(getMember, m, member)(a)); + } + else + { + if (m) + __traits(getMember, m, member)(a); + } + } + else + { + if (m) + __traits(getMember, m, member) = a; + } + } + } +} +/// General usage +@safe unittest +{ + class LongLineOfIdent3{int foo; void setFoo(int v) @safe{foo = v;}} + class LongLineOfIdent2{LongLineOfIdent3 longLineOfIdent3;} + class LongLineOfIdent1{LongLineOfIdent2 longLineOfIdent2;} + class Root {LongLineOfIdent1 longLineOfIdent1;} + + SafeAccess!Root sar = SafeAccess!Root(new Root); + // without the SafeAccess we would receive a SIGSEGV here + sar.longLineOfIdent1.longLineOfIdent2.longLineOfIdent3.setFoo(0xDEADBEEF); + + bool notAccessed = true; + // the same with `&&` whould be much longer + if (LongLineOfIdent3 a = sar.longLineOfIdent1.longLineOfIdent2.longLineOfIdent3) + { + notAccessed = false; + } + assert(notAccessed); + + // checks that forwarding actually works + sar.m.longLineOfIdent1 = new LongLineOfIdent1; + sar.m.longLineOfIdent1.longLineOfIdent2 = new LongLineOfIdent2; + sar.m.longLineOfIdent1.longLineOfIdent2.longLineOfIdent3 = new LongLineOfIdent3; + + sar.longLineOfIdent1.longLineOfIdent2.longLineOfIdent3.setFoo(42); + assert(sar.longLineOfIdent1.longLineOfIdent2.longLineOfIdent3.unwrap.foo == 42); +} + +/** + * IFTI helper for $(D SafeAccess). + * + * Returns: + * $(D m) with the ability to safely access its members that are class + * instances. + */ +auto ref safeAccess(M)(M m) +{ + return SafeAccess!M(m); +}