Refactor model components to support streaming

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-27 22:44:50 -03:00
commit 6c2a35afb1
11 changed files with 114 additions and 40 deletions

View file

@ -35,6 +35,10 @@ class AmazonBedrockComponent(CustomComponent):
"cache": {"display_name": "Cache"},
"code": {"advanced": True},
"input_value": {"display_name": "Input"},
"stream": {
"display_name": "Stream",
"info": "Stream the response from the model.",
},
}
def build(
@ -47,6 +51,7 @@ class AmazonBedrockComponent(CustomComponent):
endpoint_url: Optional[str] = None,
streaming: bool = False,
cache: Optional[bool] = None,
stream: bool = False,
) -> Text:
try:
output = BedrockChat(
@ -60,7 +65,10 @@ class AmazonBedrockComponent(CustomComponent):
) # type: ignore
except Exception as e:
raise ValueError("Could not connect to AmazonBedrock API.") from e
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
if stream:
result = output.stream(input_value)
else:
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
return result

View file

@ -50,6 +50,10 @@ class AnthropicLLM(CustomComponent):
},
"code": {"show": False},
"input_value": {"display_name": "Input"},
"stream": {
"display_name": "Stream",
"info": "Stream the response from the model.",
},
}
def build(
@ -60,6 +64,7 @@ class AnthropicLLM(CustomComponent):
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
api_endpoint: Optional[str] = None,
stream: bool = False,
) -> Text:
# Set default API endpoint if not provided
if not api_endpoint:
@ -77,7 +82,10 @@ class AnthropicLLM(CustomComponent):
)
except Exception as e:
raise ValueError("Could not connect to Anthropic API.") from e
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
if stream:
result = output.stream(input_value)
else:
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
return result

View file

@ -74,6 +74,10 @@ class AzureChatOpenAIComponent(CustomComponent):
},
"code": {"show": False},
"input_value": {"display_name": "Input"},
"stream": {
"display_name": "Stream",
"info": "Stream the response from the model.",
},
}
def build(
@ -86,6 +90,7 @@ class AzureChatOpenAIComponent(CustomComponent):
api_version: str,
temperature: float = 0.7,
max_tokens: Optional[int] = 1000,
stream: bool = False,
) -> BaseLanguageModel:
try:
output = AzureChatOpenAI(
@ -99,7 +104,10 @@ class AzureChatOpenAIComponent(CustomComponent):
)
except Exception as e:
raise ValueError("Could not connect to AzureOpenAI API.") from e
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
if stream:
result = output.stream(input_value)
else:
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
return result

View file

@ -69,6 +69,10 @@ class QianfanChatEndpointComponent(CustomComponent):
},
"code": {"show": False},
"input_value": {"display_name": "Input"},
"stream": {
"display_name": "Stream",
"info": "Stream the response from the model.",
},
}
def build(
@ -81,6 +85,7 @@ class QianfanChatEndpointComponent(CustomComponent):
temperature: Optional[float] = None,
penalty_score: Optional[float] = None,
endpoint: Optional[str] = None,
stream: bool = False,
) -> Text:
try:
output = QianfanChatEndpoint( # type: ignore
@ -94,7 +99,10 @@ class QianfanChatEndpointComponent(CustomComponent):
)
except Exception as e:
raise ValueError("Could not connect to Baidu Qianfan API.") from e
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
if stream:
result = output.stream(input_value)
else:
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
return result

View file

@ -29,6 +29,10 @@ class CTransformersComponent(CustomComponent):
"value": '{"top_k":40,"top_p":0.95,"temperature":0.8,"repetition_penalty":1.1,"last_n_tokens":64,"seed":-1,"max_new_tokens":256,"stop":"","stream":"False","reset":"True","batch_size":8,"threads":-1,"context_length":-1,"gpu_layers":0}',
},
"input_value": {"display_name": "Input"},
"stream": {
"display_name": "Stream",
"info": "Stream the response from the model.",
},
}
def build(
@ -38,11 +42,15 @@ class CTransformersComponent(CustomComponent):
input_value: str,
model_type: str,
config: Optional[Dict] = None,
stream: Optional[bool] = False,
) -> Text:
output = CTransformers(
model=model, model_file=model_file, model_type=model_type, config=config
)
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
if stream:
result = output.stream(input_value)
else:
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
return result

View file

@ -43,8 +43,10 @@ class CohereComponent(CustomComponent):
max_tokens=max_tokens,
temperature=temperature,
)
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
return result
return result
if stream:
result = output.stream(input_value)
else:
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
return result return result

View file

@ -70,7 +70,10 @@ class GoogleGenerativeAIComponent(CustomComponent):
n=n or 1,
google_api_key=SecretStr(google_api_key),
)
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
return result
if stream:
result = output.stream(input_value)
else:
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
return result

View file

@ -57,6 +57,10 @@ class LlamaCppComponent(CustomComponent):
"verbose": {"display_name": "Verbose", "advanced": True},
"vocab_only": {"display_name": "Vocab Only", "advanced": True},
"input_value": {"display_name": "Input"},
"stream": {
"display_name": "Stream",
"info": "Stream the response from the model.",
},
}
def build(
@ -97,6 +101,7 @@ class LlamaCppComponent(CustomComponent):
use_mmap: Optional[bool] = True,
verbose: bool = True,
vocab_only: bool = False,
stream: bool = False,
) -> Text:
output = LlamaCpp(
model_path=model_path,
@ -135,9 +140,10 @@ class LlamaCppComponent(CustomComponent):
verbose=verbose,
vocab_only=vocab_only,
)
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
return result
self.status = result
if stream:
result = output.stream(input_value)
else:
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
return result

View file

@ -165,6 +165,10 @@ class ChatOllamaComponent(CustomComponent):
"advanced": True,
},
"input_value": {"display_name": "Input"},
"stream": {
"display_name": "Stream",
"info": "Stream the response from the model.",
},
}
def build(
@ -197,6 +201,7 @@ class ChatOllamaComponent(CustomComponent):
timeout: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
stream: Optional[bool] = False,
) -> Text:
if not base_url:
base_url = "http://localhost:11434"
@ -250,7 +255,10 @@ class ChatOllamaComponent(CustomComponent):
output = ChatOllama(**llm_params) # type: ignore
except Exception as e:
raise ValueError("Could not initialize Ollama LLM.") from e
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
if stream:
result = output.stream(input_value)
else:
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
return result

View file

@ -57,6 +57,10 @@ class OpenAIModelComponent(CustomComponent):
"required": False,
"value": 0.7,
},
"stream": {
"display_name": "Stream",
"info": "Stream the response from the model.",
},
}
def build(
@ -68,10 +72,11 @@ class OpenAIModelComponent(CustomComponent):
openai_api_base: Optional[str] = None,
openai_api_key: Optional[str] = None,
temperature: float = 0.7,
stream: Optional[bool] = False,
) -> Text:
if not openai_api_base:
openai_api_base = "https://api.openai.com/v1"
model = ChatOpenAI(
output = ChatOpenAI(
max_tokens=max_tokens,
model_kwargs=model_kwargs,
model=model_name,
@ -79,8 +84,10 @@ class OpenAIModelComponent(CustomComponent):
api_key=openai_api_key,
temperature=temperature,
)
message = model.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
if stream:
result = output.stream(input_value)
else:
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
return result

View file

@ -58,6 +58,10 @@ class ChatVertexAIComponent(CustomComponent):
"advanced": True,
},
"input_value": {"display_name": "Input"},
"stream": {
"display_name": "Stream",
"info": "Stream the response from the model.",
},
}
def build(
@ -73,6 +77,7 @@ class ChatVertexAIComponent(CustomComponent):
top_k: int = 40,
top_p: float = 0.95,
verbose: bool = False,
stream: bool = False,
) -> Text:
try:
from langchain_google_vertexai import ChatVertexAI
@ -92,7 +97,10 @@ class ChatVertexAIComponent(CustomComponent):
top_p=top_p,
verbose=verbose,
)
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
if stream:
result = output.stream(input_value)
else:
message = output.invoke(input_value)
result = message.content if hasattr(message, "content") else message
self.status = result
return result