Merge pull request #1905 from kafmws/main

add authorization for g4f API
This commit is contained in:
H Lohaus 2024-04-28 22:47:42 +02:00 committed by GitHub
commit f47b7a2a9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 6 deletions

View File

@ -3,11 +3,14 @@ from __future__ import annotations
import logging
import json
import uvicorn
import secrets
from fastapi import FastAPI, Response, Request
from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from fastapi.security import APIKeyHeader
from starlette.exceptions import HTTPException
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from typing import Union, Optional
@ -17,10 +20,11 @@ import g4f.debug
from g4f.client import AsyncClient
from g4f.typing import Messages
def create_app() -> FastAPI:
def create_app(g4f_api_key:str = None):
app = FastAPI()
api = Api(app)
api = Api(app, g4f_api_key=g4f_api_key)
api.register_routes()
api.register_authorization()
api.register_validation_exception_handler()
return app
@ -43,9 +47,32 @@ def set_list_ignored_providers(ignored: list[str]):
list_ignored_providers = ignored
class Api:
def __init__(self, app: FastAPI) -> None:
def __init__(self, app: FastAPI, g4f_api_key=None) -> None:
self.app = app
self.client = AsyncClient()
self.g4f_api_key = g4f_api_key
self.get_g4f_api_key = APIKeyHeader(name="g4f-api-key")
def register_authorization(self):
@self.app.middleware("http")
async def authorization(request: Request, call_next):
if self.g4f_api_key and request.url.path in ["/v1/chat/completions", "/v1/completions"]:
try:
user_g4f_api_key = await self.get_g4f_api_key(request)
except HTTPException as e:
if e.status_code == 403:
return JSONResponse(
status_code=HTTP_401_UNAUTHORIZED,
content=jsonable_encoder({"detail": "G4F API key required"}),
)
if not secrets.compare_digest(self.g4f_api_key, user_g4f_api_key):
return JSONResponse(
status_code=HTTP_403_FORBIDDEN,
content=jsonable_encoder({"detail": "Invalid G4F API key"}),
)
response = await call_next(request)
return response
def register_validation_exception_handler(self):
@self.app.exception_handler(RequestValidationError)
@ -153,7 +180,8 @@ def run_api(
bind: str = None,
debug: bool = False,
workers: int = None,
use_colors: bool = None
use_colors: bool = None,
g4f_api_key: str = None
) -> None:
print(f'Starting server... [g4f v-{g4f.version.utils.current_version}]' + (" (debug)" if debug else ""))
if use_colors is None:
@ -162,4 +190,4 @@ def run_api(
host, port = bind.split(":")
if debug:
g4f.debug.logging = True
uvicorn.run("g4f.api:create_app", host=host, port=int(port), workers=workers, use_colors=use_colors, factory=True)#
uvicorn.run(create_app(g4f_api_key), host=host, port=int(port), workers=workers, use_colors=use_colors)

View File

@ -16,6 +16,7 @@ def main():
api_parser.add_argument("--workers", type=int, default=None, help="Number of workers.")
api_parser.add_argument("--disable-colors", action="store_true", help="Don't use colors.")
api_parser.add_argument("--ignore-cookie-files", action="store_true", help="Don't read .har and cookie files.")
api_parser.add_argument("--g4f-api-key", type=str, default=None, help="Sets an authentication key for your API.")
api_parser.add_argument("--ignored-providers", nargs="+", choices=[provider for provider in Provider.__map__],
default=[], help="List of providers to ignore when processing request.")
subparsers.add_parser("gui", parents=[gui_parser()], add_help=False)
@ -42,6 +43,7 @@ def run_api_args(args):
bind=args.bind,
debug=args.debug,
workers=args.workers,
g4f_api_key=args.g4f_api_key,
use_colors=not args.disable_colors
)