feat: ignore providers(#1014)

This commit is contained in:
hs_junxiang 2023-10-13 13:45:29 +08:00
parent 5a64e238c8
commit c84ff59145

View File

@ -1,13 +1,14 @@
from __future__ import annotations from __future__ import annotations
from requests import get from requests import get
from g4f.models import Model, ModelUtils from g4f.models import Model, ModelUtils
from .Provider import BaseProvider from .Provider import BaseProvider, RetryProvider
from .typing import Messages, CreateResult, Union from .typing import Messages, CreateResult, Union, List
from .debug import logging from .debug import logging
version = '0.1.6.2' version = '0.1.6.2'
version_check = True version_check = True
def check_pypi_version() -> None: def check_pypi_version() -> None:
try: try:
response = get("https://pypi.org/pypi/g4f/json").json() response = get("https://pypi.org/pypi/g4f/json").json()
@ -19,9 +20,11 @@ def check_pypi_version() -> None:
except Exception as e: except Exception as e:
print(f'Failed to check g4f pypi version: {e}') print(f'Failed to check g4f pypi version: {e}')
def get_model_and_provider(model : Union[Model, str], def get_model_and_provider(model : Union[Model, str],
provider : Union[type[BaseProvider], None], provider : Union[type[BaseProvider], None],
stream : bool) -> tuple[Model, type[BaseProvider]]: stream : bool,
ignored : List[str] = None) -> tuple[Model, type[BaseProvider]]:
if isinstance(model, str): if isinstance(model, str):
if model in ModelUtils.convert: if model in ModelUtils.convert:
@ -32,6 +35,9 @@ def get_model_and_provider(model : Union[Model, str],
if not provider: if not provider:
provider = model.best_provider provider = model.best_provider
if isinstance(provider, RetryProvider) and ignored:
provider.providers = [p for p in provider.providers if p.__name__ not in ignored]
if not provider: if not provider:
raise RuntimeError(f'No provider found for model: {model}') raise RuntimeError(f'No provider found for model: {model}')
@ -46,15 +52,17 @@ def get_model_and_provider(model : Union[Model, str],
return model, provider return model, provider
class ChatCompletion: class ChatCompletion:
@staticmethod @staticmethod
def create(model: Union[Model, str], def create(model: Union[Model, str],
messages : Messages, messages : Messages,
provider : Union[type[BaseProvider], None] = None, provider : Union[type[BaseProvider], None] = None,
stream : bool = False, stream : bool = False,
auth : Union[str, None] = None, **kwargs) -> Union[CreateResult, str]: auth : Union[str, None] = None,
ignored : List[str] = None, **kwargs) -> Union[CreateResult, str]:
model, provider = get_model_and_provider(model, provider, stream) model, provider = get_model_and_provider(model, provider, stream, ignored)
if provider.needs_auth and not auth: if provider.needs_auth and not auth:
raise ValueError( raise ValueError(
@ -71,15 +79,17 @@ class ChatCompletion:
model : Union[Model, str], model : Union[Model, str],
messages: Messages, messages: Messages,
provider: Union[type[BaseProvider], None] = None, provider: Union[type[BaseProvider], None] = None,
stream : bool = False, **kwargs) -> str: stream : bool = False,
ignored : List[str] = None, **kwargs) -> str:
if stream: if stream:
raise ValueError(f'"create_async" does not support "stream" argument') raise ValueError(f'"create_async" does not support "stream" argument')
model, provider = get_model_and_provider(model, provider, False) model, provider = get_model_and_provider(model, provider, False, ignored)
return await provider.create_async(model.name, messages, **kwargs) return await provider.create_async(model.name, messages, **kwargs)
class Completion: class Completion:
@staticmethod @staticmethod
def create( def create(
@ -87,6 +97,7 @@ class Completion:
prompt: str, prompt: str,
provider: Union[type[BaseProvider], None] = None, provider: Union[type[BaseProvider], None] = None,
stream: bool = False, stream: bool = False,
ignored : List[str] = None,
**kwargs **kwargs
) -> Union[CreateResult, str]: ) -> Union[CreateResult, str]:
@ -102,7 +113,7 @@ class Completion:
if model not in allowed_models: if model not in allowed_models:
raise Exception(f'ValueError: Can\'t use {model} with Completion.create()') raise Exception(f'ValueError: Can\'t use {model} with Completion.create()')
model, provider = get_model_and_provider(model, provider, stream) model, provider = get_model_and_provider(model, provider, stream, ignored)
result = provider.create_completion(model.name, [{"role": "user", "content": prompt}], stream, **kwargs) result = provider.create_completion(model.name, [{"role": "user", "content": prompt}], stream, **kwargs)