fix: Add LanguageModel to field_typing module (#2410)

* feat: Add LanguageModel to field_typing module

* chore: Fix type annotations in model build methods

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-07-01 13:29:27 -03:00
commit 6a6bb3e13a
14 changed files with 31 additions and 38 deletions

View file

@ -143,13 +143,13 @@ class LCModelComponent(Component):
messages.append(HumanMessage(content=input_value))
inputs: Union[list, dict] = messages or {}
try:
runnable = runnable.with_config(
{"run_name": self.display_name, "project_name": self._tracing_service.project_name}
runnable = runnable.with_config( # type: ignore
{"run_name": self.display_name, "project_name": self._tracing_service.project_name} # type: ignore
)
if stream:
return runnable.stream(inputs)
return runnable.stream(inputs) # type: ignore
else:
message = runnable.invoke(inputs)
message = runnable.invoke(inputs) # type: ignore
result = message.content if hasattr(message, "content") else message
if isinstance(message, AIMessage):
status_message = self.build_status_message(message)

View file

@ -127,7 +127,7 @@ class ChatLiteLLMModelComponent(LCModelComponent):
Output(display_name="Language Model", name="model_output", method="build_model"),
]
def build_model(self) -> LanguageModel:
def build_model(self) -> LanguageModel: # type: ignore[type-var]
try:
import litellm # type: ignore
@ -176,5 +176,4 @@ class ChatLiteLLMModelComponent(LCModelComponent):
openrouter_api_key=api_keys["openrouter_api_key"],
)
return output
return output
return output # type: ignore

View file

@ -69,7 +69,7 @@ class AmazonBedrockComponent(LCModelComponent):
Output(display_name="Language Model", name="model_output", method="build_model"),
]
def build_model(self) -> LanguageModel:
def build_model(self) -> LanguageModel: # type: ignore[type-var]
model_id = self.model_id
credentials_profile_name = self.credentials_profile_name
region_name = self.region_name
@ -89,4 +89,4 @@ class AmazonBedrockComponent(LCModelComponent):
)
except Exception as e:
raise ValueError("Could not connect to AmazonBedrock API.") from e
return output
return output # type: ignore

View file

@ -64,7 +64,7 @@ class AnthropicModelComponent(LCModelComponent):
Output(display_name="Language Model", name="model_output", method="build_model"),
]
def build_model(self) -> LanguageModel:
def build_model(self) -> LanguageModel: # type: ignore[type-var]
model = self.model
anthropic_api_key = self.anthropic_api_key
max_tokens = self.max_tokens
@ -83,7 +83,7 @@ class AnthropicModelComponent(LCModelComponent):
except Exception as e:
raise ValueError("Could not connect to Anthropic API.") from e
return output
return output # type: ignore
def _get_exception_message(self, exception: Exception) -> str | None:
"""

View file

@ -78,7 +78,7 @@ class AzureChatOpenAIComponent(LCModelComponent):
Output(display_name="Language Model", name="model_output", method="model_response"),
]
def model_response(self) -> LanguageModel:
def model_response(self) -> LanguageModel: # type: ignore[type-var]
model = self.model
azure_endpoint = self.azure_endpoint
azure_deployment = self.azure_deployment
@ -107,4 +107,4 @@ class AzureChatOpenAIComponent(LCModelComponent):
except Exception as e:
raise ValueError("Could not connect to AzureOpenAI API.") from e
return output
return output # type: ignore

View file

@ -51,4 +51,4 @@ class CohereComponent(LCModelComponent):
cohere_api_key=api_key,
)
return output
return output # type: ignore

View file

@ -3,15 +3,7 @@ from pydantic.v1 import SecretStr
from langflow.base.constants import STREAM_INFO_TEXT
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import LanguageModel
from langflow.inputs import (
BoolInput,
DropdownInput,
FloatInput,
IntInput,
MessageInput,
SecretStrInput,
StrInput,
)
from langflow.inputs import BoolInput, DropdownInput, FloatInput, IntInput, MessageInput, SecretStrInput, StrInput
class GoogleGenerativeAIComponent(LCModelComponent):
@ -66,7 +58,7 @@ class GoogleGenerativeAIComponent(LCModelComponent):
),
]
def build_model(self) -> LanguageModel:
def build_model(self) -> LanguageModel: # type: ignore[type-var]
try:
from langchain_google_genai import ChatGoogleGenerativeAI
except ImportError:
@ -90,4 +82,4 @@ class GoogleGenerativeAIComponent(LCModelComponent):
google_api_key=SecretStr(google_api_key),
)
return output
return output # type: ignore

View file

@ -68,7 +68,7 @@ class GroqModel(LCModelComponent):
),
]
def build_model(self) -> LanguageModel:
def build_model(self) -> LanguageModel: # type: ignore[type-var]
groq_api_key = self.groq_api_key
model_name = self.model_name
max_tokens = self.max_tokens
@ -87,4 +87,4 @@ class GroqModel(LCModelComponent):
streaming=stream,
)
return output
return output # type: ignore

View file

@ -36,7 +36,7 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
Output(display_name="Language Model", name="model_output", method="build_model"),
]
def build_model(self) -> LanguageModel:
def build_model(self) -> LanguageModel: # type: ignore[type-var]
endpoint_url = self.endpoint_url
task = self.task
huggingfacehub_api_token = self.huggingfacehub_api_token
@ -53,4 +53,4 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
raise ValueError("Could not connect to HuggingFace Endpoints API.") from e
output = ChatHuggingFace(llm=llm)
return output
return output # type: ignore

View file

@ -70,7 +70,7 @@ class MistralAIModelComponent(LCModelComponent):
Output(display_name="Language Model", name="model_output", method="build_model"),
]
def build_model(self) -> LanguageModel:
def build_model(self) -> LanguageModel: # type: ignore[type-var]
mistral_api_key = self.mistral_api_key
temperature = self.temperature
model_name = self.model_name
@ -102,4 +102,4 @@ class MistralAIModelComponent(LCModelComponent):
safe_mode=safe_mode,
)
return output
return output # type: ignore

View file

@ -223,7 +223,7 @@ class ChatOllamaComponent(LCModelComponent):
Output(display_name="Language Model", name="model_output", method="build_model"),
]
def build_model(self) -> LanguageModel:
def build_model(self) -> LanguageModel: # type: ignore[type-var]
# Mapping mirostat settings to their corresponding values
mirostat_options = {"Mirostat": 1, "Mirostat 2.0": 2}
@ -272,4 +272,4 @@ class ChatOllamaComponent(LCModelComponent):
except Exception as e:
raise ValueError("Could not initialize Ollama LLM.") from e
return output
return output # type: ignore

View file

@ -80,8 +80,8 @@ class OpenAIModelComponent(LCModelComponent):
),
]
def build_model(self) -> LanguageModel:
# self.output_schea is a list of dictionaries
def build_model(self) -> LanguageModel: # type: ignore[type-var]
# self.output_schea is a list of dictionarie s
# let's convert it to a dictionary
output_schema_dict: dict[str, str] = reduce(operator.ior, self.output_schema or {}, {})
openai_api_key = self.openai_api_key
@ -112,7 +112,7 @@ class OpenAIModelComponent(LCModelComponent):
else:
output = output.bind(response_format={"type": "json_object"}) # type: ignore
return output
return output # type: ignore
def _get_exception_message(self, e: Exception):
"""

View file

@ -52,7 +52,7 @@ class ChatVertexAIComponent(LCModelComponent):
Output(display_name="Language Model", name="model_output", method="build_model"),
]
def build_model(self) -> LanguageModel:
def build_model(self) -> LanguageModel: # type: ignore[type-var]
credentials = self.credentials
location = self.location
max_output_tokens = self.max_output_tokens
@ -75,4 +75,4 @@ class ChatVertexAIComponent(LCModelComponent):
verbose=verbose,
)
return output
return output # type: ignore

View file

@ -26,6 +26,7 @@ from .constants import (
TextSplitter,
Tool,
VectorStore,
LanguageModel,
)
from .range_spec import RangeSpec
@ -84,4 +85,5 @@ __all__ = [
"BaseChatModel",
"Retriever",
"Text",
"LanguageModel",
]