Ryanhub - file viewer
filename: assistant/server/api.go
branch: main
back to repo
package server

import (
	"context"
	"encoding/json"
	"fmt"
	"net/http"
	"runtime"
	"strings"
	"sync"
	"time"
	"unicode/utf8"

	"assistant/agent"
	"assistant/memory"
	"assistant/util"
)

type API struct {
	Agent              *agent.Agent
	Telemetry          *agent.Telemetry
	Store              *memory.Store
	Model              string
	ContextWindowChars int
	// ScratchpadPath is the absolute path to the shared Markdown file when scratchpad is enabled; empty otherwise.
	ScratchpadPath string
	mu               sync.Mutex
	lastUndonePrompt string
}

type askRequest struct {
	Prompt string `json:"prompt"`
}

type askResponse struct {
	Reply string `json:"reply"`
	Error string `json:"error,omitempty"`
}

func previewText(s string, max int) string {
	if max <= 0 || s == "" {
		return s
	}
	if utf8.RuneCountInString(s) <= max {
		return s
	}
	runes := []rune(s)
	if len(runes) > max {
		return string(runes[:max]) + "…"
	}
	return s
}

func (a *API) handleAsk(w http.ResponseWriter, r *http.Request) {
	if r.Method != http.MethodPost {
		http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
		return
	}
	var body askRequest
	if err := util.DecodeJSON(r.Body, 1<<20, &body); err != nil {
		_ = util.WriteJSON(w, http.StatusBadRequest, askResponse{Error: "invalid JSON body"})
		return
	}
	if body.Prompt == "" {
		_ = util.WriteJSON(w, http.StatusBadRequest, askResponse{Error: "prompt required"})
		return
	}
	ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute)
	defer cancel()
	if ok, reply, regenPrompt, err := a.runSlashCommand(ctx, strings.TrimSpace(body.Prompt)); ok {
		if err != nil {
			_ = util.WriteJSON(w, http.StatusBadRequest, askResponse{Error: err.Error()})
			return
		}
		if regenPrompt != "" {
			reply, err = a.Agent.Run(ctx, regenPrompt)
			if err != nil {
				_ = util.WriteJSON(w, http.StatusBadGateway, askResponse{Error: err.Error()})
				return
			}
		}
		_ = util.WriteJSON(w, http.StatusOK, askResponse{Reply: reply})
		return
	}
	if a.Agent == nil {
		_ = util.WriteJSON(w, http.StatusInternalServerError, askResponse{Error: "agent not configured"})
		return
	}
	reply, err := a.Agent.Run(ctx, body.Prompt)
	if err != nil {
		_ = util.WriteJSON(w, http.StatusBadGateway, askResponse{Error: err.Error()})
		return
	}
	_ = util.WriteJSON(w, http.StatusOK, askResponse{Reply: reply})
}

func (a *API) handleAskStream(w http.ResponseWriter, r *http.Request) {
	if r.Method != http.MethodPost {
		http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
		return
	}
	flusher, ok := w.(http.Flusher)
	if !ok {
		_ = util.WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "streaming unsupported"})
		return
	}
	var body askRequest
	if err := util.DecodeJSON(r.Body, 1<<20, &body); err != nil {
		_ = util.WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON body"})
		return
	}
	if body.Prompt == "" {
		_ = util.WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "prompt required"})
		return
	}

	w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
	w.Header().Set("Cache-Control", "no-cache")
	w.Header().Set("Connection", "keep-alive")
	w.Header().Set("X-Accel-Buffering", "no")

	ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute)
	defer cancel()

	send := func(ev agent.Event) error {
		b, err := json.Marshal(ev)
		if err != nil {
			return err
		}
		if _, err := fmt.Fprintf(w, "data: %s\n\n", b); err != nil {
			return err
		}
		flusher.Flush()
		return nil
	}

	if ok, reply, regenPrompt, err := a.runSlashCommand(ctx, strings.TrimSpace(body.Prompt)); ok {
		if err != nil {
			_ = send(agent.Event{Type: "error", Message: err.Error()})
			return
		}
		if regenPrompt != "" {
			_, _ = a.Agent.RunWithEvents(ctx, regenPrompt, func(ev agent.Event) {
				_ = send(ev)
			})
			return
		}
		_ = send(agent.Event{Type: "final", Text: reply})
		return
	}
	if a.Agent == nil {
		_ = send(agent.Event{Type: "error", Message: "agent not configured"})
		return
	}

	_, _ = a.Agent.RunWithEvents(ctx, body.Prompt, func(ev agent.Event) {
		_ = send(ev)
	})
}

// tryParseManualTool returns ("", "", nil) if prompt is not a /tool invocation.
func tryParseManualTool(prompt string) (name string, argsJSON string, err error) {
	p := strings.TrimSpace(prompt)
	if len(p) < 6 || !strings.EqualFold(p[:5], "/tool") {
		return "", "", nil
	}
	if strings.HasPrefix(strings.ToLower(p), "/tool-list") {
		return "", "", nil
	}
	if len(p) > 5 && p[5] != ' ' && p[5] != '\t' {
		return "", "", nil
	}
	if len(p) == 5 {
		return "", "", fmt.Errorf(`usage: /tool <tool_name> [arg=value ...]  see /tool-list for examples`)
	}
	return agent.ParseHumanToolInvocation(strings.TrimSpace(p[5:]))
}

func (a *API) runSlashCommand(ctx context.Context, prompt string) (handled bool, reply string, regeneratePrompt string, err error) {
	if prompt == "" || prompt[0] != '/' {
		return false, "", "", nil
	}
	prompt = strings.TrimSpace(prompt)

	if strings.EqualFold(prompt, "/tool-list") {
		if a.Agent == nil {
			return true, "", "", fmt.Errorf("agent not configured")
		}
		return true, a.Agent.ToolListText(), "", nil
	}

	if tName, tArgs, err := tryParseManualTool(prompt); err != nil {
		return true, "", "", err
	} else if tName != "" {
		if a.Agent == nil {
			return true, "", "", fmt.Errorf("agent not configured")
		}
		out, err := a.Agent.RunToolManual(ctx, tName, tArgs)
		if err != nil {
			return true, "", "", err
		}
		return true, out, "", nil
	}

	fields := strings.Fields(prompt)
	if len(fields) == 0 {
		return true, "", "", fmt.Errorf("empty command")
	}
	cmd := strings.ToLower(fields[0])
	switch cmd {
	case "/help":
		return true, strings.Join([]string{
			"Slash commands:",
			"- /undo: remove the most recent prompt+reply from session history",
			"- /regenerate: rerun the most recent prompt after undoing its previous answer",
			"- /clear: clear on-screen chat and all retained session history",
			"- /compact: keep recent session history and drop older messages",
			"- /clear-long: clear long-term memory store",
			"- /compact-long: dedupe / compact long-term memory store",
			"- /tool-list: human-readable tools, args, and example calls",
			"- /tool <name> [arg=value ...]: run a tool (quote values with spaces; JSON still allowed)",
			"- scratchpad: GET/PUT /api/scratchpad (full file); model uses scratchpad_read and scratchpad_write (append-only: newline + content at end)",
		}, "\n"), "", nil
	case "/undo":
		if a.Agent == nil {
			return true, "", "", fmt.Errorf("agent not configured")
		}
		prompt, ok := a.Agent.UndoLastTurn()
		if !ok {
			return true, "No prior turn to undo.", "", nil
		}
		a.mu.Lock()
		a.lastUndonePrompt = prompt
		a.mu.Unlock()
		backTo := strings.TrimSpace(a.Agent.LastUserPrompt())
		if backTo == "" {
			return true, "Undid most recent turn. You are now at the start of session history. Run /regenerate to retry that prompt.", "", nil
		}
		return true, fmt.Sprintf("Undid most recent turn. Back to: \"%s\". Run /regenerate to retry that prompt.", previewText(backTo, 70)), "", nil
	case "/regenerate":
		if a.Agent == nil {
			return true, "", "", fmt.Errorf("agent not configured")
		}
		a.mu.Lock()
		promptToRegen := strings.TrimSpace(a.lastUndonePrompt)
		a.lastUndonePrompt = ""
		a.mu.Unlock()
		if promptToRegen != "" {
			return true, "Regenerating last undone prompt…", promptToRegen, nil
		}
		// If no explicit undo happened, regenerate the latest turn by undoing it first.
		promptToRegen, ok := a.Agent.UndoLastTurn()
		if !ok || strings.TrimSpace(promptToRegen) == "" {
			return true, "", "", fmt.Errorf("no prior prompt available to regenerate")
		}
		return true, "Regenerating most recent prompt…", promptToRegen, nil
	case "/clear", "/clear-session":
		if a.Agent == nil {
			return true, "", "", fmt.Errorf("agent not configured")
		}
		n := a.Agent.ClearHistory()
		a.mu.Lock()
		a.lastUndonePrompt = ""
		a.mu.Unlock()
		return true, fmt.Sprintf("Cleared session history (%d messages removed).", n), "", nil
	case "/compact", "/compact-session":
		if a.Agent == nil {
			return true, "", "", fmt.Errorf("agent not configured")
		}
		n := a.Agent.CompactHistory()
		return true, fmt.Sprintf("Compacted session history (%d old messages removed).", n), "", nil
	case "/clear-long", "/clear-longterm":
		if a.Store == nil {
			return true, "", "", fmt.Errorf("memory store is not configured")
		}
		removed, err := a.Store.ClearMemories()
		if err != nil {
			return true, "", "", err
		}
		return true, fmt.Sprintf("Cleared long-term memory (%d entries removed).", removed), "", nil
	case "/compact-long", "/compact-longterm":
		if a.Store == nil {
			return true, "", "", fmt.Errorf("memory store is not configured")
		}
		removed, err := a.Store.CompactMemories()
		if err != nil {
			return true, "", "", err
		}
		return true, fmt.Sprintf("Compacted long-term memory (%d duplicate entries removed).", removed), "", nil
	default:
		return true, "", "", fmt.Errorf("unknown command: %s (try /help)", cmd)
	}
}

type statusResponse struct {
	agent.StatusSnapshot
	Model string `json:"model"`

	Memory struct {
		HeapAllocBytes uint64  `json:"heap_alloc_bytes"`
		HeapSysBytes   uint64  `json:"heap_sys_bytes"`
		HeapAllocMB    float64 `json:"heap_alloc_mb"`
		NumGC          uint64  `json:"num_gc"`
	} `json:"memory"`

	MemoryStore struct {
		Count              int64   `json:"count"`
		TotalChars         int64   `json:"total_chars"`
		LongCount          int64   `json:"long_count"`
		LongChars          int64   `json:"long_chars"`
		ContextWindowChars int64   `json:"context_window_chars"`
		UsagePct           float64 `json:"usage_pct"`
	} `json:"memory_store"`

	History struct {
		Count      int64 `json:"count"`
		TotalChars int64 `json:"total_chars"`
	} `json:"history"`

	ContextConsumption struct {
		UsedChars          int64   `json:"used_chars"`
		ContextWindowChars int64   `json:"context_window_chars"`
		UsagePct           float64 `json:"usage_pct"`
	} `json:"context_consumption"`

	CalendarRevision int64 `json:"calendar_revision"`

	Scratchpad struct {
		Enabled bool `json:"enabled"`
	} `json:"scratchpad"`
}

func (a *API) handleStatus(w http.ResponseWriter, r *http.Request) {
	if r.Method != http.MethodGet {
		http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
		return
	}

	var snap agent.StatusSnapshot
	if a.Telemetry != nil {
		snap = a.Telemetry.Snapshot()
	}

	var ms runtime.MemStats
	runtime.ReadMemStats(&ms)

	var resp statusResponse
	resp.StatusSnapshot = snap
	resp.Model = a.Model
	resp.Memory.HeapAllocBytes = ms.HeapAlloc
	resp.Memory.HeapSysBytes = ms.HeapSys
	resp.Memory.HeapAllocMB = float64(ms.HeapAlloc) / (1024 * 1024)
	resp.Memory.NumGC = uint64(ms.NumGC)
	if a.Agent != nil {
		hCount, hChars := a.Agent.HistoryStats()
		resp.History.Count = int64(hCount)
		resp.History.TotalChars = int64(hChars)
	}
	resp.Scratchpad.Enabled = strings.TrimSpace(a.ScratchpadPath) != ""

	if a.Store != nil {
		count, totalChars, err := a.Store.MemoryStats()
		if err == nil {
			resp.MemoryStore.LongCount = count
			resp.MemoryStore.LongChars = totalChars
			resp.MemoryStore.Count = count
			resp.MemoryStore.TotalChars = totalChars
		}
		if rev, err := a.Store.CalendarRevision(); err == nil {
			resp.CalendarRevision = rev
		}
	}
	resp.ContextConsumption.UsedChars = resp.History.TotalChars + resp.MemoryStore.TotalChars
	if a.ContextWindowChars > 0 {
		resp.MemoryStore.ContextWindowChars = int64(a.ContextWindowChars)
		resp.ContextConsumption.ContextWindowChars = int64(a.ContextWindowChars)
		if resp.MemoryStore.TotalChars > 0 {
			resp.MemoryStore.UsagePct = (float64(resp.MemoryStore.TotalChars) / float64(a.ContextWindowChars)) * 100
		}
		if resp.ContextConsumption.UsedChars > 0 {
			resp.ContextConsumption.UsagePct = (float64(resp.ContextConsumption.UsedChars) / float64(a.ContextWindowChars)) * 100
		}
	}

	_ = util.WriteJSON(w, http.StatusOK, resp)
}

// NotFound replies with JSON for unknown routes.
func (a *API) handleNotFound(w http.ResponseWriter, r *http.Request) {
	_ = util.WriteJSON(w, http.StatusNotFound, map[string]string{"error": "not found"})
}