Merge branch 'zustand/io/migration' into cz/zustand/io/migration
This commit is contained in:
commit
0f23df1c8d
24 changed files with 857 additions and 532 deletions
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "langflow"
|
||||
version = "0.6.5a9"
|
||||
version = "0.6.5a12"
|
||||
description = "A Python package with a built-in web application"
|
||||
authors = ["Logspace <contact@logspace.ai>"]
|
||||
maintainers = [
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
"""Add unique constraints
|
||||
|
||||
Revision ID: b2fa308044b5
|
||||
Revises: 0b8757876a7c
|
||||
Create Date: 2024-01-26 13:31:14.797548
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b2fa308044b5'
|
||||
down_revision: Union[str, None] = '0b8757876a7c'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
try:
|
||||
op.drop_table('flowstyle')
|
||||
with op.batch_alter_table('flow', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('is_component', sa.Boolean(), nullable=True))
|
||||
batch_op.add_column(sa.Column('updated_at', sa.DateTime(), nullable=True))
|
||||
batch_op.add_column(sa.Column('folder', sqlmodel.sql.sqltypes.AutoString(), nullable=True))
|
||||
batch_op.add_column(sa.Column('user_id', sqlmodel.sql.sqltypes.GUID(), nullable=True))
|
||||
batch_op.create_index(batch_op.f('ix_flow_user_id'), ['user_id'], unique=False)
|
||||
batch_op.create_foreign_key(None, 'user', ['user_id'], ['id'])
|
||||
except Exception:
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
try:
|
||||
with op.batch_alter_table('flow', schema=None) as batch_op:
|
||||
batch_op.drop_constraint(None, type_='foreignkey')
|
||||
batch_op.drop_index(batch_op.f('ix_flow_user_id'))
|
||||
batch_op.drop_column('user_id')
|
||||
batch_op.drop_column('folder')
|
||||
batch_op.drop_column('updated_at')
|
||||
batch_op.drop_column('is_component')
|
||||
|
||||
op.create_table('flowstyle',
|
||||
sa.Column('color', sa.VARCHAR(), nullable=False),
|
||||
sa.Column('emoji', sa.VARCHAR(), nullable=False),
|
||||
sa.Column('flow_id', sa.CHAR(length=32), nullable=True),
|
||||
sa.Column('id', sa.CHAR(length=32), nullable=False),
|
||||
sa.ForeignKeyConstraint(['flow_id'], ['flow.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
"""New fixes
|
||||
|
||||
Revision ID: bc2f01c40e4a
|
||||
Revises: b2fa308044b5
|
||||
Create Date: 2024-01-26 13:34:14.496769
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'bc2f01c40e4a'
|
||||
down_revision: Union[str, None] = 'b2fa308044b5'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
try:
|
||||
with op.batch_alter_table('flow', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('is_component', sa.Boolean(), nullable=True))
|
||||
batch_op.add_column(sa.Column('updated_at', sa.DateTime(), nullable=True))
|
||||
batch_op.add_column(sa.Column('folder', sqlmodel.sql.sqltypes.AutoString(), nullable=True))
|
||||
batch_op.add_column(sa.Column('user_id', sqlmodel.sql.sqltypes.GUID(), nullable=True))
|
||||
batch_op.create_index(batch_op.f('ix_flow_user_id'), ['user_id'], unique=False)
|
||||
batch_op.create_foreign_key('flow_user_id_fkey'
|
||||
, 'user', ['user_id'], ['id'])
|
||||
except Exception:
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
try:
|
||||
with op.batch_alter_table('flow', schema=None) as batch_op:
|
||||
batch_op.drop_constraint('flow_user_id_fkey', type_='foreignkey')
|
||||
batch_op.drop_index(batch_op.f('ix_flow_user_id'))
|
||||
batch_op.drop_column('user_id')
|
||||
batch_op.drop_column('folder')
|
||||
batch_op.drop_column('updated_at')
|
||||
batch_op.drop_column('is_component')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -4,7 +4,6 @@ from io import BytesIO
|
|||
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from langflow.api.v1.schemas import UploadFileResponse
|
||||
from langflow.services.auth.utils import get_current_active_user
|
||||
from langflow.services.database.models.flow import Flow
|
||||
|
|
@ -42,7 +41,7 @@ async def upload_file(
|
|||
file_content = await file.read()
|
||||
file_name = file.filename or hashlib.sha256(file_content).hexdigest()
|
||||
folder = flow_id
|
||||
storage_service.save_file(flow_id=folder, file_name=file_name, data=file_content)
|
||||
await storage_service.save_file(flow_id=folder, file_name=file_name, data=file_content)
|
||||
return UploadFileResponse(flowId=flow_id, file_path=f"{folder}/{file_name}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
@ -61,7 +60,7 @@ async def download_file(file_name: str, flow_id: str, storage_service: StorageSe
|
|||
if not content_type:
|
||||
raise HTTPException(status_code=500, detail=f"Content type not found for extension {extension}")
|
||||
|
||||
file_content = storage_service.get_file(flow_id=flow_id, file_name=file_name)
|
||||
file_content = await storage_service.get_file(flow_id=flow_id, file_name=file_name)
|
||||
headers = {
|
||||
"Content-Disposition": f"attachment; filename={file_name} filename*=UTF-8''{file_name}",
|
||||
"Content-Type": "application/octet-stream",
|
||||
|
|
@ -87,7 +86,7 @@ async def download_image(file_name: str, flow_id: str, storage_service: StorageS
|
|||
elif not content_type.startswith("image"):
|
||||
raise HTTPException(status_code=500, detail=f"Content type {content_type} is not an image")
|
||||
|
||||
file_content = storage_service.get_file(flow_id=flow_id, file_name=file_name)
|
||||
file_content = await storage_service.get_file(flow_id=flow_id, file_name=file_name)
|
||||
return StreamingResponse(BytesIO(file_content), media_type=content_type)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
@ -98,7 +97,7 @@ async def list_files(
|
|||
flow_id: str = Depends(get_flow_id), storage_service: StorageService = Depends(get_storage_service)
|
||||
):
|
||||
try:
|
||||
files = storage_service.list_files(flow_id=flow_id)
|
||||
files = await storage_service.list_files(flow_id=flow_id)
|
||||
return {"files": files}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
@ -109,7 +108,7 @@ async def delete_file(
|
|||
file_name: str, flow_id: str = Depends(get_flow_id), storage_service: StorageService = Depends(get_storage_service)
|
||||
):
|
||||
try:
|
||||
storage_service.delete_file(flow_id=flow_id, file_name=file_name)
|
||||
await storage_service.delete_file(flow_id=flow_id, file_name=file_name)
|
||||
return {"message": f"File {file_name} deleted successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from fastapi import Request, Response, APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlmodel import Session
|
||||
|
||||
|
|
@ -33,8 +33,8 @@ async def login_to_get_access_token(
|
|||
|
||||
if user:
|
||||
tokens = create_user_tokens(user_id=user.id, db=db, update_last_login=True)
|
||||
response.set_cookie("refresh_token_lf", tokens["refresh_token"], httponly=True, secure=True, samesite="strict")
|
||||
response.set_cookie("access_token_lf", tokens["access_token"], httponly=False, secure=True, samesite="strict")
|
||||
response.set_cookie("refresh_token_lf", tokens["refresh_token"], httponly=True)
|
||||
response.set_cookie("access_token_lf", tokens["access_token"], httponly=False)
|
||||
return tokens
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -50,7 +50,7 @@ async def auto_login(
|
|||
):
|
||||
if settings_service.auth_settings.AUTO_LOGIN:
|
||||
tokens = create_user_longterm_token(db)
|
||||
response.set_cookie("access_token_lf", tokens["access_token"], httponly=False, secure=True, samesite="strict")
|
||||
response.set_cookie("access_token_lf", tokens["access_token"], httponly=False)
|
||||
return tokens
|
||||
|
||||
raise HTTPException(
|
||||
|
|
@ -67,8 +67,8 @@ async def refresh_token(request: Request, response: Response):
|
|||
token = request.cookies.get("refresh_token_lf")
|
||||
if token:
|
||||
tokens = create_refresh_token(token)
|
||||
response.set_cookie("refresh_token_lf", tokens["refresh_token"], httponly=True, secure=True, samesite="strict")
|
||||
response.set_cookie("access_token_lf", tokens["access_token"], httponly=False, secure=True, samesite="strict")
|
||||
response.set_cookie("refresh_token_lf", tokens["refresh_token"], httponly=True)
|
||||
response.set_cookie("access_token_lf", tokens["access_token"], httponly=False)
|
||||
return tokens
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from langflow import CustomComponent
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from langchain.chains import ConversationChain
|
||||
from typing import Optional, Union, Callable
|
||||
from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain, Text
|
||||
|
||||
|
||||
class ConversationChainComponent(CustomComponent):
|
||||
|
|
@ -23,7 +25,17 @@ class ConversationChainComponent(CustomComponent):
|
|||
self,
|
||||
llm: BaseLanguageModel,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
) -> Union[Chain, Callable]:
|
||||
inputs: dict = {},
|
||||
) -> Union[Chain, Callable, Text]:
|
||||
if memory is None:
|
||||
return ConversationChain(llm=llm)
|
||||
return ConversationChain(llm=llm, memory=memory)
|
||||
chain = ConversationChain(llm=llm)
|
||||
chain = ConversationChain(llm=llm, memory=memory)
|
||||
result = chain.invoke(inputs)
|
||||
# result is an AIMessage which is a subclass of BaseMessage
|
||||
# We need to check if it is a string or a BaseMessage
|
||||
if hasattr(result, "content") and isinstance(result.content, str):
|
||||
return result.content
|
||||
elif isinstance(result, str):
|
||||
return result
|
||||
|
||||
return str(result)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from typing import Callable, Optional, Union
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.field_typing import BaseLanguageModel, BaseMemory, BasePromptTemplate, Chain
|
||||
from langflow.field_typing import BaseLanguageModel, BaseMemory, BasePromptTemplate, Chain, Text
|
||||
|
||||
|
||||
class LLMChainComponent(CustomComponent):
|
||||
|
|
@ -22,5 +23,5 @@ class LLMChainComponent(CustomComponent):
|
|||
prompt: BasePromptTemplate,
|
||||
llm: BaseLanguageModel,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
) -> Union[Chain, Callable]:
|
||||
) -> Union[Chain, Callable, Text]:
|
||||
return LLMChain(prompt=prompt, llm=llm, memory=memory)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from langflow import CustomComponent
|
||||
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import Document
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.field_typing import Text
|
||||
|
||||
|
||||
class PromptRunner(CustomComponent):
|
||||
|
|
@ -18,11 +19,15 @@ class PromptRunner(CustomComponent):
|
|||
"code": {"show": False},
|
||||
}
|
||||
|
||||
def build(self, llm: BaseLLM, prompt: PromptTemplate, inputs: dict = {}) -> Document:
|
||||
def build(self, llm: BaseLLM, prompt: PromptTemplate, inputs: dict = {}) -> Text:
|
||||
chain = prompt | llm
|
||||
# The input is an empty dict because the prompt is already filled
|
||||
result = chain.invoke(input=inputs)
|
||||
if hasattr(result, "content"):
|
||||
result = result.content
|
||||
result_message: BaseMessage = chain.invoke(input=inputs)
|
||||
if hasattr(result_message, "content"):
|
||||
result: str = result_message.content
|
||||
elif isinstance(result_message, str):
|
||||
result = result_message
|
||||
else:
|
||||
result = str(result_message)
|
||||
self.repr_value = result
|
||||
return Document(page_content=str(result))
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -5,13 +5,15 @@ from uuid import UUID
|
|||
import yaml
|
||||
from cachetools import TTLCache, cachedmethod
|
||||
from fastapi import HTTPException
|
||||
|
||||
from langflow.interface.custom.code_parser.utils import (
|
||||
extract_inner_type_from_generic_alias,
|
||||
extract_union_types_from_generic_alias,
|
||||
)
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_credential_service, get_db_service
|
||||
from langflow.services.deps import get_credential_service, get_db_service, get_storage_service
|
||||
from langflow.services.storage.service import StorageService
|
||||
from langflow.utils import validate
|
||||
|
||||
from .component import Component
|
||||
|
|
@ -34,6 +36,13 @@ class CustomComponent(Component):
|
|||
self.cache = TTLCache(maxsize=1024, ttl=60)
|
||||
super().__init__(**data)
|
||||
|
||||
def get_full_path(self, path: str) -> str:
|
||||
storage_svc: "StorageService" = get_storage_service()
|
||||
|
||||
flow_id = path.split("/")[0]
|
||||
file_name = path.split("/")[1]
|
||||
return storage_svc.build_full_path(flow_id, file_name)
|
||||
|
||||
def custom_repr(self):
|
||||
if self.repr_value == "":
|
||||
self.repr_value = self.status
|
||||
|
|
|
|||
|
|
@ -119,8 +119,8 @@ async def instantiate_based_on_type(class_object, base_type, node_type, params,
|
|||
|
||||
async def instantiate_custom_component(node_type, class_object, params, user_id):
|
||||
params_copy = params.copy()
|
||||
class_object: "CustomComponent" = eval_custom_component_code(params_copy.pop("code"))
|
||||
custom_component = class_object(user_id=user_id)
|
||||
class_object: Type["CustomComponent"] = eval_custom_component_code(params_copy.pop("code"))
|
||||
custom_component: "CustomComponent" = class_object(user_id=user_id)
|
||||
|
||||
if "retriever" in params_copy and hasattr(params_copy["retriever"], "as_retriever"):
|
||||
params_copy["retriever"] = params_copy["retriever"].as_retriever()
|
||||
|
|
|
|||
|
|
@ -1,31 +1,96 @@
|
|||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
from loguru import logger
|
||||
|
||||
from .service import StorageService
|
||||
|
||||
|
||||
class LocalStorageService(StorageService):
|
||||
def __init__(self, session_service, settings_service):
|
||||
super().__init__(session_service, settings_service)
|
||||
self.data_dir = settings_service.settings.CONFIG_DIR
|
||||
"""A service class for handling local storage operations."""
|
||||
|
||||
def __init__(self, session_service, settings_service):
|
||||
"""Initialize the local storage service with session and settings services."""
|
||||
super().__init__(session_service, settings_service)
|
||||
self.data_dir = Path(settings_service.settings.CONFIG_DIR)
|
||||
self.set_ready()
|
||||
|
||||
def save_file(self, flow_id: str, file_name: str, data: bytes):
|
||||
folder_path = Path(f"{self.data_dir}/{flow_id}")
|
||||
def build_full_path(self, flow_id: str, file_name: str) -> str:
|
||||
"""Build the full path of a file in the local storage."""
|
||||
return str(self.data_dir / flow_id / file_name)
|
||||
|
||||
async def save_file(self, flow_id: str, file_name: str, data: bytes):
|
||||
"""
|
||||
Save a file in the local storage.
|
||||
|
||||
:param flow_id: The identifier for the flow.
|
||||
:param file_name: The name of the file to be saved.
|
||||
:param data: The byte content of the file.
|
||||
:raises FileNotFoundError: If the specified flow does not exist.
|
||||
:raises IsADirectoryError: If the file name is a directory.
|
||||
:raises PermissionError: If there is no permission to write the file.
|
||||
"""
|
||||
folder_path = self.data_dir / flow_id
|
||||
folder_path.mkdir(parents=True, exist_ok=True)
|
||||
with open(folder_path / file_name, "wb") as f:
|
||||
f.write(data)
|
||||
file_path = folder_path / file_name
|
||||
|
||||
def get_file(self, flow_id: str, file_name: str) -> bytes:
|
||||
with open(f"{self.data_dir}/{flow_id}/{file_name}", "rb") as f:
|
||||
return f.read()
|
||||
try:
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(data)
|
||||
logger.info(f"File {file_name} saved successfully in flow {flow_id}.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving file {file_name} in flow {flow_id}: {e}")
|
||||
raise e
|
||||
|
||||
def list_files(self, flow_id: str):
|
||||
folder_path = Path(f"{self.data_dir}/{flow_id}")
|
||||
return [file.name for file in folder_path.iterdir() if file.is_file()]
|
||||
async def get_file(self, flow_id: str, file_name: str) -> bytes:
|
||||
"""
|
||||
Retrieve a file from the local storage.
|
||||
|
||||
def delete_file(self, flow_id: str, file_name: str):
|
||||
Path(f"{self.data_dir}/{flow_id}/{file_name}").unlink()
|
||||
:param flow_id: The identifier for the flow.
|
||||
:param file_name: The name of the file to be retrieved.
|
||||
:return: The byte content of the file.
|
||||
:raises FileNotFoundError: If the file does not exist.
|
||||
"""
|
||||
file_path = self.data_dir / flow_id / file_name
|
||||
if not file_path.exists():
|
||||
logger.warning(f"File {file_name} not found in flow {flow_id}.")
|
||||
raise FileNotFoundError(f"File {file_name} not found in flow {flow_id}")
|
||||
|
||||
async with aiofiles.open(file_path, "rb") as f:
|
||||
logger.info(f"File {file_name} retrieved successfully from flow {flow_id}.")
|
||||
return await f.read()
|
||||
|
||||
async def list_files(self, flow_id: str):
|
||||
"""
|
||||
List all files in a specified flow.
|
||||
|
||||
:param flow_id: The identifier for the flow.
|
||||
:return: A list of file names.
|
||||
:raises FileNotFoundError: If the flow directory does not exist.
|
||||
"""
|
||||
folder_path = self.data_dir / flow_id
|
||||
if not folder_path.exists() or not folder_path.is_dir():
|
||||
logger.warning(f"Flow {flow_id} directory does not exist.")
|
||||
raise FileNotFoundError(f"Flow {flow_id} directory does not exist.")
|
||||
|
||||
files = [file.name for file in folder_path.iterdir() if file.is_file()]
|
||||
logger.info(f"Listed {len(files)} files in flow {flow_id}.")
|
||||
return files
|
||||
|
||||
async def delete_file(self, flow_id: str, file_name: str):
|
||||
"""
|
||||
Delete a file from the local storage.
|
||||
|
||||
:param flow_id: The identifier for the flow.
|
||||
:param file_name: The name of the file to be deleted.
|
||||
"""
|
||||
file_path = self.data_dir / flow_id / file_name
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
logger.info(f"File {file_name} deleted successfully from flow {flow_id}.")
|
||||
else:
|
||||
logger.warning(f"Attempted to delete non-existent file {file_name} in flow {flow_id}.")
|
||||
|
||||
def teardown(self):
|
||||
pass
|
||||
"""Perform any cleanup operations when the service is being torn down."""
|
||||
pass # No specific teardown actions required for local storage at the moment.
|
||||
|
|
|
|||
|
|
@ -1,43 +1,89 @@
|
|||
import boto3
|
||||
from botocore.exceptions import ClientError, NoCredentialsError
|
||||
from loguru import logger
|
||||
|
||||
from .service import StorageService
|
||||
|
||||
|
||||
class S3StorageService(StorageService):
|
||||
def __init__(self, session_service, settings_service):
|
||||
"""A service class for handling operations with AWS S3 storage."""
|
||||
|
||||
async def __init__(self, session_service, settings_service):
|
||||
"""Initialize the S3 storage service with session and settings services."""
|
||||
super().__init__(session_service, settings_service)
|
||||
self.bucket = "langflow"
|
||||
self.s3_client = boto3.client("s3")
|
||||
self.set_ready()
|
||||
|
||||
def save_file(self, folder: str, file_name: str, data):
|
||||
async def save_file(self, folder: str, file_name: str, data):
|
||||
"""
|
||||
Save a file to the S3 bucket.
|
||||
|
||||
:param folder: The folder in the bucket to save the file.
|
||||
:param file_name: The name of the file to be saved.
|
||||
:param data: The byte content of the file.
|
||||
:raises Exception: If an error occurs during file saving.
|
||||
"""
|
||||
try:
|
||||
self.s3_client.put_object(Bucket=self.bucket, Key=f"{folder}/{file_name}", Body=data)
|
||||
logger.info(f"File {file_name} saved successfully in folder {folder}.")
|
||||
except NoCredentialsError:
|
||||
raise Exception("Credentials not available for AWS S3.")
|
||||
logger.error("Credentials not available for AWS S3.")
|
||||
raise
|
||||
except ClientError as e:
|
||||
raise Exception(f"An error occurred: {e}")
|
||||
logger.error(f"Error saving file {file_name} in folder {folder}: {e}")
|
||||
raise
|
||||
|
||||
def get_file(self, folder: str, file_name: str):
|
||||
async def get_file(self, folder: str, file_name: str):
|
||||
"""
|
||||
Retrieve a file from the S3 bucket.
|
||||
|
||||
:param folder: The folder in the bucket where the file is stored.
|
||||
:param file_name: The name of the file to be retrieved.
|
||||
:return: The byte content of the file.
|
||||
:raises Exception: If an error occurs during file retrieval.
|
||||
"""
|
||||
try:
|
||||
response = self.s3_client.get_object(Bucket=self.bucket, Key=f"{folder}/{file_name}")
|
||||
logger.info(f"File {file_name} retrieved successfully from folder {folder}.")
|
||||
return response["Body"].read()
|
||||
except ClientError as e:
|
||||
raise Exception(f"An error occurred: {e}")
|
||||
logger.error(f"Error retrieving file {file_name} from folder {folder}: {e}")
|
||||
raise
|
||||
|
||||
def list_files(self, folder: str):
|
||||
async def list_files(self, folder: str):
|
||||
"""
|
||||
List all files in a specified folder of the S3 bucket.
|
||||
|
||||
:param folder: The folder in the bucket to list files from.
|
||||
:return: A list of file names.
|
||||
:raises Exception: If an error occurs during file listing.
|
||||
"""
|
||||
try:
|
||||
response = self.s3_client.list_objects_v2(Bucket=self.bucket, Prefix=folder)
|
||||
return [item["Key"] for item in response.get("Contents", []) if "/" not in item["Key"][len(folder) :]]
|
||||
files = [item["Key"] for item in response.get("Contents", []) if "/" not in item["Key"][len(folder) :]]
|
||||
logger.info(f"{len(files)} files listed in folder {folder}.")
|
||||
return files
|
||||
except ClientError as e:
|
||||
raise Exception(f"An error occurred: {e}")
|
||||
logger.error(f"Error listing files in folder {folder}: {e}")
|
||||
raise
|
||||
|
||||
def delete_file(self, folder: str, file_name: str):
|
||||
async def delete_file(self, folder: str, file_name: str):
|
||||
"""
|
||||
Delete a file from the S3 bucket.
|
||||
|
||||
:param folder: The folder in the bucket where the file is stored.
|
||||
:param file_name: The name of the file to be deleted.
|
||||
:raises Exception: If an error occurs during file deletion.
|
||||
"""
|
||||
try:
|
||||
self.s3_client.delete_object(Bucket=self.bucket, Key=f"{folder}/{file_name}")
|
||||
logger.info(f"File {file_name} deleted successfully from folder {folder}.")
|
||||
except ClientError as e:
|
||||
raise Exception(f"An error occurred: {e}")
|
||||
logger.error(f"Error deleting file {file_name} from folder {folder}: {e}")
|
||||
raise
|
||||
|
||||
def teardown(self):
|
||||
async def teardown(self):
|
||||
"""Perform any cleanup operations when the service is being torn down."""
|
||||
# No specific teardown actions required for S3 storage at the moment.
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -16,23 +16,26 @@ class StorageService(Service):
|
|||
self.session_service = session_service
|
||||
self.set_ready()
|
||||
|
||||
def build_full_path(self, flow_id: str, file_name: str) -> str:
|
||||
pass
|
||||
|
||||
def set_ready(self):
|
||||
self.ready = True
|
||||
|
||||
@abstractmethod
|
||||
def save_file(self, flow_id: str, file_name: str, data) -> None:
|
||||
async def save_file(self, flow_id: str, file_name: str, data) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_file(self, flow_id: str, file_name: str) -> bytes:
|
||||
async def get_file(self, flow_id: str, file_name: str) -> bytes:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self, flow_id: str) -> list[str]:
|
||||
async def list_files(self, flow_id: str) -> list[str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, flow_id: str, file_name: str) -> bool:
|
||||
async def delete_file(self, flow_id: str, file_name: str) -> bool:
|
||||
pass
|
||||
|
||||
def teardown(self):
|
||||
|
|
|
|||
695
src/frontend/package-lock.json
generated
695
src/frontend/package-lock.json
generated
File diff suppressed because it is too large
Load diff
|
|
@ -83,6 +83,13 @@ export default function GenericNode({
|
|||
countHandles();
|
||||
}, [data, data.node]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!selected) {
|
||||
setInputName(false);
|
||||
setInputDescription(false);
|
||||
}
|
||||
}, [selected]);
|
||||
|
||||
// State for outline color
|
||||
const isBuilding = useFlowStore((state) => state.isBuilding);
|
||||
|
||||
|
|
@ -383,14 +390,14 @@ export default function GenericNode({
|
|||
)
|
||||
}
|
||||
>
|
||||
<div className="generic-node-status-position flex items-center">
|
||||
<div className="generic-node-status-position flex items-center justify-center">
|
||||
<IconComponent
|
||||
name="Zap"
|
||||
className={classNames(
|
||||
validationStatus && validationStatus.valid
|
||||
? "green-status"
|
||||
: "status-build-animation",
|
||||
"h-5 stroke-1"
|
||||
"absolute h-5 stroke-1"
|
||||
)}
|
||||
/>
|
||||
<IconComponent
|
||||
|
|
@ -399,7 +406,7 @@ export default function GenericNode({
|
|||
validationStatus && !validationStatus.valid
|
||||
? "red-status"
|
||||
: "status-build-animation",
|
||||
"h-5 stroke-1"
|
||||
"absolute h-5 stroke-1"
|
||||
)}
|
||||
/>
|
||||
<IconComponent
|
||||
|
|
@ -408,7 +415,7 @@ export default function GenericNode({
|
|||
!validationStatus || isBuilding
|
||||
? "yellow-status"
|
||||
: "status-build-animation",
|
||||
"h-5 stroke-1"
|
||||
"absolute h-5 stroke-1"
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -1,42 +1,54 @@
|
|||
import { cloneDeep } from "lodash";
|
||||
import useFlowStore from "../../stores/flowStore";
|
||||
import { IOInputProps } from "../../types/components";
|
||||
import IOFileInput from "../IOInputs/FileInput";
|
||||
import { Textarea } from "../ui/textarea";
|
||||
|
||||
export default function IOInputField({
|
||||
inputType,
|
||||
field,
|
||||
updateValue,
|
||||
inputId,
|
||||
}: IOInputProps): JSX.Element | undefined {
|
||||
const nodes = useFlowStore((state) => state.nodes);
|
||||
const setNode = useFlowStore((state) => state.setNode);
|
||||
const node = nodes.find((node) => node.id === inputId);
|
||||
function handleInputType() {
|
||||
console.log(inputType);
|
||||
|
||||
const handleUpdateValue = (e) => {
|
||||
updateValue(e, "text");
|
||||
};
|
||||
|
||||
if (!node) return "no node found";
|
||||
switch (inputType) {
|
||||
case "TextInput":
|
||||
return (
|
||||
<Textarea
|
||||
className="custom-scroll"
|
||||
className="h-full w-full custom-scroll"
|
||||
placeholder={"Enter text..."}
|
||||
value={field?.value}
|
||||
onChange={handleUpdateValue}
|
||||
value={node.data.node!.template["value"].value}
|
||||
onChange={(e) => {
|
||||
e.target.value;
|
||||
if (node) {
|
||||
let newNode = cloneDeep(node);
|
||||
newNode.data.node!.template["value"].value = e.target.value;
|
||||
setNode(node.id, newNode);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
);
|
||||
case "FileLoader":
|
||||
return <IOFileInput field={field} updateValue={updateValue} />;
|
||||
case "fileLoader":
|
||||
// return <IOFileInput />;
|
||||
|
||||
default:
|
||||
return (
|
||||
<Textarea
|
||||
className="custom-scroll"
|
||||
className="h-full w-full custom-scroll"
|
||||
placeholder={"Enter text..."}
|
||||
value={field?.value}
|
||||
onChange={handleUpdateValue}
|
||||
value={node.data.node!.template["value"]}
|
||||
onChange={(e) => {
|
||||
e.target.value;
|
||||
if (node) {
|
||||
let newNode = cloneDeep(node);
|
||||
newNode.data.node!.template["value"].value = e.target.value;
|
||||
setNode(node.id, newNode);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
return <div className="h-full">{handleInputType()}</div>;
|
||||
return <div className="h-full w-full">{handleInputType()}</div>;
|
||||
}
|
||||
|
|
|
|||
47
src/frontend/src/components/IOOutputView/index.tsx
Normal file
47
src/frontend/src/components/IOOutputView/index.tsx
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
import { cloneDeep } from "lodash";
|
||||
import useFlowStore from "../../stores/flowStore";
|
||||
import { IOOutputProps } from "../../types/components";
|
||||
import { Textarea } from "../ui/textarea";
|
||||
|
||||
export default function IOOutputView({
|
||||
outputType,
|
||||
outputId,
|
||||
}: IOOutputProps): JSX.Element | undefined {
|
||||
const nodes = useFlowStore((state) => state.nodes);
|
||||
const setNode = useFlowStore((state) => state.setNode);
|
||||
const flowPool = useFlowStore((state) => state.flowPool);
|
||||
const node = nodes.find((node) => node.id === outputId);
|
||||
function handleOutputType() {
|
||||
if (!node) return "no node found";
|
||||
switch (outputType) {
|
||||
case "TextOutput":
|
||||
return (
|
||||
<Textarea
|
||||
className="h-full w-full custom-scroll"
|
||||
placeholder={"Enter text..."}
|
||||
// update to real value on flowPool
|
||||
value={flowPool[node.id][flowPool[node.id].length - 1].data.results}
|
||||
readOnly
|
||||
/>
|
||||
);
|
||||
|
||||
default:
|
||||
return (
|
||||
<Textarea
|
||||
className="h-full w-full custom-scroll"
|
||||
placeholder={"Enter text..."}
|
||||
value={node.data.node!.template["value"]}
|
||||
onChange={(e) => {
|
||||
e.target.value;
|
||||
if (node) {
|
||||
let newNode = cloneDeep(node);
|
||||
newNode.data.node!.template["value"].value = e.target.value;
|
||||
setNode(node.id, newNode);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
return <div className="h-full w-full">{handleOutputType()}</div>;
|
||||
}
|
||||
|
|
@ -1,10 +1,11 @@
|
|||
import { cloneDeep } from "lodash";
|
||||
import { ReactNode, useState } from "react";
|
||||
import useFlowStore from "../../stores/flowStore";
|
||||
import { NodeType } from "../../types/flow";
|
||||
import { extractTypeFromLongId } from "../../utils/utils";
|
||||
import { isInputType, isOutputType } from "../../utils/reactflowUtils";
|
||||
import { classNames } from "../../utils/utils";
|
||||
import AccordionComponent from "../AccordionComponent";
|
||||
import IOInputField from "../IOInputField";
|
||||
import IOOutputView from "../IOOutputView";
|
||||
import IconComponent from "../genericIconComponent";
|
||||
import NewChatView from "../newChatView";
|
||||
import { Badge } from "../ui/badge";
|
||||
|
|
@ -16,32 +17,91 @@ export default function IOView(): JSX.Element {
|
|||
const outputIds = outputs.map((obj) => obj.id);
|
||||
const nodes = useFlowStore((state) => state.nodes);
|
||||
const setNode = useFlowStore((state) => state.setNode);
|
||||
const options = inputIds.concat(outputIds);
|
||||
//TODO: show output options for view
|
||||
const [selectedView, setSelectedView] = useState<ReactNode>(
|
||||
handleSelectChange(options[0])
|
||||
const categories = getCategories();
|
||||
const [selectedCategory, setSelectedCategory] = useState<string>(
|
||||
categories[0]
|
||||
);
|
||||
// if (outputTypes.includes("ChatOutput")) {
|
||||
// return <NewChatView />;
|
||||
// }
|
||||
function handleSelectChange(selected: string) {
|
||||
const type = extractTypeFromLongId(selected);
|
||||
return <NewChatView />;
|
||||
switch (type) {
|
||||
case "ChatOutput":
|
||||
return <NewChatView />;
|
||||
break;
|
||||
const [showChat, setShowChat] = useState<boolean>(false);
|
||||
const [selectedView, setSelectedView] = useState<{
|
||||
type: string;
|
||||
id?: string;
|
||||
}>(handleInitialView());
|
||||
|
||||
function handleInitialView() {
|
||||
if (outputs.map((output) => output.type).includes("ChatOutput")) {
|
||||
return { type: "ChatOutput" };
|
||||
}
|
||||
return { type: "" };
|
||||
}
|
||||
console.log(inputs);
|
||||
|
||||
function getCategories() {
|
||||
const categories: string[] = [];
|
||||
if (inputs.length > 0) categories.push("Inputs");
|
||||
if (outputs.filter((output) => output.type !== "ChatOutput").length > 0)
|
||||
categories.push("Outputs");
|
||||
return categories;
|
||||
}
|
||||
|
||||
function handleSelectChange(): ReactNode {
|
||||
const { type, id } = selectedView;
|
||||
if (type === "ChatOutput") return <NewChatView />;
|
||||
if (isInputType(type))
|
||||
return <IOInputField inputId={id!} inputType={type} />;
|
||||
if (isOutputType(type))
|
||||
return <IOOutputView outputId={id!} outputType={type} />;
|
||||
else return <div>no view selected</div>;
|
||||
}
|
||||
|
||||
function UpdateAccordion() {
|
||||
return selectedCategory === "Inputs" ? inputs : outputs;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="form-modal-iv-box">
|
||||
<div className="mr-6 flex h-full w-2/6 flex-col justify-start overflow-auto scrollbar-hide">
|
||||
<div className="file-component-arrangement">
|
||||
<IconComponent name="Variable" className=" file-component-variable" />
|
||||
<span className="file-component-variables-span text-md">Inputs</span>
|
||||
<div className="flex items-start gap-4 py-2">
|
||||
{categories.map((category, index) => {
|
||||
return (
|
||||
//hide chat button if chat is alredy on the view
|
||||
<button
|
||||
onClick={() => setSelectedCategory(category)}
|
||||
className={classNames(
|
||||
"cursor flex items-center rounded-md rounded-b-none px-1",
|
||||
category == selectedCategory
|
||||
? "border border-b-0 border-muted-foreground"
|
||||
: "hover:bg-muted-foreground"
|
||||
)}
|
||||
key={index}
|
||||
>
|
||||
<IconComponent
|
||||
name="Variable"
|
||||
className=" file-component-variable"
|
||||
/>
|
||||
<span className="file-component-variables-span text-md">
|
||||
{category}
|
||||
</span>
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
{selectedView.type !== "ChatOutput" && (
|
||||
<button
|
||||
onClick={() => setSelectedView({ type: "ChatOutput" })}
|
||||
className={
|
||||
"cursor flex items-center rounded-md rounded-b-none px-1 hover:bg-muted-foreground"
|
||||
}
|
||||
key={"chat"}
|
||||
>
|
||||
<IconComponent
|
||||
name="Variable"
|
||||
className=" file-component-variable"
|
||||
/>
|
||||
<span className="file-component-variables-span text-md">
|
||||
Chat
|
||||
</span>
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
{inputs
|
||||
{UpdateAccordion()
|
||||
.filter((input) => input.type !== "ChatInput")
|
||||
.map((input, index) => {
|
||||
const node: NodeType = nodes.find((node) => node.id === input.id)!;
|
||||
|
|
@ -54,11 +114,17 @@ export default function IOView(): JSX.Element {
|
|||
{input.id}
|
||||
</Badge>
|
||||
<div
|
||||
className="-mb-1"
|
||||
className="-mb-1 pr-4"
|
||||
onClick={(event) => {
|
||||
event.stopPropagation();
|
||||
setSelectedView({ type: input.type, id: input.id });
|
||||
}}
|
||||
></div>
|
||||
>
|
||||
<IconComponent
|
||||
className="h-4 w-4"
|
||||
name="ScreenShare"
|
||||
></IconComponent>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
key={index}
|
||||
|
|
@ -66,30 +132,7 @@ export default function IOView(): JSX.Element {
|
|||
>
|
||||
<div className="file-component-tab-column">
|
||||
{node && (
|
||||
<IOInputField
|
||||
field={
|
||||
node.data.node!.template["value"] ||
|
||||
node.data.node!.template["file_path"]["value"]
|
||||
}
|
||||
inputType={input.type}
|
||||
updateValue={(e, type) => {
|
||||
if (type === "file") {
|
||||
if (node) {
|
||||
let newNode = cloneDeep(node);
|
||||
newNode.data.node!.template["file_path"].value =
|
||||
e;
|
||||
setNode(node.id, newNode);
|
||||
}
|
||||
} else {
|
||||
if (node) {
|
||||
let newNode = cloneDeep(node);
|
||||
newNode.data.node!.template["value"].value =
|
||||
e.target.value;
|
||||
setNode(node.id, newNode);
|
||||
}
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<IOInputField inputType={input.type} inputId={input.id} />
|
||||
)}
|
||||
</div>
|
||||
</AccordionComponent>
|
||||
|
|
@ -97,7 +140,7 @@ export default function IOView(): JSX.Element {
|
|||
);
|
||||
})}
|
||||
</div>
|
||||
{selectedView}
|
||||
{handleSelectChange()}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ export default function newChatView(): JSX.Element {
|
|||
useEffect(() => {
|
||||
const chatOutputResponses: FlowPoolObjectType[] = [];
|
||||
outputIds.forEach((outputId) => {
|
||||
console.log("rodou", flowPool[outputId]);
|
||||
if (outputId.includes("ChatOutput")) {
|
||||
if (flowPool[outputId] && flowPool[outputId].length > 0) {
|
||||
chatOutputResponses.push(...flowPool[outputId]);
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import {
|
|||
applyNodeChanges,
|
||||
} from "reactflow";
|
||||
import { create } from "zustand";
|
||||
import { updateFlowInDatabase } from "../controllers/API";
|
||||
import {
|
||||
NodeDataType,
|
||||
NodeType,
|
||||
|
|
@ -45,12 +46,19 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
|
|||
set({ flowPool });
|
||||
},
|
||||
addDataToFlowPool: (data: any, nodeId: string) => {
|
||||
const currentFlow = useFlowsManagerStore.getState().currentFlow;
|
||||
let newFlowPool = cloneDeep({ ...get().flowPool });
|
||||
if (!newFlowPool[nodeId]) newFlowPool[nodeId] = [data];
|
||||
else {
|
||||
newFlowPool[nodeId].push(data);
|
||||
}
|
||||
get().setFlowPool(newFlowPool);
|
||||
if (currentFlow) {
|
||||
window.sessionStorage.setItem(
|
||||
`${currentFlow!.id}`,
|
||||
JSON.stringify(newFlowPool)
|
||||
);
|
||||
}
|
||||
},
|
||||
CleanFlowPool: () => {
|
||||
get().setFlowPool({});
|
||||
|
|
@ -59,16 +67,23 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
|
|||
set({ isPending });
|
||||
},
|
||||
resetFlow: ({ nodes, edges, viewport }) => {
|
||||
const currentFlow = useFlowsManagerStore.getState().currentFlow;
|
||||
let flowPool = {};
|
||||
if (currentFlow) {
|
||||
flowPool = JSON.parse(
|
||||
window.sessionStorage.getItem(`${currentFlow!.id}`) ?? "{}"
|
||||
);
|
||||
}
|
||||
let newEdges = cleanEdges(nodes, edges);
|
||||
const { inputs, outputs } = getInputsAndOutputs(nodes);
|
||||
|
||||
set({
|
||||
nodes,
|
||||
edges: newEdges,
|
||||
flowState: undefined,
|
||||
inputs,
|
||||
outputs,
|
||||
hasIO: inputs.length > 0 && outputs.length > 0,
|
||||
hasIO: inputs.length > 0 || outputs.length > 0,
|
||||
flowPool,
|
||||
});
|
||||
get().reactFlowInstance!.setViewport(viewport);
|
||||
},
|
||||
|
|
@ -109,7 +124,7 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
|
|||
flowState: undefined,
|
||||
inputs,
|
||||
outputs,
|
||||
hasIO: inputs.length > 0 && outputs.length > 0,
|
||||
hasIO: inputs.length > 0 || outputs.length > 0,
|
||||
});
|
||||
|
||||
const flowsManager = useFlowsManagerStore.getState();
|
||||
|
|
@ -330,6 +345,16 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
|
|||
function handleBuildUpdate(data: any) {
|
||||
get().addDataToFlowPool(data.data[data.id], data.id);
|
||||
}
|
||||
await updateFlowInDatabase({
|
||||
data: {
|
||||
nodes: get().nodes,
|
||||
edges: get().edges,
|
||||
viewport: get().reactFlowInstance?.getViewport()!,
|
||||
},
|
||||
id: currentFlow!.id,
|
||||
name: currentFlow!.name,
|
||||
description: currentFlow!.description,
|
||||
});
|
||||
return buildVertices({
|
||||
flowId: currentFlow!.id,
|
||||
nodeId,
|
||||
|
|
|
|||
|
|
@ -290,7 +290,7 @@
|
|||
@apply hidden h-4 w-4 animate-spin rounded-full bg-ring opacity-0;
|
||||
}
|
||||
.generic-node-status {
|
||||
@apply opacity-100;
|
||||
@apply opacity-100 animate-wiggle;
|
||||
}
|
||||
.green-status {
|
||||
@apply generic-node-status text-status-green fill-status-green;
|
||||
|
|
@ -302,7 +302,7 @@
|
|||
@apply generic-node-status text-status-yellow fill-status-yellow;
|
||||
}
|
||||
.status-build-animation {
|
||||
@apply hidden animate-spin text-ring opacity-0;
|
||||
@apply opacity-0;
|
||||
}
|
||||
.status-div {
|
||||
@apply absolute w-4 duration-200 ease-in-out;
|
||||
|
|
@ -1012,7 +1012,7 @@
|
|||
}
|
||||
|
||||
.beta-badge-wrapper {
|
||||
@apply absolute right-0 top-0 h-16 w-16 overflow-hidden rounded-tr-lg;
|
||||
@apply absolute right-0 top-0 h-16 w-16 overflow-hidden rounded-tr-lg pointer-events-none;
|
||||
}
|
||||
.beta-badge-content {
|
||||
@apply mt-2 w-24 rotate-45 bg-beta-background text-center text-xs font-semibold text-beta-foreground;
|
||||
|
|
|
|||
|
|
@ -649,8 +649,11 @@ export type dropdownButtonPropsType = {
|
|||
|
||||
export type IOInputProps = {
|
||||
inputType: string;
|
||||
field: TemplateVariableType;
|
||||
updateValue: (e: any, type: string) => void;
|
||||
inputId: string;
|
||||
};
|
||||
export type IOOutputProps = {
|
||||
outputType: string;
|
||||
outputId: string;
|
||||
};
|
||||
|
||||
export type IOFileInputProps = {
|
||||
|
|
|
|||
|
|
@ -677,6 +677,7 @@ export function validateSelection(
|
|||
if (selection.edges.length === 0) {
|
||||
selection.edges = edges;
|
||||
}
|
||||
|
||||
// get only edges that are connected to the nodes in the selection
|
||||
// first creates a set of all the nodes ids
|
||||
let nodesSet = new Set(selection.nodes.map((n) => n.id));
|
||||
|
|
@ -692,7 +693,17 @@ export function validateSelection(
|
|||
if (selection.nodes.length < 2) {
|
||||
errorsArray.push("Please select more than one node");
|
||||
}
|
||||
|
||||
if (
|
||||
selection.nodes.some(
|
||||
(node) =>
|
||||
isInputNode(node.data as NodeDataType) ||
|
||||
isOutputNode(node.data as NodeDataType)
|
||||
)
|
||||
) {
|
||||
errorsArray.push(
|
||||
"Please select only nodes that are not input or output nodes"
|
||||
);
|
||||
}
|
||||
//check if there are two or more nodes with free outputs
|
||||
if (
|
||||
selection.nodes.filter(
|
||||
|
|
@ -1296,3 +1307,11 @@ export function isInputNode(nodeData: NodeDataType): boolean {
|
|||
export function isOutputNode(nodeData: NodeDataType): boolean {
|
||||
return OUTPUT_TYPES.has(nodeData.type);
|
||||
}
|
||||
|
||||
export function isInputType(type: string): boolean {
|
||||
return INPUT_TYPES.has(type);
|
||||
}
|
||||
|
||||
export function isOutputType(type: string): boolean {
|
||||
return OUTPUT_TYPES.has(type);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ import {
|
|||
Save,
|
||||
SaveAll,
|
||||
Scissors,
|
||||
ScreenShare,
|
||||
Search,
|
||||
Settings2,
|
||||
Share,
|
||||
|
|
@ -388,4 +389,5 @@ export const nodeIconsLucide: iconsType = {
|
|||
TerminalIcon,
|
||||
Repeat,
|
||||
io: ArrowDownUp,
|
||||
ScreenShare,
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue