mirror of https://github.com/xtekky/gpt4free.git
Add model preselection in gui
This commit is contained in:
parent
9c381f2906
commit
13f1275ca3
|
@ -12,7 +12,7 @@ from aiohttp import ClientSession, ClientTimeout, BaseConnector, WSMsgType
|
|||
from ..typing import AsyncResult, Messages, ImageType, Cookies
|
||||
from ..image import ImageRequest
|
||||
from ..errors import ResponseStatusError
|
||||
from .base_provider import AsyncGeneratorProvider
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from .helper import get_connector, get_random_hex
|
||||
from .bing.upload_image import upload_image
|
||||
from .bing.conversation import Conversation, create_conversation, delete_conversation
|
||||
|
@ -27,7 +27,7 @@ class Tones:
|
|||
balanced = "Balanced"
|
||||
precise = "Precise"
|
||||
|
||||
class Bing(AsyncGeneratorProvider):
|
||||
class Bing(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
"""
|
||||
Bing provider for generating responses using the Bing API.
|
||||
"""
|
||||
|
@ -35,16 +35,21 @@ class Bing(AsyncGeneratorProvider):
|
|||
working = True
|
||||
supports_message_history = True
|
||||
supports_gpt_4 = True
|
||||
default_model = Tones.balanced
|
||||
models = [
|
||||
getattr(Tones, key) for key in dir(Tones) if not key.startswith("__")
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
proxy: str = None,
|
||||
timeout: int = 900,
|
||||
cookies: Cookies = None,
|
||||
connector: BaseConnector = None,
|
||||
tone: str = Tones.balanced,
|
||||
tone: str = None,
|
||||
image: ImageType = None,
|
||||
web_search: bool = False,
|
||||
**kwargs
|
||||
|
@ -62,13 +67,11 @@ class Bing(AsyncGeneratorProvider):
|
|||
:param web_search: Flag to enable or disable web search.
|
||||
:return: An asynchronous result object.
|
||||
"""
|
||||
if len(messages) < 2:
|
||||
prompt = messages[0]["content"]
|
||||
context = None
|
||||
else:
|
||||
prompt = messages[-1]["content"]
|
||||
context = create_context(messages[:-1])
|
||||
|
||||
prompt = messages[-1]["content"]
|
||||
context = create_context(messages[:-1]) if len(messages) > 1 else None
|
||||
if tone is None:
|
||||
tone = tone if model.startswith("gpt-4") else model
|
||||
tone = cls.get_model(tone)
|
||||
gpt4_turbo = True if model.startswith("gpt-4-turbo") else False
|
||||
|
||||
return stream_generate(
|
||||
|
@ -86,7 +89,9 @@ def create_context(messages: Messages) -> str:
|
|||
:return: A string representing the context created from the messages.
|
||||
"""
|
||||
return "".join(
|
||||
f"[{message['role']}]" + ("(#message)" if message['role'] != "system" else "(#additional_instructions)") + f"\n{message['content']}"
|
||||
f"[{message['role']}]" + ("(#message)"
|
||||
if message['role'] != "system"
|
||||
else "(#additional_instructions)") + f"\n{message['content']}"
|
||||
for message in messages
|
||||
) + "\n\n"
|
||||
|
||||
|
@ -403,7 +408,7 @@ async def stream_generate(
|
|||
do_read = False
|
||||
if response_txt.startswith(returned_text):
|
||||
new = response_txt[len(returned_text):]
|
||||
if new != "\n":
|
||||
if new not in ("", "\n"):
|
||||
yield new
|
||||
returned_text = response_txt
|
||||
if image_response:
|
||||
|
|
|
@ -106,6 +106,10 @@ body {
|
|||
border: 1px solid var(--blur-border);
|
||||
}
|
||||
|
||||
.hidden {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.conversations {
|
||||
max-width: 260px;
|
||||
padding: var(--section-gap);
|
||||
|
|
|
@ -162,6 +162,10 @@
|
|||
<option value="">----</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="field">
|
||||
<select name="model2" id="model2" class="hidden">
|
||||
</select>
|
||||
</div>
|
||||
<div class="field">
|
||||
<select name="jailbreak" id="jailbreak" style="display: none;">
|
||||
<option value="default" selected>Set Jailbreak</option>
|
||||
|
|
|
@ -12,7 +12,9 @@ const imageInput = document.getElementById("image");
|
|||
const cameraInput = document.getElementById("camera");
|
||||
const fileInput = document.getElementById("file");
|
||||
const inputCount = document.getElementById("input-count")
|
||||
const providerSelect = document.getElementById("provider");
|
||||
const modelSelect = document.getElementById("model");
|
||||
const modelProvider = document.getElementById("model2");
|
||||
const systemPrompt = document.getElementById("systemPrompt")
|
||||
|
||||
let prompt_lock = false;
|
||||
|
@ -44,17 +46,21 @@ const markdown_render = (content) => {
|
|||
}
|
||||
|
||||
let typesetPromise = Promise.resolve();
|
||||
let timeoutHighlightId;
|
||||
const highlight = (container) => {
|
||||
container.querySelectorAll('code:not(.hljs').forEach((el) => {
|
||||
if (el.className != "hljs") {
|
||||
hljs.highlightElement(el);
|
||||
}
|
||||
});
|
||||
typesetPromise = typesetPromise.then(
|
||||
() => MathJax.typesetPromise([container])
|
||||
).catch(
|
||||
(err) => console.log('Typeset failed: ' + err.message)
|
||||
);
|
||||
if (timeoutHighlightId) clearTimeout(timeoutHighlightId);
|
||||
timeoutHighlightId = setTimeout(() => {
|
||||
container.querySelectorAll('code:not(.hljs').forEach((el) => {
|
||||
if (el.className != "hljs") {
|
||||
hljs.highlightElement(el);
|
||||
}
|
||||
});
|
||||
typesetPromise = typesetPromise.then(
|
||||
() => MathJax.typesetPromise([container])
|
||||
).catch(
|
||||
(err) => console.log('Typeset failed: ' + err.message)
|
||||
);
|
||||
}, 100);
|
||||
}
|
||||
|
||||
const register_remove_message = async () => {
|
||||
|
@ -108,7 +114,6 @@ const handle_ask = async () => {
|
|||
if (input.files.length > 0) imageInput.dataset.src = URL.createObjectURL(input.files[0]);
|
||||
else delete imageInput.dataset.src
|
||||
|
||||
model = modelSelect.options[modelSelect.selectedIndex].value
|
||||
message_box.innerHTML += `
|
||||
<div class="message" data-index="${message_index}">
|
||||
<div class="user">
|
||||
|
@ -124,7 +129,7 @@ const handle_ask = async () => {
|
|||
: ''
|
||||
}
|
||||
</div>
|
||||
<div class="count">${count_words_and_tokens(message, model)}</div>
|
||||
<div class="count">${count_words_and_tokens(message, get_selected_model())}</div>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
|
@ -204,7 +209,6 @@ const ask_gpt = async () => {
|
|||
window.controller = new AbortController();
|
||||
|
||||
jailbreak = document.getElementById("jailbreak");
|
||||
provider = document.getElementById("provider");
|
||||
window.text = '';
|
||||
|
||||
stop_generating.classList.remove(`stop_generating-hidden`);
|
||||
|
@ -241,10 +245,10 @@ const ask_gpt = async () => {
|
|||
let body = JSON.stringify({
|
||||
id: window.token,
|
||||
conversation_id: window.conversation_id,
|
||||
model: modelSelect.options[modelSelect.selectedIndex].value,
|
||||
model: get_selected_model(),
|
||||
jailbreak: jailbreak.options[jailbreak.selectedIndex].value,
|
||||
web_search: document.getElementById(`switch`).checked,
|
||||
provider: provider.options[provider.selectedIndex].value,
|
||||
provider: providerSelect.options[providerSelect.selectedIndex].value,
|
||||
patch_provider: document.getElementById('patch')?.checked,
|
||||
messages: messages
|
||||
});
|
||||
|
@ -666,11 +670,13 @@ sidebar_button.addEventListener("click", (event) => {
|
|||
window.scrollTo(0, 0);
|
||||
});
|
||||
|
||||
const options = ["switch", "model", "model2", "jailbreak", "patch", "provider", "history"];
|
||||
|
||||
const register_settings_localstorage = async () => {
|
||||
for (id of ["switch", "model", "jailbreak", "patch", "provider", "history"]) {
|
||||
options.forEach((id) => {
|
||||
element = document.getElementById(id);
|
||||
if (!element) {
|
||||
continue;
|
||||
return;
|
||||
}
|
||||
element.addEventListener('change', async (event) => {
|
||||
switch (event.target.type) {
|
||||
|
@ -684,14 +690,14 @@ const register_settings_localstorage = async () => {
|
|||
console.warn("Unresolved element type");
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const load_settings_localstorage = async () => {
|
||||
for (id of ["switch", "model", "jailbreak", "patch", "provider", "history"]) {
|
||||
options.forEach((id) => {
|
||||
element = document.getElementById(id);
|
||||
if (!element || !(value = appStorage.getItem(element.id))) {
|
||||
continue;
|
||||
return;
|
||||
}
|
||||
if (value) {
|
||||
switch (element.type) {
|
||||
|
@ -705,7 +711,7 @@ const load_settings_localstorage = async () => {
|
|||
console.warn("Unresolved element type");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const say_hello = async () => {
|
||||
|
@ -780,13 +786,16 @@ function count_words_and_tokens(text, model) {
|
|||
}
|
||||
|
||||
let countFocus = messageInput;
|
||||
let timeoutId;
|
||||
const count_input = async () => {
|
||||
if (countFocus.value) {
|
||||
model = modelSelect.options[modelSelect.selectedIndex].value;
|
||||
inputCount.innerText = count_words_and_tokens(countFocus.value, model);
|
||||
} else {
|
||||
inputCount.innerHTML = " "
|
||||
}
|
||||
if (timeoutId) clearTimeout(timeoutId);
|
||||
timeoutId = setTimeout(() => {
|
||||
if (countFocus.value) {
|
||||
inputCount.innerText = count_words_and_tokens(countFocus.value, get_selected_model());
|
||||
} else {
|
||||
inputCount.innerHTML = " "
|
||||
}
|
||||
}, 100);
|
||||
};
|
||||
messageInput.addEventListener("keyup", count_input);
|
||||
systemPrompt.addEventListener("keyup", count_input);
|
||||
|
@ -850,11 +859,13 @@ window.onload = async () => {
|
|||
providers = await response.json()
|
||||
select = document.getElementById('provider');
|
||||
|
||||
for (provider of providers) {
|
||||
providers.forEach((provider) => {
|
||||
let option = document.createElement('option');
|
||||
option.value = option.text = provider;
|
||||
select.appendChild(option);
|
||||
}
|
||||
})
|
||||
|
||||
await load_provider_models();
|
||||
|
||||
await load_settings_localstorage()
|
||||
})();
|
||||
|
@ -914,4 +925,33 @@ fileInput.addEventListener('change', async (event) => {
|
|||
|
||||
systemPrompt?.addEventListener("blur", async () => {
|
||||
await save_system_message();
|
||||
});
|
||||
});
|
||||
|
||||
function get_selected_model() {
|
||||
if (modelProvider.selectedIndex >= 0) {
|
||||
return modelProvider.options[modelProvider.selectedIndex].value;
|
||||
} else if (modelSelect.selectedIndex >= 0) {
|
||||
return modelSelect.options[modelSelect.selectedIndex].value;
|
||||
}
|
||||
}
|
||||
|
||||
async function load_provider_models() {
|
||||
provider = providerSelect.options[providerSelect.selectedIndex].value;
|
||||
response = await fetch('/backend-api/v2/models/' + provider);
|
||||
models = await response.json();
|
||||
if (models.length > 0) {
|
||||
modelSelect.classList.add("hidden");
|
||||
modelProvider.classList.remove("hidden");
|
||||
modelProvider.innerHTML = '';
|
||||
models.forEach((model) => {
|
||||
let option = document.createElement('option');
|
||||
option.value = option.text = model.model;
|
||||
option.selected = model.default;
|
||||
modelProvider.appendChild(option);
|
||||
});
|
||||
} else {
|
||||
modelProvider.classList.add("hidden");
|
||||
modelSelect.classList.remove("hidden");
|
||||
}
|
||||
};
|
||||
providerSelect.addEventListener("change", load_provider_models)
|
|
@ -6,10 +6,11 @@ from g4f import version, models
|
|||
from g4f import get_last_provider, ChatCompletion
|
||||
from g4f.image import is_allowed_extension, to_image
|
||||
from g4f.errors import VersionNotFoundError
|
||||
from g4f.Provider import __providers__
|
||||
from g4f.Provider import ProviderType, __providers__, __map__
|
||||
from g4f.providers.base_provider import ProviderModelMixin
|
||||
from g4f.Provider.bing.create_images import patch_provider
|
||||
|
||||
class Backend_Api:
|
||||
class Backend_Api:
|
||||
"""
|
||||
Handles various endpoints in a Flask application for backend operations.
|
||||
|
||||
|
@ -33,6 +34,10 @@ class Backend_Api:
|
|||
'function': self.get_models,
|
||||
'methods': ['GET']
|
||||
},
|
||||
'/backend-api/v2/models/<provider>': {
|
||||
'function': self.get_provider_models,
|
||||
'methods': ['GET']
|
||||
},
|
||||
'/backend-api/v2/providers': {
|
||||
'function': self.get_providers,
|
||||
'methods': ['GET']
|
||||
|
@ -75,7 +80,21 @@ class Backend_Api:
|
|||
List[str]: A list of model names.
|
||||
"""
|
||||
return models._all_models
|
||||
|
||||
|
||||
def get_provider_models(self, provider: str):
|
||||
if provider in __map__:
|
||||
provider: ProviderType = __map__[provider]
|
||||
if issubclass(provider, ProviderModelMixin):
|
||||
return [{"model": model, "default": model == provider.default_model} for model in provider.get_models()]
|
||||
elif provider.supports_gpt_35_turbo or provider.supports_gpt_4:
|
||||
return [
|
||||
*([{"model": "gpt-3.5-turbo", "default": not provider.supports_gpt_4}] if provider.supports_gpt_35_turbo else []),
|
||||
*([{"model": "gpt-4", "default": not provider.supports_gpt_4}] if provider.supports_gpt_4 else [])
|
||||
]
|
||||
else:
|
||||
return [];
|
||||
return 404, "Provider not found"
|
||||
|
||||
def get_providers(self):
|
||||
"""
|
||||
Return a list of all working providers.
|
||||
|
|
Loading…
Reference in New Issue