# rag.py — Retriever simples baseado em FAISS + sentence-transformers
from __future__ import annotations
import os, json
from pathlib import Path
from typing import List, Dict, Any

import faiss
import numpy as np
from sentence_transformers import SentenceTransformer

BASE_DIR = Path(__file__).resolve().parent
INDEX_DIR = BASE_DIR / "indexes"

# Mesmo nome do modelo que usaste no ingest
EMB_MODEL = os.getenv("EMB_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2")

def _load_index(courseid: int):
    idx_path = INDEX_DIR / f"course_{courseid}.faiss"
    meta_path = INDEX_DIR / f"course_{courseid}.meta.json"
    if not idx_path.exists() or not meta_path.exists():
        raise FileNotFoundError(f"Índice do curso {courseid} não encontrado ({idx_path})")

    index = faiss.read_index(str(idx_path))
    with meta_path.open("r", encoding="utf-8") as f:
        meta = json.load(f)

    # meta["chunks"] é uma lista de objetos com "text", "source", "url"
    chunks = meta.get("chunks", [])
    return index, chunks

class Retriever:
    def __init__(self, courseid: int):
        self.courseid = courseid
        self.model = SentenceTransformer(EMB_MODEL)
        self.index, self.chunks = _load_index(courseid)

    def _embed(self, texts: List[str]) -> np.ndarray:
        vecs = self.model.encode(texts, normalize_embeddings=True)
        if isinstance(vecs, list):
            vecs = np.array(vecs, dtype="float32")
        return vecs.astype("float32")

    def search(self, query: str, k: int = 6) -> List[Dict[str, Any]]:
        qv = self._embed([query])
        # faiss.search requer shapes corretos
        D, I = self.index.search(qv, min(k, len(self.chunks)))
        I = I[0].tolist() if len(I) else []
        D = D[0].tolist() if len(D) else []
        out: List[Dict[str, Any]] = []
        for rank, (idx, dist) in enumerate(zip(I, D)):
            if idx < 0 or idx >= len(self.chunks):
                continue
            c = self.chunks[idx]
            out.append({
                "rank": rank + 1,
                "score": float(1.0 - dist) if isinstance(dist, (int, float)) else None,
                "text": c.get("text", ""),
                "source": c.get("source", ""),
                "url": c.get("url", ""),
            })
        return out

def build_prompt(course_name: str, question: str, passages: List[Dict[str, Any]]) -> str:
    # Concatenar trechos recuperados para o LLM
    ctx_lines = []
    for i, p in enumerate(passages, 1):
        ctx_lines.append(f"[{i}] {p.get('source','')} {p.get('url','')}\n{p.get('text','')}\n")
    context = "\n".join(ctx_lines).strip()

    prompt = (
        f"Curso: {course_name}\n\n"
        f"Contexto (trechos relevantes):\n{context}\n\n"
        f"Pergunta: {question}\n\n"
        "Responda de forma objetiva, usando apenas o contexto acima. "
        "Se faltar informação, explique o que falta e não invente.\n"
    )
    return prompt
