FEAT: NEW WORKFLOW ENGINE (#3160)
Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
parent
2fb9850af5
commit
7753ba2d37
1161 changed files with 103836 additions and 10327 deletions
|
|
@ -178,7 +178,7 @@ class AccountService:
|
|||
|
||||
@staticmethod
|
||||
def close_account(account: Account) -> None:
|
||||
"""todo: Close account"""
|
||||
"""Close account"""
|
||||
account.status = AccountStatus.CLOSED.value
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -445,7 +445,7 @@ class RegisterService:
|
|||
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.rollback() # todo: do not work
|
||||
db.session.rollback()
|
||||
logging.error(f'Register failed: {e}')
|
||||
raise AccountRegisterError(f'Registration failed: {e}') from e
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
import copy
|
||||
|
||||
from core.prompt.advanced_prompt_templates import (
|
||||
from core.prompt.prompt_templates.advanced_prompt_templates import (
|
||||
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
|
||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
|
||||
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
|
||||
|
|
@ -13,7 +13,7 @@ from core.prompt.advanced_prompt_templates import (
|
|||
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
|
||||
CONTEXT,
|
||||
)
|
||||
from core.prompt.prompt_transform import AppMode
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class AdvancedPromptTemplateService:
|
||||
|
|
|
|||
123
api/services/agent_service.py
Normal file
123
api/services/agent_service.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
|
||||
|
||||
|
||||
class AgentService:
|
||||
@classmethod
|
||||
def get_agent_logs(cls, app_model: App,
|
||||
conversation_id: str,
|
||||
message_id: str) -> dict:
|
||||
"""
|
||||
Service to get agent logs
|
||||
"""
|
||||
conversation: Conversation = db.session.query(Conversation).filter(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
).first()
|
||||
|
||||
if not conversation:
|
||||
raise ValueError(f"Conversation not found: {conversation_id}")
|
||||
|
||||
message: Message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.conversation_id == conversation_id,
|
||||
).first()
|
||||
|
||||
if not message:
|
||||
raise ValueError(f"Message not found: {message_id}")
|
||||
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
|
||||
if conversation.from_end_user_id:
|
||||
# only select name field
|
||||
executor = db.session.query(EndUser, EndUser.name).filter(
|
||||
EndUser.id == conversation.from_end_user_id
|
||||
).first()
|
||||
else:
|
||||
executor = db.session.query(Account, Account.name).filter(
|
||||
Account.id == conversation.from_account_id
|
||||
).first()
|
||||
|
||||
if executor:
|
||||
executor = executor.name
|
||||
else:
|
||||
executor = 'Unknown'
|
||||
|
||||
result = {
|
||||
'meta': {
|
||||
'status': 'success',
|
||||
'executor': executor,
|
||||
'start_time': message.created_at.isoformat(),
|
||||
'elapsed_time': message.provider_response_latency,
|
||||
'total_tokens': message.answer_tokens + message.message_tokens,
|
||||
'agent_mode': app_model.app_model_config.agent_mode_dict.get('strategy', 'react'),
|
||||
'iterations': len(agent_thoughts),
|
||||
},
|
||||
'iterations': [],
|
||||
'files': message.files,
|
||||
}
|
||||
|
||||
agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict())
|
||||
agent_tools = agent_config.tools
|
||||
|
||||
def find_agent_tool(tool_name: str):
|
||||
for agent_tool in agent_tools:
|
||||
if agent_tool.tool_name == tool_name:
|
||||
return agent_tool
|
||||
|
||||
for agent_thought in agent_thoughts:
|
||||
tools = agent_thought.tools
|
||||
tool_labels = agent_thought.tool_labels
|
||||
tool_meta = agent_thought.tool_meta
|
||||
tool_inputs = agent_thought.tool_inputs_dict
|
||||
tool_outputs = agent_thought.tool_outputs_dict
|
||||
tool_calls = []
|
||||
for tool in tools:
|
||||
tool_name = tool
|
||||
tool_label = tool_labels.get(tool_name, tool_name)
|
||||
tool_input = tool_inputs.get(tool_name, {})
|
||||
tool_output = tool_outputs.get(tool_name, {})
|
||||
tool_meta_data = tool_meta.get(tool_name, {})
|
||||
tool_config = tool_meta_data.get('tool_config', {})
|
||||
tool_icon = ToolManager.get_tool_icon(
|
||||
tenant_id=app_model.tenant_id,
|
||||
provider_type=tool_config.get('tool_provider_type', ''),
|
||||
provider_id=tool_config.get('tool_provider', ''),
|
||||
)
|
||||
if not tool_icon:
|
||||
tool_entity = find_agent_tool(tool_name)
|
||||
if tool_entity:
|
||||
tool_icon = ToolManager.get_tool_icon(
|
||||
tenant_id=app_model.tenant_id,
|
||||
provider_type=tool_entity.provider_type,
|
||||
provider_id=tool_entity.provider_id,
|
||||
)
|
||||
|
||||
tool_calls.append({
|
||||
'status': 'success' if not tool_meta_data.get('error') else 'error',
|
||||
'error': tool_meta_data.get('error'),
|
||||
'time_cost': tool_meta_data.get('time_cost', 0),
|
||||
'tool_name': tool_name,
|
||||
'tool_label': tool_label,
|
||||
'tool_input': tool_input,
|
||||
'tool_output': tool_output,
|
||||
'tool_parameters': tool_meta_data.get('tool_parameters', {}),
|
||||
'tool_icon': tool_icon,
|
||||
})
|
||||
|
||||
result['iterations'].append({
|
||||
'tokens': agent_thought.tokens,
|
||||
'tool_calls': tool_calls,
|
||||
'tool_raw': {
|
||||
'inputs': agent_thought.tool_input,
|
||||
'outputs': agent_thought.observation,
|
||||
},
|
||||
'thought': agent_thought.thought,
|
||||
'created_at': agent_thought.created_at.isoformat(),
|
||||
'files': agent_thought.files,
|
||||
})
|
||||
|
||||
return result
|
||||
121
api/services/app_generate_service.py
Normal file
121
api/services/app_generate_service.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
|
||||
from core.app.apps.chat.app_generator import ChatAppGenerator
|
||||
from core.app.apps.completion.app_generator import CompletionAppGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class AppGenerateService:
|
||||
|
||||
@classmethod
|
||||
def generate(cls, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True) -> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
App Content Generate
|
||||
:param app_model: app model
|
||||
:param user: user
|
||||
:param args: args
|
||||
:param invoke_from: invoke from
|
||||
:param streaming: streaming
|
||||
:return:
|
||||
"""
|
||||
if app_model.mode == AppMode.COMPLETION.value:
|
||||
return CompletionAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
)
|
||||
elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
||||
return AgentChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
)
|
||||
elif app_model.mode == AppMode.CHAT.value:
|
||||
return ChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
)
|
||||
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
workflow = cls._get_workflow(app_model, invoke_from)
|
||||
return AdvancedChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
workflow = cls._get_workflow(app_model, invoke_from)
|
||||
return WorkflowAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Invalid app mode {app_model.mode}')
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
|
||||
message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
|
||||
-> Union[dict, Generator]:
|
||||
"""
|
||||
Generate more like this
|
||||
:param app_model: app model
|
||||
:param user: user
|
||||
:param message_id: message id
|
||||
:param invoke_from: invoke from
|
||||
:param streaming: streaming
|
||||
:return:
|
||||
"""
|
||||
return CompletionAppGenerator().generate_more_like_this(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Any:
|
||||
"""
|
||||
Get workflow
|
||||
:param app_model: app model
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
workflow_service = WorkflowService()
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
# fetch draft workflow by app_model
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
else:
|
||||
# fetch published workflow by app_model
|
||||
workflow = workflow_service.get_published_workflow(app_model=app_model)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not published')
|
||||
|
||||
return workflow
|
||||
|
|
@ -1,527 +1,18 @@
|
|||
import re
|
||||
import uuid
|
||||
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.moderation.factory import ModerationFactory
|
||||
from core.prompt.prompt_transform import AppMode
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class AppModelConfigService:
|
||||
@classmethod
|
||||
def is_dataset_exists(cls, account: Account, dataset_id: str) -> bool:
|
||||
# verify if the dataset ID exists
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
if not dataset:
|
||||
return False
|
||||
|
||||
if dataset.tenant_id != account.current_tenant_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def validate_model_completion_params(cls, cp: dict, model_name: str) -> dict:
|
||||
# 6. model.completion_params
|
||||
if not isinstance(cp, dict):
|
||||
raise ValueError("model.completion_params must be of object type")
|
||||
|
||||
# stop
|
||||
if 'stop' not in cp:
|
||||
cp["stop"] = []
|
||||
elif not isinstance(cp["stop"], list):
|
||||
raise ValueError("stop in model.completion_params must be of list type")
|
||||
|
||||
if len(cp["stop"]) > 4:
|
||||
raise ValueError("stop sequences must be less than 4")
|
||||
|
||||
return cp
|
||||
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, account: Account, config: dict, app_mode: str) -> dict:
|
||||
# opening_statement
|
||||
if 'opening_statement' not in config or not config["opening_statement"]:
|
||||
config["opening_statement"] = ""
|
||||
|
||||
if not isinstance(config["opening_statement"], str):
|
||||
raise ValueError("opening_statement must be of string type")
|
||||
|
||||
# suggested_questions
|
||||
if 'suggested_questions' not in config or not config["suggested_questions"]:
|
||||
config["suggested_questions"] = []
|
||||
|
||||
if not isinstance(config["suggested_questions"], list):
|
||||
raise ValueError("suggested_questions must be of list type")
|
||||
|
||||
for question in config["suggested_questions"]:
|
||||
if not isinstance(question, str):
|
||||
raise ValueError("Elements in suggested_questions list must be of string type")
|
||||
|
||||
# suggested_questions_after_answer
|
||||
if 'suggested_questions_after_answer' not in config or not config["suggested_questions_after_answer"]:
|
||||
config["suggested_questions_after_answer"] = {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
if not isinstance(config["suggested_questions_after_answer"], dict):
|
||||
raise ValueError("suggested_questions_after_answer must be of dict type")
|
||||
|
||||
if "enabled" not in config["suggested_questions_after_answer"] or not config["suggested_questions_after_answer"]["enabled"]:
|
||||
config["suggested_questions_after_answer"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
|
||||
raise ValueError("enabled in suggested_questions_after_answer must be of boolean type")
|
||||
|
||||
# speech_to_text
|
||||
if 'speech_to_text' not in config or not config["speech_to_text"]:
|
||||
config["speech_to_text"] = {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
if not isinstance(config["speech_to_text"], dict):
|
||||
raise ValueError("speech_to_text must be of dict type")
|
||||
|
||||
if "enabled" not in config["speech_to_text"] or not config["speech_to_text"]["enabled"]:
|
||||
config["speech_to_text"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["speech_to_text"]["enabled"], bool):
|
||||
raise ValueError("enabled in speech_to_text must be of boolean type")
|
||||
|
||||
# text_to_speech
|
||||
if 'text_to_speech' not in config or not config["text_to_speech"]:
|
||||
config["text_to_speech"] = {
|
||||
"enabled": False,
|
||||
"voice": "",
|
||||
"language": ""
|
||||
}
|
||||
|
||||
if not isinstance(config["text_to_speech"], dict):
|
||||
raise ValueError("text_to_speech must be of dict type")
|
||||
|
||||
if "enabled" not in config["text_to_speech"] or not config["text_to_speech"]["enabled"]:
|
||||
config["text_to_speech"]["enabled"] = False
|
||||
config["text_to_speech"]["voice"] = ""
|
||||
config["text_to_speech"]["language"] = ""
|
||||
|
||||
if not isinstance(config["text_to_speech"]["enabled"], bool):
|
||||
raise ValueError("enabled in text_to_speech must be of boolean type")
|
||||
|
||||
# return retriever resource
|
||||
if 'retriever_resource' not in config or not config["retriever_resource"]:
|
||||
config["retriever_resource"] = {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
if not isinstance(config["retriever_resource"], dict):
|
||||
raise ValueError("retriever_resource must be of dict type")
|
||||
|
||||
if "enabled" not in config["retriever_resource"] or not config["retriever_resource"]["enabled"]:
|
||||
config["retriever_resource"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["retriever_resource"]["enabled"], bool):
|
||||
raise ValueError("enabled in retriever_resource must be of boolean type")
|
||||
|
||||
# more_like_this
|
||||
if 'more_like_this' not in config or not config["more_like_this"]:
|
||||
config["more_like_this"] = {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
if not isinstance(config["more_like_this"], dict):
|
||||
raise ValueError("more_like_this must be of dict type")
|
||||
|
||||
if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]:
|
||||
config["more_like_this"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["more_like_this"]["enabled"], bool):
|
||||
raise ValueError("enabled in more_like_this must be of boolean type")
|
||||
|
||||
# model
|
||||
if 'model' not in config:
|
||||
raise ValueError("model is required")
|
||||
|
||||
if not isinstance(config["model"], dict):
|
||||
raise ValueError("model must be of object type")
|
||||
|
||||
# model.provider
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
model_provider_names = [provider.provider for provider in provider_entities]
|
||||
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
# model.name
|
||||
if 'name' not in config["model"]:
|
||||
raise ValueError("model.name is required")
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
models = provider_manager.get_configurations(tenant_id).get_models(
|
||||
provider=config["model"]["provider"],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
if not models:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
model_ids = [m.model for m in models]
|
||||
if config["model"]["name"] not in model_ids:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
model_mode = None
|
||||
for model in models:
|
||||
if model.model == config["model"]["name"]:
|
||||
model_mode = model.model_properties.get(ModelPropertyKey.MODE)
|
||||
break
|
||||
|
||||
# model.mode
|
||||
if model_mode:
|
||||
config['model']["mode"] = model_mode
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict:
|
||||
if app_mode == AppMode.CHAT:
|
||||
return ChatAppConfigManager.config_validate(tenant_id, config)
|
||||
elif app_mode == AppMode.AGENT_CHAT:
|
||||
return AgentChatAppConfigManager.config_validate(tenant_id, config)
|
||||
elif app_mode == AppMode.COMPLETION:
|
||||
return CompletionAppConfigManager.config_validate(tenant_id, config)
|
||||
else:
|
||||
config['model']["mode"] = "completion"
|
||||
|
||||
# model.completion_params
|
||||
if 'completion_params' not in config["model"]:
|
||||
raise ValueError("model.completion_params is required")
|
||||
|
||||
config["model"]["completion_params"] = cls.validate_model_completion_params(
|
||||
config["model"]["completion_params"],
|
||||
config["model"]["name"]
|
||||
)
|
||||
|
||||
# user_input_form
|
||||
if "user_input_form" not in config or not config["user_input_form"]:
|
||||
config["user_input_form"] = []
|
||||
|
||||
if not isinstance(config["user_input_form"], list):
|
||||
raise ValueError("user_input_form must be a list of objects")
|
||||
|
||||
variables = []
|
||||
for item in config["user_input_form"]:
|
||||
key = list(item.keys())[0]
|
||||
if key not in ["text-input", "select", "paragraph", "external_data_tool"]:
|
||||
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
|
||||
|
||||
form_item = item[key]
|
||||
if 'label' not in form_item:
|
||||
raise ValueError("label is required in user_input_form")
|
||||
|
||||
if not isinstance(form_item["label"], str):
|
||||
raise ValueError("label in user_input_form must be of string type")
|
||||
|
||||
if 'variable' not in form_item:
|
||||
raise ValueError("variable is required in user_input_form")
|
||||
|
||||
if not isinstance(form_item["variable"], str):
|
||||
raise ValueError("variable in user_input_form must be of string type")
|
||||
|
||||
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
|
||||
if pattern.match(form_item["variable"]) is None:
|
||||
raise ValueError("variable in user_input_form must be a string, "
|
||||
"and cannot start with a number")
|
||||
|
||||
variables.append(form_item["variable"])
|
||||
|
||||
if 'required' not in form_item or not form_item["required"]:
|
||||
form_item["required"] = False
|
||||
|
||||
if not isinstance(form_item["required"], bool):
|
||||
raise ValueError("required in user_input_form must be of boolean type")
|
||||
|
||||
if key == "select":
|
||||
if 'options' not in form_item or not form_item["options"]:
|
||||
form_item["options"] = []
|
||||
|
||||
if not isinstance(form_item["options"], list):
|
||||
raise ValueError("options in user_input_form must be a list of strings")
|
||||
|
||||
if "default" in form_item and form_item['default'] \
|
||||
and form_item["default"] not in form_item["options"]:
|
||||
raise ValueError("default value in user_input_form must be in the options list")
|
||||
|
||||
# pre_prompt
|
||||
if "pre_prompt" not in config or not config["pre_prompt"]:
|
||||
config["pre_prompt"] = ""
|
||||
|
||||
if not isinstance(config["pre_prompt"], str):
|
||||
raise ValueError("pre_prompt must be of string type")
|
||||
|
||||
# agent_mode
|
||||
if "agent_mode" not in config or not config["agent_mode"]:
|
||||
config["agent_mode"] = {
|
||||
"enabled": False,
|
||||
"tools": []
|
||||
}
|
||||
|
||||
if not isinstance(config["agent_mode"], dict):
|
||||
raise ValueError("agent_mode must be of object type")
|
||||
|
||||
if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]:
|
||||
config["agent_mode"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["agent_mode"]["enabled"], bool):
|
||||
raise ValueError("enabled in agent_mode must be of boolean type")
|
||||
|
||||
if "strategy" not in config["agent_mode"] or not config["agent_mode"]["strategy"]:
|
||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
||||
|
||||
if config["agent_mode"]["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
|
||||
raise ValueError("strategy in agent_mode must be in the specified strategy list")
|
||||
|
||||
if "tools" not in config["agent_mode"] or not config["agent_mode"]["tools"]:
|
||||
config["agent_mode"]["tools"] = []
|
||||
|
||||
if not isinstance(config["agent_mode"]["tools"], list):
|
||||
raise ValueError("tools in agent_mode must be a list of objects")
|
||||
|
||||
for tool in config["agent_mode"]["tools"]:
|
||||
key = list(tool.keys())[0]
|
||||
if key in SUPPORT_TOOLS:
|
||||
# old style, use tool name as key
|
||||
tool_item = tool[key]
|
||||
|
||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
tool_item["enabled"] = False
|
||||
|
||||
if not isinstance(tool_item["enabled"], bool):
|
||||
raise ValueError("enabled in agent_mode.tools must be of boolean type")
|
||||
|
||||
if key == "dataset":
|
||||
if 'id' not in tool_item:
|
||||
raise ValueError("id is required in dataset")
|
||||
|
||||
try:
|
||||
uuid.UUID(tool_item["id"])
|
||||
except ValueError:
|
||||
raise ValueError("id in dataset must be of UUID type")
|
||||
|
||||
if not cls.is_dataset_exists(account, tool_item["id"]):
|
||||
raise ValueError("Dataset ID does not exist, please check your permission.")
|
||||
else:
|
||||
# latest style, use key-value pair
|
||||
if "enabled" not in tool or not tool["enabled"]:
|
||||
tool["enabled"] = False
|
||||
if "provider_type" not in tool:
|
||||
raise ValueError("provider_type is required in agent_mode.tools")
|
||||
if "provider_id" not in tool:
|
||||
raise ValueError("provider_id is required in agent_mode.tools")
|
||||
if "tool_name" not in tool:
|
||||
raise ValueError("tool_name is required in agent_mode.tools")
|
||||
if "tool_parameters" not in tool:
|
||||
raise ValueError("tool_parameters is required in agent_mode.tools")
|
||||
|
||||
# dataset_query_variable
|
||||
cls.is_dataset_query_variable_valid(config, app_mode)
|
||||
|
||||
# advanced prompt validation
|
||||
cls.is_advanced_prompt_valid(config, app_mode)
|
||||
|
||||
# external data tools validation
|
||||
cls.is_external_data_tools_valid(tenant_id, config)
|
||||
|
||||
# moderation validation
|
||||
cls.is_moderation_valid(tenant_id, config)
|
||||
|
||||
# file upload validation
|
||||
cls.is_file_upload_valid(config)
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_config = {
|
||||
"opening_statement": config["opening_statement"],
|
||||
"suggested_questions": config["suggested_questions"],
|
||||
"suggested_questions_after_answer": config["suggested_questions_after_answer"],
|
||||
"speech_to_text": config["speech_to_text"],
|
||||
"text_to_speech": config["text_to_speech"],
|
||||
"retriever_resource": config["retriever_resource"],
|
||||
"more_like_this": config["more_like_this"],
|
||||
"sensitive_word_avoidance": config["sensitive_word_avoidance"],
|
||||
"external_data_tools": config["external_data_tools"],
|
||||
"model": {
|
||||
"provider": config["model"]["provider"],
|
||||
"name": config["model"]["name"],
|
||||
"mode": config['model']["mode"],
|
||||
"completion_params": config["model"]["completion_params"]
|
||||
},
|
||||
"user_input_form": config["user_input_form"],
|
||||
"dataset_query_variable": config.get('dataset_query_variable'),
|
||||
"pre_prompt": config["pre_prompt"],
|
||||
"agent_mode": config["agent_mode"],
|
||||
"prompt_type": config["prompt_type"],
|
||||
"chat_prompt_config": config["chat_prompt_config"],
|
||||
"completion_prompt_config": config["completion_prompt_config"],
|
||||
"dataset_configs": config["dataset_configs"],
|
||||
"file_upload": config["file_upload"]
|
||||
}
|
||||
|
||||
return filtered_config
|
||||
|
||||
@classmethod
|
||||
def is_moderation_valid(cls, tenant_id: str, config: dict):
|
||||
if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]:
|
||||
config["sensitive_word_avoidance"] = {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"], dict):
|
||||
raise ValueError("sensitive_word_avoidance must be of dict type")
|
||||
|
||||
if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]:
|
||||
config["sensitive_word_avoidance"]["enabled"] = False
|
||||
|
||||
if not config["sensitive_word_avoidance"]["enabled"]:
|
||||
return
|
||||
|
||||
if "type" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["type"]:
|
||||
raise ValueError("sensitive_word_avoidance.type is required")
|
||||
|
||||
type = config["sensitive_word_avoidance"]["type"]
|
||||
config = config["sensitive_word_avoidance"]["config"]
|
||||
|
||||
ModerationFactory.validate_config(
|
||||
name=type,
|
||||
tenant_id=tenant_id,
|
||||
config=config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_file_upload_valid(cls, config: dict):
|
||||
if 'file_upload' not in config or not config["file_upload"]:
|
||||
config["file_upload"] = {}
|
||||
|
||||
if not isinstance(config["file_upload"], dict):
|
||||
raise ValueError("file_upload must be of dict type")
|
||||
|
||||
# check image config
|
||||
if 'image' not in config["file_upload"] or not config["file_upload"]["image"]:
|
||||
config["file_upload"]["image"] = {"enabled": False}
|
||||
|
||||
if config['file_upload']['image']['enabled']:
|
||||
number_limits = config['file_upload']['image']['number_limits']
|
||||
if number_limits < 1 or number_limits > 6:
|
||||
raise ValueError("number_limits must be in [1, 6]")
|
||||
|
||||
detail = config['file_upload']['image']['detail']
|
||||
if detail not in ['high', 'low']:
|
||||
raise ValueError("detail must be in ['high', 'low']")
|
||||
|
||||
transfer_methods = config['file_upload']['image']['transfer_methods']
|
||||
if not isinstance(transfer_methods, list):
|
||||
raise ValueError("transfer_methods must be of list type")
|
||||
for method in transfer_methods:
|
||||
if method not in ['remote_url', 'local_file']:
|
||||
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
|
||||
|
||||
@classmethod
|
||||
def is_external_data_tools_valid(cls, tenant_id: str, config: dict):
|
||||
if 'external_data_tools' not in config or not config["external_data_tools"]:
|
||||
config["external_data_tools"] = []
|
||||
|
||||
if not isinstance(config["external_data_tools"], list):
|
||||
raise ValueError("external_data_tools must be of list type")
|
||||
|
||||
for tool in config["external_data_tools"]:
|
||||
if "enabled" not in tool or not tool["enabled"]:
|
||||
tool["enabled"] = False
|
||||
|
||||
if not tool["enabled"]:
|
||||
continue
|
||||
|
||||
if "type" not in tool or not tool["type"]:
|
||||
raise ValueError("external_data_tools[].type is required")
|
||||
|
||||
type = tool["type"]
|
||||
config = tool["config"]
|
||||
|
||||
ExternalDataToolFactory.validate_config(
|
||||
name=type,
|
||||
tenant_id=tenant_id,
|
||||
config=config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None:
|
||||
# Only check when mode is completion
|
||||
if mode != 'completion':
|
||||
return
|
||||
|
||||
agent_mode = config.get("agent_mode", {})
|
||||
tools = agent_mode.get("tools", [])
|
||||
dataset_exists = "dataset" in str(tools)
|
||||
|
||||
dataset_query_variable = config.get("dataset_query_variable")
|
||||
|
||||
if dataset_exists and not dataset_query_variable:
|
||||
raise ValueError("Dataset query variable is required when dataset is exist")
|
||||
|
||||
@classmethod
|
||||
def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None:
|
||||
# prompt_type
|
||||
if 'prompt_type' not in config or not config["prompt_type"]:
|
||||
config["prompt_type"] = "simple"
|
||||
|
||||
if config['prompt_type'] not in ['simple', 'advanced']:
|
||||
raise ValueError("prompt_type must be in ['simple', 'advanced']")
|
||||
|
||||
# chat_prompt_config
|
||||
if 'chat_prompt_config' not in config or not config["chat_prompt_config"]:
|
||||
config["chat_prompt_config"] = {}
|
||||
|
||||
if not isinstance(config["chat_prompt_config"], dict):
|
||||
raise ValueError("chat_prompt_config must be of object type")
|
||||
|
||||
# completion_prompt_config
|
||||
if 'completion_prompt_config' not in config or not config["completion_prompt_config"]:
|
||||
config["completion_prompt_config"] = {}
|
||||
|
||||
if not isinstance(config["completion_prompt_config"], dict):
|
||||
raise ValueError("completion_prompt_config must be of object type")
|
||||
|
||||
# dataset_configs
|
||||
if 'dataset_configs' not in config or not config["dataset_configs"]:
|
||||
config["dataset_configs"] = {'retrieval_model': 'single'}
|
||||
|
||||
if 'datasets' not in config["dataset_configs"] or not config["dataset_configs"]["datasets"]:
|
||||
config["dataset_configs"]["datasets"] = {
|
||||
"strategy": "router",
|
||||
"datasets": []
|
||||
}
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
if config["dataset_configs"]['retrieval_model'] == 'multiple':
|
||||
if not config["dataset_configs"]['reranking_model']:
|
||||
raise ValueError("reranking_model has not been set")
|
||||
if not isinstance(config["dataset_configs"]['reranking_model'], dict):
|
||||
raise ValueError("reranking_model must be of object type")
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
if config['prompt_type'] == 'advanced':
|
||||
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
|
||||
raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced")
|
||||
|
||||
if config['model']["mode"] not in ['chat', 'completion']:
|
||||
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
|
||||
|
||||
if app_mode == AppMode.CHAT.value and config['model']["mode"] == "completion":
|
||||
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
|
||||
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
|
||||
|
||||
if not user_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
|
||||
|
||||
if not assistant_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
|
||||
|
||||
if config['model']["mode"] == "chat":
|
||||
prompt_list = config['chat_prompt_config']['prompt']
|
||||
|
||||
if len(prompt_list) > 10:
|
||||
raise ValueError("prompt messages must be less than 10")
|
||||
raise ValueError(f"Invalid app mode: {app_mode}")
|
||||
|
|
|
|||
400
api/services/app_service.py
Normal file
400
api/services/app_service.py
Normal file
|
|
@ -0,0 +1,400 @@
|
|||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
import yaml
|
||||
from flask import current_app
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
|
||||
from constants.model_template import default_app_templates
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from events.app_event import app_model_config_was_updated, app_was_created, app_was_deleted
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.tools import ApiToolProvider
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class AppService:
|
||||
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination:
|
||||
"""
|
||||
Get app list with pagination
|
||||
:param tenant_id: tenant id
|
||||
:param args: request args
|
||||
:return:
|
||||
"""
|
||||
filters = [
|
||||
App.tenant_id == tenant_id,
|
||||
App.is_universal == False
|
||||
]
|
||||
|
||||
if args['mode'] == 'workflow':
|
||||
filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value]))
|
||||
elif args['mode'] == 'chat':
|
||||
filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value]))
|
||||
elif args['mode'] == 'agent-chat':
|
||||
filters.append(App.mode == AppMode.AGENT_CHAT.value)
|
||||
elif args['mode'] == 'channel':
|
||||
filters.append(App.mode == AppMode.CHANNEL.value)
|
||||
|
||||
if 'name' in args and args['name']:
|
||||
name = args['name'][:30]
|
||||
filters.append(App.name.ilike(f'%{name}%'))
|
||||
|
||||
app_models = db.paginate(
|
||||
db.select(App).where(*filters).order_by(App.created_at.desc()),
|
||||
page=args['page'],
|
||||
per_page=args['limit'],
|
||||
error_out=False
|
||||
)
|
||||
|
||||
return app_models
|
||||
|
||||
def create_app(self, tenant_id: str, args: dict, account: Account) -> App:
|
||||
"""
|
||||
Create app
|
||||
:param tenant_id: tenant id
|
||||
:param args: request args
|
||||
:param account: Account instance
|
||||
"""
|
||||
app_mode = AppMode.value_of(args['mode'])
|
||||
app_template = default_app_templates[app_mode]
|
||||
|
||||
# get model config
|
||||
default_model_config = app_template.get('model_config')
|
||||
default_model_config = default_model_config.copy() if default_model_config else None
|
||||
if default_model_config and 'model' in default_model_config:
|
||||
# get model provider
|
||||
model_manager = ModelManager()
|
||||
|
||||
# get default model instance
|
||||
try:
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=account.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
except (ProviderTokenNotInitError, LLMBadRequestError):
|
||||
model_instance = None
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
model_instance = None
|
||||
|
||||
if model_instance:
|
||||
if model_instance.model == default_model_config['model']['name']:
|
||||
default_model_dict = default_model_config['model']
|
||||
else:
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
|
||||
default_model_dict = {
|
||||
'provider': model_instance.provider,
|
||||
'name': model_instance.model,
|
||||
'mode': model_schema.model_properties.get(ModelPropertyKey.MODE),
|
||||
'completion_params': {}
|
||||
}
|
||||
else:
|
||||
default_model_dict = default_model_config['model']
|
||||
|
||||
default_model_config['model'] = json.dumps(default_model_dict)
|
||||
|
||||
app = App(**app_template['app'])
|
||||
app.name = args['name']
|
||||
app.description = args.get('description', '')
|
||||
app.mode = args['mode']
|
||||
app.icon = args['icon']
|
||||
app.icon_background = args['icon_background']
|
||||
app.tenant_id = tenant_id
|
||||
|
||||
db.session.add(app)
|
||||
db.session.flush()
|
||||
|
||||
if default_model_config:
|
||||
app_model_config = AppModelConfig(**default_model_config)
|
||||
app_model_config.app_id = app.id
|
||||
db.session.add(app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
app.app_model_config_id = app_model_config.id
|
||||
|
||||
db.session.commit()
|
||||
|
||||
app_was_created.send(app, account=account)
|
||||
|
||||
return app
|
||||
|
||||
def import_app(self, tenant_id: str, data: str, args: dict, account: Account) -> App:
|
||||
"""
|
||||
Import app
|
||||
:param tenant_id: tenant id
|
||||
:param data: import data
|
||||
:param args: request args
|
||||
:param account: Account instance
|
||||
"""
|
||||
try:
|
||||
import_data = yaml.safe_load(data)
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError("Invalid YAML format in data argument.")
|
||||
|
||||
app_data = import_data.get('app')
|
||||
model_config_data = import_data.get('model_config')
|
||||
workflow = import_data.get('workflow')
|
||||
|
||||
if not app_data:
|
||||
raise ValueError("Missing app in data argument")
|
||||
|
||||
app_mode = AppMode.value_of(app_data.get('mode'))
|
||||
if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
|
||||
if not workflow:
|
||||
raise ValueError("Missing workflow in data argument "
|
||||
"when app mode is advanced-chat or workflow")
|
||||
elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]:
|
||||
if not model_config_data:
|
||||
raise ValueError("Missing model_config in data argument "
|
||||
"when app mode is chat, agent-chat or completion")
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant_id,
|
||||
mode=app_data.get('mode'),
|
||||
name=args.get("name") if args.get("name") else app_data.get('name'),
|
||||
description=args.get("description") if args.get("description") else app_data.get('description', ''),
|
||||
icon=args.get("icon") if args.get("icon") else app_data.get('icon'),
|
||||
icon_background=args.get("icon_background") if args.get("icon_background") \
|
||||
else app_data.get('icon_background'),
|
||||
enable_site=True,
|
||||
enable_api=True
|
||||
)
|
||||
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
|
||||
app_was_created.send(app, account=account)
|
||||
|
||||
if workflow:
|
||||
# init draft workflow
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app,
|
||||
graph=workflow.get('graph'),
|
||||
features=workflow.get('features'),
|
||||
account=account
|
||||
)
|
||||
workflow_service.publish_workflow(
|
||||
app_model=app,
|
||||
account=account,
|
||||
draft_workflow=draft_workflow
|
||||
)
|
||||
|
||||
if model_config_data:
|
||||
app_model_config = AppModelConfig()
|
||||
app_model_config = app_model_config.from_model_config_dict(model_config_data)
|
||||
app_model_config.app_id = app.id
|
||||
|
||||
db.session.add(app_model_config)
|
||||
db.session.commit()
|
||||
|
||||
app.app_model_config_id = app_model_config.id
|
||||
|
||||
app_model_config_was_updated.send(
|
||||
app,
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
def export_app(self, app: App) -> str:
|
||||
"""
|
||||
Export app
|
||||
:param app: App instance
|
||||
:return:
|
||||
"""
|
||||
app_mode = AppMode.value_of(app.mode)
|
||||
|
||||
export_data = {
|
||||
"app": {
|
||||
"name": app.name,
|
||||
"mode": app.mode,
|
||||
"icon": app.icon,
|
||||
"icon_background": app.icon_background
|
||||
}
|
||||
}
|
||||
|
||||
if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_draft_workflow(app)
|
||||
export_data['workflow'] = {
|
||||
"graph": workflow.graph_dict,
|
||||
"features": workflow.features_dict
|
||||
}
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
|
||||
export_data['model_config'] = app_model_config.to_dict()
|
||||
|
||||
return yaml.dump(export_data)
|
||||
|
||||
def update_app(self, app: App, args: dict) -> App:
|
||||
"""
|
||||
Update app
|
||||
:param app: App instance
|
||||
:param args: request args
|
||||
:return: App instance
|
||||
"""
|
||||
app.name = args.get('name')
|
||||
app.description = args.get('description', '')
|
||||
app.icon = args.get('icon')
|
||||
app.icon_background = args.get('icon_background')
|
||||
app.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
return app
|
||||
|
||||
def update_app_name(self, app: App, name: str) -> App:
|
||||
"""
|
||||
Update app name
|
||||
:param app: App instance
|
||||
:param name: new name
|
||||
:return: App instance
|
||||
"""
|
||||
app.name = name
|
||||
app.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
return app
|
||||
|
||||
def update_app_icon(self, app: App, icon: str, icon_background: str) -> App:
|
||||
"""
|
||||
Update app icon
|
||||
:param app: App instance
|
||||
:param icon: new icon
|
||||
:param icon_background: new icon_background
|
||||
:return: App instance
|
||||
"""
|
||||
app.icon = icon
|
||||
app.icon_background = icon_background
|
||||
app.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
return app
|
||||
|
||||
def update_app_site_status(self, app: App, enable_site: bool) -> App:
|
||||
"""
|
||||
Update app site status
|
||||
:param app: App instance
|
||||
:param enable_site: enable site status
|
||||
:return: App instance
|
||||
"""
|
||||
if enable_site == app.enable_site:
|
||||
return app
|
||||
|
||||
app.enable_site = enable_site
|
||||
app.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
return app
|
||||
|
||||
def update_app_api_status(self, app: App, enable_api: bool) -> App:
|
||||
"""
|
||||
Update app api status
|
||||
:param app: App instance
|
||||
:param enable_api: enable api status
|
||||
:return: App instance
|
||||
"""
|
||||
if enable_api == app.enable_api:
|
||||
return app
|
||||
|
||||
app.enable_api = enable_api
|
||||
app.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
return app
|
||||
|
||||
def delete_app(self, app: App) -> None:
|
||||
"""
|
||||
Delete app
|
||||
:param app: App instance
|
||||
"""
|
||||
db.session.delete(app)
|
||||
db.session.commit()
|
||||
|
||||
app_was_deleted.send(app)
|
||||
|
||||
# todo async delete related data by event
|
||||
# app_model_configs, site, api_tokens, installed_apps, recommended_apps BY app
|
||||
# app_annotation_hit_histories, app_annotation_settings, app_dataset_joins BY app
|
||||
# workflows, workflow_runs, workflow_node_executions, workflow_app_logs BY app
|
||||
# conversations, pinned_conversations, messages BY app
|
||||
# message_feedbacks, message_annotations, message_chains BY message
|
||||
# message_agent_thoughts, message_files, saved_messages BY message
|
||||
|
||||
def get_app_meta(self, app_model: App) -> dict:
|
||||
"""
|
||||
Get app meta info
|
||||
:param app_model: app model
|
||||
:return:
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
||||
meta = {
|
||||
'tool_icons': {}
|
||||
}
|
||||
|
||||
if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
return meta
|
||||
|
||||
graph = workflow.graph_dict
|
||||
nodes = graph.get('nodes', [])
|
||||
tools = []
|
||||
for node in nodes:
|
||||
if node.get('data', {}).get('type') == 'tool':
|
||||
node_data = node.get('data', {})
|
||||
tools.append({
|
||||
'provider_type': node_data.get('provider_type'),
|
||||
'provider_id': node_data.get('provider_id'),
|
||||
'tool_name': node_data.get('tool_name'),
|
||||
'tool_parameters': {}
|
||||
})
|
||||
else:
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config:
|
||||
return meta
|
||||
|
||||
agent_config = app_model_config.agent_mode_dict or {}
|
||||
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/")
|
||||
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
# current tool standard
|
||||
provider_type = tool.get('provider_type')
|
||||
provider_id = tool.get('provider_id')
|
||||
tool_name = tool.get('tool_name')
|
||||
if provider_type == 'builtin':
|
||||
meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon'
|
||||
elif provider_type == 'api':
|
||||
try:
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.id == provider_id
|
||||
)
|
||||
meta['tool_icons'][tool_name] = json.loads(provider.icon)
|
||||
except:
|
||||
meta['tool_icons'][tool_name] = {
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
|
||||
return meta
|
||||
|
|
@ -5,6 +5,7 @@ from werkzeug.datastructures import FileStorage
|
|||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
|
|
@ -20,7 +21,21 @@ ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm', 'amr']
|
|||
|
||||
class AudioService:
|
||||
@classmethod
|
||||
def transcript_asr(cls, tenant_id: str, file: FileStorage, end_user: Optional[str] = None):
|
||||
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None):
|
||||
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
if 'speech_to_text' not in features_dict or not features_dict['speech_to_text'].get('enabled'):
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
else:
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
|
||||
if file is None:
|
||||
raise NoAudioUploadedServiceError()
|
||||
|
||||
|
|
@ -37,7 +52,7 @@ class AudioService:
|
|||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
model_type=ModelType.SPEECH2TEXT
|
||||
)
|
||||
if model_instance is None:
|
||||
|
|
@ -49,17 +64,42 @@ class AudioService:
|
|||
return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)}
|
||||
|
||||
@classmethod
|
||||
def transcript_tts(cls, tenant_id: str, text: str, voice: str, streaming: bool, end_user: Optional[str] = None):
|
||||
def transcript_tts(cls, app_model: App, text: str, streaming: bool,
|
||||
voice: Optional[str] = None, end_user: Optional[str] = None):
|
||||
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise ValueError("TTS is not enabled")
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'):
|
||||
raise ValueError("TTS is not enabled")
|
||||
|
||||
voice = features_dict['text_to_speech'].get('voice') if voice is None else voice
|
||||
else:
|
||||
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
|
||||
|
||||
if not text_to_speech_dict.get('enabled'):
|
||||
raise ValueError("TTS is not enabled")
|
||||
|
||||
voice = text_to_speech_dict.get('voice') if voice is None else voice
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
model_type=ModelType.TTS
|
||||
)
|
||||
if model_instance is None:
|
||||
raise ProviderNotSupportTextToSpeechServiceError()
|
||||
|
||||
try:
|
||||
return model_instance.invoke_tts(content_text=text.strip(), user=end_user, streaming=streaming, tenant_id=tenant_id, voice=voice)
|
||||
return model_instance.invoke_tts(
|
||||
content_text=text.strip(),
|
||||
user=end_user,
|
||||
streaming=streaming,
|
||||
tenant_id=app_model.tenant_id,
|
||||
voice=voice
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
|
|
|||
|
|
@ -1,258 +0,0 @@
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from sqlalchemy import and_
|
||||
|
||||
from core.application_manager import ApplicationManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from extensions.ext_database import db
|
||||
from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
class CompletionService:
|
||||
|
||||
@classmethod
|
||||
def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any,
|
||||
invoke_from: InvokeFrom, streaming: bool = True,
|
||||
is_model_config_override: bool = False) -> Union[dict, Generator]:
|
||||
# is streaming mode
|
||||
inputs = args['inputs']
|
||||
query = args['query']
|
||||
files = args['files'] if 'files' in args and args['files'] else []
|
||||
auto_generate_name = args['auto_generate_name'] \
|
||||
if 'auto_generate_name' in args else True
|
||||
|
||||
if app_model.mode != 'completion':
|
||||
if not query:
|
||||
raise ValueError('query is required')
|
||||
|
||||
if query:
|
||||
if not isinstance(query, str):
|
||||
raise ValueError('query must be a string')
|
||||
|
||||
query = query.replace('\x00', '')
|
||||
|
||||
conversation_id = args['conversation_id'] if 'conversation_id' in args else None
|
||||
|
||||
conversation = None
|
||||
if conversation_id:
|
||||
conversation_filter = [
|
||||
Conversation.id == args['conversation_id'],
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.status == 'normal'
|
||||
]
|
||||
|
||||
if isinstance(user, Account):
|
||||
conversation_filter.append(Conversation.from_account_id == user.id)
|
||||
else:
|
||||
conversation_filter.append(Conversation.from_end_user_id == user.id if user else None)
|
||||
|
||||
conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first()
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
|
||||
if conversation.status != 'normal':
|
||||
raise ConversationCompletedError()
|
||||
|
||||
if not conversation.override_model_configs:
|
||||
app_model_config = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == conversation.app_model_config_id,
|
||||
AppModelConfig.app_id == app_model.id
|
||||
).first()
|
||||
|
||||
if not app_model_config:
|
||||
raise AppModelConfigBrokenError()
|
||||
else:
|
||||
conversation_override_model_configs = json.loads(conversation.override_model_configs)
|
||||
|
||||
app_model_config = AppModelConfig(
|
||||
id=conversation.app_model_config_id,
|
||||
app_id=app_model.id,
|
||||
)
|
||||
|
||||
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
|
||||
|
||||
if is_model_config_override:
|
||||
# build new app model config
|
||||
if 'model' not in args['model_config']:
|
||||
raise ValueError('model_config.model is required')
|
||||
|
||||
if 'completion_params' not in args['model_config']['model']:
|
||||
raise ValueError('model_config.model.completion_params is required')
|
||||
|
||||
completion_params = AppModelConfigService.validate_model_completion_params(
|
||||
cp=args['model_config']['model']['completion_params'],
|
||||
model_name=app_model_config.model_dict["name"]
|
||||
)
|
||||
|
||||
app_model_config_model = app_model_config.model_dict
|
||||
app_model_config_model['completion_params'] = completion_params
|
||||
app_model_config.retriever_resource = json.dumps({'enabled': True})
|
||||
|
||||
app_model_config = app_model_config.copy()
|
||||
app_model_config.model = json.dumps(app_model_config_model)
|
||||
else:
|
||||
if app_model.app_model_config_id is None:
|
||||
raise AppModelConfigBrokenError()
|
||||
|
||||
app_model_config = app_model.app_model_config
|
||||
|
||||
if not app_model_config:
|
||||
raise AppModelConfigBrokenError()
|
||||
|
||||
if is_model_config_override:
|
||||
if not isinstance(user, Account):
|
||||
raise Exception("Only account can override model config")
|
||||
|
||||
# validate config
|
||||
model_config = AppModelConfigService.validate_configuration(
|
||||
tenant_id=app_model.tenant_id,
|
||||
account=user,
|
||||
config=args['model_config'],
|
||||
app_mode=app_model.mode
|
||||
)
|
||||
|
||||
app_model_config = AppModelConfig(
|
||||
id=app_model_config.id,
|
||||
app_id=app_model.id,
|
||||
)
|
||||
|
||||
app_model_config = app_model_config.from_model_config_dict(model_config)
|
||||
|
||||
# clean input by app_model_config form rules
|
||||
inputs = cls.get_cleaned_inputs(inputs, app_model_config)
|
||||
|
||||
# parse files
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
app_model_config,
|
||||
user
|
||||
)
|
||||
|
||||
application_manager = ApplicationManager()
|
||||
return application_manager.generate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=app_model_config.to_dict(),
|
||||
app_model_config_override=is_model_config_override,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=file_objs,
|
||||
conversation=conversation,
|
||||
stream=streaming,
|
||||
extras={
|
||||
"auto_generate_conversation_name": auto_generate_name
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
|
||||
message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
|
||||
-> Union[dict, Generator]:
|
||||
if not user:
|
||||
raise ValueError('user cannot be None')
|
||||
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
).first()
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
|
||||
current_app_model_config = app_model.app_model_config
|
||||
more_like_this = current_app_model_config.more_like_this_dict
|
||||
|
||||
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
|
||||
raise MoreLikeThisDisabledError()
|
||||
|
||||
app_model_config = message.app_model_config
|
||||
model_dict = app_model_config.model_dict
|
||||
completion_params = model_dict.get('completion_params')
|
||||
completion_params['temperature'] = 0.9
|
||||
model_dict['completion_params'] = completion_params
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
|
||||
# parse files
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
message.files, app_model_config
|
||||
)
|
||||
|
||||
application_manager = ApplicationManager()
|
||||
return application_manager.generate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=app_model_config.to_dict(),
|
||||
app_model_config_override=True,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
files=file_objs,
|
||||
conversation=None,
|
||||
stream=streaming,
|
||||
extras={
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
|
||||
if user_inputs is None:
|
||||
user_inputs = {}
|
||||
|
||||
filtered_inputs = {}
|
||||
|
||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||
input_form_config = app_model_config.user_input_form_list
|
||||
for config in input_form_config:
|
||||
input_config = list(config.values())[0]
|
||||
variable = input_config["variable"]
|
||||
|
||||
input_type = list(config.keys())[0]
|
||||
|
||||
if variable not in user_inputs or not user_inputs[variable]:
|
||||
if input_type == "external_data_tool":
|
||||
continue
|
||||
if "required" in input_config and input_config["required"]:
|
||||
raise ValueError(f"{variable} is required in input form")
|
||||
else:
|
||||
filtered_inputs[variable] = input_config["default"] if "default" in input_config else ""
|
||||
continue
|
||||
|
||||
value = user_inputs[variable]
|
||||
|
||||
if value:
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"{variable} in input form must be a string")
|
||||
|
||||
if input_type == "select":
|
||||
options = input_config["options"] if "options" in input_config else []
|
||||
if value not in options:
|
||||
raise ValueError(f"{variable} in input form must be one of the following: {options}")
|
||||
else:
|
||||
if 'max_length' in input_config:
|
||||
max_length = input_config['max_length']
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
|
||||
|
||||
filtered_inputs[variable] = value.replace('\x00', '') if value else None
|
||||
|
||||
return filtered_inputs
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional, Union
|
||||
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.account import Account
|
||||
|
|
|
|||
|
|
@ -1,16 +1,17 @@
|
|||
import json
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.account import Account
|
||||
from models.model import App, AppModelConfig, EndUser, Message, MessageFeedback
|
||||
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
|
|
@ -18,6 +19,7 @@ from services.errors.message import (
|
|||
MessageNotExistsError,
|
||||
SuggestedQuestionsAfterAnswerDisabledError,
|
||||
)
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class MessageService:
|
||||
|
|
@ -177,7 +179,7 @@ class MessageService:
|
|||
|
||||
@classmethod
|
||||
def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Union[Account, EndUser]],
|
||||
message_id: str, check_enabled: bool = True) -> list[Message]:
|
||||
message_id: str, invoke_from: InvokeFrom) -> list[Message]:
|
||||
if not user:
|
||||
raise ValueError('user cannot be None')
|
||||
|
||||
|
|
@ -199,37 +201,57 @@ class MessageService:
|
|||
if conversation.status != 'normal':
|
||||
raise ConversationCompletedError()
|
||||
|
||||
if not conversation.override_model_configs:
|
||||
app_model_config = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == conversation.app_model_config_id,
|
||||
AppModelConfig.app_id == app_model.id
|
||||
).first()
|
||||
model_manager = ModelManager()
|
||||
|
||||
if not app_model_config:
|
||||
raise AppModelConfigBrokenError()
|
||||
else:
|
||||
conversation_override_model_configs = json.loads(conversation.override_model_configs)
|
||||
app_model_config = AppModelConfig(
|
||||
id=conversation.app_model_config_id,
|
||||
app_id=app_model.id,
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
workflow_service = WorkflowService()
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
else:
|
||||
workflow = workflow_service.get_published_workflow(app_model=app_model)
|
||||
|
||||
if workflow is None:
|
||||
return []
|
||||
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
|
||||
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
|
||||
if not app_config.additional_features.suggested_questions_after_answer:
|
||||
raise SuggestedQuestionsAfterAnswerDisabledError()
|
||||
|
||||
suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=app_model.tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
else:
|
||||
if not conversation.override_model_configs:
|
||||
app_model_config = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == conversation.app_model_config_id,
|
||||
AppModelConfig.app_id == app_model.id
|
||||
).first()
|
||||
else:
|
||||
conversation_override_model_configs = json.loads(conversation.override_model_configs)
|
||||
app_model_config = AppModelConfig(
|
||||
id=conversation.app_model_config_id,
|
||||
app_id=app_model.id,
|
||||
)
|
||||
|
||||
if check_enabled and suggested_questions_after_answer.get("enabled", False) is False:
|
||||
raise SuggestedQuestionsAfterAnswerDisabledError()
|
||||
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
|
||||
|
||||
suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
|
||||
if suggested_questions_after_answer.get("enabled", False) is False:
|
||||
raise SuggestedQuestionsAfterAnswerDisabledError()
|
||||
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=app_model.tenant_id,
|
||||
provider=app_model_config.model_dict['provider'],
|
||||
model_type=ModelType.LLM,
|
||||
model=app_model_config.model_dict['name']
|
||||
)
|
||||
|
||||
# get memory of conversation (read-only)
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=app_model.tenant_id,
|
||||
provider=app_model_config.model_dict['provider'],
|
||||
model_type=ModelType.LLM,
|
||||
model=app_model_config.model_dict['name']
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
|
|
|
|||
251
api/services/recommended_app_service.py
Normal file
251
api/services/recommended_app_service.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
import json
|
||||
import logging
|
||||
from os import path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, RecommendedApp
|
||||
from services.app_service import AppService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RecommendedAppService:
|
||||
|
||||
builtin_data: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
def get_recommended_apps_and_categories(cls, language: str) -> dict:
|
||||
"""
|
||||
Get recommended apps and categories.
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
mode = current_app.config.get('HOSTED_FETCH_APP_TEMPLATES_MODE', 'remote')
|
||||
if mode == 'remote':
|
||||
try:
|
||||
result = cls._fetch_recommended_apps_from_dify_official(language)
|
||||
except Exception as e:
|
||||
logger.warning(f'fetch recommended apps from dify official failed: {e}, switch to built-in.')
|
||||
result = cls._fetch_recommended_apps_from_builtin(language)
|
||||
elif mode == 'db':
|
||||
result = cls._fetch_recommended_apps_from_db(language)
|
||||
elif mode == 'builtin':
|
||||
result = cls._fetch_recommended_apps_from_builtin(language)
|
||||
else:
|
||||
raise ValueError(f'invalid fetch recommended apps mode: {mode}')
|
||||
|
||||
if not result.get('recommended_apps') and language != 'en-US':
|
||||
result = cls._fetch_recommended_apps_from_builtin('en-US')
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _fetch_recommended_apps_from_db(cls, language: str) -> dict:
|
||||
"""
|
||||
Fetch recommended apps from db.
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
recommended_apps = db.session.query(RecommendedApp).filter(
|
||||
RecommendedApp.is_listed == True,
|
||||
RecommendedApp.language == language
|
||||
).all()
|
||||
|
||||
if len(recommended_apps) == 0:
|
||||
recommended_apps = db.session.query(RecommendedApp).filter(
|
||||
RecommendedApp.is_listed == True,
|
||||
RecommendedApp.language == languages[0]
|
||||
).all()
|
||||
|
||||
categories = set()
|
||||
recommended_apps_result = []
|
||||
for recommended_app in recommended_apps:
|
||||
app = recommended_app.app
|
||||
if not app or not app.is_public:
|
||||
continue
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
continue
|
||||
|
||||
recommended_app_result = {
|
||||
'id': recommended_app.id,
|
||||
'app': {
|
||||
'id': app.id,
|
||||
'name': app.name,
|
||||
'mode': app.mode,
|
||||
'icon': app.icon,
|
||||
'icon_background': app.icon_background
|
||||
},
|
||||
'app_id': recommended_app.app_id,
|
||||
'description': site.description,
|
||||
'copyright': site.copyright,
|
||||
'privacy_policy': site.privacy_policy,
|
||||
'category': recommended_app.category,
|
||||
'position': recommended_app.position,
|
||||
'is_listed': recommended_app.is_listed
|
||||
}
|
||||
recommended_apps_result.append(recommended_app_result)
|
||||
|
||||
categories.add(recommended_app.category) # add category to categories
|
||||
|
||||
return {'recommended_apps': recommended_apps_result, 'categories': list(categories)}
|
||||
|
||||
@classmethod
|
||||
def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
|
||||
"""
|
||||
Fetch recommended apps from dify official.
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
domain = current_app.config.get('HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN', 'https://tmpl.dify.ai')
|
||||
url = f'{domain}/apps?language={language}'
|
||||
response = requests.get(url, timeout=(3, 10))
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f'fetch recommended apps failed, status code: {response.status_code}')
|
||||
|
||||
return response.json()
|
||||
|
||||
@classmethod
|
||||
def _fetch_recommended_apps_from_builtin(cls, language: str) -> dict:
|
||||
"""
|
||||
Fetch recommended apps from builtin.
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
builtin_data = cls._get_builtin_data()
|
||||
return builtin_data.get('recommended_apps', {}).get(language)
|
||||
|
||||
@classmethod
|
||||
def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Get recommend app detail.
|
||||
:param app_id: app id
|
||||
:return:
|
||||
"""
|
||||
mode = current_app.config.get('HOSTED_FETCH_APP_TEMPLATES_MODE', 'remote')
|
||||
if mode == 'remote':
|
||||
try:
|
||||
result = cls._fetch_recommended_app_detail_from_dify_official(app_id)
|
||||
except Exception as e:
|
||||
logger.warning(f'fetch recommended app detail from dify official failed: {e}, switch to built-in.')
|
||||
result = cls._fetch_recommended_app_detail_from_builtin(app_id)
|
||||
elif mode == 'db':
|
||||
result = cls._fetch_recommended_app_detail_from_db(app_id)
|
||||
elif mode == 'builtin':
|
||||
result = cls._fetch_recommended_app_detail_from_builtin(app_id)
|
||||
else:
|
||||
raise ValueError(f'invalid fetch recommended app detail mode: {mode}')
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Fetch recommended app detail from dify official.
|
||||
:param app_id: App ID
|
||||
:return:
|
||||
"""
|
||||
domain = current_app.config.get('HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN', 'https://tmpl.dify.ai')
|
||||
url = f'{domain}/apps/{app_id}'
|
||||
response = requests.get(url, timeout=(3, 10))
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
|
||||
return response.json()
|
||||
|
||||
@classmethod
|
||||
def _fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Fetch recommended app detail from db.
|
||||
:param app_id: App ID
|
||||
:return:
|
||||
"""
|
||||
# is in public recommended list
|
||||
recommended_app = db.session.query(RecommendedApp).filter(
|
||||
RecommendedApp.is_listed == True,
|
||||
RecommendedApp.app_id == app_id
|
||||
).first()
|
||||
|
||||
if not recommended_app:
|
||||
return None
|
||||
|
||||
# get app detail
|
||||
app_model = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_model or not app_model.is_public:
|
||||
return None
|
||||
|
||||
app_service = AppService()
|
||||
export_str = app_service.export_app(app_model)
|
||||
|
||||
return {
|
||||
'id': app_model.id,
|
||||
'name': app_model.name,
|
||||
'icon': app_model.icon,
|
||||
'icon_background': app_model.icon_background,
|
||||
'mode': app_model.mode,
|
||||
'export_data': export_str
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Fetch recommended app detail from builtin.
|
||||
:param app_id: App ID
|
||||
:return:
|
||||
"""
|
||||
builtin_data = cls._get_builtin_data()
|
||||
return builtin_data.get('app_details', {}).get(app_id)
|
||||
|
||||
@classmethod
|
||||
def _get_builtin_data(cls) -> dict:
|
||||
"""
|
||||
Get builtin data.
|
||||
:return:
|
||||
"""
|
||||
if cls.builtin_data:
|
||||
return cls.builtin_data
|
||||
|
||||
root_path = current_app.root_path
|
||||
with open(path.join(root_path, 'constants', 'recommended_apps.json'), encoding='utf-8') as f:
|
||||
json_data = f.read()
|
||||
data = json.loads(json_data)
|
||||
cls.builtin_data = data
|
||||
|
||||
return cls.builtin_data
|
||||
|
||||
@classmethod
|
||||
def fetch_all_recommended_apps_and_export_datas(cls):
|
||||
"""
|
||||
Fetch all recommended apps and export datas
|
||||
:return:
|
||||
"""
|
||||
templates = {
|
||||
"recommended_apps": {},
|
||||
"app_details": {}
|
||||
}
|
||||
for language in languages:
|
||||
try:
|
||||
result = cls._fetch_recommended_apps_from_dify_official(language)
|
||||
except Exception as e:
|
||||
logger.warning(f'fetch recommended apps from dify official failed: {e}, skip.')
|
||||
continue
|
||||
|
||||
templates['recommended_apps'][language] = result
|
||||
|
||||
for recommended_app in result.get('recommended_apps'):
|
||||
app_id = recommended_app.get('app_id')
|
||||
|
||||
# get app detail
|
||||
app_detail = cls._fetch_recommended_app_detail_from_dify_official(app_id)
|
||||
if not app_detail:
|
||||
continue
|
||||
|
||||
templates['app_details'][app_id] = app_detail
|
||||
|
||||
return templates
|
||||
|
|
@ -1,29 +1,29 @@
|
|||
import json
|
||||
import logging
|
||||
|
||||
from flask import current_app
|
||||
from httpx import get
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ApiProviderSchemaType,
|
||||
ToolCredentialsOption,
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
)
|
||||
from core.tools.entities.user_entities import UserTool, UserToolProvider
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
||||
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
|
||||
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolConfigurationManager
|
||||
from core.tools.utils.encoder import serialize_base_model_array, serialize_base_model_dict
|
||||
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider
|
||||
from services.model_provider_service import ModelProviderService
|
||||
from services.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -36,44 +36,22 @@ class ToolManageService:
|
|||
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
result = [provider.to_dict() for provider in ToolManager.user_list_providers(
|
||||
providers = ToolManager.user_list_providers(
|
||||
user_id, tenant_id
|
||||
)]
|
||||
)
|
||||
|
||||
# add icon url prefix
|
||||
for provider in result:
|
||||
ToolManageService.repack_provider(provider)
|
||||
# add icon
|
||||
for provider in providers:
|
||||
ToolTransformService.repack_provider(provider)
|
||||
|
||||
result = [provider.to_dict() for provider in providers]
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(provider: dict):
|
||||
"""
|
||||
repack provider
|
||||
|
||||
:param provider: the provider dict
|
||||
"""
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ "/console/api/workspaces/current/tool-provider/")
|
||||
|
||||
if 'icon' in provider:
|
||||
if provider['type'] == UserToolProvider.ProviderType.BUILTIN.value:
|
||||
provider['icon'] = url_prefix + 'builtin/' + provider['name'] + '/icon'
|
||||
elif provider['type'] == UserToolProvider.ProviderType.MODEL.value:
|
||||
provider['icon'] = url_prefix + 'model/' + provider['name'] + '/icon'
|
||||
elif provider['type'] == UserToolProvider.ProviderType.API.value:
|
||||
try:
|
||||
provider['icon'] = json.loads(provider['icon'])
|
||||
except:
|
||||
provider['icon'] = {
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tool_provider_tools(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
):
|
||||
) -> list[UserTool]:
|
||||
"""
|
||||
list builtin tool provider tools
|
||||
"""
|
||||
|
|
@ -95,41 +73,11 @@ class ToolManageService:
|
|||
|
||||
result = []
|
||||
for tool in tools:
|
||||
# fork tool runtime
|
||||
tool = tool.fork_tool_runtime(meta={
|
||||
'credentials': credentials,
|
||||
'tenant_id': tenant_id,
|
||||
})
|
||||
result.append(ToolTransformService.tool_to_user_tool(
|
||||
tool=tool, credentials=credentials, tenant_id=tenant_id
|
||||
))
|
||||
|
||||
# get tool parameters
|
||||
parameters = tool.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = tool.get_runtime_parameters()
|
||||
# override parameters
|
||||
current_parameters = parameters.copy()
|
||||
for runtime_parameter in runtime_parameters:
|
||||
found = False
|
||||
for index, parameter in enumerate(current_parameters):
|
||||
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||
current_parameters[index] = runtime_parameter
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
user_tool = UserTool(
|
||||
author=tool.identity.author,
|
||||
name=tool.identity.name,
|
||||
label=tool.identity.label,
|
||||
description=tool.description.human,
|
||||
parameters=current_parameters
|
||||
)
|
||||
result.append(user_tool)
|
||||
|
||||
return json.loads(
|
||||
serialize_base_model_array(result)
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(
|
||||
|
|
@ -141,9 +89,9 @@ class ToolManageService:
|
|||
:return: the list of tool providers
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name)
|
||||
return json.loads(serialize_base_model_array([
|
||||
return jsonable_encoder([
|
||||
v for _, v in (provider.credentials_schema or {}).items()
|
||||
]))
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]:
|
||||
|
|
@ -204,14 +152,12 @@ class ToolManageService:
|
|||
),
|
||||
]
|
||||
|
||||
return json.loads(serialize_base_model_dict(
|
||||
{
|
||||
'schema_type': schema_type,
|
||||
'parameters_schema': tool_bundles,
|
||||
'credentials_schema': credentials_schema,
|
||||
'warning': warnings
|
||||
}
|
||||
))
|
||||
return jsonable_encoder({
|
||||
'schema_type': schema_type,
|
||||
'parameters_schema': tool_bundles,
|
||||
'credentials_schema': credentials_schema,
|
||||
'warning': warnings
|
||||
})
|
||||
except Exception as e:
|
||||
raise ValueError(f'invalid schema: {str(e)}')
|
||||
|
||||
|
|
@ -265,7 +211,7 @@ class ToolManageService:
|
|||
schema=schema,
|
||||
description=extra_info.get('description', ''),
|
||||
schema_type_str=schema_type,
|
||||
tools_str=serialize_base_model_array(tool_bundles),
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str={},
|
||||
privacy_policy=privacy_policy
|
||||
)
|
||||
|
|
@ -322,7 +268,7 @@ class ToolManageService:
|
|||
@staticmethod
|
||||
def list_api_tool_provider_tools(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
):
|
||||
) -> list[UserTool]:
|
||||
"""
|
||||
list api tool provider tools
|
||||
"""
|
||||
|
|
@ -334,23 +280,9 @@ class ToolManageService:
|
|||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider}')
|
||||
|
||||
return json.loads(
|
||||
serialize_base_model_array([
|
||||
UserTool(
|
||||
author=tool_bundle.author,
|
||||
name=tool_bundle.operation_id,
|
||||
label=I18nObject(
|
||||
en_US=tool_bundle.operation_id,
|
||||
zh_Hans=tool_bundle.operation_id
|
||||
),
|
||||
description=I18nObject(
|
||||
en_US=tool_bundle.summary or '',
|
||||
zh_Hans=tool_bundle.summary or ''
|
||||
),
|
||||
parameters=tool_bundle.parameters
|
||||
) for tool_bundle in provider.tools
|
||||
])
|
||||
)
|
||||
return [
|
||||
ToolTransformService.tool_to_user_tool(tool_bundle) for tool_bundle in provider.tools
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def update_builtin_tool_provider(
|
||||
|
|
@ -408,6 +340,27 @@ class ToolManageService:
|
|||
|
||||
return { 'result': 'success' }
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_credentials(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
):
|
||||
"""
|
||||
get builtin tool provider credentials
|
||||
"""
|
||||
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
).first()
|
||||
|
||||
if provider is None:
|
||||
return {}
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider.provider)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
credentials = tool_configuration.mask_tool_credentials(credentials)
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def update_api_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict,
|
||||
|
|
@ -439,7 +392,7 @@ class ToolManageService:
|
|||
provider.schema = schema
|
||||
provider.description = extra_info.get('description', '')
|
||||
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
|
||||
provider.tools_str = serialize_base_model_array(tool_bundles)
|
||||
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
|
||||
provider.privacy_policy = privacy_policy
|
||||
|
||||
if 'auth_type' not in credentials:
|
||||
|
|
@ -531,7 +484,7 @@ class ToolManageService:
|
|||
@staticmethod
|
||||
def list_model_tool_provider_tools(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
):
|
||||
) -> list[UserTool]:
|
||||
"""
|
||||
list model tool provider tools
|
||||
"""
|
||||
|
|
@ -548,9 +501,7 @@ class ToolManageService:
|
|||
) for tool in tools
|
||||
]
|
||||
|
||||
return json.loads(
|
||||
serialize_base_model_array(result)
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@staticmethod
|
||||
def delete_api_tool_provider(
|
||||
|
|
@ -619,7 +570,7 @@ class ToolManageService:
|
|||
schema=schema,
|
||||
description='',
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI.value,
|
||||
tools_str=serialize_base_model_array(tool_bundles),
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
|
||||
|
|
@ -660,3 +611,87 @@ class ToolManageService:
|
|||
return { 'error': str(e) }
|
||||
|
||||
return { 'result': result or 'empty response' }
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tools(
|
||||
user_id: str, tenant_id: str
|
||||
) -> list[UserToolProvider]:
|
||||
"""
|
||||
list builtin tools
|
||||
"""
|
||||
# get all builtin providers
|
||||
provider_controllers = ToolManager.list_builtin_providers()
|
||||
|
||||
# get all user added providers
|
||||
db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id
|
||||
).all() or []
|
||||
|
||||
# find provider
|
||||
find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
|
||||
|
||||
result: list[UserToolProvider] = []
|
||||
|
||||
for provider_controller in provider_controllers:
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=find_provider(provider_controller.identity.name),
|
||||
decrypt_credentials=True
|
||||
)
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(user_builtin_provider)
|
||||
|
||||
tools = provider_controller.get_tools()
|
||||
for tool in tools:
|
||||
user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
credentials=user_builtin_provider.original_credentials,
|
||||
))
|
||||
|
||||
result.append(user_builtin_provider)
|
||||
|
||||
return BuiltinToolProviderSort.sort(result)
|
||||
|
||||
@staticmethod
|
||||
def list_api_tools(
|
||||
user_id: str, tenant_id: str
|
||||
) -> list[UserToolProvider]:
|
||||
"""
|
||||
list api tools
|
||||
"""
|
||||
# get all api providers
|
||||
db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id
|
||||
).all() or []
|
||||
|
||||
result: list[UserToolProvider] = []
|
||||
|
||||
for provider in db_providers:
|
||||
# convert provider controller to user provider
|
||||
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller,
|
||||
db_provider=provider,
|
||||
decrypt_credentials=True
|
||||
)
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(user_provider)
|
||||
|
||||
tools = provider_controller.get_tools(
|
||||
user_id=user_id, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
for tool in tools:
|
||||
user_provider.tools.append(ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
credentials=user_provider.original_credentials,
|
||||
))
|
||||
|
||||
result.append(user_provider)
|
||||
|
||||
return result
|
||||
|
|
|
|||
267
api/services/tools_transform_service.py
Normal file
267
api/services/tools_transform_service.py
Normal file
|
|
@ -0,0 +1,267 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolParameter, ToolProviderCredentials
|
||||
from core.tools.entities.user_entities import UserTool, UserToolProvider
|
||||
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.provider.model_tool_provider import ModelToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.configuration import ToolConfigurationManager
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolTransformService:
|
||||
@staticmethod
|
||||
def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
|
||||
"""
|
||||
get tool provider icon url
|
||||
"""
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ "/console/api/workspaces/current/tool-provider/")
|
||||
|
||||
if provider_type == UserToolProvider.ProviderType.BUILTIN.value:
|
||||
return url_prefix + 'builtin/' + provider_name + '/icon'
|
||||
elif provider_type == UserToolProvider.ProviderType.MODEL.value:
|
||||
return url_prefix + 'model/' + provider_name + '/icon'
|
||||
elif provider_type == UserToolProvider.ProviderType.API.value:
|
||||
try:
|
||||
return json.loads(icon)
|
||||
except:
|
||||
return {
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
|
||||
return ''
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(provider: Union[dict, UserToolProvider]):
|
||||
"""
|
||||
repack provider
|
||||
|
||||
:param provider: the provider dict
|
||||
"""
|
||||
if isinstance(provider, dict) and 'icon' in provider:
|
||||
provider['icon'] = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider['type'],
|
||||
provider_name=provider['name'],
|
||||
icon=provider['icon']
|
||||
)
|
||||
elif isinstance(provider, UserToolProvider):
|
||||
provider.icon = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider.type.value,
|
||||
provider_name=provider.name,
|
||||
icon=provider.icon
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def builtin_provider_to_user_provider(
|
||||
provider_controller: BuiltinToolProviderController,
|
||||
db_provider: Optional[BuiltinToolProvider],
|
||||
decrypt_credentials: bool = True
|
||||
) -> UserToolProvider:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
result = UserToolProvider(
|
||||
id=provider_controller.identity.name,
|
||||
author=provider_controller.identity.author,
|
||||
name=provider_controller.identity.name,
|
||||
description=I18nObject(
|
||||
en_US=provider_controller.identity.description.en_US,
|
||||
zh_Hans=provider_controller.identity.description.zh_Hans,
|
||||
),
|
||||
icon=provider_controller.identity.icon,
|
||||
label=I18nObject(
|
||||
en_US=provider_controller.identity.label.en_US,
|
||||
zh_Hans=provider_controller.identity.label.zh_Hans,
|
||||
),
|
||||
type=UserToolProvider.ProviderType.BUILTIN,
|
||||
masked_credentials={},
|
||||
is_team_authorization=False,
|
||||
tools=[]
|
||||
)
|
||||
|
||||
# get credentials schema
|
||||
schema = provider_controller.get_credentials_schema()
|
||||
for name, value in schema.items():
|
||||
result.masked_credentials[name] = \
|
||||
ToolProviderCredentials.CredentialsType.default(value.type)
|
||||
|
||||
# check if the provider need credentials
|
||||
if not provider_controller.need_credentials:
|
||||
result.is_team_authorization = True
|
||||
result.allow_delete = False
|
||||
elif db_provider:
|
||||
result.is_team_authorization = True
|
||||
|
||||
if decrypt_credentials:
|
||||
credentials = db_provider.credentials
|
||||
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
provider_controller=provider_controller
|
||||
)
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
|
||||
|
||||
result.masked_credentials = masked_credentials
|
||||
result.original_credentials = decrypted_credentials
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def api_provider_to_controller(
|
||||
db_provider: ApiToolProvider,
|
||||
) -> ApiBasedToolProviderController:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
# package tool provider controller
|
||||
controller = ApiBasedToolProviderController.from_db(
|
||||
db_provider=db_provider,
|
||||
auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else
|
||||
ApiProviderAuthType.NONE
|
||||
)
|
||||
|
||||
return controller
|
||||
|
||||
@staticmethod
|
||||
def api_provider_to_user_provider(
|
||||
provider_controller: ApiBasedToolProviderController,
|
||||
db_provider: ApiToolProvider,
|
||||
decrypt_credentials: bool = True
|
||||
) -> UserToolProvider:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
username = 'Anonymous'
|
||||
try:
|
||||
username = db_provider.user.name
|
||||
except Exception as e:
|
||||
logger.error(f'failed to get user name for api provider {db_provider.id}: {str(e)}')
|
||||
# add provider into providers
|
||||
credentials = db_provider.credentials
|
||||
result = UserToolProvider(
|
||||
id=db_provider.id,
|
||||
author=username,
|
||||
name=db_provider.name,
|
||||
description=I18nObject(
|
||||
en_US=db_provider.description,
|
||||
zh_Hans=db_provider.description,
|
||||
),
|
||||
icon=db_provider.icon,
|
||||
label=I18nObject(
|
||||
en_US=db_provider.name,
|
||||
zh_Hans=db_provider.name,
|
||||
),
|
||||
type=UserToolProvider.ProviderType.API,
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
tools=[]
|
||||
)
|
||||
|
||||
if decrypt_credentials:
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
provider_controller=provider_controller
|
||||
)
|
||||
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
|
||||
|
||||
result.masked_credentials = masked_credentials
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def model_provider_to_user_provider(
|
||||
db_provider: ModelToolProviderController,
|
||||
) -> UserToolProvider:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
return UserToolProvider(
|
||||
id=db_provider.identity.name,
|
||||
author=db_provider.identity.author,
|
||||
name=db_provider.identity.name,
|
||||
description=I18nObject(
|
||||
en_US=db_provider.identity.description.en_US,
|
||||
zh_Hans=db_provider.identity.description.zh_Hans,
|
||||
),
|
||||
icon=db_provider.identity.icon,
|
||||
label=I18nObject(
|
||||
en_US=db_provider.identity.label.en_US,
|
||||
zh_Hans=db_provider.identity.label.zh_Hans,
|
||||
),
|
||||
type=UserToolProvider.ProviderType.MODEL,
|
||||
masked_credentials={},
|
||||
is_team_authorization=db_provider.is_active,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def tool_to_user_tool(
|
||||
tool: Union[ApiBasedToolBundle, Tool], credentials: dict = None, tenant_id: str = None
|
||||
) -> UserTool:
|
||||
"""
|
||||
convert tool to user tool
|
||||
"""
|
||||
if isinstance(tool, Tool):
|
||||
# fork tool runtime
|
||||
tool = tool.fork_tool_runtime(meta={
|
||||
'credentials': credentials,
|
||||
'tenant_id': tenant_id,
|
||||
})
|
||||
|
||||
# get tool parameters
|
||||
parameters = tool.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = tool.get_runtime_parameters() or []
|
||||
# override parameters
|
||||
current_parameters = parameters.copy()
|
||||
for runtime_parameter in runtime_parameters:
|
||||
found = False
|
||||
for index, parameter in enumerate(current_parameters):
|
||||
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||
current_parameters[index] = runtime_parameter
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
user_tool = UserTool(
|
||||
author=tool.identity.author,
|
||||
name=tool.identity.name,
|
||||
label=tool.identity.label,
|
||||
description=tool.description.human,
|
||||
parameters=current_parameters
|
||||
)
|
||||
|
||||
return user_tool
|
||||
|
||||
if isinstance(tool, ApiBasedToolBundle):
|
||||
return UserTool(
|
||||
author=tool.author,
|
||||
name=tool.operation_id,
|
||||
label=I18nObject(
|
||||
en_US=tool.operation_id,
|
||||
zh_Hans=tool.operation_id
|
||||
),
|
||||
description=I18nObject(
|
||||
en_US=tool.summary or '',
|
||||
zh_Hans=tool.summary or ''
|
||||
),
|
||||
parameters=tool.parameters
|
||||
)
|
||||
0
api/services/workflow/__init__.py
Normal file
0
api/services/workflow/__init__.py
Normal file
685
api/services/workflow/workflow_converter.py
Normal file
685
api/services/workflow/workflow_converter.py
Normal file
|
|
@ -0,0 +1,685 @@
|
|||
import json
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
DatasetRetrieveConfigEntity,
|
||||
EasyUIBasedAppConfig,
|
||||
ExternalDataVariableEntity,
|
||||
FileExtraConfig,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
VariableEntity,
|
||||
)
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from core.helper import encrypter
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.simple_prompt_transform import SimplePromptTransform
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
|
||||
|
||||
class WorkflowConverter:
|
||||
"""
|
||||
App Convert to Workflow Mode
|
||||
"""
|
||||
|
||||
def convert_to_workflow(self, app_model: App,
|
||||
account: Account,
|
||||
name: str,
|
||||
icon: str,
|
||||
icon_background: str) -> App:
|
||||
"""
|
||||
Convert app to workflow
|
||||
|
||||
- basic mode of chatbot app
|
||||
|
||||
- expert mode of chatbot app
|
||||
|
||||
- completion app
|
||||
|
||||
:param app_model: App instance
|
||||
:param account: Account
|
||||
:param name: new app name
|
||||
:param icon: new app icon
|
||||
:param icon_background: new app icon background
|
||||
:return: new App instance
|
||||
"""
|
||||
# convert app model config
|
||||
workflow = self.convert_app_model_config_to_workflow(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model.app_model_config,
|
||||
account_id=account.id
|
||||
)
|
||||
|
||||
# create new app
|
||||
new_app = App()
|
||||
new_app.tenant_id = app_model.tenant_id
|
||||
new_app.name = name if name else app_model.name + '(workflow)'
|
||||
new_app.mode = AppMode.ADVANCED_CHAT.value \
|
||||
if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
|
||||
new_app.icon = icon if icon else app_model.icon
|
||||
new_app.icon_background = icon_background if icon_background else app_model.icon_background
|
||||
new_app.enable_site = app_model.enable_site
|
||||
new_app.enable_api = app_model.enable_api
|
||||
new_app.api_rpm = app_model.api_rpm
|
||||
new_app.api_rph = app_model.api_rph
|
||||
new_app.is_demo = False
|
||||
new_app.is_public = app_model.is_public
|
||||
db.session.add(new_app)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
|
||||
workflow.app_id = new_app.id
|
||||
db.session.commit()
|
||||
|
||||
app_was_created.send(new_app, account=account)
|
||||
|
||||
return new_app
|
||||
|
||||
def convert_app_model_config_to_workflow(self, app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
account_id: str) -> Workflow:
|
||||
"""
|
||||
Convert app model config to workflow mode
|
||||
:param app_model: App instance
|
||||
:param app_model_config: AppModelConfig instance
|
||||
:param account_id: Account ID
|
||||
:return:
|
||||
"""
|
||||
# get new app mode
|
||||
new_app_mode = self._get_new_app_mode(app_model)
|
||||
|
||||
# convert app model config
|
||||
app_config = self._convert_to_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
|
||||
# init workflow graph
|
||||
graph = {
|
||||
"nodes": [],
|
||||
"edges": []
|
||||
}
|
||||
|
||||
# Convert list:
|
||||
# - variables -> start
|
||||
# - model_config -> llm
|
||||
# - prompt_template -> llm
|
||||
# - file_upload -> llm
|
||||
# - external_data_variables -> http-request
|
||||
# - dataset -> knowledge-retrieval
|
||||
# - show_retrieve_source -> knowledge-retrieval
|
||||
|
||||
# convert to start node
|
||||
start_node = self._convert_to_start_node(
|
||||
variables=app_config.variables
|
||||
)
|
||||
|
||||
graph['nodes'].append(start_node)
|
||||
|
||||
# convert to http request node
|
||||
external_data_variable_node_mapping = {}
|
||||
if app_config.external_data_variables:
|
||||
http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node(
|
||||
app_model=app_model,
|
||||
variables=app_config.variables,
|
||||
external_data_variables=app_config.external_data_variables
|
||||
)
|
||||
|
||||
for http_request_node in http_request_nodes:
|
||||
graph = self._append_node(graph, http_request_node)
|
||||
|
||||
# convert to knowledge retrieval node
|
||||
if app_config.dataset:
|
||||
knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=new_app_mode,
|
||||
dataset_config=app_config.dataset,
|
||||
model_config=app_config.model
|
||||
)
|
||||
|
||||
if knowledge_retrieval_node:
|
||||
graph = self._append_node(graph, knowledge_retrieval_node)
|
||||
|
||||
# convert to llm node
|
||||
llm_node = self._convert_to_llm_node(
|
||||
original_app_mode=AppMode.value_of(app_model.mode),
|
||||
new_app_mode=new_app_mode,
|
||||
graph=graph,
|
||||
model_config=app_config.model,
|
||||
prompt_template=app_config.prompt_template,
|
||||
file_upload=app_config.additional_features.file_upload,
|
||||
external_data_variable_node_mapping=external_data_variable_node_mapping
|
||||
)
|
||||
|
||||
graph = self._append_node(graph, llm_node)
|
||||
|
||||
if new_app_mode == AppMode.WORKFLOW:
|
||||
# convert to end node by app mode
|
||||
end_node = self._convert_to_end_node()
|
||||
graph = self._append_node(graph, end_node)
|
||||
else:
|
||||
answer_node = self._convert_to_answer_node()
|
||||
graph = self._append_node(graph, answer_node)
|
||||
|
||||
app_model_config_dict = app_config.app_model_config_dict
|
||||
|
||||
# features
|
||||
if new_app_mode == AppMode.ADVANCED_CHAT:
|
||||
features = {
|
||||
"opening_statement": app_model_config_dict.get("opening_statement"),
|
||||
"suggested_questions": app_model_config_dict.get("suggested_questions"),
|
||||
"suggested_questions_after_answer": app_model_config_dict.get("suggested_questions_after_answer"),
|
||||
"speech_to_text": app_model_config_dict.get("speech_to_text"),
|
||||
"text_to_speech": app_model_config_dict.get("text_to_speech"),
|
||||
"file_upload": app_model_config_dict.get("file_upload"),
|
||||
"sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"),
|
||||
"retriever_resource": app_model_config_dict.get("retriever_resource"),
|
||||
}
|
||||
else:
|
||||
features = {
|
||||
"text_to_speech": app_model_config_dict.get("text_to_speech"),
|
||||
"file_upload": app_model_config_dict.get("file_upload"),
|
||||
"sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"),
|
||||
}
|
||||
|
||||
# create workflow record
|
||||
workflow = Workflow(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type=WorkflowType.from_app_mode(new_app_mode).value,
|
||||
version='draft',
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps(features),
|
||||
created_by=account_id
|
||||
)
|
||||
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
|
||||
return workflow
|
||||
|
||||
def _convert_to_app_config(self, app_model: App,
|
||||
app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
app_model.mode = AppMode.AGENT_CHAT.value
|
||||
app_config = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
elif app_mode == AppMode.CHAT:
|
||||
app_config = ChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
elif app_mode == AppMode.COMPLETION:
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
|
||||
return app_config
|
||||
|
||||
def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict:
|
||||
"""
|
||||
Convert to Start Node
|
||||
:param variables: list of variables
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"id": "start",
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": "START",
|
||||
"type": NodeType.START.value,
|
||||
"variables": [jsonable_encoder(v) for v in variables]
|
||||
}
|
||||
}
|
||||
|
||||
def _convert_to_http_request_node(self, app_model: App,
|
||||
variables: list[VariableEntity],
|
||||
external_data_variables: list[ExternalDataVariableEntity]) \
|
||||
-> tuple[list[dict], dict[str, str]]:
|
||||
"""
|
||||
Convert API Based Extension to HTTP Request Node
|
||||
:param app_model: App instance
|
||||
:param variables: list of variables
|
||||
:param external_data_variables: list of external data variables
|
||||
:return:
|
||||
"""
|
||||
index = 1
|
||||
nodes = []
|
||||
external_data_variable_node_mapping = {}
|
||||
tenant_id = app_model.tenant_id
|
||||
for external_data_variable in external_data_variables:
|
||||
tool_type = external_data_variable.type
|
||||
if tool_type != "api":
|
||||
continue
|
||||
|
||||
tool_variable = external_data_variable.variable
|
||||
tool_config = external_data_variable.config
|
||||
|
||||
# get params from config
|
||||
api_based_extension_id = tool_config.get("api_based_extension_id")
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = self._get_api_based_extension(
|
||||
tenant_id=tenant_id,
|
||||
api_based_extension_id=api_based_extension_id
|
||||
)
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, "
|
||||
"error: api_based_extension_id is invalid"
|
||||
.format(tool_variable))
|
||||
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=tenant_id,
|
||||
token=api_based_extension.api_key
|
||||
)
|
||||
|
||||
inputs = {}
|
||||
for v in variables:
|
||||
inputs[v.variable] = '{{#start.' + v.variable + '#}}'
|
||||
|
||||
request_body = {
|
||||
'point': APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value,
|
||||
'params': {
|
||||
'app_id': app_model.id,
|
||||
'tool_variable': tool_variable,
|
||||
'inputs': inputs,
|
||||
'query': '{{#sys.query#}}' if app_model.mode == AppMode.CHAT.value else ''
|
||||
}
|
||||
}
|
||||
|
||||
request_body_json = json.dumps(request_body)
|
||||
request_body_json = request_body_json.replace('\{\{', '{{').replace('\}\}', '}}')
|
||||
|
||||
http_request_node = {
|
||||
"id": f"http_request_{index}",
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": f"HTTP REQUEST {api_based_extension.name}",
|
||||
"type": NodeType.HTTP_REQUEST.value,
|
||||
"method": "post",
|
||||
"url": api_based_extension.api_endpoint,
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "bearer",
|
||||
"api_key": api_key
|
||||
}
|
||||
},
|
||||
"headers": "",
|
||||
"params": "",
|
||||
"body": {
|
||||
"type": "json",
|
||||
"data": request_body_json
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nodes.append(http_request_node)
|
||||
|
||||
# append code node for response body parsing
|
||||
code_node = {
|
||||
"id": f"code_{index}",
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": f"Parse {api_based_extension.name} Response",
|
||||
"type": NodeType.CODE.value,
|
||||
"variables": [{
|
||||
"variable": "response_json",
|
||||
"value_selector": [http_request_node['id'], "body"]
|
||||
}],
|
||||
"code_language": "python3",
|
||||
"code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads("
|
||||
"response_json)\n return {\n \"result\": response_body[\"result\"]\n }",
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nodes.append(code_node)
|
||||
|
||||
external_data_variable_node_mapping[external_data_variable.variable] = code_node['id']
|
||||
index += 1
|
||||
|
||||
return nodes, external_data_variable_node_mapping
|
||||
|
||||
def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode,
|
||||
dataset_config: DatasetEntity,
|
||||
model_config: ModelConfigEntity) \
|
||||
-> Optional[dict]:
|
||||
"""
|
||||
Convert datasets to Knowledge Retrieval Node
|
||||
:param new_app_mode: new app mode
|
||||
:param dataset_config: dataset
|
||||
:param model_config: model config
|
||||
:return:
|
||||
"""
|
||||
retrieve_config = dataset_config.retrieve_config
|
||||
if new_app_mode == AppMode.ADVANCED_CHAT:
|
||||
query_variable_selector = ["sys", "query"]
|
||||
elif retrieve_config.query_variable:
|
||||
# fetch query variable
|
||||
query_variable_selector = ["start", retrieve_config.query_variable]
|
||||
else:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": "knowledge_retrieval",
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": "KNOWLEDGE RETRIEVAL",
|
||||
"type": NodeType.KNOWLEDGE_RETRIEVAL.value,
|
||||
"query_variable_selector": query_variable_selector,
|
||||
"dataset_ids": dataset_config.dataset_ids,
|
||||
"retrieval_mode": retrieve_config.retrieve_strategy.value,
|
||||
"single_retrieval_config": {
|
||||
"model": {
|
||||
"provider": model_config.provider,
|
||||
"name": model_config.model,
|
||||
"mode": model_config.mode,
|
||||
"completion_params": {
|
||||
**model_config.parameters,
|
||||
"stop": model_config.stop,
|
||||
}
|
||||
}
|
||||
}
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
|
||||
else None,
|
||||
"multiple_retrieval_config": {
|
||||
"top_k": retrieve_config.top_k,
|
||||
"score_threshold": retrieve_config.score_threshold,
|
||||
"reranking_model": retrieve_config.reranking_model
|
||||
}
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE
|
||||
else None,
|
||||
}
|
||||
}
|
||||
|
||||
def _convert_to_llm_node(self, original_app_mode: AppMode,
|
||||
new_app_mode: AppMode,
|
||||
graph: dict,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template: PromptTemplateEntity,
|
||||
file_upload: Optional[FileExtraConfig] = None,
|
||||
external_data_variable_node_mapping: dict[str, str] = None) -> dict:
|
||||
"""
|
||||
Convert to LLM Node
|
||||
:param original_app_mode: original app mode
|
||||
:param new_app_mode: new app mode
|
||||
:param graph: graph
|
||||
:param model_config: model config
|
||||
:param prompt_template: prompt template
|
||||
:param file_upload: file upload config (optional)
|
||||
:param external_data_variable_node_mapping: external data variable node mapping
|
||||
"""
|
||||
# fetch start and knowledge retrieval node
|
||||
start_node = next(filter(lambda n: n['data']['type'] == NodeType.START.value, graph['nodes']))
|
||||
knowledge_retrieval_node = next(filter(
|
||||
lambda n: n['data']['type'] == NodeType.KNOWLEDGE_RETRIEVAL.value,
|
||||
graph['nodes']
|
||||
), None)
|
||||
|
||||
role_prefix = None
|
||||
|
||||
# Chat Model
|
||||
if model_config.mode == LLMMode.CHAT.value:
|
||||
if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
# get prompt template
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_template_config = prompt_transform.get_prompt_template(
|
||||
app_mode=original_app_mode,
|
||||
provider=model_config.provider,
|
||||
model=model_config.model,
|
||||
pre_prompt=prompt_template.simple_prompt_template,
|
||||
has_context=knowledge_retrieval_node is not None,
|
||||
query_in_prompt=False
|
||||
)
|
||||
|
||||
template = prompt_template_config['prompt_template'].template
|
||||
if not template:
|
||||
prompts = []
|
||||
else:
|
||||
template = self._replace_template_variables(
|
||||
template,
|
||||
start_node['data']['variables'],
|
||||
external_data_variable_node_mapping
|
||||
)
|
||||
|
||||
prompts = [
|
||||
{
|
||||
"role": 'user',
|
||||
"text": template
|
||||
}
|
||||
]
|
||||
else:
|
||||
advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template
|
||||
|
||||
prompts = []
|
||||
for m in advanced_chat_prompt_template.messages:
|
||||
if advanced_chat_prompt_template:
|
||||
text = m.text
|
||||
text = self._replace_template_variables(
|
||||
text,
|
||||
start_node['data']['variables'],
|
||||
external_data_variable_node_mapping
|
||||
)
|
||||
|
||||
prompts.append({
|
||||
"role": m.role.value,
|
||||
"text": text
|
||||
})
|
||||
# Completion Model
|
||||
else:
|
||||
if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
# get prompt template
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_template_config = prompt_transform.get_prompt_template(
|
||||
app_mode=original_app_mode,
|
||||
provider=model_config.provider,
|
||||
model=model_config.model,
|
||||
pre_prompt=prompt_template.simple_prompt_template,
|
||||
has_context=knowledge_retrieval_node is not None,
|
||||
query_in_prompt=False
|
||||
)
|
||||
|
||||
template = prompt_template_config['prompt_template'].template
|
||||
template = self._replace_template_variables(
|
||||
template,
|
||||
start_node['data']['variables'],
|
||||
external_data_variable_node_mapping
|
||||
)
|
||||
|
||||
prompts = {
|
||||
"text": template
|
||||
}
|
||||
|
||||
prompt_rules = prompt_template_config['prompt_rules']
|
||||
role_prefix = {
|
||||
"user": prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
|
||||
"assistant": prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
||||
}
|
||||
else:
|
||||
advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template
|
||||
if advanced_completion_prompt_template:
|
||||
text = advanced_completion_prompt_template.prompt
|
||||
text = self._replace_template_variables(
|
||||
text,
|
||||
start_node['data']['variables'],
|
||||
external_data_variable_node_mapping
|
||||
)
|
||||
else:
|
||||
text = ""
|
||||
|
||||
text = text.replace('{{#query#}}', '{{#sys.query#}}')
|
||||
|
||||
prompts = {
|
||||
"text": text,
|
||||
}
|
||||
|
||||
if advanced_completion_prompt_template.role_prefix:
|
||||
role_prefix = {
|
||||
"user": advanced_completion_prompt_template.role_prefix.user,
|
||||
"assistant": advanced_completion_prompt_template.role_prefix.assistant
|
||||
}
|
||||
|
||||
memory = None
|
||||
if new_app_mode == AppMode.ADVANCED_CHAT:
|
||||
memory = {
|
||||
"role_prefix": role_prefix,
|
||||
"window": {
|
||||
"enabled": False
|
||||
}
|
||||
}
|
||||
|
||||
completion_params = model_config.parameters
|
||||
completion_params.update({"stop": model_config.stop})
|
||||
return {
|
||||
"id": "llm",
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": "LLM",
|
||||
"type": NodeType.LLM.value,
|
||||
"model": {
|
||||
"provider": model_config.provider,
|
||||
"name": model_config.model,
|
||||
"mode": model_config.mode,
|
||||
"completion_params": completion_params
|
||||
},
|
||||
"prompt_template": prompts,
|
||||
"memory": memory,
|
||||
"context": {
|
||||
"enabled": knowledge_retrieval_node is not None,
|
||||
"variable_selector": ["knowledge_retrieval", "result"]
|
||||
if knowledge_retrieval_node is not None else None
|
||||
},
|
||||
"vision": {
|
||||
"enabled": file_upload is not None,
|
||||
"variable_selector": ["sys", "files"] if file_upload is not None else None,
|
||||
"configs": {
|
||||
"detail": file_upload.image_config['detail']
|
||||
} if file_upload is not None else None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def _replace_template_variables(self, template: str,
|
||||
variables: list[dict],
|
||||
external_data_variable_node_mapping: dict[str, str] = None) -> str:
|
||||
"""
|
||||
Replace Template Variables
|
||||
:param template: template
|
||||
:param variables: list of variables
|
||||
:return:
|
||||
"""
|
||||
for v in variables:
|
||||
template = template.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}')
|
||||
|
||||
if external_data_variable_node_mapping:
|
||||
for variable, code_node_id in external_data_variable_node_mapping.items():
|
||||
template = template.replace('{{' + variable + '}}',
|
||||
'{{#' + code_node_id + '.result#}}')
|
||||
|
||||
return template
|
||||
|
||||
def _convert_to_end_node(self) -> dict:
|
||||
"""
|
||||
Convert to End Node
|
||||
:return:
|
||||
"""
|
||||
# for original completion app
|
||||
return {
|
||||
"id": "end",
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": "END",
|
||||
"type": NodeType.END.value,
|
||||
"outputs": [{
|
||||
"variable": "result",
|
||||
"value_selector": ["llm", "text"]
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
def _convert_to_answer_node(self) -> dict:
|
||||
"""
|
||||
Convert to Answer Node
|
||||
:return:
|
||||
"""
|
||||
# for original chat app
|
||||
return {
|
||||
"id": "answer",
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": "ANSWER",
|
||||
"type": NodeType.ANSWER.value,
|
||||
"answer": "{{#llm.text#}}"
|
||||
}
|
||||
}
|
||||
|
||||
def _create_edge(self, source: str, target: str) -> dict:
|
||||
"""
|
||||
Create Edge
|
||||
:param source: source node id
|
||||
:param target: target node id
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"id": f"{source}-{target}",
|
||||
"source": source,
|
||||
"target": target
|
||||
}
|
||||
|
||||
def _append_node(self, graph: dict, node: dict) -> dict:
|
||||
"""
|
||||
Append Node to Graph
|
||||
|
||||
:param graph: Graph, include: nodes, edges
|
||||
:param node: Node to append
|
||||
:return:
|
||||
"""
|
||||
previous_node = graph['nodes'][-1]
|
||||
graph['nodes'].append(node)
|
||||
graph['edges'].append(self._create_edge(previous_node['id'], node['id']))
|
||||
return graph
|
||||
|
||||
def _get_new_app_mode(self, app_model: App) -> AppMode:
|
||||
"""
|
||||
Get new app mode
|
||||
:param app_model: App instance
|
||||
:return: AppMode
|
||||
"""
|
||||
if app_model.mode == AppMode.COMPLETION.value:
|
||||
return AppMode.WORKFLOW
|
||||
else:
|
||||
return AppMode.ADVANCED_CHAT
|
||||
|
||||
def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
|
||||
"""
|
||||
Get API Based Extension
|
||||
:param tenant_id: tenant id
|
||||
:param api_based_extension_id: api based extension id
|
||||
:return:
|
||||
"""
|
||||
return db.session.query(APIBasedExtension).filter(
|
||||
APIBasedExtension.tenant_id == tenant_id,
|
||||
APIBasedExtension.id == api_based_extension_id
|
||||
).first()
|
||||
62
api/services/workflow_app_service.py
Normal file
62
api/services/workflow_app_service.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
from flask_sqlalchemy.pagination import Pagination
|
||||
from sqlalchemy import and_, or_
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import CreatedByRole
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import WorkflowAppLog, WorkflowRun, WorkflowRunStatus
|
||||
|
||||
|
||||
class WorkflowAppService:
|
||||
|
||||
def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Pagination:
|
||||
"""
|
||||
Get paginate workflow app logs
|
||||
:param app: app model
|
||||
:param args: request args
|
||||
:return:
|
||||
"""
|
||||
query = (
|
||||
db.select(WorkflowAppLog)
|
||||
.where(
|
||||
WorkflowAppLog.tenant_id == app_model.tenant_id,
|
||||
WorkflowAppLog.app_id == app_model.id
|
||||
)
|
||||
)
|
||||
|
||||
status = WorkflowRunStatus.value_of(args.get('status')) if args.get('status') else None
|
||||
if args['keyword'] or status:
|
||||
query = query.join(
|
||||
WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id
|
||||
)
|
||||
|
||||
if args['keyword']:
|
||||
keyword_val = f"%{args['keyword'][:30]}%"
|
||||
keyword_conditions = [
|
||||
WorkflowRun.inputs.ilike(keyword_val),
|
||||
WorkflowRun.outputs.ilike(keyword_val),
|
||||
# filter keyword by end user session id if created by end user role
|
||||
and_(WorkflowRun.created_by_role == 'end_user', EndUser.session_id.ilike(keyword_val))
|
||||
]
|
||||
|
||||
query = query.outerjoin(
|
||||
EndUser,
|
||||
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value)
|
||||
).filter(or_(*keyword_conditions))
|
||||
|
||||
if status:
|
||||
# join with workflow_run and filter by status
|
||||
query = query.filter(
|
||||
WorkflowRun.status == status.value
|
||||
)
|
||||
|
||||
query = query.order_by(WorkflowAppLog.created_at.desc())
|
||||
|
||||
pagination = db.paginate(
|
||||
query,
|
||||
page=args['page'],
|
||||
per_page=args['limit'],
|
||||
error_out=False
|
||||
)
|
||||
|
||||
return pagination
|
||||
128
api/services/workflow_run_service.py
Normal file
128
api/services/workflow_run_service.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.model import App
|
||||
from models.workflow import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
WorkflowRunTriggeredFrom,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowRunService:
|
||||
def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination:
|
||||
"""
|
||||
Get advanced chat app workflow run list
|
||||
Only return triggered_from == advanced_chat
|
||||
|
||||
:param app_model: app model
|
||||
:param args: request args
|
||||
"""
|
||||
class WorkflowWithMessage:
|
||||
message_id: str
|
||||
conversation_id: str
|
||||
|
||||
def __init__(self, workflow_run: WorkflowRun):
|
||||
self._workflow_run = workflow_run
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._workflow_run, item)
|
||||
|
||||
pagination = self.get_paginate_workflow_runs(app_model, args)
|
||||
|
||||
with_message_workflow_runs = []
|
||||
for workflow_run in pagination.data:
|
||||
message = workflow_run.message
|
||||
with_message_workflow_run = WorkflowWithMessage(
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
if message:
|
||||
with_message_workflow_run.message_id = message.id
|
||||
with_message_workflow_run.conversation_id = message.conversation_id
|
||||
|
||||
with_message_workflow_runs.append(with_message_workflow_run)
|
||||
|
||||
pagination.data = with_message_workflow_runs
|
||||
return pagination
|
||||
|
||||
def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination:
|
||||
"""
|
||||
Get debug workflow run list
|
||||
Only return triggered_from == debugging
|
||||
|
||||
:param app_model: app model
|
||||
:param args: request args
|
||||
"""
|
||||
limit = int(args.get('limit', 20))
|
||||
|
||||
base_query = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.tenant_id == app_model.tenant_id,
|
||||
WorkflowRun.app_id == app_model.id,
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value
|
||||
)
|
||||
|
||||
if args.get('last_id'):
|
||||
last_workflow_run = base_query.filter(
|
||||
WorkflowRun.id == args.get('last_id'),
|
||||
).first()
|
||||
|
||||
if not last_workflow_run:
|
||||
raise ValueError('Last workflow run not exists')
|
||||
|
||||
workflow_runs = base_query.filter(
|
||||
WorkflowRun.created_at < last_workflow_run.created_at,
|
||||
WorkflowRun.id != last_workflow_run.id
|
||||
).order_by(WorkflowRun.created_at.desc()).limit(limit).all()
|
||||
else:
|
||||
workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
|
||||
|
||||
has_more = False
|
||||
if len(workflow_runs) == limit:
|
||||
current_page_first_workflow_run = workflow_runs[-1]
|
||||
rest_count = base_query.filter(
|
||||
WorkflowRun.created_at < current_page_first_workflow_run.created_at,
|
||||
WorkflowRun.id != current_page_first_workflow_run.id
|
||||
).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
return InfiniteScrollPagination(
|
||||
data=workflow_runs,
|
||||
limit=limit,
|
||||
has_more=has_more
|
||||
)
|
||||
|
||||
def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Get workflow run detail
|
||||
|
||||
:param app_model: app model
|
||||
:param run_id: workflow run id
|
||||
"""
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.tenant_id == app_model.tenant_id,
|
||||
WorkflowRun.app_id == app_model.id,
|
||||
WorkflowRun.id == run_id,
|
||||
).first()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[WorkflowNodeExecution]:
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
"""
|
||||
workflow_run = self.get_workflow_run(app_model, run_id)
|
||||
|
||||
if not workflow_run:
|
||||
return []
|
||||
|
||||
node_executions = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.tenant_id == app_model.tenant_id,
|
||||
WorkflowNodeExecution.app_id == app_model.id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == run_id,
|
||||
).order_by(WorkflowNodeExecution.index.desc()).all()
|
||||
|
||||
return node_executions
|
||||
302
api/services/workflow_service.py
Normal file
302
api/services/workflow_service.py
Normal file
|
|
@ -0,0 +1,302 @@
|
|||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from events.app_event import app_published_workflow_was_updated
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode
|
||||
from models.workflow import (
|
||||
CreatedByRole,
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowType,
|
||||
)
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
|
||||
class WorkflowService:
|
||||
"""
|
||||
Workflow Service
|
||||
"""
|
||||
|
||||
def get_draft_workflow(self, app_model: App) -> Optional[Workflow]:
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.version == 'draft'
|
||||
).first()
|
||||
|
||||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
|
||||
"""
|
||||
Get published workflow
|
||||
"""
|
||||
|
||||
if not app_model.workflow_id:
|
||||
return None
|
||||
|
||||
# fetch published workflow by workflow_id
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.id == app_model.workflow_id
|
||||
).first()
|
||||
|
||||
return workflow
|
||||
|
||||
def sync_draft_workflow(self, app_model: App,
|
||||
graph: dict,
|
||||
features: dict,
|
||||
account: Account) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
# validate features structure
|
||||
self.validate_features_structure(
|
||||
app_model=app_model,
|
||||
features=features
|
||||
)
|
||||
|
||||
# create draft workflow if not found
|
||||
if not workflow:
|
||||
workflow = Workflow(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type=WorkflowType.from_app_mode(app_model.mode).value,
|
||||
version='draft',
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps(features),
|
||||
created_by=account.id
|
||||
)
|
||||
db.session.add(workflow)
|
||||
# update draft workflow if found
|
||||
else:
|
||||
workflow.graph = json.dumps(graph)
|
||||
workflow.features = json.dumps(features)
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = datetime.utcnow()
|
||||
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
|
||||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def publish_workflow(self, app_model: App,
|
||||
account: Account,
|
||||
draft_workflow: Optional[Workflow] = None) -> Workflow:
|
||||
"""
|
||||
Publish workflow from draft
|
||||
|
||||
:param app_model: App instance
|
||||
:param account: Account instance
|
||||
:param draft_workflow: Workflow instance
|
||||
"""
|
||||
if not draft_workflow:
|
||||
# fetch draft workflow by app_model
|
||||
draft_workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if not draft_workflow:
|
||||
raise ValueError('No valid workflow found.')
|
||||
|
||||
# create new workflow
|
||||
workflow = Workflow(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type=draft_workflow.type,
|
||||
version=str(datetime.utcnow()),
|
||||
graph=draft_workflow.graph,
|
||||
features=draft_workflow.features,
|
||||
created_by=account.id
|
||||
)
|
||||
|
||||
# commit db session changes
|
||||
db.session.add(workflow)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
|
||||
app_model.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
|
||||
# trigger app workflow events
|
||||
app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
|
||||
|
||||
# return new workflow
|
||||
return workflow
|
||||
|
||||
def get_default_block_configs(self) -> list[dict]:
|
||||
"""
|
||||
Get default block configs
|
||||
"""
|
||||
# return default block config
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
return workflow_engine_manager.get_default_configs()
|
||||
|
||||
def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param node_type: node type
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
node_type = NodeType.value_of(node_type)
|
||||
|
||||
# return default block config
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
return workflow_engine_manager.get_default_config(node_type, filters)
|
||||
|
||||
def run_draft_workflow_node(self, app_model: App,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
account: Account) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
draft_workflow = self.get_draft_workflow(app_model=app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
# run draft workflow node
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
node_instance, node_run_result = workflow_engine_manager.single_step_run_workflow_node(
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
user_id=account.id,
|
||||
)
|
||||
except WorkflowNodeRunFailedError as e:
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
workflow_id=draft_workflow.id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value,
|
||||
index=1,
|
||||
node_id=e.node_id,
|
||||
node_type=e.node_type.value,
|
||||
title=e.node_title,
|
||||
status=WorkflowNodeExecutionStatus.FAILED.value,
|
||||
error=e.error,
|
||||
elapsed_time=time.perf_counter() - start_at,
|
||||
created_by_role=CreatedByRole.ACCOUNT.value,
|
||||
created_by=account.id,
|
||||
created_at=datetime.utcnow(),
|
||||
finished_at=datetime.utcnow()
|
||||
)
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
# create workflow node execution
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
workflow_id=draft_workflow.id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value,
|
||||
index=1,
|
||||
node_id=node_id,
|
||||
node_type=node_instance.node_type.value,
|
||||
title=node_instance.node_data.title,
|
||||
inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None,
|
||||
process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None,
|
||||
outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None,
|
||||
execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata))
|
||||
if node_run_result.metadata else None),
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
elapsed_time=time.perf_counter() - start_at,
|
||||
created_by_role=CreatedByRole.ACCOUNT.value,
|
||||
created_by=account.id,
|
||||
created_at=datetime.utcnow(),
|
||||
finished_at=datetime.utcnow()
|
||||
)
|
||||
else:
|
||||
# create workflow node execution
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
workflow_id=draft_workflow.id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value,
|
||||
index=1,
|
||||
node_id=node_id,
|
||||
node_type=node_instance.node_type.value,
|
||||
title=node_instance.node_data.title,
|
||||
status=node_run_result.status.value,
|
||||
error=node_run_result.error,
|
||||
elapsed_time=time.perf_counter() - start_at,
|
||||
created_by_role=CreatedByRole.ACCOUNT.value,
|
||||
created_by=account.id,
|
||||
created_at=datetime.utcnow(),
|
||||
finished_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
|
||||
"""
|
||||
Basic mode of chatbot app(expert mode) to workflow
|
||||
Completion App to Workflow App
|
||||
|
||||
:param app_model: App instance
|
||||
:param account: Account instance
|
||||
:param args: dict
|
||||
:return:
|
||||
"""
|
||||
# chatbot convert to workflow mode
|
||||
workflow_converter = WorkflowConverter()
|
||||
|
||||
if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]:
|
||||
raise ValueError(f'Current App mode: {app_model.mode} is not supported convert to workflow.')
|
||||
|
||||
# convert to workflow
|
||||
new_app = workflow_converter.convert_to_workflow(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
name=args.get('name'),
|
||||
icon=args.get('icon'),
|
||||
icon_background=args.get('icon_background'),
|
||||
)
|
||||
|
||||
return new_app
|
||||
|
||||
def validate_features_structure(self, app_model: App, features: dict) -> dict:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
return AdvancedChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=features,
|
||||
only_structure_validate=True
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
return WorkflowAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=features,
|
||||
only_structure_validate=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode: {app_model.mode}")
|
||||
Loading…
Add table
Add a link
Reference in a new issue