2023-12-02 10:53:31 +08:00
|
|
|
from functools import wraps
|
|
|
|
from contextlib import contextmanager
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
2023-12-08 14:33:02 +08:00
|
|
|
from db.base import SessionLocal
|
2023-12-07 17:49:36 +08:00
|
|
|
|
2023-12-02 10:53:31 +08:00
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def session_scope() -> Session:
|
|
|
|
"""上下文管理器用于自动获取 Session, 避免错误"""
|
|
|
|
session = SessionLocal()
|
|
|
|
try:
|
|
|
|
yield session
|
|
|
|
session.commit()
|
|
|
|
except:
|
|
|
|
session.rollback()
|
|
|
|
raise
|
|
|
|
finally:
|
|
|
|
session.close()
|
|
|
|
|
|
|
|
|
|
|
|
def with_session(f):
|
|
|
|
@wraps(f)
|
|
|
|
def wrapper(*args, **kwargs):
|
|
|
|
with session_scope() as session:
|
|
|
|
try:
|
|
|
|
result = f(session, *args, **kwargs)
|
|
|
|
session.commit()
|
|
|
|
return result
|
|
|
|
except:
|
|
|
|
session.rollback()
|
|
|
|
raise
|
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
def get_db() -> SessionLocal:
|
|
|
|
db = SessionLocal()
|
|
|
|
try:
|
|
|
|
yield db
|
|
|
|
finally:
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
|
|
|
def get_db0() -> SessionLocal:
|
|
|
|
db = SessionLocal()
|
|
|
|
return db
|