Ryanhub - file viewer
filename: chat/main.py
branch: main
back to repo
from model import call_model
from parser import parse_router_reply
from tools import TOOLS
import json
import os
import logging
import sys


def _load_text(path, default=""):
    agent_path = os.path.join("agent", os.path.basename(path))
    for p in (agent_path, path):
        try:
            with open(p, "r", encoding="utf-8") as f:
                return f.read()
        except Exception:
            continue
    return default


def _load_json(path, default=None):
    try:
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)
    except Exception:
        return default


CFG = _load_json(os.path.join("agent", "config.json"), {}) or {}
ROUTER_PROMPT = _load_text(os.path.join("prompts", "router.txt"))
CHAT_PROMPT = _load_text(os.path.join("prompts", "chat.txt"))
TOOLS_META = _load_json(os.path.join("agent", "tools.json"), {}) or {}

MAX_CHAT_HISTORY = 8
MAX_TASK_MEMORY = 20
MAX_STEPS = int(CFG.get("MAX_STEPS", 12))
PREVIEW = 500

CHAT_HISTORY = []
TASK_MEMORY = []
GOALS = []

LOG_LEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOG_LEVEL, format="[%(levelname)s] %(message)s")
logger = logging.getLogger("local-agent")


def _add_chat(role, content):
    CHAT_HISTORY.append({"role": role, "content": str(content)[:300]})
    CHAT_HISTORY[:] = CHAT_HISTORY[-MAX_CHAT_HISTORY:]


def _add_task_memory(tool, args, result):
    TASK_MEMORY.append({
        "tool": tool,
        "args": args,
        "result": str(result)[:4000],
    })
    TASK_MEMORY[:] = TASK_MEMORY[-MAX_TASK_MEMORY:]


def _safe_text(text):
    s = str(text)
    enc = sys.stdout.encoding or "utf-8"
    try:
        return s.encode(enc, errors="replace").decode(enc, errors="replace")
    except Exception:
        return s.encode("utf-8", errors="replace").decode("utf-8", errors="replace")


def _task_memory_summary():
    if not TASK_MEMORY:
        return ""
    return "\n".join(
        f"- {m['tool']}({m['args']}) -> {m['result'][:220]}"
        for m in TASK_MEMORY[-8:]
    )


def _goals_context():
    if not GOALS:
        return ""
    return "Goals (work toward these; return chat when satisfied or need user input):\n" + "\n".join(f"- {g}" for g in GOALS[-5:])


def _handle_goal_commands(user_input):
    """Parse goal-related commands. Returns (effective_user_input, skip_run). skip_run=True means do not run the agent."""
    raw = user_input.strip()
    lower = raw.lower()
    if lower.startswith("goal:"):
        goal_text = raw[5:].strip()
        GOALS.clear()
        GOALS.append(goal_text)
        return goal_text, False
    if lower.startswith("add goal:"):
        GOALS.append(raw[9:].strip())
        return raw, False
    if lower in ("goals", "show goals"):
        if not GOALS:
            print(_safe_text("No goals set. Use 'goal: <text>' or 'add goal: <text>'.\n"))
        else:
            print(_safe_text("Current goals:\n" + "\n".join(f"  {i+1}. {g}" for i, g in enumerate(GOALS)) + "\n"))
        return None, True
    if lower == "clear goals":
        GOALS.clear()
        print(_safe_text("Goals cleared.\n"))
        return None, True
    if lower == "continue" and GOALS:
        return "Continue working toward your goals. Take the next step.", False
    return raw, False


def _tool_catalog():
    use_for = {
        "list_files": "Use for: listing which files exist (directory listing).",
        "read_file": "Use for: reading file CONTENTS. Requires path in arguments.",
        "write_file": "Use for: creating/overwriting a file. Requires path and content in arguments.",
        "run_command": "Use for: running a shell command. Requires command in arguments.",
        "research": "Use for: research queries. Requires query in arguments.",
    }
    lines = []
    for t in TOOLS_META.get("tools", []):
        tid = t.get("id", "")
        line = f"- {tid}: {t.get('description')} params={t.get('params')}"
        if tid in use_for:
            line += f" {use_for[tid]}"
        lines.append(line)
    return "\n".join(lines)


def _call_router(user_text, task_summary, tool_error=None, last_route=None, goals_context=None):
    content = ROUTER_PROMPT
    if goals_context:
        content = goals_context + "\n\n" + content
    msgs = [
        {"role": "system", "content": content},
        {"role": "user", "content": user_text},
    ]

    if task_summary:
        msgs.append({"role": "system", "content": "Task memory:\n" + task_summary})
    if TOOLS_META:
        msgs.append({"role": "system", "content": "Available tools:\n" + _tool_catalog()})
    if tool_error and last_route:
        msgs.append({
            "role": "system",
            "content": (
                f"The last tool call failed: {tool_error}\n"
                f"You had returned: {json.dumps(last_route)}.\n"
                "Return a new JSON with the correct tool and required arguments filled from the user message (e.g. read_file needs {\"path\": \"...\"})."
            ),
        })

    for attempt in range(2):
        try:
            raw = call_model(msgs, task="router", temperature=0.0)
        except RuntimeError as e:
            return {"error": "model_error", "message": str(e)}
        logger.info("Router raw reply: %s", raw)
        parsed = parse_router_reply(raw)
        logger.info("Parsed route: %s", parsed)

        if isinstance(parsed, dict) and not parsed.get("error"):
            return parsed

        if attempt == 0 and isinstance(parsed, dict) and parsed.get("error") == "parse_error":
            err_msg = parsed.get("message", "invalid format")
            raw_preview = parsed.get("raw", raw)
            if isinstance(raw_preview, dict):
                raw_preview = json.dumps(raw_preview)
            retry_instruction = (
                f"Your previous reply was invalid: {err_msg}\n"
                f"Your reply was: {raw_preview}\n"
                "You MUST respond with exactly one JSON object with top-level key \"tool\" (string, the tool id) "
                "and optionally \"arguments\" (object). Tool parameters (path, content, etc.) go INSIDE \"arguments\", "
                "not as top-level keys. Example: {\"tool\": \"write_file\", \"arguments\": {\"path\": \"safe/new.txt\", \"content\": \"hello\"}}.\n"
                "Try again with the same user request."
            )
            msgs.append({"role": "system", "content": retry_instruction})
            continue

        return parsed

    return parsed


def _dispatch_tool(action, arguments):
    logger.info("Dispatching tool '%s' with arguments: %s", action, arguments)
    if action not in TOOLS:
        return f"[tool error] unknown tool: {action}"
    try:
        return TOOLS[action](**arguments)
    except Exception as e:
        return f"[tool error] {e}"


def _run_chat(user_input, task_summary, tool_context=None, goals_context=None):
    msgs = []
    if goals_context:
        msgs.append({"role": "system", "content": goals_context})
    if CHAT_PROMPT:
        msgs.append({"role": "system", "content": CHAT_PROMPT})
    msgs.extend(CHAT_HISTORY)
    if task_summary:
        msgs.append({"role": "system", "content": "Task memory:\n" + task_summary})
    if TOOLS_META:
        names = [t.get("id") for t in TOOLS_META.get("tools", [])]
        msgs.append({"role": "system", "content": "Available tools (routed externally): " + ", ".join(names)})
    if tool_context:
        msgs.append({"role": "system", "content": tool_context})
    msgs.append({"role": "user", "content": user_input})

    try:
        resp = call_model(msgs, task="chat")
    except RuntimeError as e:
        print(_safe_text(f"\n{e}\n"))
        return
    logger.info("Chat response: %s", (resp[:200] + "...") if len(resp) > 200 else resp)
    print(_safe_text("\n" + resp + "\n"))
    _add_chat("user", user_input)
    _add_chat("assistant", resp)


def _build_tool_context(turn_results):
    """Build one block of context from a list of {tool, args, result} dicts."""
    if not turn_results:
        return None
    blocks = []
    for t in turn_results:
        blocks.append(
            f"Tool called: {t['tool']}\nArguments: {json.dumps(t['args'])}\nOutput:\n{str(t['result'])[:2000]}"
        )
    return "\n\n---\n\n".join(blocks)


def main():
    print(_safe_text("Local Agent CLI (type 'exit' to quit)\n"))
    print(_safe_text("Goal mode: goal: <text> | add goal: <text> | goals | clear goals | continue\n"))

    while True:
        user_input = input("> ")
        if user_input.lower() in {"exit", "quit"}:
            break

        user_input, skip = _handle_goal_commands(user_input)
        if skip or user_input is None:
            continue

        goals_ctx = _goals_context()
        task_summary = _task_memory_summary()
        turn_tool_results = []
        steps = 0
        tool_error = None
        last_route = None
        tool_error_retries = 0

        while steps < MAX_STEPS:
            route = _call_router(user_input, task_summary, tool_error=tool_error, last_route=last_route, goals_context=goals_ctx or None)
            tool_error = None
            last_route = None

            if isinstance(route, dict) and route.get("error"):
                print(_safe_text(f"[router error] {route.get('message')}"))
                break

            action = route.get("tool")
            args = route.get("arguments", {})
            if not isinstance(args, dict):
                args = {}

            if action == "chat":
                _run_chat(user_input, task_summary, tool_context=_build_tool_context(turn_tool_results), goals_context=goals_ctx or None)
                break

            if action not in TOOLS:
                _run_chat(user_input, task_summary, tool_context=_build_tool_context(turn_tool_results), goals_context=goals_ctx or None)
                break

            if turn_tool_results:
                last = turn_tool_results[-1]
                if last["tool"] == action and last.get("args") == args:
                    _run_chat(user_input, task_summary, tool_context=_build_tool_context(turn_tool_results), goals_context=goals_ctx or None)
                    break

            result = _dispatch_tool(action, args)
            if isinstance(result, str) and result.startswith("[tool error]"):
                if tool_error_retries >= 2:
                    print(_safe_text(f"\n[tool error] {result}\n"))
                    _add_task_memory(action, args, result)
                    turn_tool_results.append({"tool": action, "args": args, "result": result})
                    task_summary = _task_memory_summary()
                    tool_error_retries = 0
                else:
                    tool_error = result
                    last_route = {"tool": action, "arguments": args}
                    tool_error_retries += 1
                continue
            tool_error_retries = 0

            print(_safe_text(f"\n[tool:{action}]\n{result}\n"))
            _add_task_memory(action, args, result)
            turn_tool_results.append({"tool": action, "args": args, "result": result})
            steps += 1
            task_summary = _task_memory_summary()

        else:
            _run_chat(user_input, task_summary, tool_context=_build_tool_context(turn_tool_results), goals_context=goals_ctx or None)


if __name__ == "__main__":
    main()