使用FastAPI + Sqlalchemy正确刷新数据库连接池

Jen*_*obi 5 python database sqlalchemy python-3.x fastapi

目前正在使用 sqlalchemy 和 fastapi 作为 AWS 中托管的生产微服务。问题是我们的生产数据库机密每 30 天刷新一次。尝试从机密管理器自动获取新机密,以便在 sqlalchemy 出现错误或操作错误时重新初始化数据库引擎和会话。我的问题是这种“重新初始化”应该发生在哪里?

utils/secret_mgr.py

import json
import logging

import boto3
from botocore.exceptions import ClientError


def get_secret(secret_id):
    session = boto3.client('secretsmanager', region_name='us-east-1')
    try:
        response = session.get_secret_value(SecretId=secret_id)
    except ClientError as e:
        code = e.response['Error']['Code']
        logging.exception(f'error:get_secret error_code:{code}')
        raise e
    else:
        secret_str = response['SecretString']
        secret = json.loads(secret_str)
    return secret

Run Code Online (Sandbox Code Playgroud)

实用程序/db.py

import logging
import os

from sqlalchemy.pool import QueuePool
from sqlalchemy.sql import text
from sqlmodel import SQLModel, Session, create_engine
from sqlalchemy.exc import OperationalError
from api.utils.scemgr import get_secret

engine = None
SECRET_NAME = os.environ.get('DB_SECRET_NAME')
SQLALCHEMY_DATABASE_URL = 'postgresql+psycopg2://{username}:{password}@{host}:{port}/{dbname}'


def get_database_uri():
    secret = get_secret(SECRET_NAME)
    return SQLALCHEMY_DATABASE_URL.format(
        username=secret['username'],
        password=secret['password'],
        host=secret['host'],
        port=secret['port'],
        dbname=secret['dbname'],
    )


def get_engine():
    global engine
    if engine:
        return engine
    conn_str = get_database_uri()
    engine = create_engine(
        conn_str,
        echo=True,
        poolclass=QueuePool,
        pool_pre_ping=True,
        # pool_size=15,
        # max_overflow=5,
        echo_pool="debug"
    )
    return engine


engine = get_engine()


#
# class SessionManager:
#     def __init__(self):
#         self.db = sessionmaker(bind=engine, autocommit=True, expire_on_commit=False)
#
#     def __enter__(self):
#         return self.db
#
#     def __exit__(self, exc_type, exc_val, exc_tb):
#         self.db.close()


def get_session():
    with Session(engine) as session:
        try:
            yield session
            session.commit()
        except Exception as exc:
            session.rollback()
            raise exc
        finally:
            session.close()


def init_db_sqlalchemy():
    SQLModel.metadata.create_all(engine)


def fetch(db: Session, query, *args, **kwargs):
    try:
        stmt = text(query)
        result = db.execute(stmt, *args, **kwargs)
        db.commit()
        return result
    except (Exception, OperationalError) as err:
        logging.exception(f"error_code={err} function_name={fetch.__name__}")
    finally:
        db.close()

Run Code Online (Sandbox Code Playgroud)
import os
import time
from uuid import uuid4

import uvicorn
from fastapi import FastAPI, Request, Depends
from fastapi.encoders import jsonable_encoder
from api.routes.info import info, health
from api.utils.db import init_db_sqlalchemy, get_session, fetch

app = FastAPI(
)
app.include_router(info, prefix="/info")
app.include_router(health, prefix="/health")


@app.middleware("http")
async def add_logging_and_process_time(request: Request, call_next):
    start_time = time.time()
    request_id = str(uuid4().hex)
    response = await call_next(request)
    process_time = time.time() - start_time
    process_time = str(round(process_time * 1000))
    response.headers["X-Process-Time-MS"] = process_time
    log_msg = f"request_id={request_id} service=my-svc url={request.url} host={request.client.host} " \
              f"port={request.client.port} processing_time_ms={process_time} env={os.environ.get('APP_ENV')} " \
              f"version=v1 pid={os.getpid()} region={os.environ.get('REGION')} "
    logger.info(log_msg)
    return response

@app.on_event('startup')
def startup():
    init_db_sqlalchemy()

@app.get('/getDatabaseInfo')
def get_db_data_example(db: Session = Depends(get_session)):
    try:
        records = fetch(db, DATABASE_QUERY).all()
        return jsonable_encoder(records)
    except Exception as err:
        logger.exception(f"function_name=getDatabaseInfo error={err}")



if __name__ == '__main__':
    uvicorn.run(app, host="0.0.0.0", port=8050)
Run Code Online (Sandbox Code Playgroud)

我目前在启动时初始化数据库。当数据库连接出现错误时,我们应该在哪里重新初始化数据库呢?即从 utils/secret_mgr.py 中提取新凭据并在 utils/db.py 中重新创建数据库引擎。

鉴于以上信息,有几个问题:

  • 如果我们需要每 30 天重新初始化一次,引擎应该是全局对象吗?
  • 如果 get_session 失败,则会对 get 请求进行依赖注入,失败后会关闭会话并将其添加回连接池。如果我们使用连接池并且其中一个连接失效,那么池中的所有连接也会失效。这没关系,应该发生。我们应该在哪里重新创建数据库引擎?

考虑到上述限制,执行此操作的正确方法是什么?