Skip to content

Tool Middleware

Iris v0.11.0 introduces a powerful middleware system for tools. Middleware wraps tool execution to add cross-cutting concerns like logging, caching, rate limiting, validation, and circuit breakers without modifying tool implementations.

import (
"github.com/petal-labs/iris/tools"
"time"
)
// Create middleware chain
middleware := tools.Chain(
tools.WithLogging(logger),
tools.WithTimeout(30 * time.Second),
tools.WithRetry(tools.DefaultRetryConfig()),
)
// Apply to a single tool
wrappedTool := tools.ApplyMiddleware(myTool, middleware)
// Or apply to entire registry
registry := tools.NewRegistry()
registry.Use(middleware)
registry.Register(tool1)
registry.Register(tool2)

Log tool calls for debugging and observability:

// Basic logging (tool name + duration)
tools.WithLogging(logger)
// Detailed logging (includes arguments - use only in development)
tools.WithDetailedLogging(logger)
// Example with standard library logger
import "log"
middleware := tools.WithLogging(log.Default())
// Or use a custom logger implementing tools.Logger
type Logger interface {
Printf(format string, v ...any)
}

Output:

tool call start: search_logs
tool call success: search_logs, duration=120ms

Enforce execution time limits:

// 30 second timeout per tool call
tools.WithTimeout(30 * time.Second)

If the tool exceeds the timeout, it returns an error:

tool execution timeout after 30s

Control tool call frequency:

// Token bucket: 10 calls per second with burst of 20
tools.WithRateLimit(10.0)
// Custom rate limiter
limiter := NewMyRateLimiter()
tools.WithRateLimiter(limiter)

The rate limiter interface:

type RateLimiter interface {
Allow() bool // Returns true if request can proceed
Wait(ctx context.Context) error // Blocks until allowed or context canceled
}

Cache tool results to avoid redundant calls:

// In-memory cache with 5 minute TTL
cache := tools.NewMemoryCache()
tools.WithCache(cache, 5 * time.Minute)
// Custom cache key function
tools.WithCacheCustomKey(cache, 5 * time.Minute, func(toolName string, args json.RawMessage) string {
// Generate cache key from tool name and arguments
return fmt.Sprintf("%s:%x", toolName, sha256.Sum256(args))
})

The cache interface:

type Cache interface {
Get(key string) (any, bool)
Set(key string, value any, ttl time.Duration)
}

Validate arguments before execution:

// Basic JSON validation
tools.WithBasicValidation()
// Schema validation with custom validator
tools.WithValidation(schemaValidator)

Collect execution metrics:

type MetricsCollector interface {
RecordCall(toolName string, duration time.Duration, err error)
}
collector := NewPrometheusCollector()
tools.WithMetrics(collector)

Example Prometheus implementation:

type PrometheusCollector struct {
callDuration *prometheus.HistogramVec
callErrors *prometheus.CounterVec
}
func (c *PrometheusCollector) RecordCall(name string, duration time.Duration, err error) {
c.callDuration.WithLabelValues(name).Observe(duration.Seconds())
if err != nil {
c.callErrors.WithLabelValues(name).Inc()
}
}

Automatically retry failed tool calls:

// Default: 3 attempts, exponential backoff
tools.WithRetry(tools.DefaultRetryConfig())
// Custom configuration
tools.WithRetry(tools.RetryConfig{
MaxAttempts: 5,
InitialWait: 100 * time.Millisecond,
MaxWait: 10 * time.Second,
Multiplier: 2.0,
Retryable: func(err error) bool {
// Only retry transient errors
return errors.Is(err, context.DeadlineExceeded) ||
strings.Contains(err.Error(), "timeout")
},
})

Prevent cascading failures with circuit breaker pattern:

// Default: opens after 5 failures, closes after 2 successes in half-open
tools.WithCircuitBreaker(tools.DefaultCircuitBreakerConfig())
// Custom configuration
tools.WithCircuitBreaker(tools.CircuitBreakerConfig{
FailureThreshold: 3, // Open after 3 failures
SuccessThreshold: 2, // Close after 2 successes in half-open
OpenDuration: time.Minute, // Stay open for 1 minute
})

Circuit breaker states:

StateDescription
ClosedNormal operation, requests pass through
OpenFailing, requests immediately rejected with ErrCircuitOpen
Half-OpenTesting recovery, limited requests allowed

Apply middleware only to specific tools:

// Only apply rate limiting to API tools
tools.ForTools(
[]string{"search_api", "fetch_data"},
tools.WithRateLimit(5.0),
)
// Apply caching to all tools except those with side effects
tools.ExceptTools(
[]string{"send_email", "update_ticket", "notify_slack"},
tools.WithCache(cache, 5 * time.Minute),
)

Middleware executes in the order specified (first is outermost):

middleware := tools.Chain(
tools.WithLogging(logger), // 1. Log start
tools.WithMetrics(collector), // 2. Start timing
tools.WithCircuitBreaker(config), // 3. Check circuit
tools.WithRateLimit(10.0), // 4. Check rate limit
tools.WithTimeout(30*time.Second),// 5. Apply timeout
tools.WithRetry(retryConfig), // 6. Retry on failure
tools.WithCache(cache, 5*time.Minute), // 7. Check cache
)
// Tool executes here
// Then unwinds: cache → retry → timeout → rate → circuit → metrics → logging

★ Insight ───────────────────────────────────── Middleware Order Matters: Place logging/metrics outermost to capture all behavior. Place retry inside timeout so retries respect the overall time limit. Place cache innermost so cached results skip all other middleware. ─────────────────────────────────────────────────

Middleware can access and share data via ToolContext:

func MyMiddleware(next tools.ToolCallFunc) tools.ToolCallFunc {
return func(ctx context.Context, args json.RawMessage) (any, error) {
// Get tool context
tc := tools.ToolContextFromContext(ctx)
if tc != nil {
fmt.Printf("Tool: %s, Call ID: %s\n", tc.ToolName, tc.CallID)
// Share data with other middleware
tc.Metadata["request_id"] = generateRequestID()
}
return next(ctx, args)
}
}

ToolContext fields:

type ToolContext struct {
ToolName string // Name of the tool being called
CallID string // Unique identifier for this invocation
Iteration int // Agent loop iteration (if in agent context)
Metadata map[string]any // Shared data between middleware
}

Create your own middleware:

// Simple middleware that adds request tracing
func WithTracing(tracer Tracer) tools.Middleware {
return func(next tools.ToolCallFunc) tools.ToolCallFunc {
return func(ctx context.Context, args json.RawMessage) (any, error) {
tc := tools.ToolContextFromContext(ctx)
toolName := "unknown"
if tc != nil {
toolName = tc.ToolName
}
span := tracer.StartSpan("tool.call", toolName)
defer span.End()
result, err := next(ctx, args)
if err != nil {
span.SetError(err)
}
return result, err
}
}
}
// Middleware that sanitizes sensitive data in arguments
func WithArgumentSanitization(fields []string) tools.Middleware {
return func(next tools.ToolCallFunc) tools.ToolCallFunc {
return func(ctx context.Context, args json.RawMessage) (any, error) {
var data map[string]any
json.Unmarshal(args, &data)
for _, field := range fields {
if _, ok := data[field]; ok {
data[field] = "[REDACTED]"
}
}
sanitized, _ := json.Marshal(data)
return next(ctx, sanitized)
}
}
}

Apply middleware to all tools in a registry:

registry := tools.NewRegistry()
// Apply middleware to all registered tools
registry.Use(
tools.WithLogging(logger),
tools.WithTimeout(30 * time.Second),
)
// These tools automatically get the middleware
registry.Register(searchTool)
registry.Register(lookupTool)
registry.Register(notifyTool)
// Override middleware for specific tools
registry.RegisterWithMiddleware(
dangerousTool,
tools.WithRateLimit(1.0), // Extra rate limiting
)

Complete example with recommended middleware stack:

package main
import (
"context"
"log"
"time"
"github.com/petal-labs/iris/core"
"github.com/petal-labs/iris/providers/openai"
"github.com/petal-labs/iris/tools"
)
func main() {
provider, _ := openai.NewFromKeystore()
client := core.NewClient(provider)
// Create comprehensive middleware stack
commonMiddleware := tools.Chain(
// Observability (outermost)
tools.WithLogging(log.Default()),
tools.WithMetrics(metricsCollector),
// Resilience
tools.WithCircuitBreaker(tools.CircuitBreakerConfig{
FailureThreshold: 5,
SuccessThreshold: 2,
OpenDuration: 30 * time.Second,
}),
// Resource management
tools.WithTimeout(30 * time.Second),
tools.WithRetry(tools.RetryConfig{
MaxAttempts: 3,
InitialWait: 100 * time.Millisecond,
MaxWait: 5 * time.Second,
Multiplier: 2.0,
}),
)
// Create registry with middleware
registry := tools.NewRegistry()
registry.Use(commonMiddleware)
// Apply caching only to read-only tools
registry.Use(
tools.ForTools(
[]string{"search_database", "lookup_user", "get_weather"},
tools.WithCache(tools.NewMemoryCache(), 5*time.Minute),
),
)
// Apply rate limiting only to external APIs
registry.Use(
tools.ForTools(
[]string{"search_api", "translate_text"},
tools.WithRateLimit(10.0),
),
)
// Register tools
registry.Register(searchDatabaseTool)
registry.Register(lookupUserTool)
registry.Register(searchApiTool)
registry.Register(sendNotificationTool)
// Use with chat builder
resp, err := client.Chat("gpt-4o").
System("You are a helpful assistant.").
User("Find user john@example.com and search for related tickets").
Tools(registry.List()...).
GetResponse(context.Background())
if err != nil {
log.Fatal(err)
}
// Handle tool calls if present
if len(resp.ToolCalls) > 0 {
for _, tc := range resp.ToolCalls {
result, err := registry.Execute(context.Background(), tc.Name, tc.Arguments)
if err != nil {
log.Printf("Tool error: %v", err)
continue
}
log.Printf("Tool %s result: %v", tc.Name, result)
}
} else {
log.Println(resp.Output)
}
}
PracticeRecommendation
Middleware orderLogging → Metrics → Circuit → Rate → Timeout → Retry → Cache
Cache selectivelyOnly cache idempotent, read-only tools
Rate limit APIsApply rate limiting to external API calls
Circuit breakersUse for tools calling unreliable external services
Timeout < iterationTool timeout should be less than agent iteration timeout
Retry judiciouslyOnly retry transient errors, not validation failures

Agent Tools Example

See middleware in action with AgentRunner. Agent Tools →

Memory Guide

Manage conversation history. Memory →