Merge remote-tracking branch 'origin/dev' into db

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-06 12:57:46 -03:00
commit f5f0983116
45 changed files with 519 additions and 373 deletions

View file

@ -1,4 +1,4 @@
from langflow.cache import cache_manager
from langflow.interface.loading import load_flow_from_json
from langflow.processing.process import load_flow_from_json
__all__ = ["load_flow_from_json", "cache_manager"]

View file

@ -0,0 +1,3 @@
from langflow.api.router import router
__all__ = ["router"]

View file

@ -0,0 +1,8 @@
# Router for base api
from fastapi import APIRouter
from langflow.api.v1 import chat_router, endpoints_router, validate_router
router = APIRouter(prefix="/api/v1", tags=["api"])
router.include_router(chat_router)
router.include_router(endpoints_router)
router.include_router(validate_router)

View file

@ -0,0 +1,5 @@
from langflow.api.v1.endpoints import router as endpoints_router
from langflow.api.v1.validate import router as validate_router
from langflow.api.v1.chat import router as chat_router
__all__ = ["chat_router", "endpoints_router", "validate_router"]

View file

@ -1,6 +1,6 @@
from pydantic import BaseModel, validator
from langflow.graph.utils import extract_input_variables_from_prompt
from langflow.interface.utils import extract_input_variables_from_prompt
class CacheResponse(BaseModel):

View file

@ -3,7 +3,7 @@ from typing import Any
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langflow.api.schemas import ChatResponse
from langflow.api.v1.schemas import ChatResponse
# https://github.com/hwchase17/chat-langchain/blob/master/callback.py

View file

@ -6,7 +6,7 @@ from fastapi import (
status,
)
from langflow.api.chat_manager import ChatManager
from langflow.chat.manager import ChatManager
from langflow.utils.logger import logger
router = APIRouter(tags=["Chat"])

View file

@ -1,16 +1,15 @@
import json
from langflow.database.models.flow import Flow
from langflow.processing.process import process_graph_cached
from langflow.utils.logger import logger
from importlib.metadata import version
from fastapi import APIRouter, Depends, HTTPException
from langflow.api.schemas import (
GraphData,
from langflow.api.v1.schemas import (
PredictRequest,
PredictResponse,
)
from langflow.interface.run import process_graph_cached
from langflow.interface.types import build_langchain_types_dict
from langflow.database.base import get_session
from sqlmodel import Session
@ -34,8 +33,8 @@ async def get_load(
flow_obj = session.get(Flow, flow_id)
if flow_obj is None:
raise ValueError(f"Flow {flow_id} not found")
graph_data: GraphData = json.loads(flow_obj.flow)
data = graph_data.get("data")
graph_data = flow_obj.flow
data: dict = graph_data.get("data", {})
response = process_graph_cached(data, predict_request.message)
return PredictResponse(
result=response.get("result", ""),
@ -51,8 +50,3 @@ async def get_load(
@router.get("/version")
def get_version():
return {"version": version("langflow")}
@router.get("/health")
def get_health():
return {"status": "OK"}

View file

@ -2,7 +2,7 @@ import json
from fastapi import APIRouter, HTTPException
from langflow.api.base import (
from langflow.api.v1.base import (
Code,
CodeValidationResponse,
Prompt,
@ -10,7 +10,7 @@ from langflow.api.base import (
validate_prompt,
)
from langflow.graph.vertex.types import VectorStoreVertex
from langflow.interface.run import build_graph
from langflow.graph import Graph
from langflow.utils.logger import logger
from langflow.utils.validate import validate_code
@ -44,7 +44,7 @@ def post_validate_prompt(prompt: Prompt):
def post_validate_node(node_id: str, data: dict):
try:
# build graph
graph = build_graph(data)
graph = Graph.from_payload(data)
# validate node
node = graph.get_node(node_id)
if node is None:

View file

View file

@ -1,21 +1,18 @@
import asyncio
import json
from collections import defaultdict
from typing import Dict, List
from fastapi import WebSocket, status
from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse
from langflow.api.v1.schemas import ChatMessage, ChatResponse, FileResponse
from langflow.cache import cache_manager
from langflow.cache.manager import Subject
from langflow.interface.run import (
get_result_and_steps,
load_or_build_langchain_object,
)
from langflow.interface.utils import pil_to_base64, try_setting_streaming_options
from langflow.chat.utils import process_graph
from langflow.interface.utils import pil_to_base64
from langflow.utils.logger import logger
import asyncio
import json
from typing import Dict, List
class ChatHistory(Subject):
def __init__(self):
super().__init__()
@ -191,33 +188,3 @@ class ChatManager:
except Exception as e:
logger.exception(e)
self.disconnect(client_id)
async def process_graph(
graph_data: Dict,
is_first_message: bool,
chat_message: ChatMessage,
websocket: WebSocket,
):
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
langchain_object = try_setting_streaming_options(langchain_object, websocket)
logger.debug("Loaded langchain object")
if langchain_object is None:
# Raise user facing error
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
# Generate result and thought
try:
logger.debug("Generating result and thought")
result, intermediate_steps = await get_result_and_steps(
langchain_object, chat_message.message or "", websocket=websocket
)
logger.debug("Generated result and intermediate_steps")
return result, intermediate_steps
except Exception as e:
# Log stack trace
logger.exception(e)
raise e

View file

@ -0,0 +1,41 @@
from fastapi import WebSocket
from langflow.api.v1.schemas import ChatMessage
from langflow.processing.process import (
load_or_build_langchain_object,
)
from langflow.processing.base import get_result_and_steps
from langflow.interface.utils import try_setting_streaming_options
from langflow.utils.logger import logger
from typing import Dict
async def process_graph(
graph_data: Dict,
is_first_message: bool,
chat_message: ChatMessage,
websocket: WebSocket,
):
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
langchain_object = try_setting_streaming_options(langchain_object, websocket)
logger.debug("Loaded langchain object")
if langchain_object is None:
# Raise user facing error
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
# Generate result and thought
try:
logger.debug("Generating result and thought")
result, intermediate_steps = await get_result_and_steps(
langchain_object, chat_message.message or "", websocket=websocket
)
logger.debug("Generated result and intermediate_steps")
return result, intermediate_steps
except Exception as e:
# Log stack trace
logger.exception(e)
raise e

View file

@ -51,6 +51,7 @@ embeddings:
llms:
- OpenAI
# - AzureOpenAI
# - AzureChatOpenAI
- ChatOpenAI
- LlamaCpp
- CTransformers

View file

@ -24,6 +24,27 @@ class Graph:
self._edges = edges
self._build_graph()
@classmethod
@classmethod
def from_payload(cls, payload: Dict) -> "Graph":
"""
Creates a graph from a payload.
Args:
payload (Dict): The payload to create the graph from.
Returns:
Graph: The created graph.
"""
if "data" in payload:
payload = payload["data"]
try:
nodes = payload["nodes"]
edges = payload["edges"]
return cls(nodes, edges)
except KeyError as exc:
raise ValueError("Invalid payload") from exc
def _build_graph(self) -> None:
"""Builds the graph from the nodes and edges."""
self.nodes = self._build_vertices()

View file

@ -1,6 +1,7 @@
import re
from typing import Any, Union
from langflow.interface.utils import extract_input_variables_from_prompt
def validate_prompt(prompt: str):
"""Validate prompt."""
@ -15,11 +16,6 @@ def fix_prompt(prompt: str):
return prompt + " {input}"
def extract_input_variables_from_prompt(prompt: str) -> list[str]:
"""Extract input variables from prompt."""
return re.findall(r"{(.*?)}", prompt)
def flatten_list(list_of_lists: list[Union[list, Any]]) -> list:
"""Flatten list of lists."""
new_list = []

View file

@ -1,7 +1,8 @@
from typing import Any, Dict, List, Optional, Union
from langflow.graph.vertex.base import Vertex
from langflow.graph.utils import extract_input_variables_from_prompt, flatten_list
from langflow.graph.utils import flatten_list
from langflow.interface.utils import extract_input_variables_from_prompt
class AgentVertex(Vertex):

View file

@ -5,7 +5,7 @@ from langchain.memory.buffer import ConversationBufferMemory
from langchain.schema import BaseMemory
from pydantic import Field, root_validator
from langflow.graph.utils import extract_input_variables_from_prompt
from langflow.interface.utils import extract_input_variables_from_prompt
DEFAULT_SUFFIX = """"
Current conversation:

View file

@ -11,14 +11,15 @@ from langchain import (
text_splitter,
)
from langchain.agents import agent_toolkits
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.chat_models import ChatAnthropic
from langchain.chat_models import ChatOpenAI
from langflow.interface.importing.utils import import_class
## LLMs
llm_type_to_cls_dict = llms.type_to_cls_dict
llm_type_to_cls_dict["anthropic-chat"] = ChatAnthropic # type: ignore
llm_type_to_cls_dict["azure-chat"] = AzureChatOpenAI # type: ignore
llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore
## Chains

View file

@ -12,7 +12,6 @@ from langchain.agents.load_tools import (
_LLM_TOOLS,
)
from langchain.agents.loading import load_agent_from_config
from langflow.graph import Graph
from langchain.agents.tools import Tool
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
@ -22,7 +21,6 @@ from pydantic import ValidationError
from langflow.interface.agents.custom import CUSTOM_AGENTS
from langflow.interface.importing.utils import get_function, import_by_type
from langflow.interface.run import fix_memory_inputs
from langflow.interface.toolkits.base import toolkits_creator
from langflow.interface.types import get_type_list
from langflow.interface.utils import load_file_into_dict
@ -163,37 +161,6 @@ def instantiate_utility(node_type, class_object, params):
return class_object(**params)
def load_flow_from_json(path: str, build=True):
"""Load flow from json file"""
# This is done to avoid circular imports
with open(path, "r", encoding="utf-8") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
# Substitute ZeroShotPrompt with PromptTemplate
# nodes = replace_zero_shot_prompt_with_prompt_template(nodes)
# Add input variables
# nodes = payload.extract_input_variables(nodes)
# Nodes, edges and root node
edges = data_graph["edges"]
graph = Graph(nodes, edges)
if build:
langchain_object = graph.build()
if hasattr(langchain_object, "verbose"):
langchain_object.verbose = True
if hasattr(langchain_object, "return_intermediate_steps"):
# https://github.com/hwchase17/langchain/issues/2068
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = False
fix_memory_inputs(langchain_object)
return langchain_object
return graph
def replace_zero_shot_prompt_with_prompt_template(nodes):
"""Replace ZeroShotPrompt with PromptTemplate"""
for node in nodes:

View file

@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Type
from langchain.prompts import PromptTemplate
from pydantic import root_validator
from langflow.graph.utils import extract_input_variables_from_prompt
from langflow.interface.utils import extract_input_variables_from_prompt
# Steps to create a BaseCustomPrompt:
# 1. Create a prompt template that endes with:

View file

@ -1,10 +1,3 @@
import contextlib
import io
from typing import Any, Dict, List, Tuple
from langchain.schema import AgentAction
from langflow.api.callback import AsyncStreamingLLMCallbackHandler, StreamingLLMCallbackHandler # type: ignore
from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict
from langflow.graph import Graph
from langflow.utils.logger import logger
@ -24,15 +17,6 @@ def load_langchain_object(data_graph, is_first_message=False):
return computed_hash, langchain_object
def load_or_build_langchain_object(data_graph, is_first_message=False):
"""
Load langchain object from cache if it exists, otherwise build it.
"""
if is_first_message:
build_langchain_object_with_caching.clear_cache()
return build_langchain_object_with_caching(data_graph)
@memoize_dict(maxsize=10)
def build_langchain_object_with_caching(data_graph):
"""
@ -40,16 +24,10 @@ def build_langchain_object_with_caching(data_graph):
"""
logger.debug("Building langchain object")
graph = build_graph(data_graph)
graph = Graph.from_payload(data_graph)
return graph.build()
def build_graph(data_graph):
nodes = data_graph["nodes"]
edges = data_graph["edges"]
return Graph(nodes, edges)
def build_langchain_object(data_graph):
"""
Build langchain object from data_graph.
@ -66,29 +44,6 @@ def build_langchain_object(data_graph):
return graph.build()
def process_graph_cached(data_graph: Dict[str, Any], message: str):
"""
Process graph by extracting input variables and replacing ZeroShotPrompt
with PromptTemplate,then run the graph and return the result and thought.
"""
# Load langchain object
is_first_message = len(data_graph.get("chatHistory", [])) == 0
langchain_object = load_or_build_langchain_object(data_graph, is_first_message)
logger.debug("Loaded langchain object")
if langchain_object is None:
# Raise user facing error
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
# Generate result and thought
logger.debug("Generating result and thought")
result, thought = get_result_and_thought(langchain_object, message)
logger.debug("Generated result and thought")
return {"result": str(result), "thought": thought.strip()}
def get_memory_key(langchain_object):
"""
Given a LangChain object, this function retrieves the current memory key from the object's memory attribute.
@ -124,147 +79,3 @@ def update_memory_keys(langchain_object, possible_new_mem_key):
langchain_object.memory.input_key = input_key
langchain_object.memory.output_key = output_key
langchain_object.memory.memory_key = possible_new_mem_key
def fix_memory_inputs(langchain_object):
"""
Given a LangChain object, this function checks if it has a memory attribute and if that memory key exists in the
object's input variables. If so, it does nothing. Otherwise, it gets a possible new memory key using the
get_memory_key function and updates the memory keys using the update_memory_keys function.
"""
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
try:
if langchain_object.memory.memory_key in langchain_object.input_variables:
return
except AttributeError:
input_variables = (
langchain_object.prompt.input_variables
if hasattr(langchain_object, "prompt")
else langchain_object.input_keys
)
if langchain_object.memory.memory_key in input_variables:
return
possible_new_mem_key = get_memory_key(langchain_object)
if possible_new_mem_key is not None:
update_memory_keys(langchain_object, possible_new_mem_key)
async def get_result_and_steps(langchain_object, message: str, **kwargs):
"""Get result and thought from extracted json"""
try:
if hasattr(langchain_object, "verbose"):
langchain_object.verbose = True
chat_input = None
memory_key = ""
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
memory_key = langchain_object.memory.memory_key
if hasattr(langchain_object, "input_keys"):
for key in langchain_object.input_keys:
if key not in [memory_key, "chat_history"]:
chat_input = {key: message}
else:
chat_input = message # type: ignore
if hasattr(langchain_object, "return_intermediate_steps"):
# https://github.com/hwchase17/langchain/issues/2068
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = True
fix_memory_inputs(langchain_object)
try:
async_callbacks = [AsyncStreamingLLMCallbackHandler(**kwargs)]
output = await langchain_object.acall(chat_input, callbacks=async_callbacks)
except Exception as exc:
# make the error message more informative
logger.debug(f"Error: {str(exc)}")
sync_callbacks = [StreamingLLMCallbackHandler(**kwargs)]
output = langchain_object(chat_input, callbacks=sync_callbacks)
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []
)
result = (
output.get(langchain_object.output_keys[0])
if isinstance(output, dict)
else output
)
thought = format_actions(intermediate_steps) if intermediate_steps else ""
except Exception as exc:
raise ValueError(f"Error: {str(exc)}") from exc
return result, thought
def get_result_and_thought(langchain_object, message: str):
"""Get result and thought from extracted json"""
try:
if hasattr(langchain_object, "verbose"):
langchain_object.verbose = True
chat_input = None
memory_key = ""
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
memory_key = langchain_object.memory.memory_key
if hasattr(langchain_object, "input_keys"):
for key in langchain_object.input_keys:
if key not in [memory_key, "chat_history"]:
chat_input = {key: message}
else:
chat_input = message # type: ignore
if hasattr(langchain_object, "return_intermediate_steps"):
# https://github.com/hwchase17/langchain/issues/2068
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = False
fix_memory_inputs(langchain_object)
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
try:
# if hasattr(langchain_object, "acall"):
# output = await langchain_object.acall(chat_input)
# else:
output = langchain_object(chat_input)
except ValueError as exc:
# make the error message more informative
logger.debug(f"Error: {str(exc)}")
output = langchain_object.run(chat_input)
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []
)
result = (
output.get(langchain_object.output_keys[0])
if isinstance(output, dict)
else output
)
if intermediate_steps:
thought = format_actions(intermediate_steps)
else:
thought = output_buffer.getvalue()
except Exception as exc:
raise ValueError(f"Error: {str(exc)}") from exc
return result, thought
def format_actions(actions: List[Tuple[AgentAction, str]]) -> str:
"""Format a list of (AgentAction, answer) tuples into a string."""
output = []
for action, answer in actions:
log = action.log
tool = action.tool
tool_input = action.tool_input
output.append(f"Log: {log}")
if "Action" not in log and "Action Input" not in log:
output.append(f"Tool: {tool}")
output.append(f"Tool Input: {tool_input}")
output.append(f"Answer: {answer}")
output.append("") # Add a blank line
return "\n".join(output)

View file

@ -2,6 +2,7 @@ import base64
import json
import os
from io import BytesIO
import re
import yaml
from langchain.base_language import BaseLanguageModel
@ -48,3 +49,8 @@ def try_setting_streaming_options(langchain_object, websocket):
llm.streaming = True
return langchain_object
def extract_input_variables_from_prompt(prompt: str) -> list[str]:
"""Extract input variables from prompt."""
return re.findall(r"{(.*?)}", prompt)

View file

@ -1,16 +1,7 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from langflow.api.chat import router as chat_router
from langflow.api.endpoints import router as endpoints_router
from langflow.api.validate import router as validate_router
from langflow.api.database import router as database_router
from langflow.database.base import create_db_and_tables
from fastapi import APIRouter
# build router
router = APIRouter()
from langflow.api import router
def create_app():
@ -21,6 +12,10 @@ def create_app():
"*",
]
@app.get("/health")
def get_health():
return {"status": "OK"}
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
@ -29,13 +24,7 @@ def create_app():
allow_headers=["*"],
)
app.include_router(endpoints_router)
app.include_router(validate_router)
app.include_router(chat_router)
app.include_router(database_router)
app.on_event("startup")(create_db_and_tables)
app.include_router(router)
return app

View file

@ -0,0 +1,55 @@
from langflow.api.v1.callback import (
AsyncStreamingLLMCallbackHandler,
StreamingLLMCallbackHandler,
)
from langflow.processing.process import fix_memory_inputs, format_actions
from langflow.utils.logger import logger
async def get_result_and_steps(langchain_object, message: str, **kwargs):
"""Get result and thought from extracted json"""
try:
if hasattr(langchain_object, "verbose"):
langchain_object.verbose = True
chat_input = None
memory_key = ""
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
memory_key = langchain_object.memory.memory_key
if hasattr(langchain_object, "input_keys"):
for key in langchain_object.input_keys:
if key not in [memory_key, "chat_history"]:
chat_input = {key: message}
else:
chat_input = message # type: ignore
if hasattr(langchain_object, "return_intermediate_steps"):
# https://github.com/hwchase17/langchain/issues/2068
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = True
fix_memory_inputs(langchain_object)
try:
async_callbacks = [AsyncStreamingLLMCallbackHandler(**kwargs)]
output = await langchain_object.acall(chat_input, callbacks=async_callbacks)
except Exception as exc:
# make the error message more informative
logger.debug(f"Error: {str(exc)}")
sync_callbacks = [StreamingLLMCallbackHandler(**kwargs)]
output = langchain_object(chat_input, callbacks=sync_callbacks)
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []
)
result = (
output.get(langchain_object.output_keys[0])
if isinstance(output, dict)
else output
)
thought = format_actions(intermediate_steps) if intermediate_steps else ""
except Exception as exc:
raise ValueError(f"Error: {str(exc)}") from exc
return result, thought

View file

@ -0,0 +1,172 @@
import contextlib
import io
from langchain.schema import AgentAction
import json
from langflow.interface.run import (
build_langchain_object_with_caching,
get_memory_key,
update_memory_keys,
)
from langflow.utils.logger import logger
from langflow.graph import Graph
from typing import Any, Dict, List, Tuple
def fix_memory_inputs(langchain_object):
"""
Given a LangChain object, this function checks if it has a memory attribute and if that memory key exists in the
object's input variables. If so, it does nothing. Otherwise, it gets a possible new memory key using the
get_memory_key function and updates the memory keys using the update_memory_keys function.
"""
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
try:
if langchain_object.memory.memory_key in langchain_object.input_variables:
return
except AttributeError:
input_variables = (
langchain_object.prompt.input_variables
if hasattr(langchain_object, "prompt")
else langchain_object.input_keys
)
if langchain_object.memory.memory_key in input_variables:
return
possible_new_mem_key = get_memory_key(langchain_object)
if possible_new_mem_key is not None:
update_memory_keys(langchain_object, possible_new_mem_key)
def format_actions(actions: List[Tuple[AgentAction, str]]) -> str:
"""Format a list of (AgentAction, answer) tuples into a string."""
output = []
for action, answer in actions:
log = action.log
tool = action.tool
tool_input = action.tool_input
output.append(f"Log: {log}")
if "Action" not in log and "Action Input" not in log:
output.append(f"Tool: {tool}")
output.append(f"Tool Input: {tool_input}")
output.append(f"Answer: {answer}")
output.append("") # Add a blank line
return "\n".join(output)
def get_result_and_thought(langchain_object, message: str):
"""Get result and thought from extracted json"""
try:
if hasattr(langchain_object, "verbose"):
langchain_object.verbose = True
chat_input = None
memory_key = ""
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
memory_key = langchain_object.memory.memory_key
if hasattr(langchain_object, "input_keys"):
for key in langchain_object.input_keys:
if key not in [memory_key, "chat_history"]:
chat_input = {key: message}
else:
chat_input = message # type: ignore
if hasattr(langchain_object, "return_intermediate_steps"):
# https://github.com/hwchase17/langchain/issues/2068
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = False
fix_memory_inputs(langchain_object)
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
try:
# if hasattr(langchain_object, "acall"):
# output = await langchain_object.acall(chat_input)
# else:
output = langchain_object(chat_input)
except ValueError as exc:
# make the error message more informative
logger.debug(f"Error: {str(exc)}")
output = langchain_object.run(chat_input)
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []
)
result = (
output.get(langchain_object.output_keys[0])
if isinstance(output, dict)
else output
)
if intermediate_steps:
thought = format_actions(intermediate_steps)
else:
thought = output_buffer.getvalue()
except Exception as exc:
raise ValueError(f"Error: {str(exc)}") from exc
return result, thought
def load_or_build_langchain_object(data_graph, is_first_message=False):
"""
Load langchain object from cache if it exists, otherwise build it.
"""
if is_first_message:
build_langchain_object_with_caching.clear_cache()
return build_langchain_object_with_caching(data_graph)
def process_graph_cached(data_graph: Dict[str, Any], message: str):
"""
Process graph by extracting input variables and replacing ZeroShotPrompt
with PromptTemplate,then run the graph and return the result and thought.
"""
# Load langchain object
is_first_message = len(data_graph.get("chatHistory", [])) == 0
langchain_object = load_or_build_langchain_object(data_graph, is_first_message)
logger.debug("Loaded langchain object")
if langchain_object is None:
# Raise user facing error
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
# Generate result and thought
logger.debug("Generating result and thought")
result, thought = get_result_and_thought(langchain_object, message)
logger.debug("Generated result and thought")
return {"result": str(result), "thought": thought.strip()}
def load_flow_from_json(path: str, build=True):
"""Load flow from json file"""
# This is done to avoid circular imports
with open(path, "r", encoding="utf-8") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
# Substitute ZeroShotPrompt with PromptTemplate
# nodes = replace_zero_shot_prompt_with_prompt_template(nodes)
# Add input variables
# nodes = payload.extract_input_variables(nodes)
# Nodes, edges and root node
edges = data_graph["edges"]
graph = Graph(nodes, edges)
if build:
langchain_object = graph.build()
if hasattr(langchain_object, "verbose"):
langchain_object.verbose = True
if hasattr(langchain_object, "return_intermediate_steps"):
# https://github.com/hwchase17/langchain/issues/2068
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = False
fix_memory_inputs(langchain_object)
return langchain_object
return graph

View file

@ -12,6 +12,18 @@ class LLMFrontendNode(FrontendNode):
field.name.title().replace("Openai", "OpenAI").replace("_", " ")
).replace("Api", "API")
@staticmethod
def format_azure_field(field: TemplateField):
if field.name == "model_name":
field.show = False # Azure uses deployment_name instead of model_name.
if field.name == "openai_api_type":
field.show = False
field.password = False
field.value = "azure"
if field.name == "openai_api_version":
field.password = False
field.value = "2023-03-15-preview"
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
display_names_dict = {
@ -43,8 +55,16 @@ class LLMFrontendNode(FrontendNode):
field.field_type = "code"
field.advanced = True
field.show = True
elif field.name in ["model_name", "temperature", "model_file", "model_type"]:
elif field.name in [
"model_name",
"temperature",
"model_file",
"model_type",
"deployment_name",
]:
field.advanced = False
field.show = True
LLMFrontendNode.format_openai_field(field)
if "azure" in name.lower():
LLMFrontendNode.format_azure_field(field)

View file

@ -47,7 +47,7 @@ export const TabsContext = createContext<TabsContextType>(
);
export function TabsProvider({ children }: { children: ReactNode }) {
const { setNoticeData } = useContext(alertContext);
const { setErrorData, setNoticeData } = useContext(alertContext);
const [tabIndex, setTabIndex] = useState(0);
const [flows, setFlows] = useState<Array<FlowType>>([]);
const [id, setId] = useState(uuidv4());
@ -95,25 +95,25 @@ export function TabsProvider({ children }: { children: ReactNode }) {
edge.style = { stroke: "#555555" };
});
flow.data.nodes.forEach((node) => {
if (Object.keys(templates[node.data.type]["template"]).length > 0) {
node.data.node.base_classes =
templates[node.data.type]["base_classes"];
const template = templates[node.data.type];
if (!template) {
setErrorData({ title: `Unknown node type: ${node.data.type}` });
return;
}
if (Object.keys(template["template"]).length > 0) {
node.data.node.base_classes = template["base_classes"];
flow.data.edges.forEach((edge) => {
if (edge.source === node.id) {
edge.sourceHandle = edge.sourceHandle
.split("|")
.slice(0, 2)
.concat(templates[node.data.type]["base_classes"])
.concat(template["base_classes"])
.join("|");
}
});
node.data.node.description =
templates[node.data.type]["description"];
node.data.node.description = template["description"];
node.data.node.template = updateTemplate(
templates[node.data.type][
"template"
] as unknown as APITemplateType,
template["template"] as unknown as APITemplateType,
node.data.node.template as APITemplateType
);
}
@ -316,21 +316,25 @@ export function TabsProvider({ children }: { children: ReactNode }) {
edge.animated = edge.targetHandle.split("|")[0] === "Text";
});
data.nodes.forEach((node) => {
if (Object.keys(templates[node.data.type]["template"]).length > 0) {
node.data.node.base_classes =
templates[node.data.type]["base_classes"];
const template = templates[node.data.type];
if (!template) {
setErrorData({ title: `Unknown node type: ${node.data.type}` });
return;
}
if (Object.keys(template["template"]).length > 0) {
node.data.node.base_classes = template["base_classes"];
flow.data.edges.forEach((edge) => {
if (edge.source === node.id) {
edge.sourceHandle = edge.sourceHandle
.split("|")
.slice(0, 2)
.concat(templates[node.data.type]["base_classes"])
.concat(template["base_classes"])
.join("|");
}
});
node.data.node.description = templates[node.data.type]["description"];
node.data.node.description = template["description"];
node.data.node.template = updateTemplate(
templates[node.data.type]["template"] as unknown as APITemplateType,
template["template"] as unknown as APITemplateType,
node.data.node.template as APITemplateType
);
}

View file

@ -14,13 +14,13 @@ export async function sendAll(data: sendAllProps) {
export async function checkCode(
code: string
): Promise<AxiosResponse<errorsTypeAPI>> {
return await axios.post("/validate/code", { code });
return await axios.post("api/v1/validate/code", { code });
}
export async function checkPrompt(
template: string
): Promise<AxiosResponse<PromptTypeAPI>> {
return await axios.post("/validate/prompt", { template });
return await axios.post("api/v1/validate/prompt", { template });
}
export async function getExamples(): Promise<FlowType[]> {

View file

@ -182,10 +182,10 @@ export default function ChatModal({
try {
const urlWs =
process.env.NODE_ENV === "development"
? `ws://localhost:7860/chat/${id.current}`
? `ws://localhost:7860/api/v1/chat/${id.current}`
: `${window.location.protocol === "https:" ? "wss" : "ws"}://${
window.location.host
}/chat/${id.current}`;
}api/v1/chat/${id.current}`;
const newWs = new WebSocket(urlWs);
newWs.onopen = () => {
console.log("WebSocket connection established!");

View file

@ -11,7 +11,7 @@ const apiRoutes = [
];
// Use environment variable to determine the target.
const target = process.env.VITE_PROXY_TARGET || "http://127.0.0.1:7860";
const target = process.env.VITE_PROXY_TARGET || "http://127.0.0.1:7860/api/v1";
const proxyTargets = apiRoutes.reduce((proxyObj, route) => {
proxyObj[route] = {