diff --git a/src/backend/langflow/components/llms/LiteLLM.py b/src/backend/langflow/components/llms/LiteLLM.py new file mode 100644 index 000000000..6826bf1b2 --- /dev/null +++ b/src/backend/langflow/components/llms/LiteLLM.py @@ -0,0 +1,118 @@ +from dataclasses import Field +from langflow import CustomComponent +from typing import Optional, Union, Callable, Any, Dict +from langflow.field_typing import BaseLanguageModel +from langchain_community.chat_models import ChatLiteLLM +import os + + +class LiteLLMComponent(CustomComponent): + display_name = "LiteLLM" + description = "`LiteLLM` collection of large language models." + documentation = "https://python.langchain.com/docs/integrations/chat/litellm" + + def build_config(self): + return { + "model": { + "display_name": "Model name", + "field_type": "str", + "advanced": False, + "required": True, + "info": "The name of the model to use. For example, `gpt-3.5-turbo`.", + }, + "api_key": { + "display_name": "API key", + "field_type": "str", + "advanced": False, + "required": False, + "password": True, + }, + "streaming": { + "display_name": "Streaming", + "field_type": "bool", + "advanced": True, + "required": False, + "default": True, + }, + "temperature": { + "display_name": "Temperature", + "field_type": "float", + "advanced": True, + "required": False, + "default": 0.7, + }, + "model_kwargs": { + "display_name": "Model kwargs", + "field_type": "dict", + "advanced": True, + "required": False, + "default": {}, + }, + "top_p": { + "display_name": "Top p", + "field_type": "float", + "advanced": True, + "required": False, + }, + "top_k": { + "display_name": "Top k", + "field_type": "int", + "advanced": True, + "required": False, + }, + "n": { + "display_name": "N", + "field_type": "int", + "advanced": True, + "required": False, + "info": "Number of chat completions to generate for each prompt. " + "Note that the API may not return the full n completions if duplicates are generated.", + "default": 1, + }, + "max_tokens": { + "display_name": "Max tokens", + "field_type": "int", + "advanced": True, + "required": False, + "default": 256, + "info": "The maximum number of tokens to generate for each chat completion.", + }, + "max_retries": { + "display_name": "Max retries", + "field_type": "int", + "advanced": True, + "required": False, + "default": 6, + }, + } + + def build( + self, + model: str, + api_key: str, + streaming: bool = True, + temperature: Optional[float] = 0.7, + model_kwargs: Optional[Dict[str, Any]] = {}, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + n: Optional[int] = 1, + max_tokens: Optional[int] = 256, + max_retries: Optional[int] = 6, + ) -> Union[BaseLanguageModel, Callable]: + if api_key: + if "perplexity" in model: + os.environ["PERPLEXITYAI_API_KEY"] = api_key + elif "replicate" in model: + os.environ["REPLICATE_API_KEY"] = api_key + LLM = ChatLiteLLM( + model=model, + streaming=streaming, + temperature=temperature, + model_kwargs=model_kwargs, + top_p=top_p, + top_k=top_k, + n=n, + max_tokens=max_tokens, + max_retries=max_retries, + ) + return LLM