将 SQL 查询限制为 Graphene-SQLAlchemy 中定义的字段/列

som*_*141 6 python sql sqlalchemy graphql graphene-python

这个问题已作为 GH 问题发布在https://github.com/graphql-python/graphene-sqlalchemy/issues/134下,但我想我也将其发布在这里以吸引大众。

完整的工作演示可以在https://github.com/somada141/demo-graphql-sqlalchemy-falcon下找到。

考虑以下 SQLAlchemy ORM 类:

class Author(Base, OrmBaseMixin):
    __tablename__ = "authors"

    author_id = sqlalchemy.Column(
        sqlalchemy.types.Integer(),
        primary_key=True,
    )

    name_first = sqlalchemy.Column(
        sqlalchemy.types.Unicode(length=80),
        nullable=False,
    )

    name_last = sqlalchemy.Column(
        sqlalchemy.types.Unicode(length=80),
        nullable=False,
    )
Run Code Online (Sandbox Code Playgroud)

简单地包装在SQLAlchemyObjectType这样的中:

class TypeAuthor(SQLAlchemyObjectType):
    class Meta:
        model = Author
Run Code Online (Sandbox Code Playgroud)

并通过以下方式暴露:

author = graphene.Field(
    TypeAuthor,
    author_id=graphene.Argument(type=graphene.Int, required=False),
    name_first=graphene.Argument(type=graphene.String, required=False),
    name_last=graphene.Argument(type=graphene.String, required=False),
)

@staticmethod
def resolve_author(
    args,
    info,
    author_id: Union[int, None] = None,
    name_first: Union[str, None] = None,
    name_last: Union[str, None] = None,
):
    query = TypeAuthor.get_query(info=info)

    if author_id:
        query = query.filter(Author.author_id == author_id)

    if name_first:
        query = query.filter(Author.name_first == name_first)

    if name_last:
        query = query.filter(Author.name_last == name_last)

    author = query.first()

    return author
Run Code Online (Sandbox Code Playgroud)

GraphQL 查询例如:

query GetAuthor{
  author(authorId: 1) {
    nameFirst
  }
}
Run Code Online (Sandbox Code Playgroud)

将导致发出以下原始 SQL(取自 SQLA 引擎的回显日志):

SELECT authors.author_id AS authors_author_id, authors.name_first AS authors_name_first, authors.name_last AS authors_name_last
FROM authors
WHERE authors.author_id = ?
 LIMIT ? OFFSET ?
2018-05-24 16:23:03,669 INFO sqlalchemy.engine.base.Engine (1, 1, 0)
Run Code Online (Sandbox Code Playgroud)

正如我们所看到的,我们可能只需要nameFirst字段,即name_first列,但会获取整行。当然,GraphQL 响应仅包含请求的字段,即

{
  "data": {
    "author": {
      "nameFirst": "Robert"
    }
  }
}
Run Code Online (Sandbox Code Playgroud)

但我们仍然获取了整行,这在处理宽表时成为一个主要问题。

有没有一种方法可以自动向 SQLAlchemy 传达需要哪些列,以防止这种形式的过度获取?

som*_*141 5

我的问题在 GitHub 问题上得到了解答(https://github.com/graphql-python/graphene-sqlalchemy/issues/134)。

info这个想法是从参数(类型为)中识别请求的字段,该参数通过如下函数graphql.execution.base.ResolveInfo传递给解析器函数:get_field_names

def get_field_names(info):
    """
    Parses a query info into a list of composite field names.
    For example the following query:
        {
          carts {
            edges {
              node {
                id
                name
                ...cartInfo
              }
            }
          }
        }
        fragment cartInfo on CartType { whatever }

    Will result in an array:
        [
            'carts',
            'carts.edges',
            'carts.edges.node',
            'carts.edges.node.id',
            'carts.edges.node.name',
            'carts.edges.node.whatever'
        ]
    """

    fragments = info.fragments

    def iterate_field_names(prefix, field):
        name = field.name.value

        if isinstance(field, FragmentSpread):
            _results = []
            new_prefix = prefix
            sub_selection = fragments[field.name.value].selection_set.selections
        else:
            _results = [prefix + name]
            new_prefix = prefix + name + "."
            if field.selection_set:
                sub_selection = field.selection_set.selections
            else:
                sub_selection = []

        for sub_field in sub_selection:
            _results += iterate_field_names(new_prefix, sub_field)

        return _results

    results = iterate_field_names('', info.field_asts[0])

    return results
Run Code Online (Sandbox Code Playgroud)

上述函数取自https://github.com/graphql-python/graphene/issues/348#issuecomment-267717809。该问题包含此功能的其他版本,但我认为这是最完整的。

并使用识别的字段来限制 SQLAlchemy 查询中检索到的字段,如下所示:

fields = get_field_names(info=info)
query = TypeAuthor.get_query(info=info).options(load_only(*relation_fields))
Run Code Online (Sandbox Code Playgroud)

当应用于上面的示例查询时:

query GetAuthor{
  author(authorId: 1) {
    nameFirst
  }
}
Run Code Online (Sandbox Code Playgroud)

get_field_names函数将返回['author', 'author.nameFirst']. 然而,由于“原始”SQLAlchemy ORM 字段是蛇形的,因此get_field_names需要更新查询以删除author前缀并通过函数转换字段名graphene.utils.str_converters.to_snake_case

长话短说,上述方法生成如下原始 SQL 查询:

INFO:sqlalchemy.engine.base.Engine:SELECT authors.author_id AS authors_author_id, authors.name_first AS authors_name_first
FROM authors
WHERE authors.author_id = ?
 LIMIT ? OFFSET ?
2018-06-09 13:22:16,396 INFO sqlalchemy.engine.base.Engine (1, 1, 0)
Run Code Online (Sandbox Code Playgroud)

更新

如果有人来到这里想了解实现,我会设法实现我自己的函数版本,get_query_fields如下所示:

from typing import List, Dict, Union, Type

import graphql
from graphql.language.ast import FragmentSpread
from graphql.language.ast import Field
from graphene.utils.str_converters import to_snake_case
import sqlalchemy.orm

from demo.orm_base import OrmBaseMixin

def extract_requested_fields(
    info: graphql.execution.base.ResolveInfo,
    fields: List[Union[Field, FragmentSpread]],
    do_convert_to_snake_case: bool = True,
) -> Dict:
    """Extracts the fields requested in a GraphQL query by processing the AST
    and returns a nested dictionary representing the requested fields.

    Note:
        This function should support arbitrarily nested field structures
        including fragments.

    Example:
        Consider the following query passed to a resolver and running this
        function with the `ResolveInfo` object passed to the resolver.

        >>> query = "query getAuthor{author(authorId: 1){nameFirst, nameLast}}"
        >>> extract_requested_fields(info, info.field_asts, True)
        {'author': {'name_first': None, 'name_last': None}}

    Args:
        info (graphql.execution.base.ResolveInfo): The GraphQL query info passed
            to the resolver function.
        fields (List[Union[Field, FragmentSpread]]): The list of `Field` or
            `FragmentSpread` objects parsed out of the GraphQL query and stored
            in the AST.
        do_convert_to_snake_case (bool): Whether to convert the fields as they
            appear in the GraphQL query (typically in camel-case) back to
            snake-case (which is how they typically appear in ORM classes).

    Returns:
        Dict: The nested dictionary containing all the requested fields.
    """

    result = {}
    for field in fields:

        # Set the `key` as the field name.
        key = field.name.value

        # Convert the key from camel-case to snake-case (if required).
        if do_convert_to_snake_case:
            key = to_snake_case(name=key)

        # Initialize `val` to `None`. Fields without nested-fields under them
        # will have a dictionary value of `None`.
        val = None

        # If the field is of type `Field` then extract the nested fields under
        # the `selection_set` (if defined). These nested fields will be
        # extracted recursively and placed in a dictionary under the field
        # name in the `result` dictionary.
        if isinstance(field, Field):
            if (
                hasattr(field, "selection_set") and
                field.selection_set is not None
            ):
                # Extract field names out of the field selections.
                val = extract_requested_fields(
                    info=info,
                    fields=field.selection_set.selections,
                )
            result[key] = val
        # If the field is of type `FragmentSpread` then retrieve the fragment
        # from `info.fragments` and recursively extract the nested fields but
        # as we don't want the name of the fragment appearing in the result
        # dictionary (since it does not match anything in the ORM classes) the
        # result will simply be result of the extraction.
        elif isinstance(field, FragmentSpread):
            # Retrieve referened fragment.
            fragment = info.fragments[field.name.value]
            # Extract field names out of the fragment selections.
            val = extract_requested_fields(
                info=info,
                fields=fragment.selection_set.selections,
            )
            result = val

    return result
Run Code Online (Sandbox Code Playgroud)

它将 AST 解析为dict保留查询的结构并(希望)匹配 ORM 的结构。

运行info查询的对象,例如:

query getAuthor{
  author(authorId: 1) {
    nameFirst,
    nameLast
  }
}
Run Code Online (Sandbox Code Playgroud)

产生

{'author': {'name_first': None, 'name_last': None}}
Run Code Online (Sandbox Code Playgroud)

而更复杂的查询如下:

query getAuthor{
  author(nameFirst: "Brandon") {
    ...authorFields
    books {
      ...bookFields
    }
  }
}

fragment authorFields on TypeAuthor {
  nameFirst,
  nameLast
}

fragment bookFields on TypeBook {
  title,
  year
}
Run Code Online (Sandbox Code Playgroud)

产生:

{'author': {'books': {'title': None, 'year': None},
  'name_first': None,
  'name_last': None}}
Run Code Online (Sandbox Code Playgroud)

现在,这些字典可用于定义主表上的字段(在本例中),因为它们将具有诸如或该主表的关系上的字段(例如关系上的字段Author)的值。Nonename_firsttitlebooks

自动应用这些字段的简单方法可以采用以下函数的形式:

def get_field_names(info):
    """
    Parses a query info into a list of composite field names.
    For example the following query:
        {
          carts {
            edges {
              node {
                id
                name
                ...cartInfo
              }
            }
          }
        }
        fragment cartInfo on CartType { whatever }

    Will result in an array:
        [
            'carts',
            'carts.edges',
            'carts.edges.node',
            'carts.edges.node.id',
            'carts.edges.node.name',
            'carts.edges.node.whatever'
        ]
    """

    fragments = info.fragments

    def iterate_field_names(prefix, field):
        name = field.name.value

        if isinstance(field, FragmentSpread):
            _results = []
            new_prefix = prefix
            sub_selection = fragments[field.name.value].selection_set.selections
        else:
            _results = [prefix + name]
            new_prefix = prefix + name + "."
            if field.selection_set:
                sub_selection = field.selection_set.selections
            else:
                sub_selection = []

        for sub_field in sub_selection:
            _results += iterate_field_names(new_prefix, sub_field)

        return _results

    results = iterate_field_names('', info.field_asts[0])

    return results
Run Code Online (Sandbox Code Playgroud)