Merge branch 'zustand/io/migration' of github.com:logspace-ai/langflow into zustand/io/migration

This commit is contained in:
igorrCarvalho 2024-02-22 12:57:37 -03:00
commit f57af2da52
13 changed files with 203 additions and 109 deletions

View file

@ -10,6 +10,9 @@ from fastapi import (
WebSocketException,
status,
)
from loguru import logger
from sqlmodel import Session
from langflow.api.utils import build_and_cache_graph, format_elapsed_time
from langflow.api.v1.schemas import (
ResultData,
@ -24,8 +27,6 @@ from langflow.services.auth.utils import (
from langflow.services.chat.service import ChatService
from langflow.services.deps import get_chat_service, get_session
from langflow.services.monitor.utils import log_vertex_build
from loguru import logger
from sqlmodel import Session
router = APIRouter(tags=["Chat"])
@ -172,7 +173,7 @@ async def build_vertex(
raise ValueError(f"No result found for vertex {vertex_id}")
chat_service.set_cache(flow_id, graph)
except Exception as exc:
params = repr(exc)
params = str(exc)
valid = False
result_dict = ResultData(results={})
artifacts = {}

View file

@ -23,19 +23,25 @@ class ConversationChainComponent(CustomComponent):
def build(
self,
inputs: str,
llm: BaseLanguageModel,
memory: Optional[BaseMemory] = None,
inputs: dict = {},
) -> Union[Chain, Callable, Text]:
if memory is None:
chain = ConversationChain(llm=llm)
chain = ConversationChain(llm=llm, memory=memory)
else:
chain = ConversationChain(llm=llm, memory=memory)
result = chain.invoke(inputs)
# result is an AIMessage which is a subclass of BaseMessage
# We need to check if it is a string or a BaseMessage
if hasattr(result, "content") and isinstance(result.content, str):
return result.content
self.status = "is message"
result = result.content
elif isinstance(result, str):
return result
return str(result)
self.status = "is_string"
result = result
else:
# is dict
result = result.get("response")
self.status = result
return result

View file

@ -1,4 +1,3 @@
import os
from typing import Any, Callable, Dict, Optional, Union
from langchain_community.chat_models.litellm import ChatLiteLLM, ChatLiteLLMException
@ -27,6 +26,18 @@ class ChatLiteLLMComponent(CustomComponent):
"required": False,
"password": True,
},
"provider": {
"display_name": "Provider",
"info": "The provider of the API key.",
"options": [
"OpenAI",
"Azure",
"Anthropic",
"Replicate",
"Cohere",
"OpenRouter",
],
},
"streaming": {
"display_name": "Streaming",
"field_type": "bool",
@ -96,7 +107,8 @@ class ChatLiteLLMComponent(CustomComponent):
def build(
self,
model: str,
api_key: str,
provider: str,
api_key: Optional[str] = None,
streaming: bool = True,
temperature: Optional[float] = 0.7,
model_kwargs: Optional[Dict[str, Any]] = {},
@ -114,13 +126,19 @@ class ChatLiteLLMComponent(CustomComponent):
litellm.set_verbose = verbose
except ImportError:
raise ChatLiteLLMException(
"Could not import litellm python package. " "Please install it with `pip install litellm`"
"Could not import litellm python package. "
"Please install it with `pip install litellm`"
)
if api_key:
if "perplexity" in model:
os.environ["PERPLEXITYAI_API_KEY"] = api_key
elif "replicate" in model:
os.environ["REPLICATE_API_KEY"] = api_key
provider_map = {
"OpenAI": "openai_api_key",
"Azure": "azure_api_key",
"Anthropic": "anthropic_api_key",
"Replicate": "replicate_api_key",
"Cohere": "cohere_api_key",
"OpenRouter": "openrouter_api_key",
}
# Set the API key based on the provider
kwarg = {provider_map[provider]: api_key}
LLM = ChatLiteLLM(
model=model,
@ -133,5 +151,6 @@ class ChatLiteLLMComponent(CustomComponent):
n=n,
max_tokens=max_tokens,
max_retries=max_retries,
**kwarg,
)
return LLM

View file

@ -0,0 +1,31 @@
def validate_icon(value: str, *args, **kwargs):
# we are going to use the emoji library to validate the emoji
# emojis can be defined using the :emoji_name: syntax
if not value.startswith(":") or not value.endswith(":"):
warnings.warn("Invalid emoji. Please use the :emoji_name: syntax.")
return value
emoji_value = emoji.emojize(value, variant="emoji_type")
if value == emoji_value:
warnings.warn(f"Invalid emoji. {value} is not a valid emoji.")
return value
return emoji_value
def getattr_return_str(value):
return str(value) if value else ""
def getattr_return_bool(value):
if isinstance(value, bool):
return value
ATTR_FUNC_MAPPING = {
"display_name": getattr_return_str,
"description": getattr_return_str,
"beta": getattr_return_str,
"documentation": getattr_return_str,
"icon": validate_icon,
"pinned": getattr_return_bool,
}

View file

@ -2,9 +2,10 @@ import operator
import warnings
from typing import Any, ClassVar, Optional
import emoji
from cachetools import TTLCache, cachedmethod
from fastapi import HTTPException
from langflow.interface.custom.attributes import ATTR_FUNC_MAPPING
from langflow.interface.custom.code_parser import CodeParser
from langflow.interface.custom.eval import eval_custom_component_code
from langflow.utils import validate
@ -65,14 +66,6 @@ class Component:
return validate.create_function(self.code, self._function_entrypoint_name)
def getattr_return_str(self, value):
return str(value) if value else ""
def getattr_return_bool(self, value):
if isinstance(value, bool):
return value
def build_template_config(self) -> dict:
if not self.code:
return {}
@ -80,15 +73,8 @@ class Component:
cc_class = eval_custom_component_code(self.code)
component_instance = cc_class()
template_config = {}
attributes_func_mapping = {
"display_name": self.getattr_return_str,
"description": self.getattr_return_str,
"beta": self.getattr_return_str,
"documentation": self.getattr_return_str,
"icon": self.validate_icon,
}
for attribute, func in attributes_func_mapping.items():
for attribute, func in ATTR_FUNC_MAPPING.items():
if hasattr(component_instance, attribute):
value = getattr(component_instance, attribute)
if value is not None:
@ -96,17 +82,5 @@ class Component:
return template_config
def validate_icon(self, value: str, *args, **kwargs):
# we are going to use the emoji library to validate the emoji
# emojis can be defined using the :emoji_name: syntax
if not value.startswith(":") or not value.endswith(":"):
warnings.warn("Invalid emoji. Please use the :emoji_name: syntax.")
return value
emoji_value = emoji.emojize(value, variant="emoji_type")
if value == emoji_value:
warnings.warn(f"Invalid emoji. {value} is not a valid emoji.")
return value
return emoji_value
def build(self, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError

View file

@ -6,6 +6,7 @@ from uuid import UUID
import yaml
from cachetools import TTLCache, cachedmethod
from fastapi import HTTPException
from langflow.interface.custom.code_parser.utils import (
extract_inner_type_from_generic_alias,
extract_union_types_from_generic_alias,
@ -35,6 +36,8 @@ class CustomComponent(Component):
"""The field configuration of the component. Defaults to an empty dictionary."""
field_order: Optional[List[str]] = None
"""The field order of the component. Defaults to an empty list."""
pinned: Optional[bool] = False
"""The default pinned state of the component. Defaults to False."""
code_class_base_inheritance: ClassVar[str] = "CustomComponent"
function_entrypoint_name: ClassVar[str] = "build"
function: Optional[Callable] = None

View file

@ -7,7 +7,10 @@ from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from fastapi import HTTPException
from loguru import logger
from langflow.field_typing.range_spec import RangeSpec
from langflow.interface.custom.attributes import ATTR_FUNC_MAPPING
from langflow.interface.custom.code_parser.utils import extract_inner_type
from langflow.interface.custom.custom_component import CustomComponent
from langflow.interface.custom.directory_reader.utils import (
@ -22,7 +25,6 @@ from langflow.template.frontend_node.custom_components import (
)
from langflow.utils import validate
from langflow.utils.util import get_base_classes
from loguru import logger
def add_output_types(
@ -263,16 +265,9 @@ def run_build_config(
def sanitize_template_config(template_config):
"""Sanitize the template config"""
attributes = {
"display_name",
"description",
"beta",
"documentation",
"output_types",
"icon",
}
for key in template_config.copy():
if key not in attributes:
if key not in ATTR_FUNC_MAPPING.keys():
template_config.pop(key, None)
return template_config

View file

@ -2,6 +2,8 @@ import re
from collections import defaultdict
from typing import ClassVar, Dict, List, Optional, Union
from pydantic import BaseModel, Field, field_serializer, model_serializer
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.constants import (
CLASSES_TO_REMOVE,
@ -10,7 +12,6 @@ from langflow.template.frontend_node.constants import (
from langflow.template.frontend_node.formatter import field_formatters
from langflow.template.template.base import Template
from langflow.utils import constants
from pydantic import BaseModel, Field, field_serializer, model_serializer
class FieldFormatters(BaseModel):
@ -44,17 +45,31 @@ class FieldFormatters(BaseModel):
class FrontendNode(BaseModel):
_format_template: bool = True
template: Template
"""Template for the frontend node."""
description: Optional[str] = None
"""Description of the frontend node."""
icon: Optional[str] = None
"""Icon of the frontend node."""
is_composition: Optional[bool] = None
"""Whether the frontend node is used for composition."""
base_classes: List[str]
"""List of base classes for the frontend node."""
name: str = ""
"""Name of the frontend node."""
display_name: Optional[str] = ""
"""Display name of the frontend node."""
documentation: str = ""
"""Documentation of the frontend node."""
custom_fields: Optional[Dict] = defaultdict(list)
"""Custom fields of the frontend node."""
output_types: List[str] = []
"""List of output types for the frontend node."""
full_path: Optional[str] = None
"""Full path of the frontend node."""
field_formatters: FieldFormatters = Field(default_factory=FieldFormatters)
"""Field formatters for the frontend node."""
pinned: bool = False
"""Whether the frontend node is pinned."""
beta: bool = False
error: Optional[str] = None

View file

@ -4,7 +4,9 @@ from langchain_community.chat_message_histories.mongodb import (
DEFAULT_COLLECTION_NAME,
DEFAULT_DBNAME,
)
from langchain_community.chat_message_histories.postgres import DEFAULT_CONNECTION_STRING
from langchain_community.chat_message_histories.postgres import (
DEFAULT_CONNECTION_STRING,
)
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.base import FrontendNode
@ -13,7 +15,9 @@ from langflow.template.template.base import Template
class MemoryFrontendNode(FrontendNode):
#! Needs testing
pinned: bool = True
def add_extra_fields(self) -> None:
# chat history should have another way to add common field?
# prevent adding incorect field in ChatMessageHistory
@ -77,7 +81,9 @@ class MemoryFrontendNode(FrontendNode):
field.show = True
field.advanced = False
field.value = ""
field.info = INPUT_KEY_INFO if field.name == "input_key" else OUTPUT_KEY_INFO
field.info = (
INPUT_KEY_INFO if field.name == "input_key" else OUTPUT_KEY_INFO
)
if field.name == "memory_key":
field.value = "chat_history"

View file

@ -62,5 +62,10 @@ export default function IOInputField({
);
}
}
return <div className="h-full w-full">{handleInputType()}</div>;
return (
<div className="font-xl flex h-full w-full flex-col gap-4 p-4 font-semibold">
{inputType}
{handleInputType()}
</div>
);
}

View file

@ -43,5 +43,10 @@ export default function IOOutputView({
);
}
}
return <div className="h-full w-full">{handleOutputType()}</div>;
return (
<div className="font-xl flex h-full w-full flex-col gap-4 p-4 font-semibold">
{outputType}
{handleOutputType()}
</div>
);
}

View file

@ -60,11 +60,17 @@ export default function IOView({ children, open, setOpen }): JSX.Element {
}
function UpdateAccordion() {
return (categories[selectedCategory]?.name ?? "Inputs") === "Inputs" ? inputs : outputs;
return (categories[selectedCategory]?.name ?? "Inputs") === "Inputs"
? inputs
: outputs;
}
return (
<BaseModal size={handleSelectChange() ? "large" : "small"} open={open} setOpen={setOpen}>
<BaseModal
size={handleSelectChange() ? "large" : "small"}
open={open}
setOpen={setOpen}
>
<BaseModal.Trigger>{children}</BaseModal.Trigger>
{/* TODO ADAPT TO ALL TYPES OF INPUTS AND OUTPUTS */}
<BaseModal.Header description={CHAT_FORM_DIALOG_SUBTITLE}>
@ -85,47 +91,59 @@ export default function IOView({ children, open, setOpen }): JSX.Element {
handleSelectChange() ? "w-2/6" : "w-full"
)}
>
<div className="flex items-start gap-4 py-2">
{categories.map((category, index) => {
return (
//hide chat button if chat is alredy on the view
<Button
onClick={() => setSelectedCategory(index)}
variant={
index === selectedCategory ? "primary" : "secondary"
}
key={index}
>
<IconComponent
name={category.icon}
className=" file-component-variable"
/>
<span className="file-component-variables-span text-md">
{category.name}
</span>
</Button>
);
})}
<div className="flex w-full items-center justify-between py-2">
<div className="flex items-start gap-4">
{categories.map((category, index) => {
return (
//hide chat button if chat is alredy on the view
<Button
onClick={() => setSelectedCategory(index)}
variant={
index === selectedCategory ? "primary" : "secondary"
}
key={index}
>
<IconComponent
name={category.icon}
className=" file-component-variable"
/>
<span className="file-component-variables-span text-md">
{category.name}
</span>
</Button>
);
})}
</div>
{(outputs.map((output) => output.type).includes("ChatOutput") ||
inputs.map((output) => output.type).includes("chatInput")) &&
selectedView.type !== "ChatOutput" && (
<button
<Button
onClick={() => setSelectedView({ type: "ChatOutput" })}
className={
"cursor flex items-center rounded-md rounded-b-none px-1 hover:bg-muted-foreground"
}
variant="outline"
key={"chat"}
className="self-end px-2.5"
>
<IconComponent
name="Variable"
className=" file-component-variable"
name="MessageSquareMore"
className="h-5 w-5"
/>
<span className="file-component-variables-span text-md">
Chat
</span>
</button>
</Button>
)}
</div>
<div className="mx-2 mb-2 mt-4 flex items-center gap-2 font-semibold">
{categories[selectedCategory]?.name === "Inputs" && (
<>
<IconComponent name={"FormInput"} />
Text Inputs
</>
)}
{categories[selectedCategory]?.name === "Outputs" && (
<>
<IconComponent name={"ChevronRightSquare"} />
Prompt Outputs
</>
)}
</div>
{UpdateAccordion()
.filter(
(input) =>
@ -164,25 +182,35 @@ export default function IOView({ children, open, setOpen }): JSX.Element {
keyValue={input.id}
>
<div className="file-component-tab-column">
{node &&
(categories[selectedCategory].name === "Inputs" ? (
<IOInputField
inputType={input.type}
inputId={input.id}
/>
) : (
<IOOutputView
outputType={input.type}
outputId={input.id}
/>
))}
<div className="">
{node &&
(categories[selectedCategory]?.name === "Inputs" ? (
<IOInputField
inputType={input.type}
inputId={input.id}
/>
) : (
<IOOutputView
outputType={input.type}
outputId={input.id}
/>
))}
</div>
</div>
</AccordionComponent>
</div>
);
})}
</div>
{handleSelectChange() && handleSelectChange()}
{handleSelectChange() ? (
handleSelectChange()
) : (
<div className="absolute bottom-8 right-8">
<Button className="px-3">
<IconComponent name="Play" className="h-6 w-6" />
</Button>
</div>
)}
</div>
</BaseModal.Content>
</BaseModal>

View file

@ -11,7 +11,7 @@ import {
ChevronDown,
ChevronLeft,
ChevronRight,
Sliders,
ChevronRightSquare,
ChevronUp,
ChevronsLeft,
ChevronsRight,
@ -40,6 +40,7 @@ import {
Fingerprint,
FlaskConical,
FolderPlus,
FormInput,
Forward,
Gift,
GitBranchPlus,
@ -64,6 +65,7 @@ import {
Menu,
MessageCircle,
MessageSquare,
MessageSquareMore,
MessagesSquare,
Minimize2,
Minus,
@ -88,6 +90,7 @@ import {
Share,
Share2,
Shield,
Sliders,
Sparkles,
Square,
Store,
@ -385,7 +388,10 @@ export const nodeIconsLucide: iconsType = {
Square,
Minimize2,
Maximize2,
FormInput,
ChevronRightSquare,
SaveAll,
MessageSquareMore,
Forward,
Share2,
Share,