需要检查代码是否包含某些标识符

San*_*ndy 8 c# runtime compilation roslyn .net-core

我将使用 Roslyn 动态编译和执行代码,如下例所示。我想确保代码不违反我的一些规则,例如:

  • 不使用反射
  • 不使用 HttpClient 或 WebClient
  • 不使用 System.IO 命名空间中的文件或目录类
  • 不使用源生成器
  • 不调用非托管代码

我将在以下代码中的何处插入我的规则/检查以及如何执行它们?

using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Emit;
using System.Reflection;
using System.Runtime.CompilerServices;

string code = @"using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.IO;

namespace Customization
{
    public class Script
    {
        public async Task<object?> RunAsync(object? data)
        {
            //The following should not be allowed
            File.Delete(@""C:\Temp\log.txt"");

            return await Task.FromResult(data);
        }
    }
}";

var compilation = Compile(code);
var bytes = Build(compilation);

Console.WriteLine("Done");

CSharpCompilation Compile(string code)
{
    SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(code);

    string? dotNetCoreDirectoryPath = Path.GetDirectoryName(typeof(object).GetTypeInfo().Assembly.Location);
    if (String.IsNullOrWhiteSpace(dotNetCoreDirectoryPath))
    {
        throw new ArgumentNullException("Cannot determine path to current assembly.");
    }

    string assemblyName = Path.GetRandomFileName();
    List<MetadataReference> references = new();
    references.Add(MetadataReference.CreateFromFile(typeof(object).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Console).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Dictionary<,>).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Task).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(Path.Combine(dotNetCoreDirectoryPath, "System.Runtime.dll")));

    CSharpCompilation compilation = CSharpCompilation.Create(
        assemblyName,
        syntaxTrees: new[] { syntaxTree },
        references: references,
        options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));


    SemanticModel model = compilation.GetSemanticModel(syntaxTree);
    CompilationUnitSyntax root = (CompilationUnitSyntax)syntaxTree.GetRoot();

    //TODO: Check the code for use classes that are not allowed such as File in the System.IO namespace.
    //Not exactly sure how to walk through identifiers.
    IEnumerable<IdentifierNameSyntax> identifiers = root.DescendantNodes()
        .Where(s => s is IdentifierNameSyntax)
        .Cast<IdentifierNameSyntax>();


    return compilation;
}

[MethodImpl(MethodImplOptions.NoInlining)]
byte[] Build(CSharpCompilation compilation)
{
    using (MemoryStream ms = new())
    {
        //Emit to catch build errors
        EmitResult emitResult = compilation.Emit(ms);

        if (!emitResult.Success)
        {
            Diagnostic? firstError =
                emitResult
                    .Diagnostics
                    .FirstOrDefault
                    (
                        diagnostic => diagnostic.IsWarningAsError ||
                            diagnostic.Severity == DiagnosticSeverity.Error
                    );

            throw new Exception(firstError?.GetMessage());
        }

        return ms.ToArray();
    }
}
Run Code Online (Sandbox Code Playgroud)

Aqu*_*nds 4

在检查特定类的使用时,您可以IdentifierNameSyntax使用该OfType<>()方法查找类型节点并按类名称过滤结果:

var names = root.DescendantNodes()
    .OfType<IdentifierNameSyntax>()
    .Where(i => string.Equals(i.Identifier.ValueText, className, StringComparison.OrdinalIgnoreCase));

Run Code Online (Sandbox Code Playgroud)

然后您可以使用来SemanticModel检查类的名称空间:

foreach (var name in names)
{
    var typeInfo = model.GetTypeInfo(name);
    if (string.Equals(typeInfo.Type?.ContainingNamespace?.ToString(), containingNamespace, StringComparison.OrdinalIgnoreCase))
    {
        throw new Exception($"Class {containingNamespace}.{className} is not allowed.");
    }
}
Run Code Online (Sandbox Code Playgroud)

要检查反射或非托管代码的使用,您可以检查相关的 usingsSystem.ReflectionSystem.Runtime.InteropServices

if (root.Usings.Any(u => string.Equals(u.Name.ToString(), disallowedNamespace, StringComparison.OrdinalIgnoreCase)))
{
    throw new Exception($"Namespace {disallowedNamespace} is not allowed.");
}
Run Code Online (Sandbox Code Playgroud)

这将捕获未使用使用的情况,即没有实际反射或非托管代码,但这似乎是一个可以接受的权衡。

我不确定如何处理源生成器检查,因为这些检查通常作为项目引用包含在内,因此我不知道它们如何针对动态编译的代码运行。

将检查保留在同一位置并更新代码可以得到:

using System.Reflection;
using System.Runtime.CompilerServices;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Emit;

string code = @"using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.IO;
using System;
using System.Net.Http;
using System.Reflection;
using System.Runtime.InteropServices

namespace Customization
{
    public class Script
    {
        static readonly HttpClient client = new HttpClient();

        public async Task<object?> RunAsync(object? data)
        {
            //The following should not be allowed
            File.Delete(@""C:\Temp\log.txt"");

            return await Task.FromResult(data);
        }
    }
}";

var compilation = Compile(code);

var bytes = Build(compilation);
Console.WriteLine("Done");


CSharpCompilation Compile(string code)
{
    SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(code);

    string? dotNetCoreDirectoryPath = Path.GetDirectoryName(typeof(object).GetTypeInfo().Assembly.Location);
    if (String.IsNullOrWhiteSpace(dotNetCoreDirectoryPath))
    {
        throw new InvalidOperationException("Cannot determine path to current assembly.");
    }

    string assemblyName = Path.GetRandomFileName();
    List<MetadataReference> references = new();
    references.Add(MetadataReference.CreateFromFile(typeof(object).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Console).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Dictionary<,>).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Task).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(HttpClient).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(Path.Combine(dotNetCoreDirectoryPath, "System.Runtime.dll")));

    CSharpCompilation compilation = CSharpCompilation.Create(
        assemblyName,
        syntaxTrees: new[] { syntaxTree },
        references: references,
        options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));


    SemanticModel model = compilation.GetSemanticModel(syntaxTree);
    CompilationUnitSyntax root = (CompilationUnitSyntax)syntaxTree.GetRoot();

    ThrowOnDisallowedClass("File", "System.IO", root, model);
    ThrowOnDisallowedClass("HttpClient", "System.Net.Http", root, model);
    ThrowOnDisallowedNamespace("System.Reflection", root);
    ThrowOnDisallowedNamespace("System.Runtime.InteropServices", root);

    return compilation;
}

[MethodImpl(MethodImplOptions.NoInlining)]
byte[] Build(CSharpCompilation compilation)
{
    using (MemoryStream ms = new())
    {
        //Emit to catch build errors
        EmitResult emitResult = compilation.Emit(ms);

        if (!emitResult.Success)
        {
            Diagnostic? firstError =
                emitResult
                    .Diagnostics
                    .FirstOrDefault
                    (
                        diagnostic => diagnostic.IsWarningAsError ||
                            diagnostic.Severity == DiagnosticSeverity.Error
                    );

            throw new Exception(firstError?.GetMessage());
        }

        return ms.ToArray();
    }
}

void ThrowOnDisallowedClass(string className, string containingNamespace, CompilationUnitSyntax root, SemanticModel model)
{
    var names = root.DescendantNodes()
                    .OfType<IdentifierNameSyntax>()
                    .Where(i => string.Equals(i.Identifier.ValueText, className, StringComparison.OrdinalIgnoreCase));

    foreach (var name in names)
    {
        var typeInfo = model.GetTypeInfo(name);
        if (string.Equals(typeInfo.Type?.ContainingNamespace?.ToString(), containingNamespace, StringComparison.OrdinalIgnoreCase))
        {
            throw new Exception($"Class {containingNamespace}.{className} is not allowed.");
        }
    }
}

void ThrowOnDisallowedNamespace(string disallowedNamespace, CompilationUnitSyntax root)
{
    if (root.Usings.Any(u => string.Equals(u.Name.ToString(), disallowedNamespace, StringComparison.OrdinalIgnoreCase)))
    {
        throw new Exception($"Namespace {disallowedNamespace} is not allowed.");
    }
}
Run Code Online (Sandbox Code Playgroud)

throw在这里用于违反规则,这意味着不会一次报告多个违规行为,因此您可能需要对其进行调整,以便提高效率。