feat/enhance the multi-modal support (#8818)

This commit is contained in:
-LAN- 2024-10-21 10:43:49 +08:00 committed by GitHub
commit e61752bd3a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
267 changed files with 6263 additions and 3523 deletions

View file

@ -16,13 +16,14 @@ from core.app.entities.app_invoke_entities import (
)
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file.message_file_parser import MessageFileParser
from core.file import file_manager
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMUsage,
PromptMessage,
PromptMessageContent,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
@ -40,9 +41,9 @@ from core.tools.entities.tool_entities import (
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.tool.tool import Tool
from core.tools.tool_manager import ToolManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from extensions.ext_database import db
from models.model import Conversation, Message, MessageAgentThought
from factories import file_factory
from models.model import Conversation, Message, MessageAgentThought, MessageFile
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
@ -66,23 +67,6 @@ class BaseAgentRunner(AppRunner):
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None,
) -> None:
"""
Agent runner
:param tenant_id: tenant id
:param application_generate_entity: application generate entity
:param conversation: conversation
:param app_config: app generate entity
:param model_config: model config
:param config: dataset config
:param queue_manager: queue manager
:param message: message
:param user_id: user id
:param memory: memory
:param prompt_messages: prompt messages
:param variables_pool: variables pool
:param db_variables: db variables
:param model_instance: model instance
"""
self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity
self.conversation = conversation
@ -180,7 +164,7 @@ class BaseAgentRunner(AppRunner):
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
parameter_type = parameter.type.as_normal_type()
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
@ -265,7 +249,7 @@ class BaseAgentRunner(AppRunner):
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
parameter_type = parameter.type.as_normal_type()
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
@ -511,26 +495,24 @@ class BaseAgentRunner(AppRunner):
return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
message_file_parser = MessageFileParser(
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
)
files = message.message_files
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files:
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
)
else:
file_objs = []
if not file_objs:
return UserPromptMessage(content=message.query)
else:
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file_obj in file_objs:
prompt_message_contents.append(file_obj.prompt_message_content)
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
return UserPromptMessage(content=prompt_message_contents)
else:

View file

@ -1,9 +1,11 @@
import json
from core.agent.cot_agent_runner import CotAgentRunner
from core.model_runtime.entities.message_entities import (
from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContent,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
@ -32,9 +34,10 @@ class CotChatAgentRunner(CotAgentRunner):
Organize user query
"""
if self.files:
prompt_message_contents = [TextPromptMessageContent(data=query)]
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
for file_obj in self.files:
prompt_message_contents.append(file_obj.prompt_message_content)
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:

View file

@ -7,10 +7,15 @@ from typing import Any, Optional, Union
from core.agent.base_agent_runner import BaseAgentRunner
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
@ -390,9 +395,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
Organize user query
"""
if self.files:
prompt_message_contents = [TextPromptMessageContent(data=query)]
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
for file_obj in self.files:
prompt_message_contents.append(file_obj.prompt_message_content)
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:

View file

@ -53,12 +53,11 @@ class BasicVariablesConfigManager:
VariableEntity(
type=variable_type,
variable=variable.get("variable"),
description=variable.get("description"),
description=variable.get("description", ""),
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options"),
default=variable.get("default"),
options=variable.get("options", []),
)
)

View file

@ -1,11 +1,12 @@
from collections.abc import Sequence
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.file.file_obj import FileExtraConfig
from core.file import FileExtraConfig, FileTransferMethod, FileType
from core.model_runtime.entities.message_entities import PromptMessageRole
from models import AppMode
from models.model import AppMode
class ModelConfigEntity(BaseModel):
@ -69,7 +70,7 @@ class PromptTemplateEntity(BaseModel):
ADVANCED = "advanced"
@classmethod
def value_of(cls, value: str) -> "PromptType":
def value_of(cls, value: str):
"""
Get value of given mode.
@ -93,6 +94,8 @@ class VariableEntityType(str, Enum):
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external_data_tool"
FILE = "file"
FILE_LIST = "file-list"
class VariableEntity(BaseModel):
@ -102,13 +105,14 @@ class VariableEntity(BaseModel):
variable: str
label: str
description: Optional[str] = None
description: str = ""
type: VariableEntityType
required: bool = False
max_length: Optional[int] = None
options: Optional[list[str]] = None
default: Optional[str] = None
hint: Optional[str] = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
class ExternalDataVariableEntity(BaseModel):
@ -136,7 +140,7 @@ class DatasetRetrieveConfigEntity(BaseModel):
MULTIPLE = "multiple"
@classmethod
def value_of(cls, value: str) -> "RetrieveStrategy":
def value_of(cls, value: str):
"""
Get value of given mode.

View file

@ -1,12 +1,13 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.file.file_obj import FileExtraConfig
from core.file.models import FileExtraConfig
from models import FileUploadConfig
class FileUploadConfigManager:
@classmethod
def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]:
def convert(cls, config: Mapping[str, Any], is_vision: bool = True):
"""
Convert model config to model config
@ -15,19 +16,18 @@ class FileUploadConfigManager:
"""
file_upload_dict = config.get("file_upload")
if file_upload_dict:
if file_upload_dict.get("image"):
if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
image_config = {
"number_limits": file_upload_dict["image"]["number_limits"],
"transfer_methods": file_upload_dict["image"]["transfer_methods"],
if file_upload_dict.get("enabled"):
data = {
"image_config": {
"number_limits": file_upload_dict["number_limits"],
"transfer_methods": file_upload_dict["allowed_file_upload_methods"],
}
}
if is_vision:
image_config["detail"] = file_upload_dict["image"]["detail"]
if is_vision:
data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")
return FileExtraConfig(image_config=image_config)
return None
return FileExtraConfig.model_validate(data)
@classmethod
def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]:
@ -39,29 +39,7 @@ class FileUploadConfigManager:
"""
if not config.get("file_upload"):
config["file_upload"] = {}
if not isinstance(config["file_upload"], dict):
raise ValueError("file_upload must be of dict type")
# check image config
if not config["file_upload"].get("image"):
config["file_upload"]["image"] = {"enabled": False}
if config["file_upload"]["image"]["enabled"]:
number_limits = config["file_upload"]["image"]["number_limits"]
if number_limits < 1 or number_limits > 6:
raise ValueError("number_limits must be in [1, 6]")
if is_vision:
detail = config["file_upload"]["image"]["detail"]
if detail not in {"high", "low"}:
raise ValueError("detail must be in ['high', 'low']")
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
if not isinstance(transfer_methods, list):
raise ValueError("transfer_methods must be of list type")
for method in transfer_methods:
if method not in {"remote_url", "local_file"}:
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
else:
FileUploadConfig.model_validate(config["file_upload"])
return config, ["file_upload"]

View file

@ -17,6 +17,6 @@ class WorkflowVariablesConfigManager:
# variables
for variable in user_input_form:
variables.append(VariableEntity(**variable))
variables.append(VariableEntity.model_validate(variable))
return variables

View file

@ -21,11 +21,12 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
from models.enums import CreatedByRole
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
@ -96,10 +97,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# parse files
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:
file_objs = []
@ -107,8 +114,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
trace_manager = TraceQueueManager(app_model.id, user_id)
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
if invoke_from == InvokeFrom.DEBUGGER:
# always enable retriever resource in debugger mode
@ -120,7 +128,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
task_id=str(uuid.uuid4()),
app_config=app_config,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

View file

@ -1,31 +1,27 @@
import logging
import os
from collections.abc import Mapping
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
QueueStopEvent,
QueueTextChunkEvent,
)
from core.moderation.base import ModerationError
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.enums import UserFrom
from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, WorkflowType
@ -44,12 +40,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation: Conversation,
message: Message,
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
"""
super().__init__(queue_manager)
self.application_generate_entity = application_generate_entity
@ -57,10 +47,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self.message = message
def run(self) -> None:
"""
Run application
:return:
"""
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
@ -81,7 +67,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_id = self.application_generate_entity.user_id
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
if self.application_generate_entity.single_iteration_run:
@ -201,15 +187,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
query: str,
message_id: str,
) -> bool:
"""
Handle input moderation
:param app_record: app record
:param app_generate_entity: application generate entity
:param inputs: inputs
:param query: query
:param message_id: message id
:return:
"""
try:
# process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs(
@ -229,14 +206,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
def handle_annotation_reply(
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
) -> bool:
"""
Handle annotation reply
:param app_record: app record
:param message: message
:param query: query
:param app_generate_entity: application generate entity
"""
# annotation reply
annotation_reply = self.query_app_annotations_to_reply(
app_record=app_record,
message=message,
@ -258,8 +227,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
"""
Direct output
:param text: text
:return:
"""
self._publish_event(QueueTextChunkEvent(text=text))

View file

@ -1,7 +1,7 @@
import json
import logging
import time
from collections.abc import Generator
from collections.abc import Generator, Mapping
from typing import Any, Optional, Union
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@ -9,6 +9,7 @@ from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent,
@ -50,10 +51,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from events.message_event import message_was_created
from extensions.ext_database import db
from models import Conversation, EndUser, Message, MessageFile
from models.account import Account
from models.model import Conversation, EndUser, Message
from models.enums import CreatedByRole
from models.workflow import (
Workflow,
WorkflowNodeExecution,
@ -120,6 +123,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._wip_workflow_node_executions = {}
self._conversation_name_generate_thread = None
self._recorded_files: list[Mapping[str, Any]] = []
def process(self):
"""
@ -298,6 +302,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
# Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
@ -364,7 +372,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs) if event.outputs else None,
outputs=event.outputs,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
@ -490,10 +498,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._conversation_name_generate_thread.join()
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
"""
Save message.
:return:
"""
self._refetch_message()
self._message.answer = self._task_state.answer
@ -501,6 +505,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message_files = [
MessageFile(
message_id=self._message.id,
type=file["type"],
transfer_method=file["transfer_method"],
url=file["remote_url"],
belongs_to="assistant",
upload_file_id=file["related_id"],
created_by_role=CreatedByRole.ACCOUNT
if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatedByRole.END_USER,
created_by=self._message.from_account_id or self._message.from_end_user_id or "",
)
for file in self._recorded_files
]
db.session.add_all(message_files)
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
@ -540,7 +560,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
del extras["metadata"]["annotation_reply"]
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras
)
def _handle_output_moderation_chunk(self, text: str) -> bool:

View file

@ -18,12 +18,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser
from factories import file_factory
from models import Account, App, EndUser
from models.enums import CreatedByRole
logger = logging.getLogger(__name__)
@ -50,7 +50,12 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
) -> dict: ...
def generate(
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
@ -98,12 +103,19 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True}
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
files = args.get("files") or []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:
file_objs = []
@ -116,8 +128,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
)
# get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
trace_manager = TraceQueueManager(app_model.id, user_id)
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
# init application generate entity
application_generate_entity = AgentChatAppGenerateEntity(
@ -125,7 +136,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

View file

@ -1,35 +1,92 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
from core.app.app_config.entities import VariableEntityType
from core.file import File, FileExtraConfig
from factories import file_factory
if TYPE_CHECKING:
from core.app.app_config.entities import AppConfig, VariableEntity
from models.enums import CreatedByRole
class BaseAppGenerator:
def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]:
def _prepare_user_inputs(
self,
*,
user_inputs: Optional[Mapping[str, Any]],
app_config: "AppConfig",
user_id: str,
role: "CreatedByRole",
) -> Mapping[str, Any]:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
return filtered_inputs
user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
# Convert files in inputs to File
entity_dictionary = {item.variable: item for item in app_config.variables}
# Convert single file to File
files_inputs = {
k: file_factory.build_from_mapping(
mapping=v,
tenant_id=app_config.tenant_id,
user_id=user_id,
role=role,
config=FileExtraConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
),
)
for k, v in user_inputs.items()
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
}
# Convert list of files to File
file_list_inputs = {
k: file_factory.build_from_mappings(
mappings=v,
tenant_id=app_config.tenant_id,
user_id=user_id,
role=role,
config=FileExtraConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
),
)
for k, v in user_inputs.items()
if isinstance(v, list)
# Ensure skip List<File>
and all(isinstance(item, dict) for item in v)
and entity_dictionary[k].type == VariableEntityType.FILE_LIST
}
# Merge all inputs
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.variable)
if var.required and not user_input_value:
raise ValueError(f"{var.variable} is required in input form")
if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None?
return var.default or ""
if (
var.type
in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
}
and user_input_value
and not isinstance(user_input_value, str)
# Check if all files are converted to File
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
raise ValueError("Invalid input type")
if any(
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
):
raise ValueError("Invalid input type")
return user_inputs
def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"):
user_input_value = inputs.get(var.variable)
if not user_input_value:
if var.required:
raise ValueError(f"{var.variable} is required in input form")
else:
return None
if var.type in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
} and not isinstance(user_input_value, str):
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number
@ -41,12 +98,24 @@ class BaseAppGenerator:
except ValueError:
raise ValueError(f"{var.variable} in input form must be a valid number")
if var.type == VariableEntityType.SELECT:
options = var.options or []
options = var.options
if user_input_value not in options:
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
if var.max_length and len(user_input_value) > var.max_length:
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
elif var.type == VariableEntityType.FILE:
if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
raise ValueError(f"{var.variable} in input form must be a file")
elif var.type == VariableEntityType.FILE_LIST:
if not (
isinstance(user_input_value, list)
and (
all(isinstance(item, dict) for item in user_input_value)
or all(isinstance(item, File) for item in user_input_value)
)
):
raise ValueError(f"{var.variable} in input form must be a list of files")
return user_input_value

View file

@ -27,7 +27,7 @@ from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from models.model import App, AppMode, Message, MessageAnnotation
if TYPE_CHECKING:
from core.file.file_obj import FileVar
from core.file.models import File
class AppRunner:
@ -37,7 +37,7 @@ class AppRunner:
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
files: list["File"],
query: Optional[str] = None,
) -> int:
"""
@ -137,7 +137,7 @@ class AppRunner:
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
files: list["File"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None,

View file

@ -18,11 +18,12 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
from models.enums import CreatedByRole
from models.model import App, EndUser
logger = logging.getLogger(__name__)
@ -100,12 +101,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True}
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:
file_objs = []
@ -118,7 +126,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
)
# get tracing instance
trace_manager = TraceQueueManager(app_model.id)
trace_manager = TraceQueueManager(app_id=app_model.id)
# init application generate entity
application_generate_entity = ChatAppGenerateEntity(
@ -126,15 +134,17 @@ class ChatAppGenerator(MessageBasedAppGenerator):
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
stream=stream,
)
# init generate records

View file

@ -17,12 +17,12 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser, Message
from factories import file_factory
from models import Account, App, EndUser, Message
from models.enums import CreatedByRole
from services.errors.app import MoreLikeThisDisabledError
from services.errors.message import MessageNotExistsError
@ -88,12 +88,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
tenant_id=app_model.tenant_id, config=args.get("model_config")
)
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:
file_objs = []
@ -103,6 +110,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
)
# get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
trace_manager = TraceQueueManager(app_model.id)
# init application generate entity
@ -110,7 +118,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
inputs=self._get_cleaned_inputs(inputs, app_config),
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
query=query,
files=file_objs,
user_id=user.id,
@ -251,10 +259,16 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
override_model_config_dict["model"] = model_dict
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user)
file_objs = file_factory.build_from_mappings(
mappings=message.message_files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
else:
file_objs = []

View file

@ -26,7 +26,7 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db
from models.account import Account
from models import Account
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
@ -235,13 +235,13 @@ class MessageBasedAppGenerator(BaseAppGenerator):
for file in application_generate_entity.files:
message_file = MessageFile(
message_id=message.id,
type=file.type.value,
transfer_method=file.transfer_method.value,
type=file.type,
transfer_method=file.transfer_method,
belongs_to="user",
url=file.url,
url=file.remote_url,
upload_file_id=file.related_id,
created_by_role=("account" if account_id else "end_user"),
created_by=account_id or end_user_id,
created_by=account_id or end_user_id or "",
)
db.session.add(message_file)
db.session.commit()

View file

@ -3,7 +3,7 @@ import logging
import os
import threading
import uuid
from collections.abc import Generator
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
@ -20,13 +20,12 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser
from models.workflow import Workflow
from factories import file_factory
from models import Account, App, EndUser, Workflow
from models.enums import CreatedByRole
logger = logging.getLogger(__name__)
@ -63,49 +62,46 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
stream: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
):
"""
Generate App response.
files: Sequence[Mapping[str, Any]] = args.get("files") or []
:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param stream: is stream
:param call_depth: call depth
:param workflow_thread_pool_id: workflow thread pool id
"""
inputs = args["inputs"]
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
system_files = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow,
)
# get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
trace_manager = TraceQueueManager(app_model.id, user_id)
trace_manager = TraceQueueManager(
app_id=app_model.id,
user_id=user.id if isinstance(user, Account) else user.session_id,
)
inputs: Mapping[str, Any] = args["inputs"]
workflow_run_id = str(uuid.uuid4())
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
inputs=self._get_cleaned_inputs(inputs, app_config),
files=file_objs,
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
files=system_files,
user_id=user.id,
stream=stream,
invoke_from=invoke_from,

View file

@ -1,21 +1,20 @@
import logging
import os
from typing import Optional, cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.enums import UserFrom
from models.model import App, EndUser
from models.workflow import WorkflowType
@ -71,7 +70,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
db.session.close()
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
# if only single iteration run is requested

View file

@ -1,4 +1,3 @@
import json
import logging
import time
from collections.abc import Generator
@ -334,9 +333,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs)
if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs
else None,
outputs=event.outputs,
conversation_id=None,
trace_manager=trace_manager,
)

View file

@ -20,7 +20,6 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
@ -41,9 +40,9 @@ from core.workflow.graph_engine.entities.event import (
ParallelBranchRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.nodes.node_mapping import node_classes
from core.workflow.nodes import NodeType
from core.workflow.nodes.iteration import IterationNodeData
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
@ -137,9 +136,8 @@ class WorkflowBasedAppRunner(AppRunner):
raise ValueError("iteration node id not found in workflow graph")
# Get node class
node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type"))
node_cls = node_classes.get(node_type)
node_cls = cast(type[BaseNode], node_cls)
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping[node_type]
# init variable pool
variable_pool = VariablePool(

View file

@ -1,4 +1,4 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Optional
@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
from constants import UUID_NIL
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
from core.file.file_obj import FileVar
from core.file.models import File
from core.model_runtime.entities.model_entities import AIModelEntity
from core.ops.ops_trace_manager import TraceQueueManager
@ -23,7 +23,7 @@ class InvokeFrom(Enum):
DEBUGGER = "debugger"
@classmethod
def value_of(cls, value: str) -> "InvokeFrom":
def value_of(cls, value: str):
"""
Get value of given mode.
@ -82,7 +82,7 @@ class AppGenerateEntity(BaseModel):
app_config: AppConfig
inputs: Mapping[str, Any]
files: list[FileVar] = []
files: Sequence[File]
user_id: str
# extras

View file

@ -5,9 +5,10 @@ from typing import Any, Optional
from pydantic import BaseModel, field_validator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData
class QueueEvent(str, Enum):

View file

@ -1,3 +1,4 @@
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Optional
@ -119,6 +120,7 @@ class MessageEndStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE_END
id: str
metadata: dict = {}
files: Optional[Sequence[Mapping[str, Any]]] = None
class MessageFileStreamResponse(StreamResponse):
@ -211,7 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
created_by: Optional[dict] = None
created_at: int
finished_at: int
files: Optional[list[dict]] = []
files: Optional[Sequence[Mapping[str, Any]]] = []
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
workflow_run_id: str
@ -296,7 +298,7 @@ class NodeFinishStreamResponse(StreamResponse):
execution_metadata: Optional[dict] = None
created_at: int
finished_at: int
files: Optional[list[dict]] = []
files: Optional[Sequence[Mapping[str, Any]]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None

View file

@ -1,76 +0,0 @@
from collections.abc import Mapping
from typing import Any
from configs import dify_config
from .exc import VariableError
from .segments import (
ArrayAnySegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
StringSegment,
)
from .types import SegmentType
from .variables import (
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FloatVariable,
IntegerVariable,
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
)
def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if (value_type := mapping.get("value_type")) is None:
raise VariableError("missing value type")
if not mapping.get("name"):
raise VariableError("missing name")
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
case SegmentType.SECRET:
result = SecretVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, int):
result = IntegerVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, float):
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f"invalid number value {value}")
case SegmentType.OBJECT if isinstance(value, dict):
result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list):
result = ArrayStringVariable.model_validate(mapping)
case SegmentType.ARRAY_NUMBER if isinstance(value, list):
result = ArrayNumberVariable.model_validate(mapping)
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping)
case _:
raise VariableError(f"not supported value type {value_type}")
if result.size > dify_config.MAX_VARIABLE_SIZE:
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
return result
def build_segment(value: Any, /) -> Segment:
if value is None:
return NoneSegment()
if isinstance(value, str):
return StringSegment(value=value)
if isinstance(value, int):
return IntegerSegment(value=value)
if isinstance(value, float):
return FloatSegment(value=value)
if isinstance(value, dict):
return ObjectSegment(value=value)
if isinstance(value, list):
return ArrayAnySegment(value=value)
raise ValueError(f"not supported value {value}")

View file

@ -1,18 +0,0 @@
import re
from core.workflow.entities.variable_pool import VariablePool
from . import SegmentGroup, factory
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
def convert_template(*, template: str, variable_pool: VariablePool):
parts = re.split(VARIABLE_PATTERN, template)
segments = []
for part in filter(lambda x: x, parts):
if "." in part and (value := variable_pool.get(part.split("."))):
segments.append(value)
else:
segments.append(factory.build_segment(part))
return SegmentGroup(value=segments)

View file

@ -1,5 +1,6 @@
import json
import time
from collections.abc import Mapping, Sequence
from datetime import datetime, timezone
from typing import Any, Optional, Union, cast
@ -27,27 +28,26 @@ from core.app.entities.task_entities import (
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.file.file_obj import FileVar
from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
from models.model import EndUser
from models.workflow import (
CreatedByRole,
Workflow,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
WorkflowRunStatus,
WorkflowRunTriggeredFrom,
)
@ -117,7 +117,7 @@ class WorkflowCycleManage:
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Optional[str] = None,
outputs: Mapping[str, Any] | None = None,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
@ -133,8 +133,10 @@ class WorkflowCycleManage:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
outputs = WorkflowEntry.handle_special_values(outputs)
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
workflow_run.outputs = outputs
workflow_run.outputs = json.dumps(outputs or {})
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
@ -265,6 +267,7 @@ class WorkflowCycleManage:
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
@ -276,7 +279,7 @@ class WorkflowCycleManage:
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.execution_metadata: execution_metadata,
WorkflowNodeExecution.finished_at: finished_at,
@ -286,10 +289,11 @@ class WorkflowCycleManage:
db.session.commit()
db.session.close()
process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.finished_at = finished_at
@ -308,6 +312,7 @@ class WorkflowCycleManage:
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
@ -317,7 +322,7 @@ class WorkflowCycleManage:
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.finished_at: finished_at,
WorkflowNodeExecution.elapsed_time: elapsed_time,
@ -326,11 +331,12 @@ class WorkflowCycleManage:
db.session.commit()
db.session.close()
process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
@ -637,7 +643,7 @@ class WorkflowCycleManage:
),
)
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from node outputs
:param outputs_dict: node outputs dict
@ -646,15 +652,15 @@ class WorkflowCycleManage:
if not outputs_dict:
return []
files = []
for output_var, output_value in outputs_dict.items():
file_vars = self._fetch_files_from_variable_value(output_value)
if file_vars:
files.extend(file_vars)
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
# Remove None
files = [file for file in files if file]
# Flatten list
files = [file for sublist in files for file in sublist]
return files
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> list[dict]:
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from variable value
:param value: variable value
@ -666,17 +672,17 @@ class WorkflowCycleManage:
files = []
if isinstance(value, list):
for item in value:
file_var = self._get_file_var_from_value(item)
if file_var:
files.append(file_var)
file = self._get_file_var_from_value(item)
if file:
files.append(file)
elif isinstance(value, dict):
file_var = self._get_file_var_from_value(value)
if file_var:
files.append(file_var)
file = self._get_file_var_from_value(value)
if file:
files.append(file)
return files
def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]:
def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
"""
Get file var from value
:param value: variable value
@ -685,14 +691,11 @@ class WorkflowCycleManage:
if not value:
return None
if isinstance(value, dict):
if "__variant" in value and value["__variant"] == FileVar.__name__:
return value
elif isinstance(value, FileVar):
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
return value
elif isinstance(value, File):
return value.to_dict()
return None
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
"""
Refetch workflow run

View file

@ -1,29 +0,0 @@
import enum
from typing import Any
from pydantic import BaseModel
class PromptMessageFileType(enum.Enum):
IMAGE = "image"
@staticmethod
def value_of(value):
for member in PromptMessageFileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class PromptMessageFile(BaseModel):
type: PromptMessageFileType
data: Any = None
class ImagePromptMessageFile(PromptMessageFile):
class DETAIL(enum.Enum):
LOW = "low"
HIGH = "high"
type: PromptMessageFileType = PromptMessageFileType.IMAGE
detail: DETAIL = DETAIL.LOW

View file

@ -0,0 +1,19 @@
from .constants import FILE_MODEL_IDENTITY
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
from .models import (
File,
FileExtraConfig,
ImageConfig,
)
__all__ = [
"FileType",
"FileExtraConfig",
"FileTransferMethod",
"FileBelongsTo",
"File",
"ImageConfig",
"FileAttribute",
"ArrayFileAttribute",
"FILE_MODEL_IDENTITY",
]

View file

@ -0,0 +1 @@
FILE_MODEL_IDENTITY = "__dify__file__"

55
api/core/file/enums.py Normal file
View file

@ -0,0 +1,55 @@
from enum import Enum
class FileType(str, Enum):
IMAGE = "image"
DOCUMENT = "document"
AUDIO = "audio"
VIDEO = "video"
CUSTOM = "custom"
@staticmethod
def value_of(value):
for member in FileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileTransferMethod(str, Enum):
REMOTE_URL = "remote_url"
LOCAL_FILE = "local_file"
TOOL_FILE = "tool_file"
@staticmethod
def value_of(value):
for member in FileTransferMethod:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileBelongsTo(str, Enum):
USER = "user"
ASSISTANT = "assistant"
@staticmethod
def value_of(value):
for member in FileBelongsTo:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileAttribute(str, Enum):
TYPE = "type"
SIZE = "size"
NAME = "name"
MIME_TYPE = "mime_type"
TRANSFER_METHOD = "transfer_method"
URL = "url"
EXTENSION = "extension"
class ArrayFileAttribute(str, Enum):
LENGTH = "length"

View file

@ -0,0 +1,156 @@
import base64
from configs import dify_config
from core.file import file_repository
from core.helper import ssrf_proxy
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent
from extensions.ext_database import db
from extensions.ext_storage import storage
from . import helpers
from .enums import FileAttribute
from .models import File, FileTransferMethod, FileType
from .tool_file_parser import ToolFileParser
def get_attr(*, file: File, attr: FileAttribute):
match attr:
case FileAttribute.TYPE:
return file.type.value
case FileAttribute.SIZE:
return file.size
case FileAttribute.NAME:
return file.filename
case FileAttribute.MIME_TYPE:
return file.mime_type
case FileAttribute.TRANSFER_METHOD:
return file.transfer_method.value
case FileAttribute.URL:
return file.remote_url
case FileAttribute.EXTENSION:
return file.extension
case _:
raise ValueError(f"Invalid file attribute: {attr}")
def to_prompt_message_content(f: File, /):
"""
Convert a File object to an ImagePromptMessageContent object.
This function takes a File object and converts it to an ImagePromptMessageContent
object, which can be used as a prompt for image-based AI models.
Args:
file (File): The File object to convert. Must be of type FileType.IMAGE.
Returns:
ImagePromptMessageContent: An object containing the image data and detail level.
Raises:
ValueError: If the file is not an image or if the file data is missing.
Note:
The detail level of the image prompt is determined by the file's extra_config.
If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW.
"""
match f.type:
case FileType.IMAGE:
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)
if f._extra_config and f._extra_config.image_config and f._extra_config.image_config.detail:
detail = f._extra_config.image_config.detail
else:
detail = ImagePromptMessageContent.DETAIL.LOW
return ImagePromptMessageContent(data=data, detail=detail)
case FileType.AUDIO:
encoded_string = _file_to_encoded_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
case _:
raise ValueError(f"file type {f.type} is not supported")
def download(f: File, /):
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
return _download_file_content(upload_file.key)
def _download_file_content(path: str, /):
"""
Download and return the contents of a file as bytes.
This function loads the file from storage and ensures it's in bytes format.
Args:
path (str): The path to the file in storage.
Returns:
bytes: The contents of the file as a bytes object.
Raises:
ValueError: If the loaded file is not a bytes object.
"""
data = storage.load(path, stream=False)
if not isinstance(data, bytes):
raise ValueError(f"file {path} is not a bytes object")
return data
def _get_encoded_string(f: File, /):
match f.transfer_method:
case FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url)
response.raise_for_status()
content = response.content
encoded_string = base64.b64encode(content).decode("utf-8")
return encoded_string
case FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
data = _download_file_content(upload_file.key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string
case FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
data = _download_file_content(tool_file.file_key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string
case _:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
def _to_base64_data_string(f: File, /):
encoded_string = _get_encoded_string(f)
return f"data:{f.mime_type};base64,{encoded_string}"
def _file_to_encoded_string(f: File, /):
match f.type:
case FileType.IMAGE:
return _to_base64_data_string(f)
case FileType.AUDIO:
return _get_encoded_string(f)
case _:
raise ValueError(f"file type {f.type} is not supported")
def _to_url(f: File, /):
if f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
raise ValueError("Missing file remote_url")
return f.remote_url
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
if f.related_id is None:
raise ValueError("Missing file related_id")
return helpers.get_signed_file_url(upload_file_id=f.related_id)
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
# add sign url
if f.related_id is None or f.extension is None:
raise ValueError("Missing file related_id or extension")
return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension)
else:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")

View file

@ -1,145 +0,0 @@
import enum
from typing import Any, Optional
from pydantic import BaseModel
from core.file.tool_file_parser import ToolFileParser
from core.file.upload_file_parser import UploadFileParser
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from extensions.ext_database import db
class FileExtraConfig(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[dict[str, Any]] = None
class FileType(enum.Enum):
IMAGE = "image"
@staticmethod
def value_of(value):
for member in FileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileTransferMethod(enum.Enum):
REMOTE_URL = "remote_url"
LOCAL_FILE = "local_file"
TOOL_FILE = "tool_file"
@staticmethod
def value_of(value):
for member in FileTransferMethod:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileBelongsTo(enum.Enum):
USER = "user"
ASSISTANT = "assistant"
@staticmethod
def value_of(value):
for member in FileBelongsTo:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileVar(BaseModel):
id: Optional[str] = None # message file id
tenant_id: str
type: FileType
transfer_method: FileTransferMethod
url: Optional[str] = None # remote url
related_id: Optional[str] = None
extra_config: Optional[FileExtraConfig] = None
filename: Optional[str] = None
extension: Optional[str] = None
mime_type: Optional[str] = None
def to_dict(self) -> dict:
return {
"__variant": self.__class__.__name__,
"tenant_id": self.tenant_id,
"type": self.type.value,
"transfer_method": self.transfer_method.value,
"url": self.preview_url,
"remote_url": self.url,
"related_id": self.related_id,
"filename": self.filename,
"extension": self.extension,
"mime_type": self.mime_type,
}
def to_markdown(self) -> str:
"""
Convert file to markdown
:return:
"""
preview_url = self.preview_url
if self.type == FileType.IMAGE:
text = f'![{self.filename or ""}]({preview_url})'
else:
text = f"[{self.filename or preview_url}]({preview_url})"
return text
@property
def data(self) -> Optional[str]:
"""
Get image data, file signed url or base64 data
depending on config MULTIMODAL_SEND_IMAGE_FORMAT
:return:
"""
return self._get_data()
@property
def preview_url(self) -> Optional[str]:
"""
Get signed preview url
:return:
"""
return self._get_data(force_url=True)
@property
def prompt_message_content(self) -> ImagePromptMessageContent:
if self.type == FileType.IMAGE:
image_config = self.extra_config.image_config
return ImagePromptMessageContent(
data=self.data,
detail=ImagePromptMessageContent.DETAIL.HIGH
if image_config.get("detail") == "high"
else ImagePromptMessageContent.DETAIL.LOW,
)
def _get_data(self, force_url: bool = False) -> Optional[str]:
from models.model import UploadFile
if self.type == FileType.IMAGE:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
upload_file = (
db.session.query(UploadFile)
.filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id)
.first()
)
return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url)
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
extension = self.extension
# add sign url
return ToolFileParser.get_tool_file_manager().sign_file(
tool_file_id=self.related_id, extension=extension
)
return None

View file

@ -0,0 +1,32 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from models import ToolFile, UploadFile
from .models import File
def get_upload_file(*, session: Session, file: File):
if file.related_id is None:
raise ValueError("Missing file related_id")
stmt = select(UploadFile).filter(
UploadFile.id == file.related_id,
UploadFile.tenant_id == file.tenant_id,
)
record = session.scalar(stmt)
if not record:
raise ValueError(f"upload file {file.related_id} not found")
return record
def get_tool_file(*, session: Session, file: File):
if file.related_id is None:
raise ValueError("Missing file related_id")
stmt = select(ToolFile).filter(
ToolFile.id == file.related_id,
ToolFile.tenant_id == file.tenant_id,
)
record = session.scalar(stmt)
if not record:
raise ValueError(f"tool file {file.related_id} not found")
return record

48
api/core/file/helpers.py Normal file
View file

@ -0,0 +1,48 @@
import base64
import hashlib
import hmac
import os
import time
from configs import dify_config
def get_signed_file_url(upload_file_id: str) -> str:
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
key = dify_config.SECRET_KEY.encode()
msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT

View file

@ -1,243 +0,0 @@
import re
from collections.abc import Mapping, Sequence
from typing import Any, Union
from urllib.parse import parse_qs, urlparse
import requests
from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser, MessageFile, UploadFile
from services.file_service import IMAGE_EXTENSIONS
class MessageFileParser:
def __init__(self, tenant_id: str, app_id: str) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
def validate_and_transform_files_arg(
self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser]
) -> list[FileVar]:
"""
validate and transform files arg
:param files:
:param file_extra_config:
:param user:
:return:
"""
for file in files:
if not isinstance(file, dict):
raise ValueError("Invalid file format, must be dict")
if not file.get("type"):
raise ValueError("Missing file type")
FileType.value_of(file.get("type"))
if not file.get("transfer_method"):
raise ValueError("Missing file transfer method")
FileTransferMethod.value_of(file.get("transfer_method"))
if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value:
if not file.get("url"):
raise ValueError("Missing file url")
if not file.get("url").startswith("http"):
raise ValueError("Invalid file url")
if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"):
raise ValueError("Missing file upload_file_id")
if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"):
raise ValueError("Missing file tool_file_id")
# transform files to file objs
type_file_objs = self._to_file_objs(files, file_extra_config)
# validate files
new_files = []
for file_type, file_objs in type_file_objs.items():
if file_type == FileType.IMAGE:
# parse and validate files
image_config = file_extra_config.image_config
# check if image file feature is enabled
if not image_config:
continue
# Validate number of files
if len(files) > image_config["number_limits"]:
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
for file_obj in file_objs:
# Validate transfer method
if file_obj.transfer_method.value not in image_config["transfer_methods"]:
raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}")
# Validate file type
if file_obj.type != FileType.IMAGE:
raise ValueError(f"Invalid file type: {file_obj.type}")
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
# check remote url valid and is image
result, error = self._check_image_remote_url(file_obj.url)
if result is False:
raise ValueError(error)
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
# get upload file from upload_file_id
upload_file = (
db.session.query(UploadFile)
.filter(
UploadFile.id == file_obj.related_id,
UploadFile.tenant_id == self.tenant_id,
UploadFile.created_by == user.id,
UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
UploadFile.extension.in_(IMAGE_EXTENSIONS),
)
.first()
)
# check upload file is belong to tenant and user
if not upload_file:
raise ValueError("Invalid upload file")
new_files.append(file_obj)
# return all file objs
return new_files
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
"""
transform message files
:param files:
:param file_extra_config:
:return:
"""
# transform files to file objs
type_file_objs = self._to_file_objs(files, file_extra_config)
# return all file objs
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
def _to_file_objs(
self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig
) -> dict[FileType, list[FileVar]]:
"""
transform files to file objs
:param files:
:param file_extra_config:
:return:
"""
type_file_objs: dict[FileType, list[FileVar]] = {
# Currently only support image
FileType.IMAGE: []
}
if not files:
return type_file_objs
# group by file type and convert file args or message files to FileObj
for file in files:
if isinstance(file, MessageFile):
if file.belongs_to == FileBelongsTo.ASSISTANT.value:
continue
file_obj = self._to_file_obj(file, file_extra_config)
if file_obj.type not in type_file_objs:
continue
type_file_objs[file_obj.type].append(file_obj)
return type_file_objs
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
"""
transform file to file obj
:param file:
:return:
"""
if isinstance(file, dict):
transfer_method = FileTransferMethod.value_of(file.get("transfer_method"))
if transfer_method != FileTransferMethod.TOOL_FILE:
return FileVar(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get("type")),
transfer_method=transfer_method,
url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None,
extra_config=file_extra_config,
)
return FileVar(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get("type")),
transfer_method=transfer_method,
url=None,
related_id=file.get("tool_file_id"),
extra_config=file_extra_config,
)
else:
return FileVar(
id=file.id,
tenant_id=self.tenant_id,
type=FileType.value_of(file.type),
transfer_method=FileTransferMethod.value_of(file.transfer_method),
url=file.url,
related_id=file.upload_file_id or None,
extra_config=file_extra_config,
)
def _check_image_remote_url(self, url):
try:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
" Chrome/91.0.4472.124 Safari/537.36"
}
def is_s3_presigned_url(url):
try:
parsed_url = urlparse(url)
if "amazonaws.com" not in parsed_url.netloc:
return False
query_params = parse_qs(parsed_url.query)
def check_presign_v2(query_params):
required_params = ["Signature", "Expires"]
for param in required_params:
if param not in query_params:
return False
if not query_params["Expires"][0].isdigit():
return False
signature = query_params["Signature"][0]
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
return False
return True
def check_presign_v4(query_params):
required_params = ["X-Amz-Signature", "X-Amz-Expires"]
for param in required_params:
if param not in query_params:
return False
if not query_params["X-Amz-Expires"][0].isdigit():
return False
signature = query_params["X-Amz-Signature"][0]
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
return False
return True
return check_presign_v4(query_params) or check_presign_v2(query_params)
except Exception:
return False
if is_s3_presigned_url(url):
response = requests.get(url, headers=headers, allow_redirects=True)
if response.status_code in {200, 304}:
return True, ""
response = requests.head(url, headers=headers, allow_redirects=True)
if response.status_code in {200, 304}:
return True, ""
else:
return False, "URL does not exist."
except requests.RequestException as e:
return False, f"Error checking URL: {e}"

140
api/core/file/models.py Normal file
View file

@ -0,0 +1,140 @@
from collections.abc import Mapping, Sequence
from typing import Optional
from pydantic import BaseModel, Field, model_validator
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from . import helpers
from .constants import FILE_MODEL_IDENTITY
from .enums import FileTransferMethod, FileType
from .tool_file_parser import ToolFileParser
class ImageConfig(BaseModel):
"""
NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
"""
number_limits: int = 0
transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
detail: ImagePromptMessageContent.DETAIL | None = None
class FileExtraConfig(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[ImageConfig] = None
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_extensions: Sequence[str] = Field(default_factory=list)
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
number_limits: int = 0
class File(BaseModel):
dify_model_identity: str = FILE_MODEL_IDENTITY
id: Optional[str] = None # message file id
tenant_id: str
type: FileType
transfer_method: FileTransferMethod
remote_url: Optional[str] = None # remote url
related_id: Optional[str] = None
filename: Optional[str] = None
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
mime_type: Optional[str] = None
size: int = -1
_extra_config: FileExtraConfig | None = None
def to_dict(self) -> Mapping[str, str | int | None]:
data = self.model_dump(mode="json")
return {
**data,
"url": self.generate_url(),
}
@property
def markdown(self) -> str:
url = self.generate_url()
if self.type == FileType.IMAGE:
text = f'![{self.filename or ""}]({url})'
else:
text = f"[{self.filename or url}]({url})"
return text
def generate_url(self) -> Optional[str]:
if self.type == FileType.IMAGE:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.remote_url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
if self.related_id is None:
raise ValueError("Missing file related_id")
return helpers.get_signed_file_url(upload_file_id=self.related_id)
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
assert self.related_id is not None
assert self.extension is not None
return ToolFileParser.get_tool_file_manager().sign_file(
tool_file_id=self.related_id, extension=self.extension
)
else:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.remote_url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
if self.related_id is None:
raise ValueError("Missing file related_id")
return helpers.get_signed_file_url(upload_file_id=self.related_id)
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
assert self.related_id is not None
assert self.extension is not None
return ToolFileParser.get_tool_file_manager().sign_file(
tool_file_id=self.related_id, extension=self.extension
)
@model_validator(mode="after")
def validate_after(self):
match self.transfer_method:
case FileTransferMethod.REMOTE_URL:
if not self.remote_url:
raise ValueError("Missing file url")
if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"):
raise ValueError("Invalid file url")
case FileTransferMethod.LOCAL_FILE:
if not self.related_id:
raise ValueError("Missing file related_id")
case FileTransferMethod.TOOL_FILE:
if not self.related_id:
raise ValueError("Missing file related_id")
# Validate the extra config.
if not self._extra_config:
return self
if self._extra_config.allowed_file_types:
if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM:
raise ValueError(f"Invalid file type: {self.type}")
if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions:
raise ValueError(f"Invalid file extension: {self.extension}")
if (
self._extra_config.allowed_upload_methods
and self.transfer_method not in self._extra_config.allowed_upload_methods
):
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
match self.type:
case FileType.IMAGE:
# NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
if not self._extra_config.image_config:
return self
# TODO: skip check if transfer_methods is empty, because many test cases are not setting this field
if (
self._extra_config.image_config.transfer_methods
and self.transfer_method not in self._extra_config.image_config.transfer_methods
):
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
return self

View file

@ -1,4 +1,9 @@
tool_file_manager = {"manager": None}
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from core.tools.tool_file_manager import ToolFileManager
tool_file_manager: dict[str, Any] = {"manager": None}
class ToolFileParser:

View file

@ -1,79 +0,0 @@
import base64
import hashlib
import hmac
import logging
import os
import time
from typing import Optional
from configs import dify_config
from extensions.ext_storage import storage
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
class UploadFileParser:
@classmethod
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
if not upload_file:
return None
if upload_file.extension not in IMAGE_EXTENSIONS:
return None
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url:
return cls.get_signed_temp_image_url(upload_file.id)
else:
# get image file base64
try:
data = storage.load(upload_file.key)
except FileNotFoundError:
logging.error(f"File not found: {upload_file.key}")
return None
encoded_string = base64.b64encode(data).decode("utf-8")
return f"data:{upload_file.mime_type};base64,{encoded_string}"
@classmethod
def get_signed_temp_image_url(cls, upload_file_id) -> str:
"""
get signed url from upload file
:param upload_file: UploadFile object
:return:
"""
base_url = dify_config.FILES_URL
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode()
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@classmethod
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
:param upload_file_id: file id
:param timestamp: timestamp
:param nonce: nonce
:param sign: signature
:return:
"""
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT

View file

@ -13,8 +13,11 @@ SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "")
SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "")
SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3"))
proxies = (
{"http://": SSRF_PROXY_HTTP_URL, "https://": SSRF_PROXY_HTTPS_URL}
proxy_mounts = (
{
"http://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTP_URL),
"https://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTPS_URL),
}
if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL
else None
)
@ -33,11 +36,14 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
while retries <= max_retries:
try:
if SSRF_PROXY_ALL_URL:
response = httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs)
elif proxies:
response = httpx.request(method=method, url=url, proxies=proxies, **kwargs)
with httpx.Client(proxy=SSRF_PROXY_ALL_URL) as client:
response = client.request(method=method, url=url, **kwargs)
elif proxy_mounts:
with httpx.Client(mounts=proxy_mounts) as client:
response = client.request(method=method, url=url, **kwargs)
else:
response = httpx.request(method=method, url=url, **kwargs)
with httpx.Client() as client:
response = client.request(method=method, url=url, **kwargs)
if response.status_code not in STATUS_FORCELIST:
return response

View file

@ -1,18 +1,20 @@
from typing import Optional
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file.message_file_parser import MessageFileParser
from core.file import file_manager
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageRole,
TextPromptMessageContent,
UserPromptMessage,
)
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from factories import file_factory
from models.model import AppMode, Conversation, Message, MessageFile
from models.workflow import WorkflowRun
@ -65,13 +67,12 @@ class TokenBufferMemory:
messages = list(reversed(thread_messages))
message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id)
prompt_messages = []
for message in messages:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files:
file_extra_config = None
if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
else:
if message.workflow_run_id:
@ -84,17 +85,21 @@ class TokenBufferMemory:
workflow_run.workflow.features_dict, is_vision=False
)
if file_extra_config:
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
if file_extra_config and app_record:
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
)
else:
file_objs = []
if not file_objs:
prompt_messages.append(UserPromptMessage(content=message.query))
else:
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file_obj in file_objs:
prompt_message_contents.append(file_obj.prompt_message_content)
prompt_message = file_manager.to_prompt_message_content(file_obj)
prompt_message_contents.append(prompt_message)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:

View file

@ -1,7 +1,7 @@
import logging
import os
from collections.abc import Callable, Generator, Sequence
from typing import IO, Optional, Union, cast
from collections.abc import Callable, Generator, Iterable, Sequence
from typing import IO, Any, Optional, Union, cast
from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
@ -274,7 +274,7 @@ class ModelInstance:
user=user,
)
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str:
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]:
"""
Invoke large language tts model
@ -298,7 +298,7 @@ class ModelInstance:
voice=voice,
)
def _round_robin_invoke(self, function: Callable, *args, **kwargs):
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
"""
Round-robin invoke
:param function: function to invoke

View file

@ -0,0 +1,38 @@
from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from .message_entities import (
AssistantPromptMessage,
AudioPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageRole,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from .model_entities import ModelPropertyKey
__all__ = [
"ImagePromptMessageContent",
"PromptMessage",
"PromptMessageRole",
"LLMUsage",
"ModelPropertyKey",
"AssistantPromptMessage",
"PromptMessage",
"PromptMessageContent",
"PromptMessageRole",
"SystemPromptMessage",
"TextPromptMessageContent",
"UserPromptMessage",
"PromptMessageTool",
"ToolPromptMessage",
"PromptMessageContentType",
"LLMResult",
"LLMResultChunk",
"LLMResultChunkDelta",
"AudioPromptMessageContent",
]

View file

@ -2,7 +2,7 @@ from abc import ABC
from enum import Enum
from typing import Optional
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, Field, field_validator
class PromptMessageRole(Enum):
@ -55,6 +55,7 @@ class PromptMessageContentType(Enum):
TEXT = "text"
IMAGE = "image"
AUDIO = "audio"
class PromptMessageContent(BaseModel):
@ -74,12 +75,18 @@ class TextPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.TEXT
class AudioPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.AUDIO
data: str = Field(..., description="Base64 encoded audio data")
format: str = Field(..., description="Audio format")
class ImagePromptMessageContent(PromptMessageContent):
"""
Model class for image prompt message content.
"""
class DETAIL(Enum):
class DETAIL(str, Enum):
LOW = "low"
HIGH = "high"

View file

@ -1,5 +1,4 @@
import logging
import os
import re
import time
from abc import abstractmethod
@ -8,6 +7,7 @@ from typing import Optional, Union
from pydantic import ConfigDict
from configs import dify_config
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.callbacks.logging_callback import LoggingCallback
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@ -77,7 +77,7 @@ class LargeLanguageModel(AIModel):
callbacks = callbacks or []
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
if dify_config.DEBUG:
callbacks.append(LoggingCallback())
# trigger before invoke callbacks
@ -107,7 +107,16 @@ class LargeLanguageModel(AIModel):
callbacks=callbacks,
)
else:
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
result = self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
)
except Exception as e:
self._trigger_invoke_error_callbacks(
model=model,

View file

@ -1,6 +1,7 @@
import logging
import re
from abc import abstractmethod
from collections.abc import Iterable
from typing import Any, Optional
from pydantic import ConfigDict
@ -22,8 +23,14 @@ class TTSModel(AIModel):
model_config = ConfigDict(protected_namespaces=())
def invoke(
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
):
self,
model: str,
tenant_id: str,
credentials: dict,
content_text: str,
voice: str,
user: Optional[str] = None,
) -> Iterable[bytes]:
"""
Invoke large language model
@ -50,8 +57,14 @@ class TTSModel(AIModel):
@abstractmethod
def _invoke(
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
):
self,
model: str,
tenant_id: str,
credentials: dict,
content_text: str,
voice: str,
user: Optional[str] = None,
) -> Iterable[bytes]:
"""
Invoke large language model
@ -68,25 +81,25 @@ class TTSModel(AIModel):
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
"""
Get voice for given tts model voices
Retrieves the list of voices supported by a given text-to-speech (TTS) model.
:param language: tts language
:param model: model name
:param credentials: model credentials
:return: voices lists
:param language: The language for which the voices are requested.
:param model: The name of the TTS model.
:param credentials: The credentials required to access the TTS model.
:return: A list of voices supported by the TTS model.
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties:
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
if language:
return [
{"name": d["name"], "value": d["mode"]}
for d in voices
if language and language in d.get("language")
]
else:
return [{"name": d["name"], "value": d["mode"]} for d in voices]
if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties:
raise ValueError("this model does not support voice")
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
if language:
return [
{"name": d["name"], "value": d["mode"]} for d in voices if language and language in d.get("language")
]
else:
return [{"name": d["name"], "value": d["mode"]} for d in voices]
def _get_model_default_voice(self, model: str, credentials: dict) -> Any:
"""
@ -111,8 +124,10 @@ class TTSModel(AIModel):
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties:
raise ValueError("this model does not support audio type")
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
"""
@ -121,8 +136,10 @@ class TTSModel(AIModel):
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties:
raise ValueError("this model does not support word limit")
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
"""
@ -131,8 +148,10 @@ class TTSModel(AIModel):
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties:
raise ValueError("this model does not support max workers")
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
@staticmethod
def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"):

View file

@ -1,3 +1,4 @@
- gpt-4o-audio-preview
- gpt-4
- gpt-4o
- gpt-4o-2024-05-13

View file

@ -0,0 +1,44 @@
model: gpt-4o-audio-preview
label:
zh_Hans: gpt-4o-audio-preview
en_US: gpt-4o-audio-preview
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 4096
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '5.00'
output: '15.00'
unit: '0.000001'
currency: USD

View file

@ -1,7 +1,7 @@
import json
import logging
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import Any, Optional, Union, cast
import tiktoken
from openai import OpenAI, Stream
@ -11,9 +11,9 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, Cho
from openai.types.chat.chat_completion_message import FunctionCall
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities import (
AssistantPromptMessage,
AudioPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
@ -23,6 +23,7 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType, PriceConfig
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@ -613,6 +614,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
# clear illegal prompt messages
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
# o1 compatibility
block_as_stream = False
if model.startswith("o1"):
if stream:
@ -626,8 +628,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
del extra_model_kwargs["stop"]
# chat model
messages: Any = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
response = client.chat.completions.create(
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
messages=messages,
model=model,
stream=stream,
**model_parameters,
@ -946,23 +949,29 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
Convert PromptMessage to dict for OpenAI API
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
elif isinstance(message.content, list):
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
if isinstance(message_content, TextPromptMessageContent):
sub_message_dict = {"type": "text", "text": message_content.data}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
elif isinstance(message_content, ImagePromptMessageContent):
sub_message_dict = {
"type": "image_url",
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
}
sub_messages.append(sub_message_dict)
elif isinstance(message_content, AudioPromptMessageContent):
sub_message_dict = {
"type": "input_audio",
"input_audio": {
"data": message_content.data,
"format": message_content.format,
},
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):

View file

@ -358,8 +358,8 @@ class TraceTask:
workflow_run_id = workflow_run.id
workflow_run_elapsed_time = workflow_run.elapsed_time
workflow_run_status = workflow_run.status
workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {}
workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {}
workflow_run_inputs = workflow_run.inputs_dict
workflow_run_outputs = workflow_run.outputs_dict
workflow_run_version = workflow_run.version
error = workflow_run.error or ""

View file

@ -1,12 +1,15 @@
from typing import Optional, Union
from collections.abc import Sequence
from typing import Optional
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file.file_obj import FileVar
from core.file import file_manager
from core.file.models import File
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContent,
PromptMessageRole,
SystemPromptMessage,
TextPromptMessageContent,
@ -14,8 +17,8 @@ from core.model_runtime.entities.message_entities import (
)
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.prompt_transform import PromptTransform
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.variable_pool import VariablePool
class AdvancedPromptTransform(PromptTransform):
@ -28,22 +31,19 @@ class AdvancedPromptTransform(PromptTransform):
def get_prompt(
self,
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate],
inputs: dict,
*,
prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate,
inputs: dict[str, str],
query: str,
files: list[FileVar],
files: Sequence[File],
context: Optional[str],
memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
query_prompt_template: Optional[str] = None,
) -> list[PromptMessage]:
inputs = {key: str(value) for key, value in inputs.items()}
prompt_messages = []
model_mode = ModelMode.value_of(model_config.mode)
if model_mode == ModelMode.COMPLETION:
if isinstance(prompt_template, CompletionModelPromptTemplate):
prompt_messages = self._get_completion_model_prompt_messages(
prompt_template=prompt_template,
inputs=inputs,
@ -54,12 +54,11 @@ class AdvancedPromptTransform(PromptTransform):
memory=memory,
model_config=model_config,
)
elif model_mode == ModelMode.CHAT:
elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
prompt_messages = self._get_chat_model_prompt_messages(
prompt_template=prompt_template,
inputs=inputs,
query=query,
query_prompt_template=query_prompt_template,
files=files,
context=context,
memory_config=memory_config,
@ -74,7 +73,7 @@ class AdvancedPromptTransform(PromptTransform):
prompt_template: CompletionModelPromptTemplate,
inputs: dict,
query: Optional[str],
files: list[FileVar],
files: Sequence[File],
context: Optional[str],
memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory],
@ -88,10 +87,10 @@ class AdvancedPromptTransform(PromptTransform):
prompt_messages = []
if prompt_template.edition_type == "basic" or not prompt_template.edition_type:
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
if memory and memory_config:
role_prefix = memory_config.role_prefix
@ -100,15 +99,15 @@ class AdvancedPromptTransform(PromptTransform):
memory_config=memory_config,
raw_prompt=raw_prompt,
role_prefix=role_prefix,
prompt_template=prompt_template,
parser=parser,
prompt_inputs=prompt_inputs,
model_config=model_config,
)
if query:
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
prompt_inputs = self._set_query_variable(query, parser, prompt_inputs)
prompt = prompt_template.format(prompt_inputs)
prompt = parser.format(prompt_inputs)
else:
prompt = raw_prompt
prompt_inputs = inputs
@ -116,9 +115,10 @@ class AdvancedPromptTransform(PromptTransform):
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
@ -131,35 +131,38 @@ class AdvancedPromptTransform(PromptTransform):
prompt_template: list[ChatModelMessage],
inputs: dict,
query: Optional[str],
files: list[FileVar],
files: Sequence[File],
context: Optional[str],
memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
query_prompt_template: Optional[str] = None,
) -> list[PromptMessage]:
"""
Get chat model prompt messages.
"""
raw_prompt_list = prompt_template
prompt_messages = []
for prompt_item in raw_prompt_list:
for prompt_item in prompt_template:
raw_prompt = prompt_item.text
if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
prompt = prompt_template.format(prompt_inputs)
if self.with_variable_tmpl:
vp = VariablePool()
for k, v in inputs.items():
if k.startswith("#"):
vp.add(k[1:-1].split("."), v)
raw_prompt = raw_prompt.replace("{{#context#}}", context or "")
prompt = vp.convert_template(raw_prompt).text
else:
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
prompt_inputs = self._set_context_variable(
context=context, parser=parser, prompt_inputs=prompt_inputs
)
prompt = parser.format(prompt_inputs)
elif prompt_item.edition_type == "jinja2":
prompt = raw_prompt
prompt_inputs = inputs
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
prompt = Jinja2Formatter.format(template=prompt, inputs=prompt_inputs)
else:
raise ValueError(f"Invalid edition type: {prompt_item.edition_type}")
@ -170,25 +173,25 @@ class AdvancedPromptTransform(PromptTransform):
elif prompt_item.role == PromptMessageRole.ASSISTANT:
prompt_messages.append(AssistantPromptMessage(content=prompt))
if query and query_prompt_template:
prompt_template = PromptTemplateParser(
template=query_prompt_template, with_variable_tmpl=self.with_variable_tmpl
if query and memory_config and memory_config.query_prompt_template:
parser = PromptTemplateParser(
template=memory_config.query_prompt_template, with_variable_tmpl=self.with_variable_tmpl
)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
prompt_inputs["#sys.query#"] = query
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
query = prompt_template.format(prompt_inputs)
query = parser.format(prompt_inputs)
if memory and memory_config:
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
if files:
prompt_message_contents = [TextPromptMessageContent(data=query)]
if files and query is not None:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
@ -200,19 +203,19 @@ class AdvancedPromptTransform(PromptTransform):
# get last user message content and add files
prompt_message_contents = [TextPromptMessageContent(data=last_message.content)]
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
last_message.content = prompt_message_contents
else:
prompt_message_contents = [TextPromptMessageContent(data="")] # not for query
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_message_contents = [TextPromptMessageContent(data=query)]
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
elif query:
@ -220,8 +223,8 @@ class AdvancedPromptTransform(PromptTransform):
return prompt_messages
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
if "#context#" in prompt_template.variable_keys:
def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
if "#context#" in parser.variable_keys:
if context:
prompt_inputs["#context#"] = context
else:
@ -229,8 +232,8 @@ class AdvancedPromptTransform(PromptTransform):
return prompt_inputs
def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
if "#query#" in prompt_template.variable_keys:
def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
if "#query#" in parser.variable_keys:
if query:
prompt_inputs["#query#"] = query
else:
@ -244,16 +247,16 @@ class AdvancedPromptTransform(PromptTransform):
memory_config: MemoryConfig,
raw_prompt: str,
role_prefix: MemoryConfig.RolePrefix,
prompt_template: PromptTemplateParser,
parser: PromptTemplateParser,
prompt_inputs: dict,
model_config: ModelConfigWithCredentialsEntity,
) -> dict:
if "#histories#" in prompt_template.variable_keys:
if "#histories#" in parser.variable_keys:
if memory:
inputs = {"#histories#": "", **prompt_inputs}
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
tmp_human_message = UserPromptMessage(content=prompt_template.format(prompt_inputs))
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)

View file

@ -5,9 +5,11 @@ from typing import TYPE_CHECKING, Optional
from core.app.app_config.entities import PromptTemplateEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import file_manager
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageContent,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
@ -18,10 +20,10 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from models.model import AppMode
if TYPE_CHECKING:
from core.file.file_obj import FileVar
from core.file.models import File
class ModelMode(enum.Enum):
class ModelMode(str, enum.Enum):
COMPLETION = "completion"
CHAT = "chat"
@ -53,7 +55,7 @@ class SimplePromptTransform(PromptTransform):
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
query: str,
files: list["FileVar"],
files: list["File"],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
@ -169,7 +171,7 @@ class SimplePromptTransform(PromptTransform):
inputs: dict,
query: str,
context: Optional[str],
files: list["FileVar"],
files: list["File"],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
@ -214,7 +216,7 @@ class SimplePromptTransform(PromptTransform):
inputs: dict,
query: str,
context: Optional[str],
files: list["FileVar"],
files: list["File"],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
@ -261,11 +263,12 @@ class SimplePromptTransform(PromptTransform):
return [self.get_last_user_message(prompt, files)], stops
def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage:
def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage:
if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
prompt_message = UserPromptMessage(content=prompt_message_contents)
else:

View file

@ -1,7 +1,9 @@
from typing import Any
from constants import UUID_NIL
def extract_thread_messages(messages: list[dict]) -> list[dict]:
def extract_thread_messages(messages: list[Any]):
thread_messages = []
next_message = None

View file

@ -1,7 +1,8 @@
from typing import cast
from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities import (
AssistantPromptMessage,
AudioPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
@ -21,7 +22,7 @@ class PromptMessageUtil:
:return:
"""
prompts = []
if model_mode == ModelMode.CHAT.value:
if model_mode == ModelMode.CHAT:
tool_calls = []
for prompt_message in prompt_messages:
if prompt_message.role == PromptMessageRole.USER:
@ -51,11 +52,9 @@ class PromptMessageUtil:
files = []
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
if isinstance(content, TextPromptMessageContent):
text += content.data
else:
content = cast(ImagePromptMessageContent, content)
elif isinstance(content, ImagePromptMessageContent):
files.append(
{
"type": "image",
@ -63,6 +62,14 @@ class PromptMessageUtil:
"detail": content.detail.value,
}
)
elif isinstance(content, AudioPromptMessageContent):
files.append(
{
"type": "audio",
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
"format": content.format,
}
)
else:
text = prompt_message.content

View file

@ -121,7 +121,7 @@ class WordExtractor(BaseExtractor):
db.session.add(upload_file)
db.session.commit()
image_map[rel.target_part] = (
f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)"
f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/file-preview)"
)
return image_map

View file

@ -9,7 +9,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.llm import LLMNode
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""

View file

@ -32,8 +32,8 @@ class UserToolProvider(BaseModel):
original_credentials: Optional[dict] = None
is_team_authorization: bool = False
allow_delete: bool = True
tools: list[UserTool] = None
labels: list[str] = None
tools: list[UserTool] | None = None
labels: list[str] | None = None
def to_dict(self) -> dict:
# -------------
@ -42,7 +42,7 @@ class UserToolProvider(BaseModel):
for tool in tools:
if tool.get("parameters"):
for parameter in tool.get("parameters"):
if parameter.get("type") == ToolParameter.ToolParameterType.FILE.value:
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
parameter["type"] = "files"
# -------------

View file

@ -104,14 +104,15 @@ class ToolInvokeMessage(BaseModel):
BLOB = "blob"
JSON = "json"
IMAGE_LINK = "image_link"
FILE_VAR = "file_var"
FILE = "file"
type: MessageType = MessageType.TEXT
"""
plain text, image url or link url
"""
message: str | bytes | dict | None = None
meta: dict[str, Any] | None = None
# TODO: Use a BaseModel for meta
meta: dict[str, Any] = Field(default_factory=dict)
save_as: str = ""
@ -143,6 +144,67 @@ class ToolParameter(BaseModel):
SELECT = "select"
SECRET_INPUT = "secret-input"
FILE = "file"
FILES = "files"
# deprecated, should not use.
SYSTEM_FILES = "systme-files"
def as_normal_type(self):
if self in {
ToolParameter.ToolParameterType.SECRET_INPUT,
ToolParameter.ToolParameterType.SELECT,
}:
return "string"
return self.value
def cast_value(self, value: Any, /):
try:
match self:
case (
ToolParameter.ToolParameterType.STRING
| ToolParameter.ToolParameterType.SECRET_INPUT
| ToolParameter.ToolParameterType.SELECT
):
if value is None:
return ""
else:
return value if isinstance(value, str) else str(value)
case ToolParameter.ToolParameterType.BOOLEAN:
if value is None:
return False
elif isinstance(value, str):
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
# and also '0' for False and '1' for True
match value.lower():
case "true" | "yes" | "y" | "1":
return True
case "false" | "no" | "n" | "0":
return False
case _:
return bool(value)
else:
return value if isinstance(value, bool) else bool(value)
case ToolParameter.ToolParameterType.NUMBER:
if isinstance(value, int | float):
return value
elif isinstance(value, str) and value:
if "." in value:
return float(value)
else:
return int(value)
case (
ToolParameter.ToolParameterType.SYSTEM_FILES
| ToolParameter.ToolParameterType.FILE
| ToolParameter.ToolParameterType.FILES
):
return value
case _:
return str(value)
except Exception:
raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.")
class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool

View file

@ -66,7 +66,7 @@ class DallE3Tool(BuiltinTool):
for image in response.data:
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
blob_message = self.create_blob_message(
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE
)
result.append(blob_message)
return result

View file

@ -2,7 +2,7 @@ from typing import Any
from duckduckgo_search import DDGS
from core.file.file_obj import FileTransferMethod
from core.file.models import FileTransferMethod
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View file

@ -0,0 +1,24 @@
<svg width="100" height="100" viewBox="0 0 100 100" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="100" height="100" rx="20" fill="#4A90E2" />
<path
d="M50 25C40.6 25 33 32.6 33 42V58C33 67.4 40.6 75 50 75C59.4 75 67 67.4 67 58V42C67 32.6 59.4 25 50 25ZM61 58C61 64.1 56.1 69 50 69C43.9 69 39 64.1 39 58V42C39 35.9 43.9 31 50 31C56.1 31 61 35.9 61 42V58Z"
fill="white" />
<path d="M50 37C47.2 37 45 39.2 45 42V58C45 60.8 47.2 63 50 63C52.8 63 55 60.8 55 58V42C55 39.2 52.8 37 50 37Z"
fill="white" />
<path
d="M73 49H69V58C69 68.5 60.5 77 50 77C39.5 77 31 68.5 31 58V49H27V58C27 70.7 37.3 81 50 81C62.7 81 73 70.7 73 58V49Z"
fill="white" />
<path d="M50 85C51.1 85 52 84.1 52 83V81H48V83C48 84.1 48.9 85 50 85Z" fill="white" />
<path
d="M35 45C36.1046 45 37 44.1046 37 43C37 41.8954 36.1046 41 35 41C33.8954 41 33 41.8954 33 43C33 44.1046 33.8954 45 35 45Z"
fill="white" />
<path
d="M35 55C36.1046 55 37 54.1046 37 53C37 51.8954 36.1046 51 35 51C33.8954 51 33 51.8954 33 53C33 54.1046 33.8954 55 35 55Z"
fill="white" />
<path
d="M65 45C66.1046 45 67 44.1046 67 43C67 41.8954 66.1046 41 65 41C63.8954 41 63 41.8954 63 43C63 44.1046 63.8954 45 65 45Z"
fill="white" />
<path
d="M65 55C66.1046 55 67 54.1046 67 53C67 51.8954 66.1046 51 65 51C63.8954 51 63 51.8954 63 53C63 54.1046 63.8954 55 65 55Z"
fill="white" />
</svg>

After

Width:  |  Height:  |  Size: 1.4 KiB

View file

@ -0,0 +1,33 @@
from typing import Any
import openai
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class PodcastGeneratorProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
tts_service = credentials.get("tts_service")
api_key = credentials.get("api_key")
if not tts_service:
raise ToolProviderCredentialValidationError("TTS service is not specified")
if not api_key:
raise ToolProviderCredentialValidationError("API key is missing")
if tts_service == "openai":
self._validate_openai_credentials(api_key)
else:
raise ToolProviderCredentialValidationError(f"Unsupported TTS service: {tts_service}")
def _validate_openai_credentials(self, api_key: str) -> None:
client = openai.OpenAI(api_key=api_key)
try:
# We're using a simple API call to validate the credentials
client.models.list()
except openai.AuthenticationError:
raise ToolProviderCredentialValidationError("Invalid OpenAI API key")
except Exception as e:
raise ToolProviderCredentialValidationError(f"Error validating OpenAI API key: {str(e)}")

View file

@ -0,0 +1,34 @@
identity:
author: Dify
name: podcast_generator
label:
en_US: Podcast Generator
zh_Hans: 播客生成器
description:
en_US: Generate podcast audio using Text-to-Speech services
zh_Hans: 使用文字转语音服务生成播客音频
icon: icon.svg
credentials_for_provider:
tts_service:
type: select
required: true
label:
en_US: TTS Service
zh_Hans: TTS 服务
placeholder:
en_US: Select a TTS service
zh_Hans: 选择一个 TTS 服务
options:
- label:
en_US: OpenAI TTS
zh_Hans: OpenAI TTS
value: openai
api_key:
type: secret-input
required: true
label:
en_US: API Key
zh_Hans: API 密钥
placeholder:
en_US: Enter your TTS service API key
zh_Hans: 输入您的 TTS 服务 API 密钥

View file

@ -0,0 +1,100 @@
import concurrent.futures
import io
import random
from typing import Any, Literal, Optional, Union
import openai
from pydub import AudioSegment
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolParameterValidationError, ToolProviderCredentialValidationError
from core.tools.tool.builtin_tool import BuiltinTool
class PodcastAudioGeneratorTool(BuiltinTool):
@staticmethod
def _generate_silence(duration: float):
# Generate silent WAV data using pydub
silence = AudioSegment.silent(duration=int(duration * 1000)) # pydub uses milliseconds
return silence
@staticmethod
def _generate_audio_segment(
client: openai.OpenAI,
line: str,
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"],
index: int,
) -> tuple[int, Union[AudioSegment, str], Optional[AudioSegment]]:
try:
response = client.audio.speech.create(model="tts-1", voice=voice, input=line.strip(), response_format="wav")
audio = AudioSegment.from_wav(io.BytesIO(response.content))
silence_duration = random.uniform(0.1, 1.5)
silence = PodcastAudioGeneratorTool._generate_silence(silence_duration)
return index, audio, silence
except Exception as e:
return index, f"Error generating audio: {str(e)}", None
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
# Extract parameters
script = tool_parameters.get("script", "")
host1_voice = tool_parameters.get("host1_voice")
host2_voice = tool_parameters.get("host2_voice")
# Split the script into lines
script_lines = [line for line in script.split("\n") if line.strip()]
# Ensure voices are provided
if not host1_voice or not host2_voice:
raise ToolParameterValidationError("Host voices are required")
# Get OpenAI API key from credentials
if not self.runtime or not self.runtime.credentials:
raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing")
api_key = self.runtime.credentials.get("api_key")
if not api_key:
raise ToolProviderCredentialValidationError("OpenAI API key is missing")
# Initialize OpenAI client
client = openai.OpenAI(api_key=api_key)
# Create a thread pool
max_workers = 5
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for i, line in enumerate(script_lines):
voice = host1_voice if i % 2 == 0 else host2_voice
future = executor.submit(self._generate_audio_segment, client, line, voice, i)
futures.append(future)
# Collect results
audio_segments: list[Any] = [None] * len(script_lines)
for future in concurrent.futures.as_completed(futures):
index, audio, silence = future.result()
if isinstance(audio, str): # Error occurred
return self.create_text_message(audio)
audio_segments[index] = (audio, silence)
# Combine audio segments in the correct order
combined_audio = AudioSegment.empty()
for i, (audio, silence) in enumerate(audio_segments):
if audio:
combined_audio += audio
if i < len(audio_segments) - 1 and silence:
combined_audio += silence
# Export the combined audio to a WAV file in memory
buffer = io.BytesIO()
combined_audio.export(buffer, format="wav")
wav_bytes = buffer.getvalue()
# Create a blob message with the combined audio
return [
self.create_text_message("Audio generated successfully"),
self.create_blob_message(
blob=wav_bytes,
meta={"mime_type": "audio/x-wav"},
save_as=self.VariableKey.AUDIO,
),
]

View file

@ -0,0 +1,95 @@
identity:
name: podcast_audio_generator
author: Dify
label:
en_US: Podcast Audio Generator
zh_Hans: 播客音频生成器
description:
human:
en_US: Generate a podcast audio file from a script with two alternating voices using OpenAI's TTS service.
zh_Hans: 使用 OpenAI 的 TTS 服务,从包含两个交替声音的脚本生成播客音频文件。
llm: This tool converts a prepared podcast script into an audio file using OpenAI's Text-to-Speech service, with two specified voices for alternating hosts.
parameters:
- name: script
type: string
required: true
label:
en_US: Podcast Script
zh_Hans: 播客脚本
human_description:
en_US: A string containing alternating lines for two hosts, separated by newline characters.
zh_Hans: 包含两位主持人交替台词的字符串,每行用换行符分隔。
llm_description: A string representing the script, with alternating lines for two hosts separated by newline characters.
form: llm
- name: host1_voice
type: select
required: true
label:
en_US: Host 1 Voice
zh_Hans: 主持人1 音色
human_description:
en_US: The voice for the first host.
zh_Hans: 第一位主持人的音色。
llm_description: The voice identifier for the first host's voice.
options:
- label:
en_US: Alloy
zh_Hans: Alloy
value: alloy
- label:
en_US: Echo
zh_Hans: Echo
value: echo
- label:
en_US: Fable
zh_Hans: Fable
value: fable
- label:
en_US: Onyx
zh_Hans: Onyx
value: onyx
- label:
en_US: Nova
zh_Hans: Nova
value: nova
- label:
en_US: Shimmer
zh_Hans: Shimmer
value: shimmer
form: form
- name: host2_voice
type: select
required: true
label:
en_US: Host 2 Voice
zh_Hans: 主持人2 音色
human_description:
en_US: The voice for the second host.
zh_Hans: 第二位主持人的音色。
llm_description: The voice identifier for the second host's voice.
options:
- label:
en_US: Alloy
zh_Hans: Alloy
value: alloy
- label:
en_US: Echo
zh_Hans: Echo
value: echo
- label:
en_US: Fable
zh_Hans: Fable
value: fable
- label:
en_US: Onyx
zh_Hans: Onyx
value: onyx
- label:
en_US: Nova
zh_Hans: Nova
value: nova
- label:
en_US: Shimmer
zh_Hans: Shimmer
value: shimmer
form: form

View file

@ -13,7 +13,6 @@ from core.tools.errors import (
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from core.tools.utils.yaml_utils import load_yaml_file
@ -208,9 +207,7 @@ class BuiltinToolProviderController(ToolProviderController):
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
default_value = ToolParameterConverter.cast_parameter_by_type(
parameter_schema.default, parameter_schema.type
)
default_value = parameter_schema.type.cast_value(parameter_schema.default)
tool_parameters[parameter] = default_value
def validate_credentials(self, credentials: dict[str, Any]) -> None:

View file

@ -11,7 +11,6 @@ from core.tools.entities.tool_entities import (
)
from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
from core.tools.tool.tool import Tool
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
class ToolProviderController(BaseModel, ABC):
@ -127,9 +126,7 @@ class ToolProviderController(BaseModel, ABC):
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(
parameter_schema.default, parameter_schema.type
)
tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default)
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
"""

View file

@ -1,6 +1,6 @@
from typing import Optional
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.app_config.entities import VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
@ -23,6 +23,8 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
}
@ -36,8 +38,8 @@ class WorkflowToolProviderController(ToolProviderController):
if not app:
raise ValueError("app not found")
controller = WorkflowToolProviderController(
**{
controller = WorkflowToolProviderController.model_validate(
{
"identity": {
"author": db_provider.user.name if db_provider.user_id and db_provider.user else "",
"name": db_provider.label,
@ -67,7 +69,7 @@ class WorkflowToolProviderController(ToolProviderController):
:param app: the app
:return: the tool
"""
workflow: Workflow = (
workflow = (
db.session.query(Workflow)
.filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.first()
@ -76,14 +78,14 @@ class WorkflowToolProviderController(ToolProviderController):
raise ValueError("workflow not found")
# fetch start node
graph: dict = workflow.graph_dict
features_dict: dict = workflow.features_dict
graph = workflow.graph_dict
features_dict = workflow.features_dict
features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW)
parameters = db_provider.parameter_configurations
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
def fetch_workflow_variable(variable_name: str) -> VariableEntity:
def fetch_workflow_variable(variable_name: str):
return next(filter(lambda x: x.variable == variable_name, variables), None)
user = db_provider.user
@ -114,7 +116,6 @@ class WorkflowToolProviderController(ToolProviderController):
llm_description=parameter.description,
required=variable.required,
options=options,
default=variable.default,
)
)
elif features.file_upload:
@ -123,7 +124,7 @@ class WorkflowToolProviderController(ToolProviderController):
name=parameter.name,
label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name),
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
type=ToolParameter.ToolParameterType.FILE,
type=ToolParameter.ToolParameterType.SYSTEM_FILES,
llm_description=parameter.description,
required=False,
form=parameter.form,

View file

@ -20,10 +20,9 @@ from core.tools.entities.tool_entities import (
ToolRuntimeVariablePool,
)
from core.tools.tool_file_manager import ToolFileManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
if TYPE_CHECKING:
from core.file.file_obj import FileVar
from core.file.models import File
class Tool(BaseModel, ABC):
@ -63,8 +62,12 @@ class Tool(BaseModel, ABC):
def __init__(self, **data: Any):
super().__init__(**data)
class VariableKey(Enum):
class VariableKey(str, Enum):
IMAGE = "image"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
CUSTOM = "custom"
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
"""
@ -221,9 +224,7 @@ class Tool(BaseModel, ABC):
result = deepcopy(tool_parameters)
for parameter in self.parameters or []:
if parameter.name in tool_parameters:
result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(
tool_parameters[parameter.name], parameter.type
)
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
return result
@ -295,10 +296,8 @@ class Tool(BaseModel, ABC):
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as)
def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.FILE_VAR, message="", meta={"file_var": file_var}, save_as=""
)
def create_file_message(self, file: "File") -> ToolInvokeMessage:
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE, message="", meta={"file": file}, save_as="")
def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage:
"""

View file

@ -3,7 +3,7 @@ import logging
from copy import deepcopy
from typing import Any, Optional, Union
from core.file.file_obj import FileTransferMethod, FileVar
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
from core.tools.tool.tool import Tool
from extensions.ext_database import db
@ -45,11 +45,13 @@ class WorkflowTool(Tool):
workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
# transform the tool parameters
tool_parameters, files = self._transform_args(tool_parameters)
tool_parameters, files = self._transform_args(tool_parameters=tool_parameters)
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
generator = WorkflowAppGenerator()
assert self.runtime is not None
assert self.runtime.invoke_from is not None
result = generator.generate(
app_model=app,
workflow=workflow,
@ -74,7 +76,7 @@ class WorkflowTool(Tool):
else:
outputs, files = self._extract_files(outputs)
for file in files:
result.append(self.create_file_var_message(file))
result.append(self.create_file_message(file))
result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
result.append(self.create_json_message(outputs))
@ -154,22 +156,22 @@ class WorkflowTool(Tool):
parameters_result = {}
files = []
for parameter in parameter_rules:
if parameter.type == ToolParameter.ToolParameterType.FILE:
if parameter.type == ToolParameter.ToolParameterType.SYSTEM_FILES:
file = tool_parameters.get(parameter.name)
if file:
try:
file_var_list = [FileVar(**f) for f in file]
for file_var in file_var_list:
file_dict = {
"transfer_method": file_var.transfer_method.value,
"type": file_var.type.value,
file_var_list = [File.model_validate(f) for f in file]
for file in file_var_list:
file_dict: dict[str, str | None] = {
"transfer_method": file.transfer_method.value,
"type": file.type.value,
}
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file_var.related_id
elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE:
file_dict["upload_file_id"] = file_var.related_id
elif file_var.transfer_method == FileTransferMethod.REMOTE_URL:
file_dict["url"] = file_var.preview_url
if file.transfer_method == FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file.related_id
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
file_dict["upload_file_id"] = file.related_id
elif file.transfer_method == FileTransferMethod.REMOTE_URL:
file_dict["url"] = file.generate_url()
files.append(file_dict)
except Exception as e:
@ -179,7 +181,7 @@ class WorkflowTool(Tool):
return parameters_result, files
def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]:
def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]:
"""
extract files from the result
@ -190,17 +192,13 @@ class WorkflowTool(Tool):
result = {}
for key, value in outputs.items():
if isinstance(value, list):
has_file = False
for item in value:
if isinstance(item, dict) and item.get("__variant") == "FileVar":
try:
files.append(FileVar(**item))
has_file = True
except Exception as e:
pass
if has_file:
continue
if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
file = File.model_validate(item)
files.append(file)
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
file = File.model_validate(value)
files.append(file)
result[key] = value
return result, files

View file

@ -10,7 +10,8 @@ from yarl import URL
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod
from core.file import FileType
from core.file.models import FileTransferMethod
from core.ops.ops_trace_manager import TraceQueueManager
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
from core.tools.errors import (
@ -26,6 +27,7 @@ from core.tools.tool.tool import Tool
from core.tools.tool.workflow_tool import WorkflowTool
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from extensions.ext_database import db
from models.enums import CreatedByRole
from models.model import Message, MessageFile
@ -128,6 +130,7 @@ class ToolEngine:
"""
try:
# hit the callback handler
assert tool.identity is not None
workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters)
if isinstance(tool, WorkflowTool):
@ -258,7 +261,10 @@ class ToolEngine:
@staticmethod
def _create_message_files(
tool_messages: list[ToolInvokeMessageBinary], agent_message: Message, invoke_from: InvokeFrom, user_id: str
tool_messages: list[ToolInvokeMessageBinary],
agent_message: Message,
invoke_from: InvokeFrom,
user_id: str,
) -> list[tuple[Any, str]]:
"""
Create message file
@ -269,29 +275,31 @@ class ToolEngine:
result = []
for message in tool_messages:
file_type = "bin"
if "image" in message.mimetype:
file_type = "image"
file_type = FileType.IMAGE
elif "video" in message.mimetype:
file_type = "video"
file_type = FileType.VIDEO
elif "audio" in message.mimetype:
file_type = "audio"
elif "text" in message.mimetype:
file_type = "text"
elif "pdf" in message.mimetype:
file_type = "pdf"
elif "zip" in message.mimetype:
file_type = "archive"
# ...
file_type = FileType.AUDIO
elif "text" in message.mimetype or "pdf" in message.mimetype:
file_type = FileType.DOCUMENT
else:
file_type = FileType.CUSTOM
# extract tool file id from url
tool_file_id = message.url.split("/")[-1].split(".")[0]
message_file = MessageFile(
message_id=agent_message.id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE.value,
transfer_method=FileTransferMethod.TOOL_FILE,
belongs_to="assistant",
url=message.url,
upload_file_id=None,
created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"),
upload_file_id=tool_file_id,
created_by_role=(
CreatedByRole.ACCOUNT
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatedByRole.END_USER
),
created_by=user_id,
)

View file

@ -4,7 +4,6 @@ import hmac
import logging
import os
import time
from collections.abc import Generator
from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from uuid import uuid4
@ -57,22 +56,32 @@ class ToolFileManager:
@staticmethod
def create_file_by_raw(
user_id: str, tenant_id: str, conversation_id: Optional[str], file_binary: bytes, mimetype: str
*,
user_id: str,
tenant_id: str,
conversation_id: Optional[str],
file_binary: bytes,
mimetype: str,
) -> ToolFile:
"""
create file
"""
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
filename = f"tools/{tenant_id}/{unique_name}{extension}"
storage.save(filename, file_binary)
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, file_binary)
tool_file = ToolFile(
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=filename, mimetype=mimetype
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_key=filepath,
mimetype=mimetype,
name=filename,
size=len(file_binary),
)
db.session.add(tool_file)
db.session.commit()
db.session.refresh(tool_file)
return tool_file
@ -80,29 +89,34 @@ class ToolFileManager:
def create_file_by_url(
user_id: str,
tenant_id: str,
conversation_id: str,
conversation_id: str | None,
file_url: str,
) -> ToolFile:
"""
create file
"""
# try to download image
response = get(file_url)
response.raise_for_status()
blob = response.content
try:
response = get(file_url)
response.raise_for_status()
blob = response.content
except Exception as e:
logger.error(f"Failed to download file from {file_url}: {e}")
raise
mimetype = guess_type(file_url)[0] or "octet/stream"
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
filename = f"tools/{tenant_id}/{unique_name}{extension}"
storage.save(filename, blob)
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, blob)
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_key=filename,
file_key=filepath,
mimetype=mimetype,
original_url=file_url,
name=filename,
size=len(blob),
)
db.session.add(tool_file)
@ -110,18 +124,6 @@ class ToolFileManager:
return tool_file
@staticmethod
def create_file_by_key(
user_id: str, tenant_id: str, conversation_id: str, file_key: str, mimetype: str
) -> ToolFile:
"""
create file
"""
tool_file = ToolFile(
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=file_key, mimetype=mimetype
)
return tool_file
@staticmethod
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
"""
@ -131,7 +133,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
tool_file: ToolFile = (
tool_file = (
db.session.query(ToolFile)
.filter(
ToolFile.id == id,
@ -155,7 +157,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
message_file: MessageFile = (
message_file = (
db.session.query(MessageFile)
.filter(
MessageFile.id == id,
@ -166,13 +168,16 @@ class ToolFileManager:
# Check if message_file is not None
if message_file is not None:
# get tool file id
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split(".")[0]
if message_file.url is not None:
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split(".")[0]
else:
tool_file_id = None
else:
tool_file_id = None
tool_file: ToolFile = (
tool_file = (
db.session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
@ -188,7 +193,7 @@ class ToolFileManager:
return blob, tool_file.mimetype
@staticmethod
def get_file_generator_by_tool_file_id(tool_file_id: str) -> Union[tuple[Generator, str], None]:
def get_file_generator_by_tool_file_id(tool_file_id: str):
"""
get file binary
@ -196,7 +201,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
tool_file: ToolFile = (
tool_file = (
db.session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
@ -205,11 +210,11 @@ class ToolFileManager:
)
if not tool_file:
return None
return None, None
generator = storage.load_stream(tool_file.file_key)
stream = storage.load_stream(tool_file.file_key)
return generator, tool_file.mimetype
return stream, tool_file
# init tool_file_parser

View file

@ -24,7 +24,6 @@ from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
@ -203,7 +202,7 @@ class ToolManager:
raise ToolProviderNotFoundError(f"provider type {provider_type} not found")
@classmethod
def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict):
"""
init runtime parameter
"""
@ -222,7 +221,7 @@ class ToolManager:
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
)
return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type)
return parameter_rule.type.cast_value(parameter_value)
@classmethod
def get_agent_tool_runtime(
@ -243,7 +242,11 @@ class ToolManager:
parameters = tool_entity.get_all_runtime_parameters()
for parameter in parameters:
# check file types
if parameter.type == ToolParameter.ToolParameterType.FILE:
if parameter.type in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}:
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
if parameter.form == ToolParameter.ToolParameterForm.FORM:

View file

@ -1,7 +1,8 @@
import logging
from mimetypes import guess_extension
from typing import Optional
from core.file.file_obj import FileTransferMethod, FileType
from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
@ -11,7 +12,7 @@ logger = logging.getLogger(__name__)
class ToolFileMessageTransformer:
@classmethod
def transform_tool_invoke_messages(
cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str
cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str | None
) -> list[ToolInvokeMessage]:
"""
Transform tool message and handle file download
@ -21,7 +22,7 @@ class ToolFileMessageTransformer:
for message in messages:
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
result.append(message)
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(message.message, str):
# try to download image
try:
file = ToolFileManager.create_file_by_url(
@ -50,11 +51,14 @@ class ToolFileMessageTransformer:
)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage
assert message.meta is not None
mimetype = message.meta.get("mime_type", "octet/stream")
# if message is str, encode it to bytes
if isinstance(message.message, str):
message.message = message.message.encode("utf-8")
# FIXME: should do a type check here.
assert isinstance(message.message, bytes)
file = ToolFileManager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
@ -63,7 +67,7 @@ class ToolFileMessageTransformer:
mimetype=mimetype,
)
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype))
# check if file is image
if "image" in mimetype:
@ -84,12 +88,14 @@ class ToolFileMessageTransformer:
meta=message.meta.copy() if message.meta is not None else {},
)
)
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
file_var = message.meta.get("file_var")
if file_var:
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
if file_var.type == FileType.IMAGE:
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
file = message.meta.get("file")
if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
if file.type == FileType.IMAGE:
result.append(
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
@ -107,11 +113,13 @@ class ToolFileMessageTransformer:
meta=message.meta.copy() if message.meta is not None else {},
)
)
else:
result.append(message)
else:
result.append(message)
return result
@classmethod
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
return f'/files/tools/{tool_file_id}{extension or ".bin"}'

View file

@ -1,71 +0,0 @@
from typing import Any
from core.tools.entities.tool_entities import ToolParameter
class ToolParameterConverter:
@staticmethod
def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str:
match parameter_type:
case (
ToolParameter.ToolParameterType.STRING
| ToolParameter.ToolParameterType.SECRET_INPUT
| ToolParameter.ToolParameterType.SELECT
):
return "string"
case ToolParameter.ToolParameterType.BOOLEAN:
return "boolean"
case ToolParameter.ToolParameterType.NUMBER:
return "number"
case _:
raise ValueError(f"Unsupported parameter type {parameter_type}")
@staticmethod
def cast_parameter_by_type(value: Any, parameter_type: str) -> Any:
# convert tool parameter config to correct type
try:
match parameter_type:
case (
ToolParameter.ToolParameterType.STRING
| ToolParameter.ToolParameterType.SECRET_INPUT
| ToolParameter.ToolParameterType.SELECT
):
if value is None:
return ""
else:
return value if isinstance(value, str) else str(value)
case ToolParameter.ToolParameterType.BOOLEAN:
if value is None:
return False
elif isinstance(value, str):
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
# and also '0' for False and '1' for True
match value.lower():
case "true" | "yes" | "y" | "1":
return True
case "false" | "no" | "n" | "0":
return False
case _:
return bool(value)
else:
return value if isinstance(value, bool) else bool(value)
case ToolParameter.ToolParameterType.NUMBER:
if isinstance(value, int) | isinstance(value, float):
return value
elif isinstance(value, str) and value != "":
if "." in value:
return float(value)
else:
return int(value)
case ToolParameter.ToolParameterType.FILE:
return value
case _:
return str(value)
except Exception:
raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.")

View file

@ -1,19 +1,18 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[dict]):
"""
check parameter configurations
"""
def check_parameter_configurations(cls, configurations: Mapping[str, Any]):
for configuration in configurations:
if not WorkflowToolParameterConfiguration(**configuration):
raise ValueError("invalid parameter configuration")
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod
def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
"""
get workflow graph variables
"""

View file

@ -1,4 +1,5 @@
import logging
from pathlib import Path
from typing import Any
import yaml
@ -17,15 +18,18 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any
:param default_value: the value returned when errors ignored
:return: an object of the YAML content
"""
try:
with open(file_path, encoding="utf-8") as yaml_file:
try:
yaml_content = yaml.safe_load(yaml_file)
return yaml_content or default_value
except Exception as e:
raise YAMLError(f"Failed to load YAML file {file_path}: {e}")
except Exception as e:
if not file_path or not Path(file_path).exists():
if ignore_error:
return default_value
else:
raise e
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, encoding="utf-8") as yaml_file:
try:
yaml_content = yaml.safe_load(yaml_file)
return yaml_content or default_value
except Exception as e:
if ignore_error:
return default_value
else:
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e

View file

@ -1,7 +1,12 @@
from .segment_group import SegmentGroup
from .segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArraySegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
@ -15,6 +20,7 @@ from .variables import (
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
@ -46,4 +52,10 @@ __all__ = [
"ArrayNumberVariable",
"ArrayObjectVariable",
"ArraySegment",
"ArrayFileSegment",
"ArrayNumberSegment",
"ArrayObjectSegment",
"ArrayStringSegment",
"FileSegment",
"FileVariable",
]

View file

@ -5,6 +5,8 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, field_validator
from core.file import File
from .types import SegmentType
@ -39,6 +41,9 @@ class Segment(BaseModel):
@property
def size(self) -> int:
"""
Return the size of the value in bytes.
"""
return sys.getsizeof(self.value)
def to_object(self) -> Any:
@ -99,13 +104,27 @@ class ArraySegment(Segment):
def markdown(self) -> str:
items = []
for item in self.value:
if hasattr(item, "to_markdown"):
items.append(item.to_markdown())
else:
items.append(str(item))
items.append(str(item))
return "\n".join(items)
class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE
value: File
@property
def markdown(self) -> str:
return self.value.markdown
@property
def log(self) -> str:
return str(self.value)
@property
def text(self) -> str:
return str(self.value)
class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Any]
@ -124,3 +143,15 @@ class ArrayNumberSegment(ArraySegment):
class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]]
class ArrayFileSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_FILE
value: Sequence[File]
@property
def markdown(self) -> str:
items = []
for item in self.value:
items.append(item.markdown)
return "\n".join(items)

View file

@ -11,5 +11,7 @@ class SegmentType(str, Enum):
ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]"
OBJECT = "object"
FILE = "file"
ARRAY_FILE = "array[file]"
GROUP = "group"

View file

@ -7,6 +7,7 @@ from .segments import (
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
@ -73,3 +74,7 @@ class SecretVariable(StringVariable):
class NoneVariable(NoneSegment, Variable):
value_type: SegmentType = SegmentType.NONE
value: None = None
class FileVariable(FileSegment, Variable):
pass

View file

@ -0,0 +1,7 @@
from .base_workflow_callback import WorkflowCallback
from .workflow_logging_callback import WorkflowLoggingCallback
__all__ = [
"WorkflowLoggingCallback",
"WorkflowCallback",
]

View file

@ -1,7 +1,6 @@
from typing import Optional
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
@ -20,6 +19,8 @@ from core.workflow.graph_engine.entities.event import (
ParallelBranchRunSucceededEvent,
)
from .base_workflow_callback import WorkflowCallback
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",

View file

@ -0,0 +1,3 @@
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"

View file

@ -1,52 +1,14 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMUsage
from models import WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionStatus
class NodeType(Enum):
"""
Node Types.
"""
START = "start"
END = "end"
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
IF_ELSE = "if-else"
CODE = "code"
TEMPLATE_TRANSFORM = "template-transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
VARIABLE_AGGREGATOR = "variable-aggregator"
# TODO: merge this into VARIABLE_AGGREGATOR
VARIABLE_ASSIGNER = "variable-assigner"
LOOP = "loop"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # fake start node for iteration
PARAMETER_EXTRACTOR = "parameter-extractor"
CONVERSATION_VARIABLE_ASSIGNER = "assigner"
@classmethod
def value_of(cls, value: str) -> "NodeType":
"""
Get value of given node type.
:param value: node type value
:return: node type
"""
for node_type in cls:
if node_type.value == value:
return node_type
raise ValueError(f"invalid node type value {value}")
class NodeRunMetadataKey(Enum):
class NodeRunMetadataKey(str, Enum):
"""
Node Run Metadata Key.
"""
@ -70,7 +32,7 @@ class NodeRunResult(BaseModel):
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[dict[str, Any]] = None # node inputs
inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[dict[str, Any]] = None # process data
outputs: Optional[dict[str, Any]] = None # node outputs
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
@ -79,24 +41,3 @@ class NodeRunResult(BaseModel):
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed
class UserFrom(Enum):
"""
User from
"""
ACCOUNT = "account"
END_USER = "end-user"
@classmethod
def value_of(cls, value: str) -> "UserFrom":
"""
Value of
:param value: value
:return:
"""
for item in cls:
if item.value == value:
return item
raise ValueError(f"Invalid value: {value}")

View file

@ -1,3 +1,5 @@
from collections.abc import Sequence
from pydantic import BaseModel
@ -7,4 +9,4 @@ class VariableSelector(BaseModel):
"""
variable: str
value_selector: list[str]
value_selector: Sequence[str]

View file

@ -1,20 +1,23 @@
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Union
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field
from typing_extensions import deprecated
from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.enums import SystemVariableKey
from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.segments import FileSegment
from factories import variable_factory
VariableValue = Union[str, int, float, dict, list, FileVar]
from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from ..enums import SystemVariableKey
VariableValue = Union[str, int, float, dict, list, File]
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
class VariablePool(BaseModel):
@ -23,46 +26,63 @@ class VariablePool(BaseModel):
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: dict[str, dict[int, Segment]] = Field(
description="Variables mapping", default=defaultdict(dict)
description="Variables mapping",
default=defaultdict(dict),
)
# TODO: This user inputs is not used for pool.
user_inputs: Mapping[str, Any] = Field(
description="User inputs",
)
system_variables: Mapping[SystemVariableKey, Any] = Field(
description="System variables",
)
environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
default_factory=list,
)
conversation_variables: Sequence[Variable] = Field(
description="Conversation variables.",
default_factory=list,
)
environment_variables: Sequence[Variable] = Field(description="Environment variables.", default_factory=list)
def __init__(
self,
*,
system_variables: Mapping[SystemVariableKey, Any] | None = None,
user_inputs: Mapping[str, Any] | None = None,
environment_variables: Sequence[Variable] | None = None,
conversation_variables: Sequence[Variable] | None = None,
**kwargs,
):
environment_variables = environment_variables or []
conversation_variables = conversation_variables or []
user_inputs = user_inputs or {}
system_variables = system_variables or {}
conversation_variables: Sequence[Variable] | None = None
super().__init__(
system_variables=system_variables,
user_inputs=user_inputs,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
**kwargs,
)
@model_validator(mode="after")
def val_model_after(self):
"""
Append system variables
:return:
"""
# Add system variables to the variable pool
for key, value in self.system_variables.items():
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
# Add environment variables to the variable pool
for var in self.environment_variables or []:
for var in self.environment_variables:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
# Add conversation variables to the variable pool
for var in self.conversation_variables or []:
for var in self.conversation_variables:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
return self
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""
Adds a variable to the variable pool.
NOTE: You should not add a non-Segment value to the variable pool
even if it is allowed now.
Args:
selector (Sequence[str]): The selector for the variable.
value (VariableValue): The value of the variable.
@ -82,7 +102,7 @@ class VariablePool(BaseModel):
if isinstance(value, Segment):
v = value
else:
v = factory.build_segment(value)
v = variable_factory.build_segment(value)
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]][hash_key] = v
@ -101,10 +121,19 @@ class VariablePool(BaseModel):
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
raise ValueError("Invalid selector")
return None
hash_key = hash(tuple(selector[1:]))
value = self.variable_dictionary[selector[0]].get(hash_key)
if value is None:
selector, attr = selector[:-1], selector[-1]
value = self.get(selector)
if isinstance(value, FileSegment):
attr = FileAttribute(attr)
attr_value = file_manager.get_attr(file=value.value, attr=attr)
return variable_factory.build_segment(attr_value)
return value
@deprecated("This method is deprecated, use `get` instead.")
@ -145,14 +174,18 @@ class VariablePool(BaseModel):
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]].pop(hash_key, None)
def remove_node(self, node_id: str, /):
"""
Remove all variables associated with a given node id.
def convert_template(self, template: str, /):
parts = VARIABLE_PATTERN.split(template)
segments = []
for part in filter(lambda x: x, parts):
if "." in part and (variable := self.get(part.split("."))):
segments.append(variable)
else:
segments.append(variable_factory.build_segment(part))
return SegmentGroup(value=segments)
Args:
node_id (str): The node id to remove.
Returns:
None
"""
self.variable_dictionary.pop(node_id, None)
def get_file(self, selector: Sequence[str], /) -> FileSegment | None:
segment = self.get(selector)
if isinstance(segment, FileSegment):
return segment
return None

View file

@ -3,12 +3,13 @@ from typing import Optional
from pydantic import BaseModel
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.base_node_data_entities import BaseIterationState
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode, UserFrom
from core.workflow.nodes.base import BaseIterationState, BaseNode
from models.enums import UserFrom
from models.workflow import Workflow, WorkflowType
from .node_entities import NodeRunResult
from .variable_pool import VariablePool
class WorkflowNodeAndResult:
node: BaseNode

View file

@ -1,4 +1,4 @@
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.base import BaseNode
class WorkflowNodeRunFailedError(Exception):

View file

@ -0,0 +1,3 @@
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]

View file

@ -18,11 +18,10 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler):
# process condition
condition_processor = ConditionProcessor()
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=graph_runtime_state.variable_pool, conditions=self.condition.conditions
_, _, final_result = condition_processor.process_conditions(
variable_pool=graph_runtime_state.variable_pool,
conditions=self.condition.conditions,
operator="and",
)
# Apply the logical operator for the current case
compare_result = all(group_result)
return compare_result
return final_result

View file

@ -0,0 +1,6 @@
from .graph import Graph
from .graph_init_params import GraphInitParams
from .graph_runtime_state import GraphRuntimeState
from .runtime_route_state import RuntimeRouteState
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]

View file

@ -3,9 +3,9 @@ from typing import Any, Optional
from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData
class GraphEngineEvent(BaseModel):

View file

@ -4,8 +4,8 @@ from typing import Any, Optional, cast
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes import NodeType
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter

View file

@ -4,7 +4,7 @@ from typing import Any
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import UserFrom
from models.enums import UserFrom
from models.workflow import WorkflowType

View file

@ -10,11 +10,7 @@ from flask import Flask, current_app
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeType,
UserFrom,
)
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
@ -36,12 +32,14 @@ from core.workflow.graph_engine.entities.graph import Graph, GraphEdge
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import NodeType
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import node_classes
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
logger = logging.getLogger(__name__)
@ -229,10 +227,8 @@ class GraphEngine:
raise GraphRunFailedError(f"Node {node_id} config not found.")
# convert to specific node
node_type = NodeType.value_of(node_config.get("data", {}).get("type"))
node_cls = node_classes.get(node_type)
if not node_cls:
raise GraphRunFailedError(f"Node {node_id} type {node_type} not found.")
node_type = NodeType(node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping[node_type]
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None

View file

@ -0,0 +1,3 @@
from .enums import NodeType
__all__ = ["NodeType"]

View file

@ -0,0 +1,4 @@
from .answer_node import AnswerNode
from .entities import AnswerStreamGenerateRoute
__all__ = ["AnswerStreamGenerateRoute", "AnswerNode"]

View file

@ -1,7 +1,8 @@
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.variables import ArrayFileSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
@ -9,12 +10,13 @@ from core.workflow.nodes.answer.entities import (
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models.workflow import WorkflowNodeExecutionStatus
class AnswerNode(BaseNode):
class AnswerNode(BaseNode[AnswerNodeData]):
_node_data_cls = AnswerNodeData
_node_type: NodeType = NodeType.ANSWER
@ -23,30 +25,35 @@ class AnswerNode(BaseNode):
Run node
:return:
"""
node_data = self.node_data
node_data = cast(AnswerNodeData, node_data)
# generate routes
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data)
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data)
answer = ""
files = []
for part in generate_routes:
if part.type == GenerateRouteChunk.ChunkType.VAR:
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
value = self.graph_runtime_state.variable_pool.get(value_selector)
if value:
answer += value.markdown
variable = self.graph_runtime_state.variable_pool.get(value_selector)
if variable:
if isinstance(variable, FileSegment):
files.append(variable.value)
elif isinstance(variable, ArrayFileSegment):
files.extend(variable.value)
answer += variable.markdown
else:
part = cast(TextGenerateRouteChunk, part)
answer += part.text
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer})
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files})
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: AnswerNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: AnswerNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -55,9 +62,6 @@ class AnswerNode(BaseNode):
:param node_data: node data
:return:
"""
node_data = node_data
node_data = cast(AnswerNodeData, node_data)
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()

Some files were not shown because too many files have changed in this diff Show more