Add model preselection in gui

This commit is contained in:
Heiner Lohaus 2024-03-13 17:52:48 +01:00
parent 9c381f2906
commit 13f1275ca3
5 changed files with 118 additions and 46 deletions

View File

@ -12,7 +12,7 @@ from aiohttp import ClientSession, ClientTimeout, BaseConnector, WSMsgType
from ..typing import AsyncResult, Messages, ImageType, Cookies
from ..image import ImageRequest
from ..errors import ResponseStatusError
from .base_provider import AsyncGeneratorProvider
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import get_connector, get_random_hex
from .bing.upload_image import upload_image
from .bing.conversation import Conversation, create_conversation, delete_conversation
@ -27,7 +27,7 @@ class Tones:
balanced = "Balanced"
precise = "Precise"
class Bing(AsyncGeneratorProvider):
class Bing(AsyncGeneratorProvider, ProviderModelMixin):
"""
Bing provider for generating responses using the Bing API.
"""
@ -35,16 +35,21 @@ class Bing(AsyncGeneratorProvider):
working = True
supports_message_history = True
supports_gpt_4 = True
default_model = Tones.balanced
models = [
getattr(Tones, key) for key in dir(Tones) if not key.startswith("__")
]
@staticmethod
@classmethod
def create_async_generator(
cls,
model: str,
messages: Messages,
proxy: str = None,
timeout: int = 900,
cookies: Cookies = None,
connector: BaseConnector = None,
tone: str = Tones.balanced,
tone: str = None,
image: ImageType = None,
web_search: bool = False,
**kwargs
@ -62,13 +67,11 @@ class Bing(AsyncGeneratorProvider):
:param web_search: Flag to enable or disable web search.
:return: An asynchronous result object.
"""
if len(messages) < 2:
prompt = messages[0]["content"]
context = None
else:
prompt = messages[-1]["content"]
context = create_context(messages[:-1])
prompt = messages[-1]["content"]
context = create_context(messages[:-1]) if len(messages) > 1 else None
if tone is None:
tone = tone if model.startswith("gpt-4") else model
tone = cls.get_model(tone)
gpt4_turbo = True if model.startswith("gpt-4-turbo") else False
return stream_generate(
@ -86,7 +89,9 @@ def create_context(messages: Messages) -> str:
:return: A string representing the context created from the messages.
"""
return "".join(
f"[{message['role']}]" + ("(#message)" if message['role'] != "system" else "(#additional_instructions)") + f"\n{message['content']}"
f"[{message['role']}]" + ("(#message)"
if message['role'] != "system"
else "(#additional_instructions)") + f"\n{message['content']}"
for message in messages
) + "\n\n"
@ -403,7 +408,7 @@ async def stream_generate(
do_read = False
if response_txt.startswith(returned_text):
new = response_txt[len(returned_text):]
if new != "\n":
if new not in ("", "\n"):
yield new
returned_text = response_txt
if image_response:

View File

@ -106,6 +106,10 @@ body {
border: 1px solid var(--blur-border);
}
.hidden {
display: none;
}
.conversations {
max-width: 260px;
padding: var(--section-gap);

View File

@ -162,6 +162,10 @@
<option value="">----</option>
</select>
</div>
<div class="field">
<select name="model2" id="model2" class="hidden">
</select>
</div>
<div class="field">
<select name="jailbreak" id="jailbreak" style="display: none;">
<option value="default" selected>Set Jailbreak</option>

View File

@ -12,7 +12,9 @@ const imageInput = document.getElementById("image");
const cameraInput = document.getElementById("camera");
const fileInput = document.getElementById("file");
const inputCount = document.getElementById("input-count")
const providerSelect = document.getElementById("provider");
const modelSelect = document.getElementById("model");
const modelProvider = document.getElementById("model2");
const systemPrompt = document.getElementById("systemPrompt")
let prompt_lock = false;
@ -44,17 +46,21 @@ const markdown_render = (content) => {
}
let typesetPromise = Promise.resolve();
let timeoutHighlightId;
const highlight = (container) => {
container.querySelectorAll('code:not(.hljs').forEach((el) => {
if (el.className != "hljs") {
hljs.highlightElement(el);
}
});
typesetPromise = typesetPromise.then(
() => MathJax.typesetPromise([container])
).catch(
(err) => console.log('Typeset failed: ' + err.message)
);
if (timeoutHighlightId) clearTimeout(timeoutHighlightId);
timeoutHighlightId = setTimeout(() => {
container.querySelectorAll('code:not(.hljs').forEach((el) => {
if (el.className != "hljs") {
hljs.highlightElement(el);
}
});
typesetPromise = typesetPromise.then(
() => MathJax.typesetPromise([container])
).catch(
(err) => console.log('Typeset failed: ' + err.message)
);
}, 100);
}
const register_remove_message = async () => {
@ -108,7 +114,6 @@ const handle_ask = async () => {
if (input.files.length > 0) imageInput.dataset.src = URL.createObjectURL(input.files[0]);
else delete imageInput.dataset.src
model = modelSelect.options[modelSelect.selectedIndex].value
message_box.innerHTML += `
<div class="message" data-index="${message_index}">
<div class="user">
@ -124,7 +129,7 @@ const handle_ask = async () => {
: ''
}
</div>
<div class="count">${count_words_and_tokens(message, model)}</div>
<div class="count">${count_words_and_tokens(message, get_selected_model())}</div>
</div>
</div>
`;
@ -204,7 +209,6 @@ const ask_gpt = async () => {
window.controller = new AbortController();
jailbreak = document.getElementById("jailbreak");
provider = document.getElementById("provider");
window.text = '';
stop_generating.classList.remove(`stop_generating-hidden`);
@ -241,10 +245,10 @@ const ask_gpt = async () => {
let body = JSON.stringify({
id: window.token,
conversation_id: window.conversation_id,
model: modelSelect.options[modelSelect.selectedIndex].value,
model: get_selected_model(),
jailbreak: jailbreak.options[jailbreak.selectedIndex].value,
web_search: document.getElementById(`switch`).checked,
provider: provider.options[provider.selectedIndex].value,
provider: providerSelect.options[providerSelect.selectedIndex].value,
patch_provider: document.getElementById('patch')?.checked,
messages: messages
});
@ -666,11 +670,13 @@ sidebar_button.addEventListener("click", (event) => {
window.scrollTo(0, 0);
});
const options = ["switch", "model", "model2", "jailbreak", "patch", "provider", "history"];
const register_settings_localstorage = async () => {
for (id of ["switch", "model", "jailbreak", "patch", "provider", "history"]) {
options.forEach((id) => {
element = document.getElementById(id);
if (!element) {
continue;
return;
}
element.addEventListener('change', async (event) => {
switch (event.target.type) {
@ -684,14 +690,14 @@ const register_settings_localstorage = async () => {
console.warn("Unresolved element type");
}
});
}
});
}
const load_settings_localstorage = async () => {
for (id of ["switch", "model", "jailbreak", "patch", "provider", "history"]) {
options.forEach((id) => {
element = document.getElementById(id);
if (!element || !(value = appStorage.getItem(element.id))) {
continue;
return;
}
if (value) {
switch (element.type) {
@ -705,7 +711,7 @@ const load_settings_localstorage = async () => {
console.warn("Unresolved element type");
}
}
}
});
}
const say_hello = async () => {
@ -780,13 +786,16 @@ function count_words_and_tokens(text, model) {
}
let countFocus = messageInput;
let timeoutId;
const count_input = async () => {
if (countFocus.value) {
model = modelSelect.options[modelSelect.selectedIndex].value;
inputCount.innerText = count_words_and_tokens(countFocus.value, model);
} else {
inputCount.innerHTML = "&nbsp;"
}
if (timeoutId) clearTimeout(timeoutId);
timeoutId = setTimeout(() => {
if (countFocus.value) {
inputCount.innerText = count_words_and_tokens(countFocus.value, get_selected_model());
} else {
inputCount.innerHTML = "&nbsp;"
}
}, 100);
};
messageInput.addEventListener("keyup", count_input);
systemPrompt.addEventListener("keyup", count_input);
@ -850,11 +859,13 @@ window.onload = async () => {
providers = await response.json()
select = document.getElementById('provider');
for (provider of providers) {
providers.forEach((provider) => {
let option = document.createElement('option');
option.value = option.text = provider;
select.appendChild(option);
}
})
await load_provider_models();
await load_settings_localstorage()
})();
@ -914,4 +925,33 @@ fileInput.addEventListener('change', async (event) => {
systemPrompt?.addEventListener("blur", async () => {
await save_system_message();
});
});
function get_selected_model() {
if (modelProvider.selectedIndex >= 0) {
return modelProvider.options[modelProvider.selectedIndex].value;
} else if (modelSelect.selectedIndex >= 0) {
return modelSelect.options[modelSelect.selectedIndex].value;
}
}
async function load_provider_models() {
provider = providerSelect.options[providerSelect.selectedIndex].value;
response = await fetch('/backend-api/v2/models/' + provider);
models = await response.json();
if (models.length > 0) {
modelSelect.classList.add("hidden");
modelProvider.classList.remove("hidden");
modelProvider.innerHTML = '';
models.forEach((model) => {
let option = document.createElement('option');
option.value = option.text = model.model;
option.selected = model.default;
modelProvider.appendChild(option);
});
} else {
modelProvider.classList.add("hidden");
modelSelect.classList.remove("hidden");
}
};
providerSelect.addEventListener("change", load_provider_models)

View File

@ -6,10 +6,11 @@ from g4f import version, models
from g4f import get_last_provider, ChatCompletion
from g4f.image import is_allowed_extension, to_image
from g4f.errors import VersionNotFoundError
from g4f.Provider import __providers__
from g4f.Provider import ProviderType, __providers__, __map__
from g4f.providers.base_provider import ProviderModelMixin
from g4f.Provider.bing.create_images import patch_provider
class Backend_Api:
class Backend_Api:
"""
Handles various endpoints in a Flask application for backend operations.
@ -33,6 +34,10 @@ class Backend_Api:
'function': self.get_models,
'methods': ['GET']
},
'/backend-api/v2/models/<provider>': {
'function': self.get_provider_models,
'methods': ['GET']
},
'/backend-api/v2/providers': {
'function': self.get_providers,
'methods': ['GET']
@ -75,7 +80,21 @@ class Backend_Api:
List[str]: A list of model names.
"""
return models._all_models
def get_provider_models(self, provider: str):
if provider in __map__:
provider: ProviderType = __map__[provider]
if issubclass(provider, ProviderModelMixin):
return [{"model": model, "default": model == provider.default_model} for model in provider.get_models()]
elif provider.supports_gpt_35_turbo or provider.supports_gpt_4:
return [
*([{"model": "gpt-3.5-turbo", "default": not provider.supports_gpt_4}] if provider.supports_gpt_35_turbo else []),
*([{"model": "gpt-4", "default": not provider.supports_gpt_4}] if provider.supports_gpt_4 else [])
]
else:
return [];
return 404, "Provider not found"
def get_providers(self):
"""
Return a list of all working providers.