如何在FastAPI响应模式中使用箭头类型?

qew*_*jhb 7 arrow-python openapi pydantic fastapi

我想使用Arrow类型作为FastAPI响应,因为我已经在SQLAlchemy模型中使用它(感谢sqlalchemy_utils)。

我准备了一个小型的独立示例,其中包含一个最小的 FastAPI 应用程序。我希望这个应用程序product1从数据库返回数据。

不幸的是,下面的代码给出了异常:

Exception has occurred: FastAPIError
Invalid args for response field! Hint: check that <class 'arrow.arrow.Arrow'> is a valid pydantic field type
Run Code Online (Sandbox Code Playgroud)
Exception has occurred: FastAPIError
Invalid args for response field! Hint: check that <class 'arrow.arrow.Arrow'> is a valid pydantic field type
Run Code Online (Sandbox Code Playgroud)

要求.txt:

sqlalchemy==1.4.23
sqlalchemy_utils==0.37.8
arrow==1.1.1
fastapi==0.68.1
uvicorn==0.15.0
Run Code Online (Sandbox Code Playgroud)

这个错误已经在那些 FastAPI 问题中讨论过:

  1. https://github.com/tiangolo/fastapi/issues/1186
  2. https://github.com/tiangolo/fastapi/issues/2382

一种可能的解决方法是添加此代码(源代码):

import sqlalchemy
import uvicorn
from arrow import Arrow
from fastapi import FastAPI
from pydantic import BaseModel
from sqlalchemy import Column, Integer, Text, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import ArrowType

app = FastAPI()

engine = sqlalchemy.create_engine('sqlite:///db.db')
Base = declarative_base()

class Product(Base):
    __tablename__ = "product"
    id = Column(Integer, primary_key=True, autoincrement=True)
    name = Column(Text, nullable=True)
    created_at = Column(ArrowType(timezone=True), nullable=False, server_default=func.now())

Base.metadata.create_all(engine)


Session = sessionmaker(bind=engine)
session = Session()

product1 = Product(name="ice cream")
product2 = Product(name="donut")
product3 = Product(name="apple pie")

session.add_all([product1, product2, product3])
session.commit()


class ProductResponse(BaseModel):
    id: int
    name: str
    created_at: Arrow

    class Config:
        orm_mode = True
        arbitrary_types_allowed = True


@app.get('/', response_model=ProductResponse)
async def return_product():

    product = session.query(Product).filter(Product.id == 1).first()

    return product

if __name__ == "__main__":
    uvicorn.run(app, host="localhost", port=8000)
Run Code Online (Sandbox Code Playgroud)

放在上面就足够了@app.get('/'...,甚至可以放在前面app = FastAPI()

此解决方案的问题是 GET 端点的输出将是:

sqlalchemy==1.4.23
sqlalchemy_utils==0.37.8
arrow==1.1.1
fastapi==0.68.1
uvicorn==0.15.0
Run Code Online (Sandbox Code Playgroud)

而不是期望的:

from pydantic import BaseConfig
BaseConfig.arbitrary_types_allowed = True
Run Code Online (Sandbox Code Playgroud)

qew*_*jhb 2

解决方案是对 pydantic 进行 Monkeypatch ENCODERS_BY_TYPE,以便它知道如何转换 Arrow 对象,以便它可以被 json 格式接受:

from arrow import Arrow
from pydantic.json import ENCODERS_BY_TYPE
ENCODERS_BY_TYPE |= {Arrow: str}
Run Code Online (Sandbox Code Playgroud)

设置BaseConfig.arbitrary_types_allowed = True也是必要的。

结果:

// 20220514022717
// http://localhost:8000/

{
  "id": 1,
  "name": "ice cream",
  "created_at": "2022-05-14T00:20:11+00:00"
}
Run Code Online (Sandbox Code Playgroud)

完整代码:

import sqlalchemy
import uvicorn
from arrow import Arrow
from fastapi import FastAPI
from pydantic import BaseModel
from sqlalchemy import Column, Integer, Text, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import ArrowType

from pydantic.json import ENCODERS_BY_TYPE
ENCODERS_BY_TYPE |= {Arrow: str}

from pydantic import BaseConfig
BaseConfig.arbitrary_types_allowed = True

app = FastAPI()

engine = sqlalchemy.create_engine('sqlite:///db.db')
Base = declarative_base()

class Product(Base):
    __tablename__ = "product"
    id = Column(Integer, primary_key=True, autoincrement=True)
    name = Column(Text, nullable=True)
    created_at = Column(ArrowType(timezone=True), nullable=False, server_default=func.now())

Base.metadata.create_all(engine)


Session = sessionmaker(bind=engine)
session = Session()

product1 = Product(name="ice cream")
product2 = Product(name="donut")
product3 = Product(name="apple pie")

session.add_all([product1, product2, product3])
session.commit()


class ProductResponse(BaseModel):
    id: int
    name: str
    created_at: Arrow

    class Config:
        orm_mode = True
        arbitrary_types_allowed = True


@app.get('/', response_model=ProductResponse)
async def return_product():

    product = session.query(Product).filter(Product.id == 1).first()

    return product

if __name__ == "__main__":
    uvicorn.run(app, host="localhost", port=8000)
Run Code Online (Sandbox Code Playgroud)