~ | Merge pull request #1064 from Lin-jun-xiang/ignore_providers

feat: ignore providers
This commit is contained in:
Tekky 2023-10-13 11:33:44 +01:00 committed by GitHub
commit 5f0cadd0ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,13 +1,14 @@
from __future__ import annotations
from requests import get
from g4f.models import Model, ModelUtils
from .Provider import BaseProvider
from .typing import Messages, CreateResult, Union
from .Provider import BaseProvider, RetryProvider
from .typing import Messages, CreateResult, Union, List
from .debug import logging
version = '0.1.6.2'
version_check = True
def check_pypi_version() -> None:
try:
response = get("https://pypi.org/pypi/g4f/json").json()
@ -19,9 +20,11 @@ def check_pypi_version() -> None:
except Exception as e:
print(f'Failed to check g4f pypi version: {e}')
def get_model_and_provider(model : Union[Model, str],
provider : Union[type[BaseProvider], None],
stream : bool) -> tuple[Model, type[BaseProvider]]:
stream : bool,
ignored : List[str] = None) -> tuple[Model, type[BaseProvider]]:
if isinstance(model, str):
if model in ModelUtils.convert:
@ -32,6 +35,9 @@ def get_model_and_provider(model : Union[Model, str],
if not provider:
provider = model.best_provider
if isinstance(provider, RetryProvider) and ignored:
provider.providers = [p for p in provider.providers if p.__name__ not in ignored]
if not provider:
raise RuntimeError(f'No provider found for model: {model}')
@ -46,15 +52,17 @@ def get_model_and_provider(model : Union[Model, str],
return model, provider
class ChatCompletion:
@staticmethod
def create(model: Union[Model, str],
messages : Messages,
provider : Union[type[BaseProvider], None] = None,
stream : bool = False,
auth : Union[str, None] = None, **kwargs) -> Union[CreateResult, str]:
auth : Union[str, None] = None,
ignored : List[str] = None, **kwargs) -> Union[CreateResult, str]:
model, provider = get_model_and_provider(model, provider, stream)
model, provider = get_model_and_provider(model, provider, stream, ignored)
if provider.needs_auth and not auth:
raise ValueError(
@ -71,15 +79,17 @@ class ChatCompletion:
model : Union[Model, str],
messages: Messages,
provider: Union[type[BaseProvider], None] = None,
stream : bool = False, **kwargs) -> str:
stream : bool = False,
ignored : List[str] = None, **kwargs) -> str:
if stream:
raise ValueError(f'"create_async" does not support "stream" argument')
model, provider = get_model_and_provider(model, provider, False)
model, provider = get_model_and_provider(model, provider, False, ignored)
return await provider.create_async(model.name, messages, **kwargs)
class Completion:
@staticmethod
def create(
@ -87,6 +97,7 @@ class Completion:
prompt: str,
provider: Union[type[BaseProvider], None] = None,
stream: bool = False,
ignored : List[str] = None,
**kwargs
) -> Union[CreateResult, str]:
@ -102,7 +113,7 @@ class Completion:
if model not in allowed_models:
raise Exception(f'ValueError: Can\'t use {model} with Completion.create()')
model, provider = get_model_and_provider(model, provider, stream)
model, provider = get_model_and_provider(model, provider, stream, ignored)
result = provider.create_completion(model.name, [{"role": "user", "content": prompt}], stream, **kwargs)