Refactor code and fix bugs

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-26 21:48:48 -03:00
commit d54e4504f9
13 changed files with 33 additions and 79 deletions

View file

@ -9,9 +9,9 @@ from typing import Optional
import httpx
import typer
from dotenv import load_dotenv
from multiprocess import (
Process, # type: ignore
cpu_count, # type: ignore
from multiprocess import ( # noqa
Process, # noqa
cpu_count, # noqa; type: ignore
)
from rich import box
from rich import print as rprint

View file

@ -69,6 +69,8 @@ async def run_flow(
flow_name: Optional[str] = None,
user_id: Optional[str] = None,
) -> Any:
if not user_id:
raise ValueError("Session is invalid")
graph = await load_flow(user_id, flow_id, flow_name, tweaks)
if inputs is None:

View file

@ -1,6 +1,6 @@
from typing import List, Union
from typing import List, Optional, Union
from langchain.agents import AgentExecutor, BaseMultiActionAgent, BaseSingleActionAgent
from langchain.agents.agent import AgentExecutor, BaseMultiActionAgent, BaseSingleActionAgent, RunnableMultiActionAgent
from langchain_core.runnables import Runnable
from langflow.field_typing import BaseMemory, Text, Tool
@ -44,12 +44,14 @@ class LCAgentComponent(CustomComponent):
inputs: str,
input_variables: list[str],
tools: List[Tool],
memory: BaseMemory = None,
memory: Optional[BaseMemory] = None,
handle_parsing_errors: bool = True,
output_key: str = "output",
) -> Text:
if isinstance(agent, AgentExecutor):
runnable = agent
elif isinstance(agent, Runnable):
runnable = RunnableMultiActionAgent(runnable=agent, stream_runnable=False)
else:
runnable = AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, verbose=True, memory=memory, handle_parsing_errors=handle_parsing_errors

View file

@ -4,7 +4,7 @@ from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import LLM
from langchain_core.messages import HumanMessage, SystemMessage
from langflow.interface.custom.custom_component import CustomComponent
from langflow.custom import CustomComponent
class LCModelComponent(CustomComponent):
@ -35,10 +35,10 @@ class LCModelComponent(CustomComponent):
self, runnable: BaseChatModel, stream: bool, input_value: str, system_message: Optional[str] = None
):
messages = []
if system_message:
messages.append(SystemMessage(system_message))
if input_value:
messages.append(HumanMessage(input_value))
if system_message:
messages.append(SystemMessage(system_message))
if stream:
result = runnable.stream(messages)
else:

View file

@ -1,6 +1,6 @@
import asyncio
from collections import defaultdict
from typing import TYPE_CHECKING, Coroutine, List
from typing import TYPE_CHECKING, Awaitable, Callable, List
if TYPE_CHECKING:
from langflow.graph.graph.base import Graph
@ -55,7 +55,7 @@ class RunnableVerticesManager:
async def get_next_runnable_vertices(
self,
lock: asyncio.Lock,
set_cache_coro: Coroutine,
set_cache_coro: Callable[["Graph", asyncio.Lock], Awaitable[None]],
graph: "Graph",
vertex: "Vertex",
):

View file

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Callable, Coroutine, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union
from pydantic.v1 import BaseModel, Field, create_model
from sqlmodel import select
@ -69,6 +69,8 @@ async def run_flow(
flow_name: Optional[str] = None,
user_id: Optional[str] = None,
) -> Any:
if not user_id:
raise ValueError("Session is invalid")
graph = await load_flow(user_id, flow_id, flow_name, tweaks)
if inputs is None:
@ -77,7 +79,7 @@ async def run_flow(
inputs_components = []
types = []
for input_dict in inputs:
inputs_list.append({INPUT_FIELD_NAME: input_dict.get("input_value")})
inputs_list.append({INPUT_FIELD_NAME: input_dict.get("input_value", "")})
inputs_components.append(input_dict.get("components", []))
types.append(input_dict.get("type", []))
@ -138,12 +140,12 @@ async def flow_function({func_args}):
"""
compiled_func = compile(func_body, "<string>", "exec")
local_scope = {}
local_scope: dict = {}
exec(compiled_func, globals(), local_scope)
return local_scope["flow_function"]
def build_function_and_schema(flow_record: Record, graph: "Graph") -> Tuple[Callable, BaseModel]:
def build_function_and_schema(flow_record: Record, graph: "Graph") -> Tuple[Coroutine, BaseModel]:
"""
Builds a dynamic function and schema for a given flow.
@ -178,7 +180,7 @@ def get_flow_inputs(graph: "Graph") -> List["Vertex"]:
return inputs
def build_schema_from_inputs(name: str, inputs: List[tuple[str, str, str]]) -> BaseModel:
def build_schema_from_inputs(name: str, inputs: List["Vertex"]) -> BaseModel:
"""
Builds a schema from the given inputs.

View file

@ -23,7 +23,7 @@ class ServiceFactory:
raise self.service_class(*args, **kwargs)
def hash_factory(factory: ServiceFactory) -> str:
def hash_factory(factory: Type[ServiceFactory]) -> str:
return factory.service_class.__name__
@ -38,7 +38,7 @@ def hash_infer_service_types_args(factory_class: Type[ServiceFactory], available
@cached(cache=LRUCache(maxsize=10), key=hash_infer_service_types_args)
def infer_service_types(factory_class: Type[ServiceFactory], available_services=None) -> "ServiceType":
def infer_service_types(factory_class: Type[ServiceFactory], available_services=None) -> list["ServiceType"]:
create_method = factory_class.create
type_hints = get_type_hints(create_method, globalns=available_services)
service_types = []

View file

@ -1,5 +1,6 @@
import secrets
from pathlib import Path
from typing import Literal
from loguru import logger
from passlib.context import CryptContext
@ -14,7 +15,7 @@ class AuthSettings(BaseSettings):
# Login settings
CONFIG_DIR: str
SECRET_KEY: SecretStr = Field(
default=None,
default="",
description="Secret key for JWT. If not provided, a random one will be generated.",
frozen=False,
)
@ -33,13 +34,13 @@ class AuthSettings(BaseSettings):
SUPERUSER: str = DEFAULT_SUPERUSER
SUPERUSER_PASSWORD: str = DEFAULT_SUPERUSER_PASSWORD
REFRESH_SAME_SITE: str = "none"
REFRESH_SAME_SITE: Literal["lax", "strict", "none"] = "none"
"""The SameSite attribute of the refresh token cookie."""
REFRESH_SECURE: bool = True
"""The Secure attribute of the refresh token cookie."""
REFRESH_HTTPONLY: bool = True
"""The HttpOnly attribute of the refresh token cookie."""
ACCESS_SAME_SITE: str = "none"
ACCESS_SAME_SITE: Literal["lax", "strict", "none"] = "none"
"""The SameSite attribute of the access token cookie."""
ACCESS_SECURE: bool = True
"""The Secure attribute of the access token cookie."""

View file

@ -1,8 +1,8 @@
from typing import List, Union
from langchain.agents import AgentExecutor, BaseMultiActionAgent, BaseSingleActionAgent
from langchain.agents import (AgentExecutor, BaseMultiActionAgent,
from langflow import CustomComponent
from langflow.custom import CustomComponent
from langflow.field_typing import BaseMemory, Text, Tool

View file

@ -1,3 +0,0 @@
from .model import LCModelComponent
__all__ = ["LCModelComponent"]

View file

@ -1,48 +0,0 @@
from typing import Optional
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import LLM
from langchain_core.messages import HumanMessage, SystemMessage
from langflow import CustomComponent
class LCModelComponent(CustomComponent):
display_name: str = "Model Name"
description: str = "Model Description"
def get_result(self, runnable: LLM, stream: bool, input_value: str):
"""
Retrieves the result from the output of a Runnable object.
Args:
output (Runnable): The output object to retrieve the result from.
stream (bool): Indicates whether to use streaming or invocation mode.
input_value (str): The input value to pass to the output object.
Returns:
The result obtained from the output object.
"""
if stream:
result = runnable.stream(input_value)
else:
message = runnable.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
return result
def get_chat_result(
self, runnable: BaseChatModel, stream: bool, input_value: str, system_message: Optional[str] = None
):
messages = []
if input_value:
messages.append(HumanMessage(input_value))
if system_message:
messages.append(SystemMessage(system_message))
if stream:
result = runnable.stream(messages)
else:
message = runnable.invoke(messages)
result = message.content
self.status = result
return result

View file

@ -1,13 +1,12 @@
from typing import Any, List, Optional, Text
from langchain_core.tools import StructuredTool
from loguru import logger
from langflow import CustomComponent
from langflow.custom import CustomComponent
from langflow.field_typing import Tool
from langflow.graph.graph.base import Graph
from langflow.helpers.flow import build_function_and_schema
from langflow.schema.dotdict import dotdict
from loguru import logger
class FlowToolComponent(CustomComponent):

View file

@ -1,7 +1,6 @@
from langchain_community.tools.searchapi import SearchAPIRun
from langchain_community.utilities.searchapi import SearchApiAPIWrapper
from langflow import CustomComponent
from langflow.custom import CustomComponent
from langflow.field_typing import Tool