🐛 fix(loading.py): include "config" in the list of keys to check for *kwargs

 feat(constants.py): add default config for CTransformers
🚀 feat(llms.py): add method to format ctransformers field in LLMFrontendNode
The fix in loading.py ensures that the *kwargs are converted to a dictionary when the key contains "config". The addition of the default config for CTransformers in constants.py provides a default configuration for the CTransformers model. The new method in llms.py formats the ctransformers field in the LLMFrontendNode.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-20 16:14:16 -03:00
commit f0c507a660
3 changed files with 35 additions and 2 deletions

View file

@ -53,8 +53,8 @@ def convert_params_to_sets(params):
def convert_kwargs(params):
# if *kwargs are passed as a string, convert to dict
# first find any key that has kwargs in it
kwargs_keys = [key for key in params.keys() if "kwargs" in key]
# first find any key that has kwargs or config in it
kwargs_keys = [key for key in params.keys() if "kwargs" in key or "config" in key]
for key in kwargs_keys:
if isinstance(params[key], str):
params[key] = json.loads(params[key])
@ -82,10 +82,16 @@ def instantiate_based_on_type(class_object, base_type, node_type, params):
return instantiate_utility(node_type, class_object, params)
elif base_type == "chains":
return instantiate_chains(node_type, class_object, params)
elif base_type == "llms":
return instantiate_llm(node_type, class_object, params)
else:
return class_object(**params)
def instantiate_llm(node_type, class_object, params):
return class_object(**params)
def instantiate_chains(node_type, class_object, params):
if "retriever" in params and hasattr(params["retriever"], "as_retriever"):
params["retriever"] = params["retriever"].as_retriever()

View file

@ -32,3 +32,20 @@ You are a good listener and you can talk about anything.
HUMAN_PROMPT = "{input}"
QA_CHAIN_TYPES = ["stuff", "map_reduce", "map_rerank", "refine"]
CTRANSFORMERS_DEFAULT_CONFIG = {
"top_k": 40,
"top_p": 0.95,
"temperature": 0.8,
"repetition_penalty": 1.1,
"last_n_tokens": 64,
"seed": -1,
"max_new_tokens": 256,
"stop": None,
"stream": False,
"reset": True,
"batch_size": 8,
"threads": -1,
"context_length": -1,
"gpu_layers": 0,
}

View file

@ -1,7 +1,9 @@
import json
from typing import Optional
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.base import FrontendNode
from langflow.template.frontend_node.constants import CTRANSFORMERS_DEFAULT_CONFIG
class LLMFrontendNode(FrontendNode):
@ -31,6 +33,13 @@ class LLMFrontendNode(FrontendNode):
field.show = True
field.advanced = not field.required
@staticmethod
def format_ctransformers_field(field: TemplateField):
if field.name == "config":
field.show = True
field.advanced = True
field.value = json.dumps(CTRANSFORMERS_DEFAULT_CONFIG, indent=2)
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
display_names_dict = {
@ -38,6 +47,7 @@ class LLMFrontendNode(FrontendNode):
}
FrontendNode.format_field(field, name)
LLMFrontendNode.format_openai_field(field)
LLMFrontendNode.format_ctransformers_field(field)
if name and "azure" in name.lower():
LLMFrontendNode.format_azure_field(field)
if name and "llama" in name.lower():