Tool Input Schema
By default, tools infer the argument schema by inspecting the function signature. For more strict requirements, custom input schema can be specified, along with custom validation logic.
from typing import Any, Dict
from langchain.agents import AgentType, initialize_agent
from langchain.llms import OpenAI
from langchain.tools.requests.tool import RequestsGetTool, TextRequestsWrapper
from pydantic import BaseModel, Field, root_validator
llm = OpenAI(temperature=0)
!pip install tldextract > /dev/null
[notice] A new release of pip is available: 23.0.1 -> 23.1
[notice] To update, run: pip install --upgrade pip
import tldextract
_APPROVED_DOMAINS = {
"langchain",
"wikipedia",
}
class ToolInputSchema(BaseModel):
url: str = Field(...)
@root_validator
def validate_query(cls, values: Dict[str, Any]) -> Dict:
url = values["url"]
domain = tldextract.extract(url).domain
if domain not in _APPROVED_DOMAINS:
raise ValueError(
f"Domain {domain} is not on the approved list:"
f" {sorted(_APPROVED_DOMAINS)}"
)
return values
tool = RequestsGetTool(
args_schema=ToolInputSchema, requests_wrapper=TextRequestsWrapper()
)
agent = initialize_agent(
[tool], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False
)
# This will succeed, since there aren't any arguments that will be triggered during validation
answer = agent.run("What's the main title on langchain.com?")
print(answer)
The main title of langchain.com is "LANG CHAIN 🦜️🔗 Official Home Page"
agent.run("What's the main title on google.com?")
ValidationError: 1 validation error for ToolInputSchema
__root__
Domain google is not on the approved list: ['langchain', 'wikipedia'] (type=value_error)