Merge pull request #1925 from hlohaus/worker

Add Ollama provider, Add vision support to Openai
This commit is contained in:
H Lohaus 2024-05-05 23:51:40 +02:00 committed by GitHub
commit 1d02a06456
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 72 additions and 32 deletions

View File

@ -457,10 +457,13 @@ async def stream_generate(
returned_text = '' returned_text = ''
message_id = None message_id = None
while do_read: while do_read:
msg = await wss.receive_str() try:
msg = await wss.receive_str()
except TypeError:
continue
objects = msg.split(Defaults.delimiter) objects = msg.split(Defaults.delimiter)
for obj in objects: for obj in objects:
if obj is None or not obj: if not obj:
continue continue
try: try:
response = json.loads(obj) response = json.loads(obj)

View File

@ -1,8 +1,7 @@
from __future__ import annotations from __future__ import annotations
import requests import requests
from ..typing import AsyncResult, Messages, ImageType from ..typing import AsyncResult, Messages
from ..image import to_data_uri
from .needs_auth.Openai import Openai from .needs_auth.Openai import Openai
class DeepInfra(Openai): class DeepInfra(Openai):
@ -33,7 +32,6 @@ class DeepInfra(Openai):
model: str, model: str,
messages: Messages, messages: Messages,
stream: bool, stream: bool,
image: ImageType = None,
api_base: str = "https://api.deepinfra.com/v1/openai", api_base: str = "https://api.deepinfra.com/v1/openai",
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: int = 1028, max_tokens: int = 1028,
@ -54,19 +52,6 @@ class DeepInfra(Openai):
'sec-ch-ua-mobile': '?0', 'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"macOS"', 'sec-ch-ua-platform': '"macOS"',
} }
if image is not None:
if not model:
model = cls.default_vision_model
messages[-1]["content"] = [
{
"type": "image_url",
"image_url": {"url": to_data_uri(image)}
},
{
"type": "text",
"text": messages[-1]["content"]
}
]
return super().create_async_generator( return super().create_async_generator(
model, messages, model, messages,
stream=stream, stream=stream,

33
g4f/Provider/Ollama.py Normal file
View File

@ -0,0 +1,33 @@
from __future__ import annotations
import requests
from .needs_auth.Openai import Openai
from ..typing import AsyncResult, Messages
class Ollama(Openai):
label = "Ollama"
url = "https://ollama.com"
needs_auth = False
working = True
@classmethod
def get_models(cls):
if not cls.models:
url = 'http://127.0.0.1:11434/api/tags'
models = requests.get(url).json()["models"]
cls.models = [model['name'] for model in models]
cls.default_model = cls.models[0]
return cls.models
@classmethod
def create_async_generator(
cls,
model: str,
messages: Messages,
api_base: str = "http://localhost:11434/v1",
**kwargs
) -> AsyncResult:
return super().create_async_generator(
model, messages, api_base=api_base, **kwargs
)

View File

@ -43,6 +43,7 @@ from .Llama import Llama
from .Local import Local from .Local import Local
from .MetaAI import MetaAI from .MetaAI import MetaAI
from .MetaAIAccount import MetaAIAccount from .MetaAIAccount import MetaAIAccount
from .Ollama import Ollama
from .PerplexityLabs import PerplexityLabs from .PerplexityLabs import PerplexityLabs
from .Pi import Pi from .Pi import Pi
from .Replicate import Replicate from .Replicate import Replicate

View File

@ -4,9 +4,10 @@ import json
from ..helper import filter_none from ..helper import filter_none
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason
from ...typing import Union, Optional, AsyncResult, Messages from ...typing import Union, Optional, AsyncResult, Messages, ImageType
from ...requests import StreamSession, raise_for_status from ...requests import StreamSession, raise_for_status
from ...errors import MissingAuthError, ResponseError from ...errors import MissingAuthError, ResponseError
from ...image import to_data_uri
class Openai(AsyncGeneratorProvider, ProviderModelMixin): class Openai(AsyncGeneratorProvider, ProviderModelMixin):
label = "OpenAI API" label = "OpenAI API"
@ -23,6 +24,7 @@ class Openai(AsyncGeneratorProvider, ProviderModelMixin):
messages: Messages, messages: Messages,
proxy: str = None, proxy: str = None,
timeout: int = 120, timeout: int = 120,
image: ImageType = None,
api_key: str = None, api_key: str = None,
api_base: str = "https://api.openai.com/v1", api_base: str = "https://api.openai.com/v1",
temperature: float = None, temperature: float = None,
@ -36,6 +38,19 @@ class Openai(AsyncGeneratorProvider, ProviderModelMixin):
) -> AsyncResult: ) -> AsyncResult:
if cls.needs_auth and api_key is None: if cls.needs_auth and api_key is None:
raise MissingAuthError('Add a "api_key"') raise MissingAuthError('Add a "api_key"')
if image is not None:
if not model and hasattr(cls, "default_vision_model"):
model = cls.default_vision_model
messages[-1]["content"] = [
{
"type": "image_url",
"image_url": {"url": to_data_uri(image)}
},
{
"type": "text",
"text": messages[-1]["content"]
}
]
async with StreamSession( async with StreamSession(
proxies={"all": proxy}, proxies={"all": proxy},
headers=cls.get_headers(stream, api_key, headers), headers=cls.get_headers(stream, api_key, headers),
@ -51,7 +66,6 @@ class Openai(AsyncGeneratorProvider, ProviderModelMixin):
stream=stream, stream=stream,
**extra_data **extra_data
) )
async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response: async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response:
await raise_for_status(response) await raise_for_status(response)
if not stream: if not stream:
@ -103,8 +117,7 @@ class Openai(AsyncGeneratorProvider, ProviderModelMixin):
"Content-Type": "application/json", "Content-Type": "application/json",
**( **(
{"Authorization": f"Bearer {api_key}"} {"Authorization": f"Bearer {api_key}"}
if cls.needs_auth and api_key is not None if api_key is not None else {}
else {}
), ),
**({} if headers is None else headers) **({} if headers is None else headers)
} }

View File

@ -201,7 +201,7 @@ def run_api(
if bind is not None: if bind is not None:
host, port = bind.split(":") host, port = bind.split(":")
uvicorn.run( uvicorn.run(
f"g4f.api:{'create_app_debug' if debug else 'create_app'}", f"g4f.api:create_app{'_debug' if debug else ''}",
host=host, port=int(port), host=host, port=int(port),
workers=workers, workers=workers,
use_colors=use_colors, use_colors=use_colors,

View File

@ -11,6 +11,10 @@ def main():
api_parser = subparsers.add_parser("api") api_parser = subparsers.add_parser("api")
api_parser.add_argument("--bind", default="0.0.0.0:1337", help="The bind string.") api_parser.add_argument("--bind", default="0.0.0.0:1337", help="The bind string.")
api_parser.add_argument("--debug", action="store_true", help="Enable verbose logging.") api_parser.add_argument("--debug", action="store_true", help="Enable verbose logging.")
api_parser.add_argument("--model", default=None, help="Default model for chat completion. (incompatible with --debug and --workers)")
api_parser.add_argument("--provider", choices=[provider.__name__ for provider in Provider.__providers__ if provider.working],
default=None, help="Default provider for chat completion. (incompatible with --debug and --workers)")
api_parser.add_argument("--proxy", default=None, help="Default used proxy.")
api_parser.add_argument("--workers", type=int, default=None, help="Number of workers.") api_parser.add_argument("--workers", type=int, default=None, help="Number of workers.")
api_parser.add_argument("--disable-colors", action="store_true", help="Don't use colors.") api_parser.add_argument("--disable-colors", action="store_true", help="Don't use colors.")
api_parser.add_argument("--ignore-cookie-files", action="store_true", help="Don't read .har and cookie files.") api_parser.add_argument("--ignore-cookie-files", action="store_true", help="Don't read .har and cookie files.")
@ -31,14 +35,15 @@ def main():
def run_api_args(args): def run_api_args(args):
from g4f.api import AppConfig, run_api from g4f.api import AppConfig, run_api
AppConfig.set_ignore_cookie_files( AppConfig.set_config(
args.ignore_cookie_files ignore_cookie_files=args.ignore_cookie_files,
) ignored_providers=args.ignored_providers,
AppConfig.set_list_ignored_providers( g4f_api_key=args.g4f_api_key,
args.ignored_providers defaults={
) "model": args.model,
AppConfig.set_g4f_api_key( "provider": args.provider,
args.g4f_api_key "proxy": args.proxy
}
) )
run_api( run_api(
bind=args.bind, bind=args.bind,

View File

@ -40,7 +40,7 @@ async def get_args_from_webview(url: str) -> dict:
"Referer": window.real_url "Referer": window.real_url
} }
cookies = [list(*cookie.items()) for cookie in window.get_cookies()] cookies = [list(*cookie.items()) for cookie in window.get_cookies()]
cookies = dict([(name, cookie.value) for name, cookie in cookies]) cookies = {name: cookie.value for name, cookie in cookies}
window.destroy() window.destroy()
return {"headers": headers, "cookies": cookies} return {"headers": headers, "cookies": cookies}