Refactor embeddings and chat components

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-22 09:10:08 -03:00
commit b43e4de38f
3 changed files with 11 additions and 12 deletions

View file

@ -1,6 +1,7 @@
from typing import Optional
from langchain_community.embeddings.cohere import CohereEmbeddings
from langflow import CustomComponent
@ -21,7 +22,7 @@ class CohereEmbeddingsComponent(CustomComponent):
self,
request_timeout: Optional[float] = None,
cohere_api_key: str = "",
max_retries: Optional[int] = None,
max_retries: int = 3,
model: str = "embed-english-v2.0",
truncate: Optional[str] = None,
user_agent: str = "langchain",

View file

@ -1,10 +1,10 @@
from pydantic.v1.types import SecretStr
from langflow import CustomComponent
from typing import Optional, Union, Callable
from langflow.field_typing import BaseLanguageModel
from typing import Callable, Optional, Union
# from langchain_community.chat_models.anthropic import ChatAnthropic
from langchain_anthropic import ChatAnthropic
from pydantic.v1.types import SecretStr
from langflow import CustomComponent
from langflow.field_typing import BaseLanguageModel
class ChatAnthropicComponent(CustomComponent):
@ -60,7 +60,7 @@ class ChatAnthropicComponent(CustomComponent):
model_kwargs=model_kwargs,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
max_tokens=max_tokens, # type: ignore
top_k=top_k,
top_p=top_p,
)

View file

@ -7,13 +7,13 @@ from langchain.schema import AgentAction, Document
from langchain_community.vectorstores import VectorStore
from langchain_core.messages import AIMessage
from langchain_core.runnables.base import Runnable
from loguru import logger
from pydantic import BaseModel
from langflow.graph.graph.base import Graph
from langflow.interface.custom.custom_component import CustomComponent
from langflow.interface.run import build_sorted_vertices, get_memory_key, update_memory_keys
from langflow.services.deps import get_session_service
from langflow.services.session.service import SessionService
from loguru import logger
from pydantic import BaseModel
def fix_memory_inputs(langchain_object):
@ -172,8 +172,6 @@ async def process_inputs_dict(built_object: Union[Chain, VectorStore, Runnable],
result = result.content
else:
result = result
elif hasattr(built_object, "run") and isinstance(built_object, CustomComponent):
result = built_object.run(inputs)
else:
result = None