diff --git a/g4f/__init__.py b/g4f/__init__.py index a2eec9e2..8a1cb3cd 100644 --- a/g4f/__init__.py +++ b/g4f/__init__.py @@ -115,4 +115,4 @@ class Completion: return result if stream else ''.join(result) if version_check: - check_pypi_version() \ No newline at end of file + check_pypi_version() diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index fec5606f..43bca2a5 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -1,163 +1,137 @@ -import typing -from .. import BaseProvider -import g4f; g4f.debug.logging = True +from fastapi import FastAPI, Response, Request +from typing import List, Union, Any, Dict, AnyStr +from ._tokenizer import tokenize +from .. import BaseProvider + import time import json import random import string -import logging - -from typing import Union -from loguru import logger -from waitress import serve -from ._logging import hook_logging -from ._tokenizer import tokenize -from flask_cors import CORS -from werkzeug.serving import WSGIRequestHandler -from werkzeug.exceptions import default_exceptions -from werkzeug.middleware.proxy_fix import ProxyFix - -from flask import ( - Flask, - jsonify, - make_response, - request, -) +import uvicorn +import nest_asyncio +import g4f class Api: - __default_ip = '127.0.0.1' - __default_port = 1337 - def __init__(self, engine: g4f, debug: bool = True, sentry: bool = False, - list_ignored_providers:typing.List[typing.Union[str, BaseProvider]]=None) -> None: - self.engine = engine - self.debug = debug - self.sentry = sentry - self.list_ignored_providers = list_ignored_providers - self.log_level = logging.DEBUG if debug else logging.WARN - - hook_logging(level=self.log_level, format='[%(asctime)s] %(levelname)s in %(module)s: %(message)s') - self.logger = logging.getLogger('waitress') - - self.app = Flask(__name__) - self.app.wsgi_app = ProxyFix(self.app.wsgi_app, x_port=1) - self.app.after_request(self.__after_request) - - def run(self, bind_str, threads=8): - host, port = self.__parse_bind(bind_str) + list_ignored_providers: List[Union[str, BaseProvider]] = None) -> None: + self.engine = engine + self.debug = debug + self.sentry = sentry + self.list_ignored_providers = list_ignored_providers - CORS(self.app, resources={r'/v1/*': {'supports_credentials': True, 'expose_headers': [ - 'Content-Type', - 'Authorization', - 'X-Requested-With', - 'Accept', - 'Origin', - 'Access-Control-Request-Method', - 'Access-Control-Request-Headers', - 'Content-Disposition'], 'max_age': 600}}) + self.app = FastAPI() + nest_asyncio.apply() - self.app.route('/v1/models', methods=['GET'])(self.models) - self.app.route('/v1/models/', methods=['GET'])(self.model_info) + JSONObject = Dict[AnyStr, Any] + JSONArray = List[Any] + JSONStructure = Union[JSONArray, JSONObject] - self.app.route('/v1/chat/completions', methods=['POST'])(self.chat_completions) - self.app.route('/v1/completions', methods=['POST'])(self.completions) + @self.app.get("/") + async def read_root(): + return Response(content=json.dumps({"info": "g4f API"}, indent=4), media_type="application/json") - for ex in default_exceptions: - self.app.register_error_handler(ex, self.__handle_error) + @self.app.get("/v1") + async def read_root_v1(): + return Response(content=json.dumps({"info": "Go to /v1/chat/completions or /v1/models."}, indent=4), media_type="application/json") - if not self.debug: - self.logger.warning(f'Serving on http://{host}:{port}') + @self.app.get("/v1/models") + async def models(): + model_list = [{ + 'id': model, + 'object': 'model', + 'created': 0, + 'owned_by': 'g4f'} for model in g4f.Model.__all__()] - WSGIRequestHandler.protocol_version = 'HTTP/1.1' - serve(self.app, host=host, port=port, ident=None, threads=threads) - - def __handle_error(self, e: Exception): - self.logger.error(e) + return Response(content=json.dumps({ + 'object': 'list', + 'data': model_list}, indent=4), media_type="application/json") - return make_response(jsonify({ - 'code': e.code, - 'message': str(e.original_exception if self.debug and hasattr(e, 'original_exception') else e.name)}), 500) - - @staticmethod - def __after_request(resp): - resp.headers['X-Server'] = f'g4f/{g4f.version}' - - return resp - - def __parse_bind(self, bind_str): - sections = bind_str.split(':', 2) - if len(sections) < 2: + @self.app.get("/v1/models/{model_name}") + async def model_info(model_name: str): try: - port = int(sections[0]) - return self.__default_ip, port - except ValueError: - return sections[0], self.__default_port + model_info = (g4f.ModelUtils.convert[model_name]) - return sections[0], int(sections[1]) - - async def home(self): - return 'Hello world | https://127.0.0.1:1337/v1' - - async def chat_completions(self): - model = request.json.get('model', 'gpt-3.5-turbo') - stream = request.json.get('stream', False) - messages = request.json.get('messages') - - logger.info(f'model: {model}, stream: {stream}, request: {messages[-1]["content"]}') + return Response(content=json.dumps({ + 'id': model_name, + 'object': 'model', + 'created': 0, + 'owned_by': model_info.base_provider + }, indent=4), media_type="application/json") + except: + return Response(content=json.dumps({"error": "The model does not exist."}, indent=4), media_type="application/json") - config = None - proxy = None - - try: - config = json.load(open("config.json","r",encoding="utf-8")) - proxy = config["proxy"] - - except Exception: - pass - - if proxy != None: - response = self.engine.ChatCompletion.create(model=model, - stream=stream, messages=messages, - ignored=self.list_ignored_providers, - proxy=proxy) - else: - response = self.engine.ChatCompletion.create(model=model, - stream=stream, messages=messages, - ignored=self.list_ignored_providers) - - completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28)) - completion_timestamp = int(time.time()) - - if not stream: - prompt_tokens, _ = tokenize(''.join([message['content'] for message in messages])) - completion_tokens, _ = tokenize(response) - - return { - 'id': f'chatcmpl-{completion_id}', - 'object': 'chat.completion', - 'created': completion_timestamp, - 'model': model, - 'choices': [ - { - 'index': 0, - 'message': { - 'role': 'assistant', - 'content': response, - }, - 'finish_reason': 'stop', - } - ], - 'usage': { - 'prompt_tokens': prompt_tokens, - 'completion_tokens': completion_tokens, - 'total_tokens': prompt_tokens + completion_tokens, - }, + @self.app.post("/v1/chat/completions") + async def chat_completions(request: Request, item: JSONStructure = None): + item_data = { + 'model': 'gpt-3.5-turbo', + 'stream': False, } - def streaming(): + item_data.update(item or {}) + model = item_data.get('model') + stream = item_data.get('stream') + messages = item_data.get('messages') + try: - for chunk in response: - completion_data = { + response = g4f.ChatCompletion.create(model=model, stream=stream, messages=messages) + except: + return Response(content=json.dumps({"error": "An error occurred while generating the response."}, indent=4), media_type="application/json") + + completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28)) + completion_timestamp = int(time.time()) + + if not stream: + prompt_tokens, _ = tokenize(''.join([message['content'] for message in messages])) + completion_tokens, _ = tokenize(response) + + json_data = { + 'id': f'chatcmpl-{completion_id}', + 'object': 'chat.completion', + 'created': completion_timestamp, + 'model': model, + 'choices': [ + { + 'index': 0, + 'message': { + 'role': 'assistant', + 'content': response, + }, + 'finish_reason': 'stop', + } + ], + 'usage': { + 'prompt_tokens': prompt_tokens, + 'completion_tokens': completion_tokens, + 'total_tokens': prompt_tokens + completion_tokens, + }, + } + + return Response(content=json.dumps(json_data, indent=4), media_type="application/json") + + def streaming(): + try: + for chunk in response: + completion_data = { + 'id': f'chatcmpl-{completion_id}', + 'object': 'chat.completion.chunk', + 'created': completion_timestamp, + 'model': model, + 'choices': [ + { + 'index': 0, + 'delta': { + 'content': chunk, + }, + 'finish_reason': None, + } + ], + } + + content = json.dumps(completion_data, separators=(',', ':')) + yield f'data: {content}\n\n' + time.sleep(0.03) + + end_completion_data = { 'id': f'chatcmpl-{completion_id}', 'object': 'chat.completion.chunk', 'created': completion_timestamp, @@ -165,63 +139,24 @@ class Api: 'choices': [ { 'index': 0, - 'delta': { - 'content': chunk, - }, - 'finish_reason': None, + 'delta': {}, + 'finish_reason': 'stop', } ], } - content = json.dumps(completion_data, separators=(',', ':')) + content = json.dumps(end_completion_data, separators=(',', ':')) yield f'data: {content}\n\n' - time.sleep(0.03) - end_completion_data = { - 'id': f'chatcmpl-{completion_id}', - 'object': 'chat.completion.chunk', - 'created': completion_timestamp, - 'model': model, - 'choices': [ - { - 'index': 0, - 'delta': {}, - 'finish_reason': 'stop', - } - ], - } - - content = json.dumps(end_completion_data, separators=(',', ':')) - yield f'data: {content}\n\n' - - logger.success(f'model: {model}, stream: {stream}') - - except GeneratorExit: - pass + except GeneratorExit: + pass - return self.app.response_class(streaming(), mimetype='text/event-stream') - - async def completions(self): - return 'not working yet', 500 - - async def model_info(self, model_name): - model_info = (g4f.ModelUtils.convert[model_name]) - - return jsonify({ - 'id' : model_name, - 'object' : 'model', - 'created' : 0, - 'owned_by' : model_info.base_provider - }) - - async def models(self): - model_list = [{ - 'id' : model, - 'object' : 'model', - 'created' : 0, - 'owned_by' : 'g4f'} for model in g4f.Model.__all__()] - - return jsonify({ - 'object': 'list', - 'data': model_list}) - \ No newline at end of file + return Response(content=json.dumps(streaming(), indent=4), media_type="application/json") + + @self.app.post("/v1/completions") + async def completions(): + return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json") + + def run(self, ip): + split_ip = ip.split(":") + uvicorn.run(app=self.app, host=split_ip[0], port=int(split_ip[1]), use_colors=False) diff --git a/g4f/api/run.py b/g4f/api/run.py index 12bf9eed..88f34741 100644 --- a/g4f/api/run.py +++ b/g4f/api/run.py @@ -3,4 +3,4 @@ import g4f.api if __name__ == "__main__": print(f'Starting server... [g4f v-{g4f.version}]') - g4f.api.Api(g4f).run('127.0.0.1:1337', 8) \ No newline at end of file + g4f.api.Api(engine = g4f, debug = True).run(ip = "127.0.0.1:1337") diff --git a/g4f/cli.py b/g4f/cli.py index cb19dde1..20131e5d 100644 --- a/g4f/cli.py +++ b/g4f/cli.py @@ -7,11 +7,9 @@ from g4f import Provider from g4f.api import Api from g4f.gui.run import gui_parser, run_gui_args - def run_gui(args): print("Running GUI...") - def main(): IgnoredProviders = Enum("ignore_providers", {key: key for key in Provider.__all__}) parser = argparse.ArgumentParser(description="Run gpt4free") @@ -19,22 +17,19 @@ def main(): api_parser=subparsers.add_parser("api") api_parser.add_argument("--bind", default="127.0.0.1:1337", help="The bind string.") api_parser.add_argument("--debug", type=bool, default=False, help="Enable verbose logging") - api_parser.add_argument("--num-threads", type=int, default=8, help="The number of threads.") api_parser.add_argument("--ignored-providers", nargs="+", choices=[provider.name for provider in IgnoredProviders], default=[], help="List of providers to ignore when processing request.") subparsers.add_parser("gui", parents=[gui_parser()], add_help=False) args = parser.parse_args() if args.mode == "api": - controller=Api(g4f, debug=args.debug) - controller.list_ignored_providers=args.ignored_providers - controller.run(args.bind, args.num_threads) + controller=Api(engine=g4f, debug=args.debug, list_ignored_providers=args.ignored_providers) + controller.run(args.bind) elif args.mode == "gui": run_gui_args(args) else: parser.print_help() exit(1) - if __name__ == "__main__": main() diff --git a/requirements.txt b/requirements.txt index 3ef9b32e..ffadf62a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,8 +6,6 @@ certifi browser_cookie3 websockets js2py -flask[async] -flask-cors typing-extensions PyExecJS duckduckgo-search @@ -20,3 +18,5 @@ pillow platformdirs numpy asgiref +fastapi +uvicorn