71 lines
2.1 KiB
Python
71 lines
2.1 KiB
Python
from fastapi import Security, APIRouter, HTTPException, Depends
|
|
from fastapi.routing import APIRoute
|
|
from jose import jwt, JWTError
|
|
from pydantic import ValidationError
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
from starlette.routing import BaseRoute
|
|
|
|
import app
|
|
import schema
|
|
from badmin.models.model import AccountAdmin
|
|
from config import conf
|
|
from database import get_db
|
|
|
|
|
|
class BaseRoute(APIRoute):
|
|
def get_route_handler(self):
|
|
original_route_handler = super().get_route_handler()
|
|
|
|
async def custom_route_handler(request: Request) -> Response:
|
|
response: Response = await original_route_handler(request)
|
|
return response
|
|
|
|
return custom_route_handler
|
|
|
|
|
|
async def current_user(token: str = Depends(schema.oauth2)):
|
|
authenticate_value = f"Bearer"
|
|
credentials_exception = HTTPException(
|
|
status_code=401,
|
|
detail="Could not validate credentials.",
|
|
headers={"WWW-Authenticate": authenticate_value},
|
|
)
|
|
try:
|
|
payload = jwt.decode(
|
|
token, conf.AUTH_SECRET_KEY, algorithms=[
|
|
conf.AUTH_ALGORITHM])
|
|
username: str = payload.get("id")
|
|
|
|
if username is None:
|
|
raise credentials_exception
|
|
db = next(get_db())
|
|
user = db.query(AccountAdmin).filter(AccountAdmin.id == username).first()
|
|
app.app.state.user = user
|
|
except (JWTError, ValidationError):
|
|
raise credentials_exception
|
|
return username
|
|
|
|
|
|
async def get_current_active_user(username=Security(current_user, scopes=["account"]), ):
|
|
return username
|
|
|
|
|
|
async def getUser():
|
|
return app.app.state.user if app.app.state.user else False
|
|
|
|
|
|
def generateRouter(
|
|
prefix,
|
|
dependencies=None,
|
|
auth_dependencies=False
|
|
):
|
|
if auth_dependencies:
|
|
dependencies = dependencies or []
|
|
dependencies.extend([Security(get_current_active_user)])
|
|
return APIRouter(
|
|
prefix=f"{conf.PREFIX_API_URL}/{prefix}",
|
|
route_class=BaseRoute,
|
|
dependencies=dependencies if dependencies else []
|
|
)
|