feat: ignore providers(#1014)

This commit is contained in:
hs_junxiang 2023-10-13 13:45:29 +08:00
parent 5a64e238c8
commit c84ff59145
1 changed files with 19 additions and 8 deletions

View File

@ -1,13 +1,14 @@
from __future__ import annotations
from requests import get
from g4f.models import Model, ModelUtils
from .Provider import BaseProvider
from .typing import Messages, CreateResult, Union
from .Provider import BaseProvider, RetryProvider
from .typing import Messages, CreateResult, Union, List
from .debug import logging
version = '0.1.6.2'
version_check = True
def check_pypi_version() -> None:
try:
response = get("https://pypi.org/pypi/g4f/json").json()
@ -19,9 +20,11 @@ def check_pypi_version() -> None:
except Exception as e:
print(f'Failed to check g4f pypi version: {e}')
def get_model_and_provider(model : Union[Model, str],
provider : Union[type[BaseProvider], None],
stream : bool) -> tuple[Model, type[BaseProvider]]:
stream : bool,
ignored : List[str] = None) -> tuple[Model, type[BaseProvider]]:
if isinstance(model, str):
if model in ModelUtils.convert:
@ -32,6 +35,9 @@ def get_model_and_provider(model : Union[Model, str],
if not provider:
provider = model.best_provider
if isinstance(provider, RetryProvider) and ignored:
provider.providers = [p for p in provider.providers if p.__name__ not in ignored]
if not provider:
raise RuntimeError(f'No provider found for model: {model}')
@ -46,15 +52,17 @@ def get_model_and_provider(model : Union[Model, str],
return model, provider
class ChatCompletion:
@staticmethod
def create(model: Union[Model, str],
messages : Messages,
provider : Union[type[BaseProvider], None] = None,
stream : bool = False,
auth : Union[str, None] = None, **kwargs) -> Union[CreateResult, str]:
auth : Union[str, None] = None,
ignored : List[str] = None, **kwargs) -> Union[CreateResult, str]:
model, provider = get_model_and_provider(model, provider, stream)
model, provider = get_model_and_provider(model, provider, stream, ignored)
if provider.needs_auth and not auth:
raise ValueError(
@ -71,15 +79,17 @@ class ChatCompletion:
model : Union[Model, str],
messages: Messages,
provider: Union[type[BaseProvider], None] = None,
stream : bool = False, **kwargs) -> str:
stream : bool = False,
ignored : List[str] = None, **kwargs) -> str:
if stream:
raise ValueError(f'"create_async" does not support "stream" argument')
model, provider = get_model_and_provider(model, provider, False)
model, provider = get_model_and_provider(model, provider, False, ignored)
return await provider.create_async(model.name, messages, **kwargs)
class Completion:
@staticmethod
def create(
@ -87,6 +97,7 @@ class Completion:
prompt: str,
provider: Union[type[BaseProvider], None] = None,
stream: bool = False,
ignored : List[str] = None,
**kwargs
) -> Union[CreateResult, str]:
@ -102,7 +113,7 @@ class Completion:
if model not in allowed_models:
raise Exception(f'ValueError: Can\'t use {model} with Completion.create()')
model, provider = get_model_and_provider(model, provider, stream)
model, provider = get_model_and_provider(model, provider, stream, ignored)
result = provider.create_completion(model.name, [{"role": "user", "content": prompt}], stream, **kwargs)