FEAT: NEW WORKFLOW ENGINE (#3160)

Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: nite-knite <nkCoding@gmail.com>
Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
takatost 2024-04-08 18:51:46 +08:00 committed by GitHub
commit 7753ba2d37
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
1161 changed files with 103836 additions and 10327 deletions

View file

View file

View file

@ -0,0 +1,80 @@
from abc import ABC, abstractmethod
from typing import Optional
from core.app.entities.queue_entities import AppQueueEvent
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
class BaseWorkflowCallback(ABC):
@abstractmethod
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
raise NotImplementedError
@abstractmethod
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
raise NotImplementedError
@abstractmethod
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
raise NotImplementedError
@abstractmethod
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
raise NotImplementedError
@abstractmethod
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute succeeded
"""
raise NotImplementedError
@abstractmethod
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
"""
Workflow node execute failed
"""
raise NotImplementedError
@abstractmethod
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
raise NotImplementedError
@abstractmethod
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
raise NotImplementedError

View file

View file

@ -0,0 +1,9 @@
from abc import ABC
from typing import Optional
from pydantic import BaseModel
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None

View file

@ -0,0 +1,85 @@
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from models.workflow import WorkflowNodeExecutionStatus
class NodeType(Enum):
"""
Node Types.
"""
START = 'start'
END = 'end'
ANSWER = 'answer'
LLM = 'llm'
KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval'
IF_ELSE = 'if-else'
CODE = 'code'
TEMPLATE_TRANSFORM = 'template-transform'
QUESTION_CLASSIFIER = 'question-classifier'
HTTP_REQUEST = 'http-request'
TOOL = 'tool'
VARIABLE_ASSIGNER = 'variable-assigner'
@classmethod
def value_of(cls, value: str) -> 'NodeType':
"""
Get value of given node type.
:param value: node type value
:return: node type
"""
for node_type in cls:
if node_type.value == value:
return node_type
raise ValueError(f'invalid node type value {value}')
class SystemVariable(Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION = 'conversation'
@classmethod
def value_of(cls, value: str) -> 'SystemVariable':
"""
Get value of given system variable.
:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')
class NodeRunMetadataKey(Enum):
"""
Node Run Metadata Key.
"""
TOTAL_TOKENS = 'total_tokens'
TOTAL_PRICE = 'total_price'
CURRENCY = 'currency'
TOOL_INFO = 'tool_info'
class NodeRunResult(BaseModel):
"""
Node Run Result.
"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[dict] = None # node inputs
process_data: Optional[dict] = None # process data
outputs: Optional[dict] = None # node outputs
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed

View file

@ -0,0 +1,9 @@
from pydantic import BaseModel
class VariableSelector(BaseModel):
"""
Variable Selector.
"""
variable: str
value_selector: list[str]

View file

@ -0,0 +1,94 @@
from enum import Enum
from typing import Any, Optional, Union
from core.file.file_obj import FileVar
from core.workflow.entities.node_entities import SystemVariable
VariableValue = Union[str, int, float, dict, list, FileVar]
class ValueType(Enum):
"""
Value Type Enum
"""
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILE = "array[file]"
FILE = "file"
class VariablePool:
variables_mapping = {}
user_inputs: dict
system_variables: dict[SystemVariable, Any]
def __init__(self, system_variables: dict[SystemVariable, Any],
user_inputs: dict) -> None:
# system variables
# for example:
# {
# 'query': 'abc',
# 'files': []
# }
self.user_inputs = user_inputs
self.system_variables = system_variables
for system_variable, value in system_variables.items():
self.append_variable('sys', [system_variable.value], value)
def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:
"""
Append variable
:param node_id: node id
:param variable_key_list: variable key list, like: ['result', 'text']
:param value: value
:return:
"""
if node_id not in self.variables_mapping:
self.variables_mapping[node_id] = {}
variable_key_list_hash = hash(tuple(variable_key_list))
self.variables_mapping[node_id][variable_key_list_hash] = value
def get_variable_value(self, variable_selector: list[str],
target_value_type: Optional[ValueType] = None) -> Optional[VariableValue]:
"""
Get variable
:param variable_selector: include node_id and variables
:param target_value_type: target value type
:return:
"""
if len(variable_selector) < 2:
raise ValueError('Invalid value selector')
node_id = variable_selector[0]
if node_id not in self.variables_mapping:
return None
# fetch variable keys, pop node_id
variable_key_list = variable_selector[1:]
variable_key_list_hash = hash(tuple(variable_key_list))
value = self.variables_mapping[node_id].get(variable_key_list_hash)
if target_value_type:
if target_value_type == ValueType.STRING:
return str(value)
elif target_value_type == ValueType.NUMBER:
return int(value)
elif target_value_type == ValueType.OBJECT:
if not isinstance(value, dict):
raise ValueError('Invalid value type: object')
elif target_value_type in [ValueType.ARRAY_STRING,
ValueType.ARRAY_NUMBER,
ValueType.ARRAY_OBJECT,
ValueType.ARRAY_FILE]:
if not isinstance(value, list):
raise ValueError(f'Invalid value type: {target_value_type.value}')
return value

View file

@ -0,0 +1,49 @@
from typing import Optional
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode, UserFrom
from models.workflow import Workflow, WorkflowType
class WorkflowNodeAndResult:
node: BaseNode
result: Optional[NodeRunResult] = None
def __init__(self, node: BaseNode, result: Optional[NodeRunResult] = None):
self.node = node
self.result = result
class WorkflowRunState:
tenant_id: str
app_id: str
workflow_id: str
workflow_type: WorkflowType
user_id: str
user_from: UserFrom
start_at: float
variable_pool: VariablePool
total_tokens: int = 0
workflow_nodes_and_results: list[WorkflowNodeAndResult]
def __init__(self, workflow: Workflow,
start_at: float,
variable_pool: VariablePool,
user_id: str,
user_from: UserFrom):
self.workflow_id = workflow.id
self.tenant_id = workflow.tenant_id
self.app_id = workflow.app_id
self.workflow_type = WorkflowType.value_of(workflow.type)
self.user_id = user_id
self.user_from = user_from
self.start_at = start_at
self.variable_pool = variable_pool
self.total_tokens = 0
self.workflow_nodes_and_results = []

View file

@ -0,0 +1,10 @@
from core.workflow.entities.node_entities import NodeType
class WorkflowNodeRunFailedError(Exception):
def __init__(self, node_id: str, node_type: NodeType, node_title: str, error: str):
self.node_id = node_id
self.node_type = node_type
self.node_title = node_title
self.error = error
super().__init__(f"Node {node_title} run failed: {error}")

View file

View file

@ -0,0 +1,155 @@
import json
from typing import cast
from core.file.file_obj import FileVar
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.base_node import BaseNode
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models.workflow import WorkflowNodeExecutionStatus
class AnswerNode(BaseNode):
_node_data_cls = AnswerNodeData
node_type = NodeType.ANSWER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
# generate routes
generate_routes = self.extract_generate_route_from_node_data(node_data)
answer = ''
for part in generate_routes:
if part.type == "var":
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
value = variable_pool.get_variable_value(
variable_selector=value_selector
)
text = ''
if isinstance(value, str | int | float):
text = str(value)
elif isinstance(value, dict):
# other types
text = json.dumps(value, ensure_ascii=False)
elif isinstance(value, FileVar):
# convert file to markdown
text = value.to_markdown()
elif isinstance(value, list):
for item in value:
if isinstance(item, FileVar):
text += item.to_markdown() + ' '
text = text.strip()
if not text and value:
# other types
text = json.dumps(value, ensure_ascii=False)
answer += text
else:
part = cast(TextGenerateRouteChunk, part)
answer += part.text
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"answer": answer
}
)
@classmethod
def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
"""
Extract generate route selectors
:param config: node config
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
node_data = cast(cls._node_data_cls, node_data)
return cls.extract_generate_route_from_node_data(node_data)
@classmethod
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
"""
Extract generate route from node data
:param node_data: node data object
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector
for variable_selector in variable_selectors
}
variable_keys = list(value_selector_mapping.keys())
# format answer template
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
template_variable_keys = template_parser.variable_keys
# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))
template = node_data.answer
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
generate_routes = []
for part in template.split('Ω'):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(
value_selector=value_selector
))
else:
generate_routes.append(TextGenerateRouteChunk(
text=part
))
return generate_routes
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace('{{', '').replace('}}', '')
return part.startswith('{{') and cleaned_part in variable_keys
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
node_data = node_data
node_data = cast(cls._node_data_cls, node_data)
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
return variable_mapping

View file

@ -0,0 +1,34 @@
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
class AnswerNodeData(BaseNodeData):
"""
Answer Node Data.
"""
answer: str
class GenerateRouteChunk(BaseModel):
"""
Generate Route Chunk.
"""
type: str
class VarGenerateRouteChunk(GenerateRouteChunk):
"""
Var Generate Route Chunk.
"""
type: str = "var"
value_selector: list[str]
class TextGenerateRouteChunk(GenerateRouteChunk):
"""
Text Generate Route Chunk.
"""
type: str = "text"
text: str

View file

@ -0,0 +1,142 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
class UserFrom(Enum):
"""
User from
"""
ACCOUNT = "account"
END_USER = "end-user"
@classmethod
def value_of(cls, value: str) -> "UserFrom":
"""
Value of
:param value: value
:return:
"""
for item in cls:
if item.value == value:
return item
raise ValueError(f"Invalid value: {value}")
class BaseNode(ABC):
_node_data_cls: type[BaseNodeData]
_node_type: NodeType
tenant_id: str
app_id: str
workflow_id: str
user_id: str
user_from: UserFrom
node_id: str
node_data: BaseNodeData
node_run_result: Optional[NodeRunResult] = None
callbacks: list[BaseWorkflowCallback]
def __init__(self, tenant_id: str,
app_id: str,
workflow_id: str,
user_id: str,
user_from: UserFrom,
config: dict,
callbacks: list[BaseWorkflowCallback] = None) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
self.workflow_id = workflow_id
self.user_id = user_id
self.user_from = user_from
self.node_id = config.get("id")
if not self.node_id:
raise ValueError("Node ID is required.")
self.node_data = self._node_data_cls(**config.get("data", {}))
self.callbacks = callbacks or []
@abstractmethod
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
raise NotImplementedError
def run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node entry
:param variable_pool: variable pool
:return:
"""
result = self._run(
variable_pool=variable_pool
)
self.node_run_result = result
return result
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
"""
Publish text chunk
:param text: chunk text
:param value_selector: value selector
:return:
"""
if self.callbacks:
for callback in self.callbacks:
callback.on_node_text_chunk(
node_id=self.node_id,
text=text,
metadata={
"node_type": self.node_type,
"value_selector": value_selector
}
)
@classmethod
def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param config: node config
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping(node_data)
@classmethod
@abstractmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
raise NotImplementedError
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
return {}
@property
def node_type(self) -> NodeType:
"""
Get node type
:return:
"""
return self._node_type

View file

View file

@ -0,0 +1,348 @@
import os
from typing import Optional, Union, cast
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.code.entities import CodeNodeData
from models.workflow import WorkflowNodeExecutionStatus
MAX_NUMBER = int(os.environ.get('CODE_MAX_NUMBER', '9223372036854775807'))
MIN_NUMBER = int(os.environ.get('CODE_MIN_NUMBER', '-9223372036854775808'))
MAX_PRECISION = 20
MAX_DEPTH = 5
MAX_STRING_LENGTH = int(os.environ.get('CODE_MAX_STRING_LENGTH', '80000'))
MAX_STRING_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_STRING_ARRAY_LENGTH', '30'))
MAX_OBJECT_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_OBJECT_ARRAY_LENGTH', '30'))
MAX_NUMBER_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_NUMBER_ARRAY_LENGTH', '1000'))
JAVASCRIPT_DEFAULT_CODE = """function main({arg1, arg2}) {
return {
result: arg1 + arg2
}
}"""
PYTHON_DEFAULT_CODE = """def main(arg1: int, arg2: int) -> dict:
return {
"result": arg1 + arg2,
}"""
class CodeNode(BaseNode):
_node_data_cls = CodeNodeData
node_type = NodeType.CODE
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
if filters and filters.get("code_language") == "javascript":
return {
"type": "code",
"config": {
"variables": [
{
"variable": "arg1",
"value_selector": []
},
{
"variable": "arg2",
"value_selector": []
}
],
"code_language": "javascript",
"code": JAVASCRIPT_DEFAULT_CODE,
"outputs": {
"result": {
"type": "string",
"children": None
}
}
}
}
return {
"type": "code",
"config": {
"variables": [
{
"variable": "arg1",
"value_selector": []
},
{
"variable": "arg2",
"value_selector": []
}
],
"code_language": "python3",
"code": PYTHON_DEFAULT_CODE,
"outputs": {
"result": {
"type": "string",
"children": None
}
}
}
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run code
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data: CodeNodeData = cast(self._node_data_cls, node_data)
# Get code language
code_language = node_data.code_language
code = node_data.code
# Get variables
variables = {}
for variable_selector in node_data.variables:
variable = variable_selector.variable
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
variables[variable] = value
# Run code
try:
result = CodeExecutor.execute_code(
language=code_language,
code=code,
inputs=variables
)
# Transform result
result = self._transform_result(result, node_data.outputs)
except (CodeExecutionException, ValueError) as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e)
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
outputs=result
)
def _check_string(self, value: str, variable: str) -> str:
"""
Check string
:param value: value
:param variable: variable
:return:
"""
if not isinstance(value, str):
raise ValueError(f"{variable} in output form must be a string")
if len(value) > MAX_STRING_LENGTH:
raise ValueError(f'{variable} in output form must be less than {MAX_STRING_LENGTH} characters')
return value.replace('\x00', '')
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
"""
Check number
:param value: value
:param variable: variable
:return:
"""
if not isinstance(value, int | float):
raise ValueError(f"{variable} in output form must be a number")
if value > MAX_NUMBER or value < MIN_NUMBER:
raise ValueError(f'{variable} in input form is out of range.')
if isinstance(value, float):
# raise error if precision is too high
if len(str(value).split('.')[1]) > MAX_PRECISION:
raise ValueError(f'{variable} in output form has too high precision.')
return value
def _transform_result(self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]],
prefix: str = '',
depth: int = 1) -> dict:
"""
Transform result
:param result: result
:param output_schema: output schema
:return:
"""
if depth > MAX_DEPTH:
raise ValueError("Depth limit reached, object too deep.")
transformed_result = {}
if output_schema is None:
# validate output thought instance type
for output_name, output_value in result.items():
if isinstance(output_value, dict):
self._transform_result(
result=output_value,
output_schema=None,
prefix=f'{prefix}.{output_name}' if prefix else output_name,
depth=depth + 1
)
elif isinstance(output_value, int | float):
self._check_number(
value=output_value,
variable=f'{prefix}.{output_name}' if prefix else output_name
)
elif isinstance(output_value, str):
self._check_string(
value=output_value,
variable=f'{prefix}.{output_name}' if prefix else output_name
)
elif isinstance(output_value, list):
first_element = output_value[0] if len(output_value) > 0 else None
if first_element is not None:
if isinstance(first_element, int | float) and all(isinstance(value, int | float) for value in output_value):
for i, value in enumerate(output_value):
self._check_number(
value=value,
variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]'
)
elif isinstance(first_element, str) and all(isinstance(value, str) for value in output_value):
for i, value in enumerate(output_value):
self._check_string(
value=value,
variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]'
)
elif isinstance(first_element, dict) and all(isinstance(value, dict) for value in output_value):
for i, value in enumerate(output_value):
self._transform_result(
result=value,
output_schema=None,
prefix=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]',
depth=depth + 1
)
else:
raise ValueError(f'Output {prefix}.{output_name} is not a valid array. make sure all elements are of the same type.')
else:
raise ValueError(f'Output {prefix}.{output_name} is not a valid type.')
return result
parameters_validated = {}
for output_name, output_config in output_schema.items():
dot = '.' if prefix else ''
if output_config.type == 'object':
# check if output is object
if not isinstance(result.get(output_name), dict):
raise ValueError(
f'Output {prefix}{dot}{output_name} is not an object, got {type(result.get(output_name))} instead.'
)
transformed_result[output_name] = self._transform_result(
result=result[output_name],
output_schema=output_config.children,
prefix=f'{prefix}.{output_name}',
depth=depth + 1
)
elif output_config.type == 'number':
# check if number available
transformed_result[output_name] = self._check_number(
value=result[output_name],
variable=f'{prefix}{dot}{output_name}'
)
elif output_config.type == 'string':
# check if string available
transformed_result[output_name] = self._check_string(
value=result[output_name],
variable=f'{prefix}{dot}{output_name}',
)
elif output_config.type == 'array[number]':
# check if array of number available
if not isinstance(result[output_name], list):
raise ValueError(
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
)
if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH:
raise ValueError(
f'{prefix}{dot}{output_name} in output form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters.'
)
transformed_result[output_name] = [
self._check_number(
value=value,
variable=f'{prefix}{dot}{output_name}[{i}]'
)
for i, value in enumerate(result[output_name])
]
elif output_config.type == 'array[string]':
# check if array of string available
if not isinstance(result[output_name], list):
raise ValueError(
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
)
if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH:
raise ValueError(
f'{prefix}{dot}{output_name} in output form must be less than {MAX_STRING_ARRAY_LENGTH} characters.'
)
transformed_result[output_name] = [
self._check_string(
value=value,
variable=f'{prefix}{dot}{output_name}[{i}]'
)
for i, value in enumerate(result[output_name])
]
elif output_config.type == 'array[object]':
# check if array of object available
if not isinstance(result[output_name], list):
raise ValueError(
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
)
if len(result[output_name]) > MAX_OBJECT_ARRAY_LENGTH:
raise ValueError(
f'{prefix}{dot}{output_name} in output form must be less than {MAX_OBJECT_ARRAY_LENGTH} characters.'
)
for i, value in enumerate(result[output_name]):
if not isinstance(value, dict):
raise ValueError(
f'Output {prefix}{dot}{output_name}[{i}] is not an object, got {type(value)} instead at index {i}.'
)
transformed_result[output_name] = [
self._transform_result(
result=value,
output_schema=output_config.children,
prefix=f'{prefix}{dot}{output_name}[{i}]',
depth=depth + 1
)
for i, value in enumerate(result[output_name])
]
else:
raise ValueError(f'Output type {output_config.type} is not supported.')
parameters_validated[output_name] = True
# check if all output parameters are validated
if len(parameters_validated) != len(result):
raise ValueError('Not all output parameters are validated.')
return transformed_result
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {
variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
}

View file

@ -0,0 +1,20 @@
from typing import Literal, Optional
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class CodeNodeData(BaseNodeData):
"""
Code Node Data.
"""
class Output(BaseModel):
type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]']
children: Optional[dict[str, 'Output']]
variables: list[VariableSelector]
code_language: Literal['python3', 'javascript']
code: str
outputs: dict[str, Output]

View file

View file

@ -0,0 +1,46 @@
from typing import cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.entities import EndNodeData
from models.workflow import WorkflowNodeExecutionStatus
class EndNode(BaseNode):
_node_data_cls = EndNodeData
node_type = NodeType.END
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
output_variables = node_data.outputs
outputs = {}
for variable_selector in output_variables:
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
outputs[variable_selector.variable] = value
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=outputs,
outputs=outputs
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {}

View file

@ -0,0 +1,9 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class EndNodeData(BaseNodeData):
"""
END Node Data.
"""
outputs: list[VariableSelector]

View file

@ -0,0 +1,43 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, validator
from core.workflow.entities.base_node_data_entities import BaseNodeData
class HttpRequestNodeData(BaseNodeData):
"""
Code Node Data.
"""
class Authorization(BaseModel):
class Config(BaseModel):
type: Literal[None, 'basic', 'bearer', 'custom']
api_key: Union[None, str]
header: Union[None, str]
type: Literal['no-auth', 'api-key']
config: Optional[Config]
@validator('config', always=True, pre=True)
def check_config(cls, v, values):
"""
Check config, if type is no-auth, config should be None, otherwise it should be a dict.
"""
if values['type'] == 'no-auth':
return None
else:
if not v or not isinstance(v, dict):
raise ValueError('config should be a dict')
return v
class Body(BaseModel):
type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json']
data: Union[None, str]
method: Literal['get', 'post', 'put', 'patch', 'delete', 'head']
url: str
authorization: Authorization
headers: str
params: str
body: Optional[Body]

View file

@ -0,0 +1,402 @@
import json
from copy import deepcopy
from random import randint
from typing import Any, Optional, Union
from urllib.parse import urlencode
import httpx
import requests
import core.helper.ssrf_proxy as ssrf_proxy
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.http_request.entities import HttpRequestNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
HTTP_REQUEST_DEFAULT_TIMEOUT = (10, 60)
MAX_BINARY_SIZE = 1024 * 1024 * 10 # 10MB
READABLE_MAX_BINARY_SIZE = '10MB'
MAX_TEXT_SIZE = 1024 * 1024 // 10 # 0.1MB
READABLE_MAX_TEXT_SIZE = '0.1MB'
class HttpExecutorResponse:
headers: dict[str, str]
response: Union[httpx.Response, requests.Response]
def __init__(self, response: Union[httpx.Response, requests.Response] = None):
"""
init
"""
headers = {}
if isinstance(response, httpx.Response):
for k, v in response.headers.items():
headers[k] = v
elif isinstance(response, requests.Response):
for k, v in response.headers.items():
headers[k] = v
self.headers = headers
self.response = response
@property
def is_file(self) -> bool:
"""
check if response is file
"""
content_type = self.get_content_type()
file_content_types = ['image', 'audio', 'video']
for v in file_content_types:
if v in content_type:
return True
return False
def get_content_type(self) -> str:
"""
get content type
"""
for key, val in self.headers.items():
if key.lower() == 'content-type':
return val
return ''
def extract_file(self) -> tuple[str, bytes]:
"""
extract file from response if content type is file related
"""
if self.is_file:
return self.get_content_type(), self.body
return '', b''
@property
def content(self) -> str:
"""
get content
"""
if isinstance(self.response, httpx.Response):
return self.response.text
elif isinstance(self.response, requests.Response):
return self.response.text
else:
raise ValueError(f'Invalid response type {type(self.response)}')
@property
def body(self) -> bytes:
"""
get body
"""
if isinstance(self.response, httpx.Response):
return self.response.content
elif isinstance(self.response, requests.Response):
return self.response.content
else:
raise ValueError(f'Invalid response type {type(self.response)}')
@property
def status_code(self) -> int:
"""
get status code
"""
if isinstance(self.response, httpx.Response):
return self.response.status_code
elif isinstance(self.response, requests.Response):
return self.response.status_code
else:
raise ValueError(f'Invalid response type {type(self.response)}')
@property
def size(self) -> int:
"""
get size
"""
return len(self.body)
@property
def readable_size(self) -> str:
"""
get readable size
"""
if self.size < 1024:
return f'{self.size} bytes'
elif self.size < 1024 * 1024:
return f'{(self.size / 1024):.2f} KB'
else:
return f'{(self.size / 1024 / 1024):.2f} MB'
class HttpExecutor:
server_url: str
method: str
authorization: HttpRequestNodeData.Authorization
params: dict[str, Any]
headers: dict[str, Any]
body: Union[None, str]
files: Union[None, dict[str, Any]]
boundary: str
variable_selectors: list[VariableSelector]
def __init__(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None):
"""
init
"""
self.server_url = node_data.url
self.method = node_data.method
self.authorization = node_data.authorization
self.params = {}
self.headers = {}
self.body = None
self.files = None
# init template
self.variable_selectors = []
self._init_template(node_data, variable_pool)
def _is_json_body(self, body: HttpRequestNodeData.Body):
"""
check if body is json
"""
if body and body.type == 'json':
try:
json.loads(body.data)
return True
except:
return False
return False
def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None):
"""
init template
"""
variable_selectors = []
# extract all template in url
self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool)
# extract all template in params
params, params_variable_selectors = self._format_template(node_data.params, variable_pool)
# fill in params
kv_paris = params.split('\n')
for kv in kv_paris:
if not kv.strip():
continue
kv = kv.split(':')
if len(kv) == 2:
k, v = kv
elif len(kv) == 1:
k, v = kv[0], ''
else:
raise ValueError(f'Invalid params {kv}')
self.params[k.strip()] = v
# extract all template in headers
headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool)
# fill in headers
kv_paris = headers.split('\n')
for kv in kv_paris:
if not kv.strip():
continue
kv = kv.split(':')
if len(kv) == 2:
k, v = kv
elif len(kv) == 1:
k, v = kv[0], ''
else:
raise ValueError(f'Invalid headers {kv}')
self.headers[k.strip()] = v.strip()
# extract all template in body
body_data_variable_selectors = []
if node_data.body:
# check if it's a valid JSON
is_valid_json = self._is_json_body(node_data.body)
body_data = node_data.body.data or ''
if body_data:
body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json)
if node_data.body.type == 'json':
self.headers['Content-Type'] = 'application/json'
elif node_data.body.type == 'x-www-form-urlencoded':
self.headers['Content-Type'] = 'application/x-www-form-urlencoded'
if node_data.body.type in ['form-data', 'x-www-form-urlencoded']:
body = {}
kv_paris = body_data.split('\n')
for kv in kv_paris:
if not kv.strip():
continue
kv = kv.split(':')
if len(kv) == 2:
body[kv[0].strip()] = kv[1]
elif len(kv) == 1:
body[kv[0].strip()] = ''
else:
raise ValueError(f'Invalid body {kv}')
if node_data.body.type == 'form-data':
self.files = {
k: ('', v) for k, v in body.items()
}
random_str = lambda n: ''.join([chr(randint(97, 122)) for _ in range(n)])
self.boundary = f'----WebKitFormBoundary{random_str(16)}'
self.headers['Content-Type'] = f'multipart/form-data; boundary={self.boundary}'
else:
self.body = urlencode(body)
elif node_data.body.type in ['json', 'raw-text']:
self.body = body_data
elif node_data.body.type == 'none':
self.body = ''
self.variable_selectors = (server_url_variable_selectors + params_variable_selectors
+ headers_variable_selectors + body_data_variable_selectors)
def _assembling_headers(self) -> dict[str, Any]:
authorization = deepcopy(self.authorization)
headers = deepcopy(self.headers) or {}
if self.authorization.type == 'api-key':
if self.authorization.config.api_key is None:
raise ValueError('api_key is required')
if not self.authorization.config.header:
authorization.config.header = 'Authorization'
if self.authorization.config.type == 'bearer':
headers[authorization.config.header] = f'Bearer {authorization.config.api_key}'
elif self.authorization.config.type == 'basic':
headers[authorization.config.header] = f'Basic {authorization.config.api_key}'
elif self.authorization.config.type == 'custom':
headers[authorization.config.header] = authorization.config.api_key
return headers
def _validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> HttpExecutorResponse:
"""
validate the response
"""
if isinstance(response, httpx.Response | requests.Response):
executor_response = HttpExecutorResponse(response)
else:
raise ValueError(f'Invalid response type {type(response)}')
if executor_response.is_file:
if executor_response.size > MAX_BINARY_SIZE:
raise ValueError(f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.')
else:
if executor_response.size > MAX_TEXT_SIZE:
raise ValueError(f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.')
return executor_response
def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response:
"""
do http request depending on api bundle
"""
# do http request
kwargs = {
'url': self.server_url,
'headers': headers,
'params': self.params,
'timeout': HTTP_REQUEST_DEFAULT_TIMEOUT,
'follow_redirects': True
}
if self.method == 'get':
response = ssrf_proxy.get(**kwargs)
elif self.method == 'post':
response = ssrf_proxy.post(data=self.body, files=self.files, **kwargs)
elif self.method == 'put':
response = ssrf_proxy.put(data=self.body, files=self.files, **kwargs)
elif self.method == 'delete':
response = ssrf_proxy.delete(data=self.body, files=self.files, **kwargs)
elif self.method == 'patch':
response = ssrf_proxy.patch(data=self.body, files=self.files, **kwargs)
elif self.method == 'head':
response = ssrf_proxy.head(**kwargs)
elif self.method == 'options':
response = ssrf_proxy.options(**kwargs)
else:
raise ValueError(f'Invalid http method {self.method}')
return response
def invoke(self) -> HttpExecutorResponse:
"""
invoke http request
"""
# assemble headers
headers = self._assembling_headers()
# do http request
response = self._do_http_request(headers)
# validate response
return self._validate_and_parse_response(response)
def to_raw_request(self) -> str:
"""
convert to raw request
"""
server_url = self.server_url
if self.params:
server_url += f'?{urlencode(self.params)}'
raw_request = f'{self.method.upper()} {server_url} HTTP/1.1\n'
headers = self._assembling_headers()
for k, v in headers.items():
raw_request += f'{k}: {v}\n'
raw_request += '\n'
# if files, use multipart/form-data with boundary
if self.files:
boundary = self.boundary
raw_request += f'--{boundary}'
for k, v in self.files.items():
raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n'
raw_request += f'{v[1]}\n'
raw_request += f'--{boundary}'
raw_request += '--'
else:
raw_request += self.body or ''
return raw_request
def _format_template(self, template: str, variable_pool: VariablePool, escape_quotes: bool = False) \
-> tuple[str, list[VariableSelector]]:
"""
format template
"""
variable_template_parser = VariableTemplateParser(template=template)
variable_selectors = variable_template_parser.extract_variable_selectors()
if variable_pool:
variable_value_mapping = {}
for variable_selector in variable_selectors:
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector,
target_value_type=ValueType.STRING
)
if value is None:
raise ValueError(f'Variable {variable_selector.variable} not found')
if escape_quotes:
value = value.replace('"', '\\"')
variable_value_mapping[variable_selector.variable] = value
return variable_template_parser.format(variable_value_mapping), variable_selectors
else:
return template, variable_selectors

View file

@ -0,0 +1,112 @@
import logging
from mimetypes import guess_extension
from os import path
from typing import cast
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.http_request.entities import HttpRequestNodeData
from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse
from models.workflow import WorkflowNodeExecutionStatus
class HttpRequestNode(BaseNode):
_node_data_cls = HttpRequestNodeData
node_type = NodeType.HTTP_REQUEST
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: HttpRequestNodeData = cast(self._node_data_cls, self.node_data)
# init http executor
http_executor = None
try:
http_executor = HttpExecutor(node_data=node_data, variable_pool=variable_pool)
# invoke http executor
response = http_executor.invoke()
except Exception as e:
process_data = {}
if http_executor:
process_data = {
'request': http_executor.to_raw_request(),
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
process_data=process_data
)
files = self.extract_files(http_executor.server_url, response)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'status_code': response.status_code,
'body': response.content if not files else '',
'headers': response.headers,
'files': files,
},
process_data={
'request': http_executor.to_raw_request(),
}
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: HttpRequestNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
try:
http_executor = HttpExecutor(node_data=node_data)
variable_selectors = http_executor.variable_selectors
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
return variable_mapping
except Exception as e:
logging.exception(f"Failed to extract variable selector to variable mapping: {e}")
return {}
def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]:
"""
Extract files from response
"""
files = []
mimetype, file_binary = response.extract_file()
# if not image, return directly
if 'image' not in mimetype:
return files
if mimetype:
# extract filename from url
filename = path.basename(url)
# extract extension if possible
extension = guess_extension(mimetype) or '.bin'
tool_file = ToolFileManager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
file_binary=file_binary,
mimetype=mimetype,
)
files.append(FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file.id,
filename=filename,
extension=extension,
mime_type=mimetype,
))
return files

View file

@ -0,0 +1,26 @@
from typing import Literal, Optional
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
class IfElseNodeData(BaseNodeData):
"""
Answer Node Data.
"""
class Condition(BaseModel):
"""
Condition entity
"""
variable_selector: list[str]
comparison_operator: Literal[
# for string or array
"contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty",
# for number
"=", "", ">", "<", "", "", "null", "not null"
]
value: Optional[str] = None
logical_operator: Literal["and", "or"] = "and"
conditions: list[Condition]

View file

@ -0,0 +1,398 @@
from typing import Optional, cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.if_else.entities import IfElseNodeData
from models.workflow import WorkflowNodeExecutionStatus
class IfElseNode(BaseNode):
_node_data_cls = IfElseNodeData
node_type = NodeType.IF_ELSE
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
node_inputs = {
"conditions": []
}
process_datas = {
"condition_results": []
}
try:
logical_operator = node_data.logical_operator
input_conditions = []
for condition in node_data.conditions:
actual_value = variable_pool.get_variable_value(
variable_selector=condition.variable_selector
)
expected_value = condition.value
input_conditions.append({
"actual_value": actual_value,
"expected_value": expected_value,
"comparison_operator": condition.comparison_operator
})
node_inputs["conditions"] = input_conditions
for input_condition in input_conditions:
actual_value = input_condition["actual_value"]
expected_value = input_condition["expected_value"]
comparison_operator = input_condition["comparison_operator"]
if comparison_operator == "contains":
compare_result = self._assert_contains(actual_value, expected_value)
elif comparison_operator == "not contains":
compare_result = self._assert_not_contains(actual_value, expected_value)
elif comparison_operator == "start with":
compare_result = self._assert_start_with(actual_value, expected_value)
elif comparison_operator == "end with":
compare_result = self._assert_end_with(actual_value, expected_value)
elif comparison_operator == "is":
compare_result = self._assert_is(actual_value, expected_value)
elif comparison_operator == "is not":
compare_result = self._assert_is_not(actual_value, expected_value)
elif comparison_operator == "empty":
compare_result = self._assert_empty(actual_value)
elif comparison_operator == "not empty":
compare_result = self._assert_not_empty(actual_value)
elif comparison_operator == "=":
compare_result = self._assert_equal(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_not_equal(actual_value, expected_value)
elif comparison_operator == ">":
compare_result = self._assert_greater_than(actual_value, expected_value)
elif comparison_operator == "<":
compare_result = self._assert_less_than(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_greater_than_or_equal(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_less_than_or_equal(actual_value, expected_value)
elif comparison_operator == "null":
compare_result = self._assert_null(actual_value)
elif comparison_operator == "not null":
compare_result = self._assert_not_null(actual_value)
else:
continue
process_datas["condition_results"].append({
**input_condition,
"result": compare_result
})
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=node_inputs,
process_data=process_datas,
error=str(e)
)
if logical_operator == "and":
compare_result = False not in [condition["result"] for condition in process_datas["condition_results"]]
else:
compare_result = True in [condition["result"] for condition in process_datas["condition_results"]]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_datas,
edge_source_handle="false" if not compare_result else "true",
outputs={
"result": compare_result
}
)
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value not in actual_value:
return False
return True
def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert not contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return True
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value in actual_value:
return False
return True
def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert start with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.startswith(expected_value):
return False
return True
def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert end with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.endswith(expected_value):
return False
return True
def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value != expected_value:
return False
return True
def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is not
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value == expected_value:
return False
return True
def _assert_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert empty
:param actual_value: actual value
:return:
"""
if not actual_value:
return True
return False
def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert not empty
:param actual_value: actual value
:return:
"""
if actual_value:
return True
return False
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value != expected_value:
return False
return True
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert not equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value == expected_value:
return False
return True
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert greater than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value <= expected_value:
return False
return True
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert less than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value >= expected_value:
return False
return True
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert greater than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value < expected_value:
return False
return True
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert less than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value > expected_value:
return False
return True
def _assert_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert null
:param actual_value: actual value
:return:
"""
if actual_value is None:
return True
return False
def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert not null
:param actual_value: actual value
:return:
"""
if actual_value is not None:
return True
return False
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {}

View file

@ -0,0 +1,51 @@
from typing import Any, Literal, Optional
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""
provider: str
model: str
class MultipleRetrievalConfig(BaseModel):
"""
Multiple Retrieval Config.
"""
top_k: int
score_threshold: Optional[float]
reranking_model: RerankingModelConfig
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
class SingleRetrievalConfig(BaseModel):
"""
Single Retrieval Config.
"""
model: ModelConfig
class KnowledgeRetrievalNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
type: str = 'knowledge-retrieval'
query_variable_selector: list[str]
dataset_ids: list[str]
retrieval_mode: Literal['single', 'multiple']
multiple_retrieval_config: Optional[MultipleRetrievalConfig]
single_retrieval_config: Optional[SingleRetrievalConfig]

View file

@ -0,0 +1,443 @@
import threading
from typing import Any, cast
from flask import Flask, current_app
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.datasource.retrieval_service import RetrievalService
from core.rerank.rerank import RerankRunner
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from core.workflow.nodes.knowledge_retrieval.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
from core.workflow.nodes.knowledge_retrieval.multi_dataset_react_route import ReactMultiDatasetRouter
from extensions.ext_database import db
from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment
from models.workflow import WorkflowNodeExecutionStatus
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}
class KnowledgeRetrievalNode(BaseNode):
_node_data_cls = KnowledgeRetrievalNodeData
node_type = NodeType.KNOWLEDGE_RETRIEVAL
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
# extract variables
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
variables = {
'query': query
}
if not query:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error="Query is required."
)
# retrieve knowledge
try:
results = self._fetch_dataset_retriever(
node_data=node_data, query=query
)
outputs = {
'result': results
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=None,
outputs=outputs
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e)
)
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[
dict[str, Any]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param node_data: node data
:param query: query
"""
tools = []
available_datasets = []
dataset_ids = node_data.dataset_ids
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == dataset_id
).first()
# pass if dataset is not available
if not dataset:
continue
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
continue
available_datasets.append(dataset)
all_documents = []
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
all_documents = self._single_retrieve(available_datasets, node_data, query)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
all_documents = self._multiple_retrieve(available_datasets, node_data, query)
context_list = []
if all_documents:
document_score_list = {}
for item in all_documents:
if 'score' in item.metadata and item.metadata['score']:
document_score_list[item.metadata['doc_id']] = item.metadata['score']
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids)
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(segments,
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
float('inf')))
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
else:
document_context_list.append(segment.content)
for segment in sorted_segments:
dataset = Dataset.query.filter_by(
id=segment.dataset_id
).first()
document = Document.query.filter(Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
resource_number = 1
if dataset and document:
source = {
'metadata': {
'_source': 'knowledge',
'position': resource_number,
'dataset_id': dataset.id,
'dataset_name': dataset.name,
'document_id': document.id,
'document_name': document.name,
'document_data_source_type': document.data_source_type,
'segment_id': segment.id,
'retriever_from': 'workflow',
'score': document_score_list.get(segment.index_node_id, None),
'segment_hit_count': segment.hit_count,
'segment_word_count': segment.word_count,
'segment_position': segment.position,
'segment_index_node_hash': segment.index_node_hash,
},
'title': document.name
}
if segment.answer:
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
else:
source['content'] = segment.content
context_list.append(source)
resource_number += 1
return context_list
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
node_data = node_data
node_data = cast(cls._node_data_cls, node_data)
variable_mapping = {}
variable_mapping['query'] = node_data.query_variable_selector
return variable_mapping
def _single_retrieve(self, available_datasets, node_data, query):
tools = []
for dataset in available_datasets:
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
description = description.replace('\n', '').replace('\r', '')
message_tool = PromptMessageTool(
name=dataset.id,
description=description,
parameters={
"type": "object",
"properties": {},
"required": [],
}
)
tools.append(message_tool)
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data)
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model_config.model,
credentials=model_config.credentials
)
if not model_schema:
return None
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
dataset_id = None
if planning_strategy == PlanningStrategy.REACT_ROUTER:
react_multi_dataset_router = ReactMultiDatasetRouter()
dataset_id = react_multi_dataset_router.invoke(query, tools, node_data, model_config, model_instance,
self.user_id, self.tenant_id)
elif planning_strategy == PlanningStrategy.ROUTER:
function_call_router = FunctionCallMultiDatasetRouter()
dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
if dataset_id:
# get retrieval model config
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
if dataset:
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
# get retrieval method
retrival_method = retrieval_model_config['search_method']
# get reranking model
reranking_model=retrieval_model_config['reranking_model'] \
if retrieval_model_config['reranking_enable'] else None
# get score threshold
score_threshold = .0
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,
query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model)
self._on_query(query, [dataset_id])
if results:
self._on_retrival_end(results)
return results
return []
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[
ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data: node data
:return:
"""
model_name = node_data.single_retrieval_config.model.name
provider_name = node_data.single_retrieval_config.model.provider
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
provider=provider_name,
model=model_name
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name,
model_type=ModelType.LLM
)
if provider_model is None:
raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = node_data.single_retrieval_config.model.completion_params
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = node_data.single_retrieval_config.model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(
model_name,
model_credentials
)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
def _multiple_retrieve(self, available_datasets, node_data, query):
threads = []
all_documents = []
dataset_ids = [dataset.id for dataset in available_datasets]
for dataset in available_datasets:
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset.id,
'query': query,
'top_k': node_data.multiple_retrieval_config.top_k,
'all_documents': all_documents,
})
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
# do rerank for searched documents
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=node_data.multiple_retrieval_config.reranking_model.provider,
model_type=ModelType.RERANK,
model=node_data.multiple_retrieval_config.reranking_model.model
)
rerank_runner = RerankRunner(rerank_model_instance)
all_documents = rerank_runner.run(query, all_documents,
node_data.multiple_retrieval_config.score_threshold,
node_data.multiple_retrieval_config.top_k)
self._on_query(query, dataset_ids)
if all_documents:
self._on_retrival_end(all_documents)
return all_documents
def _on_retrival_end(self, documents: list[Document]) -> None:
"""Handle retrival end."""
for document in documents:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata['doc_id']
)
# if 'dataset_id' in document.metadata:
if 'dataset_id' in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False
)
db.session.commit()
def _on_query(self, query: str, dataset_ids: list[str]) -> None:
"""
Handle query.
"""
if not query:
return
for dataset_id in dataset_ids:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source='app',
source_app_id=self.app_id,
created_by_role=self.user_from.value,
created_by=self.user_id
)
db.session.add(dataset_query)
db.session.commit()
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
return []
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(retrival_method='keyword_search',
dataset_id=dataset.id,
query=query,
top_k=top_k
)
if documents:
all_documents.extend(documents)
else:
if top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=retrieval_model['score_threshold']
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model['reranking_model']
if retrieval_model['reranking_enable'] else None
)
all_documents.extend(documents)

View file

@ -0,0 +1,47 @@
from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
class FunctionCallMultiDatasetRouter:
def invoke(
self,
query: str,
dataset_tools: list[PromptMessageTool],
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
) -> Union[str, None]:
"""Given input, decided what to do.
Returns:
Action specifying what tool to use.
"""
if len(dataset_tools) == 0:
return None
elif len(dataset_tools) == 1:
return dataset_tools[0].name
try:
prompt_messages = [
SystemPromptMessage(content='You are a helpful AI assistant.'),
UserPromptMessage(content=query)
]
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=dataset_tools,
stream=False,
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
if result.message.tool_calls:
# get retrieval model config
return result.message.tool_calls[0].function.name
return None
except Exception as e:
return None

View file

@ -0,0 +1,254 @@
from collections.abc import Generator, Sequence
from typing import Optional, Union
from langchain import PromptTemplate
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from langchain.schema import AgentAction
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage
from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from core.workflow.nodes.llm.llm_node import LLMNode
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{
"action": "Final Answer",
"action_input": "Final response to human"
}}
```"""
class ReactMultiDatasetRouter:
def invoke(
self,
query: str,
dataset_tools: list[PromptMessageTool],
node_data: KnowledgeRetrievalNodeData,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
user_id: str,
tenant_id: str,
) -> Union[str, None]:
"""Given input, decided what to do.
Returns:
Action specifying what tool to use.
"""
if len(dataset_tools) == 0:
return None
elif len(dataset_tools) == 1:
return dataset_tools[0].name
try:
return self._react_invoke(query=query, node_data=node_data, model_config=model_config, model_instance=model_instance,
tools=dataset_tools, user_id=user_id, tenant_id=tenant_id)
except Exception as e:
return None
def _react_invoke(
self,
query: str,
node_data: KnowledgeRetrievalNodeData,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
tools: Sequence[PromptMessageTool],
user_id: str,
tenant_id: str,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
) -> Union[str, None]:
if model_config.mode == "chat":
prompt = self.create_chat_prompt(
query=query,
tools=tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
)
else:
prompt = self.create_completion_prompt(
tools=tools,
prefix=prefix,
format_instructions=format_instructions,
input_variables=None
)
stop = ['Observation:']
# handle invoke result
prompt_transform = AdvancedPromptTransform()
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt,
inputs={},
query='',
files=[],
context='',
memory_config=None,
memory=None,
model_config=model_config
)
result_text, usage = self._invoke_llm(
node_data=node_data,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=user_id,
tenant_id=tenant_id
)
output_parser = StructuredChatOutputParser()
agent_decision = output_parser.parse(result_text)
if isinstance(agent_decision, AgentAction):
return agent_decision.tool
return None
def _invoke_llm(self, node_data: KnowledgeRetrievalNodeData,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: list[str], user_id: str, tenant_id: str) -> tuple[str, LLMUsage]:
"""
Invoke large language model
:param node_data: node data
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
:return:
"""
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data.single_retrieval_config.model.completion_params,
stop=stop,
stream=True,
user=user_id,
)
# handle invoke result
text, usage = self._handle_invoke_result(
invoke_result=invoke_result
)
# deduct quota
LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
return text, usage
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
model = None
prompt_messages = []
full_text = ''
usage = None
for result in invoke_result:
text = result.delta.message.content
full_text += text
if not model:
model = result.model
if not prompt_messages:
prompt_messages = result.prompt_messages
if not usage and result.delta.usage:
usage = result.delta.usage
if not usage:
usage = LLMUsage.empty_usage()
return full_text, usage
def create_chat_prompt(
self,
query: str,
tools: Sequence[PromptMessageTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
) -> list[ChatModelMessage]:
tool_strings = []
for tool in tools:
tool_strings.append(f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}")
formatted_tools = "\n".join(tool_strings)
unique_tool_names = set(tool.name for tool in tools)
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
prompt_messages = []
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=template
)
prompt_messages.append(system_prompt_messages)
user_prompt_message = ChatModelMessage(
role=PromptMessageRole.USER,
text=query
)
prompt_messages.append(user_prompt_message)
return prompt_messages
def create_completion_prompt(
self,
tools: Sequence[PromptMessageTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
Args:
tools: List of tools the agent will have access to, used to format the
prompt.
prefix: String to put before the list of tools.
input_variables: List of input variables the final prompt will expect.
Returns:
A PromptTemplate with the template assembled from the pieces here.
"""
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Question: {input}
Thought: {agent_scratchpad}
"""
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)

View file

View file

@ -0,0 +1,49 @@
from typing import Any, Literal, Optional, Union
from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
class ContextConfig(BaseModel):
"""
Context Config.
"""
enabled: bool
variable_selector: Optional[list[str]] = None
class VisionConfig(BaseModel):
"""
Vision Config.
"""
class Configs(BaseModel):
"""
Configs.
"""
detail: Literal['low', 'high']
enabled: bool
configs: Optional[Configs] = None
class LLMNodeData(BaseNodeData):
"""
LLM Node Data.
"""
model: ModelConfig
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig

View file

@ -0,0 +1,554 @@
from collections.abc import Generator
from typing import Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import Conversation
from models.provider import Provider, ProviderType
from models.workflow import WorkflowNodeExecutionStatus
class LLMNode(BaseNode):
_node_data_cls = LLMNodeData
node_type = NodeType.LLM
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
node_inputs = None
process_data = None
try:
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data, variable_pool)
node_inputs = {}
# fetch files
files: list[FileVar] = self._fetch_files(node_data, variable_pool)
if files:
node_inputs['#files#'] = [file.to_dict() for file in files]
# fetch context value
context = self._fetch_context(node_data, variable_pool)
if context:
node_inputs['#context#'] = context
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages(
node_data=node_data,
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value])
if node_data.memory else None,
inputs=inputs,
files=files,
context=context,
memory=memory,
model_config=model_config
)
process_data = {
'model_mode': model_config.mode,
'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode,
prompt_messages=prompt_messages
)
}
# handle invoke result
result_text, usage = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data
)
outputs = {
'text': result_text,
'usage': jsonable_encoder(usage)
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_data,
outputs=outputs,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
)
def _invoke_llm(self, node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: list[str]) -> tuple[str, LLMUsage]:
"""
Invoke large language model
:param node_data_model: node data model
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
:return:
"""
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data_model.completion_params,
stop=stop,
stream=True,
user=self.user_id,
)
# handle invoke result
text, usage = self._handle_invoke_result(
invoke_result=invoke_result
)
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
model = None
prompt_messages = []
full_text = ''
usage = None
for result in invoke_result:
text = result.delta.message.content
full_text += text
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])
if not model:
model = result.model
if not prompt_messages:
prompt_messages = result.prompt_messages
if not usage and result.delta.usage:
usage = result.delta.usage
if not usage:
usage = LLMUsage.empty_usage()
return full_text, usage
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
"""
Fetch inputs
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
inputs = {}
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
elif isinstance(prompt_template, CompletionModelPromptTemplate):
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
for variable_selector in variable_selectors:
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f'Variable {variable_selector.variable} not found')
inputs[variable_selector.variable] = variable_value
return inputs
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
"""
Fetch files
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
if not node_data.vision.enabled:
return []
files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value])
if not files:
return []
return files
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]:
"""
Fetch context
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
if not node_data.context.enabled:
return None
if not node_data.context.variable_selector:
return None
context_value = variable_pool.get_variable_value(node_data.context.variable_selector)
if context_value:
if isinstance(context_value, str):
return context_value
elif isinstance(context_value, list):
context_str = ''
original_retriever_resource = []
for item in context_value:
if 'content' not in item:
raise ValueError(f'Invalid context structure: {item}')
context_str += item['content'] + '\n'
retriever_resource = self._convert_to_original_retriever_resource(item)
if retriever_resource:
original_retriever_resource.append(retriever_resource)
if self.callbacks:
for callback in self.callbacks:
callback.on_event(
event=QueueRetrieverResourcesEvent(
retriever_resources=original_retriever_resource
)
)
return context_str.strip()
return None
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
"""
Convert to original retriever resource, temp.
:param context_dict: context dict
:return:
"""
if ('metadata' in context_dict and '_source' in context_dict['metadata']
and context_dict['metadata']['_source'] == 'knowledge'):
metadata = context_dict.get('metadata', {})
source = {
'position': metadata.get('position'),
'dataset_id': metadata.get('dataset_id'),
'dataset_name': metadata.get('dataset_name'),
'document_id': metadata.get('document_id'),
'document_name': metadata.get('document_name'),
'data_source_type': metadata.get('document_data_source_type'),
'segment_id': metadata.get('segment_id'),
'retriever_from': metadata.get('retriever_from'),
'score': metadata.get('score'),
'hit_count': metadata.get('segment_hit_count'),
'word_count': metadata.get('segment_word_count'),
'segment_position': metadata.get('segment_position'),
'index_node_hash': metadata.get('segment_index_node_hash'),
'content': context_dict.get('content'),
}
return source
return None
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data_model: node data model
:return:
"""
model_name = node_data_model.name
provider_name = node_data_model.provider
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
provider=provider_name,
model=model_name
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name,
model_type=ModelType.LLM
)
if provider_model is None:
raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = node_data_model.completion_params
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = node_data_model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(
model_name,
model_credentials
)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
def _fetch_memory(self, node_data_memory: Optional[MemoryConfig],
variable_pool: VariablePool,
model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
"""
Fetch memory
:param node_data_memory: node data memory
:param variable_pool: variable pool
:return:
"""
if not node_data_memory:
return None
# get conversation id
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION.value])
if conversation_id is None:
return None
# get conversation
conversation = db.session.query(Conversation).filter(
Conversation.app_id == self.app_id,
Conversation.id == conversation_id
).first()
if not conversation:
return None
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
return memory
def _fetch_prompt_messages(self, node_data: LLMNodeData,
query: Optional[str],
inputs: dict[str, str],
files: list[FileVar],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Fetch prompt messages
:param node_data: node data
:param query: query
:param inputs: inputs
:param files: files
:param context: context
:param memory: memory
:param model_config: model config
:return:
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_messages = prompt_transform.get_prompt(
prompt_template=node_data.prompt_template,
inputs=inputs,
query=query if query else '',
files=files,
context=context,
memory_config=node_data.memory,
memory=memory,
model_config=model_config
)
stop = model_config.stop
return prompt_messages, stop
@classmethod
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
"""
Deduct LLM quota
:param tenant_id: tenant id
:param model_instance: model instance
:param usage: usage
:return:
"""
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = 1
if 'gpt-4' in model_instance.model:
used_quota = 20
else:
used_quota = 1
if used_quota is not None:
db.session.query(Provider).filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_instance.provider,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + used_quota})
db.session.commit()
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
node_data = node_data
node_data = cast(cls._node_data_cls, node_data)
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
else:
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if node_data.context.enabled:
variable_mapping['#context#'] = node_data.context.variable_selector
if node_data.vision.enabled:
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value]
return variable_mapping
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
return {
"type": "llm",
"config": {
"prompt_templates": {
"chat_model": {
"prompts": [
{
"role": "system",
"text": "You are a helpful AI assistant."
}
]
},
"completion_model": {
"conversation_histories_role": {
"user_prefix": "Human",
"assistant_prefix": "Assistant"
},
"prompt": {
"text": "Here is the chat histories between human and assistant, inside "
"<histories></histories> XML tags.\n\n<histories>\n{{"
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:"
},
"stop": ["Human:"]
}
}
}
}

View file

@ -0,0 +1,36 @@
from typing import Any, Optional
from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
class ClassConfig(BaseModel):
"""
Class Config.
"""
id: str
name: str
class QuestionClassifierNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
query_variable_selector: list[str]
type: str = 'question-classifier'
model: ModelConfig
classes: list[ClassConfig]
instruction: Optional[str]
memory: Optional[MemoryConfig]

View file

@ -0,0 +1,253 @@
import json
import logging
from typing import Optional, Union, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.question_classifier.template_prompts import (
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2,
QUESTION_CLASSIFIER_COMPLETION_PROMPT,
QUESTION_CLASSIFIER_SYSTEM_PROMPT,
QUESTION_CLASSIFIER_USER_PROMPT_1,
QUESTION_CLASSIFIER_USER_PROMPT_2,
QUESTION_CLASSIFIER_USER_PROMPT_3,
)
from models.workflow import WorkflowNodeExecutionStatus
class QuestionClassifierNode(LLMNode):
_node_data_cls = QuestionClassifierNodeData
node_type = NodeType.QUESTION_CLASSIFIER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
node_data = cast(QuestionClassifierNodeData, node_data)
# extract variables
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
variables = {
'query': query
}
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt(
node_data=node_data,
context='',
query=query,
memory=memory,
model_config=model_config
)
# handle invoke result
result_text, usage = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop
)
categories = [_class.name for _class in node_data.classes]
try:
result_text_json = json.loads(result_text.strip('```JSON\n'))
categories = result_text_json.get('categories', [])
except Exception:
logging.error(f"Failed to parse result text: {result_text}")
try:
process_data = {
'model_mode': model_config.mode,
'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode,
prompt_messages=prompt_messages
),
'usage': jsonable_encoder(usage),
'topics': categories[0] if categories else ''
}
outputs = {
'class_name': categories[0] if categories else ''
}
classes = node_data.classes
classes_map = {class_.name: class_.id for class_ in classes}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=process_data,
outputs=outputs,
edge_source_handle=classes_map.get(categories[0], None)
)
except ValueError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
node_data = node_data
node_data = cast(cls._node_data_cls, node_data)
variable_mapping = {'query': node_data.query_variable_selector}
return variable_mapping
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
return {
"type": "question-classifier",
"config": {
"instructions": ""
}
}
def _fetch_prompt(self, node_data: QuestionClassifierNodeData,
query: str,
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Fetch prompt
:param node_data: node data
:param query: inputs
:param context: context
:param memory: memory
:param model_config: model config
:return:
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, model_config, context)
prompt_template = self._get_prompt_template(node_data, query, memory, rest_token)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query='',
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config
)
stop = model_config.stop
return prompt_messages, stop
def _calculate_rest_token(self, node_data: QuestionClassifierNodeData, query: str,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str]) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query='',
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config
)
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
curr_message_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
prompt_messages
)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000) \
-> Union[list[ChatModelMessage], CompletionModelPromptTemplate]:
model_mode = ModelMode.value_of(node_data.model.mode)
classes = node_data.classes
class_names = [class_.name for class_ in classes]
class_names_str = ','.join(class_names)
instruction = node_data.instruction if node_data.instruction else ''
input_text = query
memory_str = ''
if memory:
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size)
prompt_messages = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = ChatModelMessage(
role=PromptMessageRole.USER,
text=QUESTION_CLASSIFIER_USER_PROMPT_1
)
prompt_messages.append(user_prompt_message_1)
assistant_prompt_message_1 = ChatModelMessage(
role=PromptMessageRole.ASSISTANT,
text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = ChatModelMessage(
role=PromptMessageRole.USER,
text=QUESTION_CLASSIFIER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = ChatModelMessage(
role=PromptMessageRole.ASSISTANT,
text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2
)
prompt_messages.append(assistant_prompt_message_2)
user_prompt_message_3 = ChatModelMessage(
role=PromptMessageRole.USER,
text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text,
categories=class_names_str,
classification_instructions=instruction)
)
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == ModelMode.COMPLETION:
return CompletionModelPromptTemplate(
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str,
input_text=input_text,
categories=class_names_str,
classification_instructions=instruction)
)
else:
raise ValueError(f"Model mode {model_mode} not support.")

View file

@ -0,0 +1,72 @@
QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
### Job Description',
You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories.
### Task
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output.Additionally, you need to extract the key words from the text that are related to the classification.
### Format
The input text is in the variable text_field.Categories are specified as a comma-separated list in the variable categories or left empty for automatic determination.Classification instructions may be included to improve the classification accuracy.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Memory
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
"""
QUESTION_CLASSIFIER_USER_PROMPT_1 = """
{ "input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],
"categories": ["Customer Service, Satisfaction, Sales, Product"],
"classification_instructions": ["classify the text based on the feedback provided by customer"]}```JSON
"""
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """
{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],
"categories": ["Customer Service"]}```
"""
QUESTION_CLASSIFIER_USER_PROMPT_2 = """
{"input_text": ["bad service, slow to bring the food"],
"categories": ["Food Quality, Experience, Price" ],
"classification_instructions": []}```JSON
"""
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """
{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],
"categories": ["Experience""]}```
"""
QUESTION_CLASSIFIER_USER_PROMPT_3 = """
'{{"input_text": ["{input_text}"],',
'"categories": ["{categories}" ], ',
'"classification_instructions": ["{classification_instructions}"]}}```JSON'
"""
QUESTION_CLASSIFIER_COMPLETION_PROMPT = """
### Job Description
You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories.
### Task
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification.
### Format
The input text is in the variable text_field. Categories are specified as a comma-separated list in the variable categories or left empty for automatic determination. Classification instructions may be included to improve the classification accuracy.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
<example>
User:{{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],"categories": ["Customer Service, Satisfaction, Sales, Product"], "classification_instructions": ["classify the text based on the feedback provided by customer"]}}
Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"categories": ["Customer Service"]}}
User:{{"input_text": ["bad service, slow to bring the food"],"categories": ["Food Quality, Experience, Price" ], "classification_instructions": []}}
Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"categories": ["Customer Service"]}}{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"categories": ["Experience""]}}
</example>
### Memory
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
### User Input
{{"input_text" : ["{input_text}"], "categories" : ["{categories}"],"classification_instruction" : ["{classification_instructions}"]}}
### Assistant Output
"""

View file

@ -0,0 +1,9 @@
from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData
class StartNodeData(BaseNodeData):
"""
Start Node Data
"""
variables: list[VariableEntity] = []

View file

@ -0,0 +1,84 @@
from typing import cast
from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus
class StartNode(BaseNode):
_node_data_cls = StartNodeData
node_type = NodeType.START
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
variables = node_data.variables
# Get cleaned inputs
cleaned_inputs = self._get_cleaned_inputs(variables, variable_pool.user_inputs)
for var in variable_pool.system_variables:
if var == SystemVariable.CONVERSATION:
continue
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=cleaned_inputs,
outputs=cleaned_inputs
)
def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict):
if user_inputs is None:
user_inputs = {}
filtered_inputs = {}
for variable_config in variables:
variable = variable_config.variable
if variable not in user_inputs or not user_inputs[variable]:
if variable_config.required:
raise ValueError(f"Input form variable {variable} is required")
else:
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
continue
value = user_inputs[variable]
if value:
if not isinstance(value, str):
raise ValueError(f"{variable} in input form must be a string")
if variable_config.type == VariableEntity.Type.SELECT:
options = variable_config.options if variable_config.options is not None else []
if value not in options:
raise ValueError(f"{variable} in input form must be one of the following: {options}")
else:
if variable_config.max_length is not None:
max_length = variable_config.max_length
if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
filtered_inputs[variable] = value.replace('\x00', '') if value else None
return filtered_inputs
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {}

View file

@ -0,0 +1,12 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class TemplateTransformNodeData(BaseNodeData):
"""
Code Node Data.
"""
variables: list[VariableSelector]
template: str

View file

@ -0,0 +1,91 @@
import os
from typing import Optional, cast
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from models.workflow import WorkflowNodeExecutionStatus
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000'))
class TemplateTransformNode(BaseNode):
_node_data_cls = TemplateTransformNodeData
_node_type = NodeType.TEMPLATE_TRANSFORM
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
return {
"type": "template-transform",
"config": {
"variables": [
{
"variable": "arg1",
"value_selector": []
}
],
"template": "{{ arg1 }}"
}
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
"""
node_data = self.node_data
node_data: TemplateTransformNodeData = cast(self._node_data_cls, node_data)
# Get variables
variables = {}
for variable_selector in node_data.variables:
variable = variable_selector.variable
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
variables[variable] = value
# Run code
try:
result = CodeExecutor.execute_code(
language='jinja2',
code=node_data.template,
inputs=variables
)
except CodeExecutionException as e:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e)
)
if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters"
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
outputs={
'output': result['result']
}
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {
variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
}

View file

View file

@ -0,0 +1,37 @@
from typing import Literal, Union
from pydantic import BaseModel, validator
from core.workflow.entities.base_node_data_entities import BaseNodeData
ToolParameterValue = Union[str, int, float, bool]
class ToolEntity(BaseModel):
provider_id: str
provider_type: Literal['builtin', 'api']
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy
tool_configurations: dict[str, ToolParameterValue]
class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(BaseModel):
value: Union[ToolParameterValue, list[str]]
type: Literal['mixed', 'variable', 'constant']
@validator('type', pre=True, always=True)
def check_type(cls, value, values):
typ = value
value = values.get('value')
if typ == 'mixed' and not isinstance(value, str):
raise ValueError('value must be a string')
elif typ == 'variable' and not isinstance(value, list):
raise ValueError('value must be a list')
elif typ == 'constant' and not isinstance(value, ToolParameterValue):
raise ValueError('value must be a string, int, float, or bool')
return typ
"""
Tool Node Schema
"""
tool_parameters: dict[str, ToolInput]

View file

@ -0,0 +1,201 @@
from os import path
from typing import cast
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models.workflow import WorkflowNodeExecutionStatus
class ToolNode(BaseNode):
"""
Tool Node
"""
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run the tool node
"""
node_data = cast(ToolNodeData, self.node_data)
# fetch tool icon
tool_info = {
'provider_type': node_data.provider_type,
'provider_id': node_data.provider_id
}
# get parameters
parameters = self._generate_parameters(variable_pool, node_data)
# get tool runtime
try:
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters,
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
error=f'Failed to get tool runtime: {str(e)}'
)
try:
messages = ToolEngine.workflow_invoke(
tool=tool_runtime,
tool_parameters=parameters,
user_id=self.user_id,
workflow_id=self.workflow_id,
workflow_tool_callback=DifyWorkflowCallbackHandler()
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters,
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
error=f'Failed to invoke tool: {str(e)}',
)
# convert tool messages
plain_text, files = self._convert_tool_messages(messages)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'text': plain_text,
'files': files
},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
inputs=parameters
)
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict:
"""
Generate parameters
"""
result = {}
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
if input.type == 'mixed':
result[parameter_name] = self._format_variable_template(input.value, variable_pool)
elif input.type == 'variable':
result[parameter_name] = variable_pool.get_variable_value(input.value)
elif input.type == 'constant':
result[parameter_name] = input.value
return result
def _format_variable_template(self, template: str, variable_pool: VariablePool) -> str:
"""
Format variable template
"""
inputs = {}
template_parser = VariableTemplateParser(template)
for selector in template_parser.extract_variable_selectors():
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
return template_parser.format(inputs)
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
)
# extract plain text and files
files = self._extract_tool_response_binary(messages)
plain_text = self._extract_tool_response_text(messages)
return plain_text, files
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]:
"""
Extract tool response binary
"""
result = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
url = response.message
ext = path.splitext(url)[1]
mimetype = response.meta.get('mime_type', 'image/jpeg')
filename = response.save_as or url.split('/')[-1]
# get tool file id
tool_file_id = url.split('/')[-1].split('.')[0]
result.append(FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
filename=filename,
extension=ext,
mime_type=mimetype,
))
elif response.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
tool_file_id = response.message.split('/')[-1].split('.')[0]
result.append(FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
filename=response.save_as,
extension=path.splitext(response.save_as)[1],
mime_type=response.meta.get('mime_type', 'application/octet-stream'),
))
elif response.type == ToolInvokeMessage.MessageType.LINK:
pass # TODO:
return result
def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str:
"""
Extract tool response text
"""
return ''.join([
f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else
f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else ''
for message in tool_response
])
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
result = {}
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
if input.type == 'mixed':
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == 'variable':
result[parameter_name] = input.value
elif input.type == 'constant':
pass
return result

View file

@ -0,0 +1,12 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
class VariableAssignerNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
type: str = 'variable-assigner'
output_type: str
variables: list[list[str]]

View file

@ -0,0 +1,41 @@
from typing import cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.variable_assigner.entities import VariableAssignerNodeData
from models.workflow import WorkflowNodeExecutionStatus
class VariableAssignerNode(BaseNode):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: VariableAssignerNodeData = cast(self._node_data_cls, self.node_data)
# Get variables
outputs = {}
inputs = {}
for variable in node_data.variables:
value = variable_pool.get_variable_value(variable)
if value is not None:
outputs = {
"output": value
}
inputs = {
'.'.join(variable[1:]): value
}
break
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
inputs=inputs
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
return {}

View file

View file

@ -0,0 +1,58 @@
import re
from core.workflow.entities.variable_entities import VariableSelector
REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
class VariableTemplateParser:
"""
Rules:
1. Template variables must be enclosed in `{{}}`.
2. The template variable Key can only be: #node_id.var1.var2#.
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
"""
def __init__(self, template: str):
self.template = template
self.variable_keys = self.extract()
def extract(self) -> list:
# Regular expression to match the template rules
matches = re.findall(REGEX, self.template)
first_group_matches = [match[0] for match in matches]
return list(set(first_group_matches))
def extract_variable_selectors(self) -> list[VariableSelector]:
variable_selectors = []
for variable_key in self.variable_keys:
remove_hash = variable_key.replace('#', '')
split_result = remove_hash.split('.')
if len(split_result) < 2:
continue
variable_selectors.append(VariableSelector(
variable=variable_key,
value_selector=split_result
))
return variable_selectors
def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
def replacer(match):
key = match.group(1)
value = inputs.get(key, match.group(0)) # return original matched string if key not found
if remove_template_variables:
return VariableTemplateParser.remove_template_variables(value)
return value
prompt = re.sub(REGEX, replacer, self.template)
return re.sub(r'<\|.*?\|>', '', prompt)
@classmethod
def remove_template_variables(cls, text: str):
return re.sub(REGEX, r'{\1}', text)

View file

@ -0,0 +1,531 @@
import logging
import time
from typing import Optional
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.file.file_obj import FileVar
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base_node import BaseNode, UserFrom
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode
from extensions.ext_database import db
from models.workflow import (
Workflow,
WorkflowNodeExecutionStatus,
)
node_classes = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.CODE: CodeNode,
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_ASSIGNER: VariableAssignerNode,
}
logger = logging.getLogger(__name__)
class WorkflowEngineManager:
def get_default_configs(self) -> list[dict]:
"""
Get default block configs
"""
default_block_configs = []
for node_type, node_class in node_classes.items():
default_config = node_class.get_default_config()
if default_config:
default_block_configs.append(default_config)
return default_block_configs
def get_default_config(self, node_type: NodeType, filters: Optional[dict] = None) -> Optional[dict]:
"""
Get default config of node.
:param node_type: node type
:param filters: filter by node config parameters.
:return:
"""
node_class = node_classes.get(node_type)
if not node_class:
return None
default_config = node_class.get_default_config(filters=filters)
if not default_config:
return None
return default_config
def run_workflow(self, workflow: Workflow,
user_id: str,
user_from: UserFrom,
user_inputs: dict,
system_inputs: Optional[dict] = None,
callbacks: list[BaseWorkflowCallback] = None) -> None:
"""
Run workflow
:param workflow: Workflow instance
:param user_id: user id
:param user_from: user from
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:param callbacks: workflow callbacks
:return:
"""
# fetch workflow graph
graph = workflow.graph_dict
if not graph:
raise ValueError('workflow graph not found')
if 'nodes' not in graph or 'edges' not in graph:
raise ValueError('nodes or edges not found in workflow graph')
if not isinstance(graph.get('nodes'), list):
raise ValueError('nodes in workflow graph must be a list')
if not isinstance(graph.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
# init workflow run
if callbacks:
for callback in callbacks:
callback.on_workflow_run_started()
# init workflow run state
workflow_run_state = WorkflowRunState(
workflow=workflow,
start_at=time.perf_counter(),
variable_pool=VariablePool(
system_variables=system_inputs,
user_inputs=user_inputs
),
user_id=user_id,
user_from=user_from
)
try:
predecessor_node = None
has_entry_node = False
while True:
# get next node, multiple target nodes in the future
next_node = self._get_next_node(
workflow_run_state=workflow_run_state,
graph=graph,
predecessor_node=predecessor_node,
callbacks=callbacks
)
if not next_node:
break
# check is already ran
if next_node.node_id in [node_and_result.node.node_id
for node_and_result in workflow_run_state.workflow_nodes_and_results]:
predecessor_node = next_node
continue
has_entry_node = True
# max steps 30 reached
if len(workflow_run_state.workflow_nodes_and_results) > 30:
raise ValueError('Max steps 30 reached.')
# or max execution time 10min reached
if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=600):
raise ValueError('Max execution time 10min reached.')
# run workflow, run multiple target nodes in the future
self._run_workflow_node(
workflow_run_state=workflow_run_state,
node=next_node,
predecessor_node=predecessor_node,
callbacks=callbacks
)
if next_node.node_type in [NodeType.END]:
break
predecessor_node = next_node
if not has_entry_node:
self._workflow_run_failed(
error='Start node not found in workflow graph.',
callbacks=callbacks
)
return
except GenerateTaskStoppedException as e:
return
except Exception as e:
self._workflow_run_failed(
error=str(e),
callbacks=callbacks
)
return
# workflow run success
self._workflow_run_success(
callbacks=callbacks
)
def single_step_run_workflow_node(self, workflow: Workflow,
node_id: str,
user_id: str,
user_inputs: dict) -> tuple[BaseNode, NodeRunResult]:
"""
Single step run workflow node
:param workflow: Workflow instance
:param node_id: node id
:param user_id: user id
:param user_inputs: user inputs
:return:
"""
# fetch node info from workflow graph
graph = workflow.graph_dict
if not graph:
raise ValueError('workflow graph not found')
nodes = graph.get('nodes')
if not nodes:
raise ValueError('nodes not found in workflow graph')
# fetch node config from node id
node_config = None
for node in nodes:
if node.get('id') == node_id:
node_config = node
break
if not node_config:
raise ValueError('node id not found in workflow graph')
# Get node class
node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type')))
# init workflow run state
node_instance = node_cls(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
config=node_config
)
try:
# init variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={}
)
# variable selector to variable mapping
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config)
except NotImplementedError:
variable_mapping = {}
for variable_key, variable_selector in variable_mapping.items():
if variable_key not in user_inputs:
raise ValueError(f'Variable key {variable_key} not found in user inputs.')
# fetch variable node id from variable selector
variable_node_id = variable_selector[0]
variable_key_list = variable_selector[1:]
# append variable and value to variable pool
variable_pool.append_variable(
node_id=variable_node_id,
variable_key_list=variable_key_list,
value=user_inputs.get(variable_key)
)
# run node
node_run_result = node_instance.run(
variable_pool=variable_pool
)
# sign output files
node_run_result.outputs = self.handle_special_values(node_run_result.outputs)
except Exception as e:
raise WorkflowNodeRunFailedError(
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_title=node_instance.node_data.title,
error=str(e)
)
return node_instance, node_run_result
def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None:
"""
Workflow run success
:param callbacks: workflow callbacks
:return:
"""
if callbacks:
for callback in callbacks:
callback.on_workflow_run_succeeded()
def _workflow_run_failed(self, error: str,
callbacks: list[BaseWorkflowCallback] = None) -> None:
"""
Workflow run failed
:param error: error message
:param callbacks: workflow callbacks
:return:
"""
if callbacks:
for callback in callbacks:
callback.on_workflow_run_failed(
error=error
)
def _get_next_node(self, workflow_run_state: WorkflowRunState,
graph: dict,
predecessor_node: Optional[BaseNode] = None,
callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]:
"""
Get next node
multiple target nodes in the future.
:param graph: workflow graph
:param predecessor_node: predecessor node
:param callbacks: workflow callbacks
:return:
"""
nodes = graph.get('nodes')
if not nodes:
return None
if not predecessor_node:
for node_config in nodes:
if node_config.get('data', {}).get('type', '') == NodeType.START.value:
return StartNode(
tenant_id=workflow_run_state.tenant_id,
app_id=workflow_run_state.app_id,
workflow_id=workflow_run_state.workflow_id,
user_id=workflow_run_state.user_id,
user_from=workflow_run_state.user_from,
config=node_config,
callbacks=callbacks
)
else:
edges = graph.get('edges')
source_node_id = predecessor_node.node_id
# fetch all outgoing edges from source node
outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id]
if not outgoing_edges:
return None
# fetch target node id from outgoing edges
outgoing_edge = None
source_handle = predecessor_node.node_run_result.edge_source_handle \
if predecessor_node.node_run_result else None
if source_handle:
for edge in outgoing_edges:
if edge.get('sourceHandle') and edge.get('sourceHandle') == source_handle:
outgoing_edge = edge
break
else:
outgoing_edge = outgoing_edges[0]
if not outgoing_edge:
return None
target_node_id = outgoing_edge.get('target')
# fetch target node from target node id
target_node_config = None
for node in nodes:
if node.get('id') == target_node_id:
target_node_config = node
break
if not target_node_config:
return None
# get next node
target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type')))
return target_node(
tenant_id=workflow_run_state.tenant_id,
app_id=workflow_run_state.app_id,
workflow_id=workflow_run_state.workflow_id,
user_id=workflow_run_state.user_id,
user_from=workflow_run_state.user_from,
config=target_node_config,
callbacks=callbacks
)
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
Check timeout
:param start_at: start time
:param max_execution_time: max execution time
:return:
"""
return time.perf_counter() - start_at > max_execution_time
def _run_workflow_node(self, workflow_run_state: WorkflowRunState,
node: BaseNode,
predecessor_node: Optional[BaseNode] = None,
callbacks: list[BaseWorkflowCallback] = None) -> None:
if callbacks:
for callback in callbacks:
callback.on_workflow_node_execute_started(
node_id=node.node_id,
node_type=node.node_type,
node_data=node.node_data,
node_run_index=len(workflow_run_state.workflow_nodes_and_results) + 1,
predecessor_node_id=predecessor_node.node_id if predecessor_node else None
)
db.session.close()
workflow_nodes_and_result = WorkflowNodeAndResult(
node=node,
result=None
)
# add to workflow_nodes_and_results
workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result)
try:
# run node, result must have inputs, process_data, outputs, execution_metadata
node_run_result = node.run(
variable_pool=workflow_run_state.variable_pool
)
except GenerateTaskStoppedException as e:
node_run_result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error='Workflow stopped.'
)
except Exception as e:
logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
node_run_result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e)
)
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
# node run failed
if callbacks:
for callback in callbacks:
callback.on_workflow_node_execute_failed(
node_id=node.node_id,
node_type=node.node_type,
node_data=node.node_data,
error=node_run_result.error,
inputs=node_run_result.inputs,
outputs=node_run_result.outputs,
process_data=node_run_result.process_data,
)
raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}")
workflow_nodes_and_result.result = node_run_result
# node run success
if callbacks:
for callback in callbacks:
callback.on_workflow_node_execute_succeeded(
node_id=node.node_id,
node_type=node.node_type,
node_data=node.node_data,
inputs=node_run_result.inputs,
process_data=node_run_result.process_data,
outputs=node_run_result.outputs,
execution_metadata=node_run_result.metadata
)
if node_run_result.outputs:
for variable_key, variable_value in node_run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
variable_pool=workflow_run_state.variable_pool,
node_id=node.node_id,
variable_key_list=[variable_key],
variable_value=variable_value
)
if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
db.session.close()
def _append_variables_recursively(self, variable_pool: VariablePool,
node_id: str,
variable_key_list: list[str],
variable_value: VariableValue):
"""
Append variables recursively
:param variable_pool: variable pool
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
variable_pool.append_variable(
node_id=node_id,
variable_key_list=variable_key_list,
value=variable_value
)
# if variable_value is a dict, then recursively append variables
if isinstance(variable_value, dict):
for key, value in variable_value.items():
# construct new key list
new_key_list = variable_key_list + [key]
self._append_variables_recursively(
variable_pool=variable_pool,
node_id=node_id,
variable_key_list=new_key_list,
variable_value=value
)
@classmethod
def handle_special_values(cls, value: Optional[dict]) -> Optional[dict]:
"""
Handle special values
:param value: value
:return:
"""
if not value:
return None
new_value = value.copy()
if isinstance(new_value, dict):
for key, val in new_value.items():
if isinstance(val, FileVar):
new_value[key] = val.to_dict()
elif isinstance(val, list):
new_val = []
for v in val:
if isinstance(v, FileVar):
new_val.append(v.to_dict())
else:
new_val.append(v)
new_value[key] = new_val
return new_value