""" GNN-based Medical RAG System - Main Application Entry point cho HuggingFace Spaces deployment. Pipeline đầy đủ: MedQuad Dataset → LlamaIndex PropertyGraphIndex → PyTorch Geometric graph → 2-layer GCN (message passing) → GNNHybridRetriever (α × semantic + (1-α) × structural) → Qwen3.5-4B GGUF via llama-cpp → Gradio UI """ # ══════════════════════════════════════════════════════════════════════════════ # BOOTSTRAP — phải chạy TRƯỚC mọi import khác # # Vấn đề: KHÔNG có pre-built wheel nào từ PyPI / abetlen/releases hoạt động # trên HuggingFace Spaces vì tất cả đều dùng musl libc (Alpine build), # trong khi HF Spaces chạy Debian (glibc). # # Giải pháp: Query GitHub API để tìm manylinux wheel đúng version + python tag, # sau đó cài trực tiếp. Fallback về build từ source nếu không tìm được. # ══════════════════════════════════════════════════════════════════════════════ import subprocess import sys import json import urllib.request def _find_manylinux_wheel(py_tag: str, version: str = "0.3.19") -> str | None: """ Query GitHub API để tìm manylinux_x86_64 wheel cho đúng python tag. Thử repo eswarthammana/llama-cpp-wheels trước (chuyên build manylinux), sau đó fallback về abetlen/llama-cpp-python. """ repos = [ f"eswarthammana/llama-cpp-wheels", f"abetlen/llama-cpp-python", ] target_suffix = f"{py_tag}-{py_tag}-manylinux" for repo in repos: try: api_url = f"https://api.github.com/repos/{repo}/releases" req = urllib.request.Request( api_url, headers={"User-Agent": "python-bootstrap/1.0"}, ) with urllib.request.urlopen(req, timeout=15) as resp: releases = json.loads(resp.read()) for release in releases[:10]: # check 10 releases gần nhất tag = release.get("tag_name", "") # Tìm release có version khớp (hoặc gần nhất) if version not in tag and version.replace(".", "") not in tag.replace(".", ""): continue assets = release.get("assets", []) for asset in assets: name = asset.get("name", "") if target_suffix in name and name.endswith(".whl"): url = asset.get("browser_download_url", "") print(f"[bootstrap] Tìm thấy: {name} ({repo})") return url except Exception as e: print(f"[bootstrap] Không query được {repo}: {e}") continue return None def _install_wheel(url: str) -> bool: """Cài wheel từ URL, trả về True nếu thành công.""" result = subprocess.run( [ sys.executable, "-m", "pip", "install", "--quiet", "--force-reinstall", "--no-deps", url, ], capture_output=True, text=True, ) if result.returncode == 0: return True print(f"[bootstrap] pip lỗi: {result.stderr.strip()[:400]}") return False def _clear_llama_cache(): """Xoá module cache để import lại sạch sau khi cài.""" for mod in list(sys.modules.keys()): if mod.startswith("llama_cpp"): del sys.modules[mod] def _ensure_llama_cpp(): """ Đảm bảo llama-cpp-python hoạt động trên Debian/glibc (HF Spaces). Chiến lược: 1. Nếu đã import được → done 2. Query GitHub API tìm manylinux wheel → cài 3. Thử fallback URLs hardcode cho các version phổ biến 4. Build từ source (last resort) """ # ── Bước 1: Thử import ──────────────────────────────────────────────────── try: from llama_cpp import Llama # noqa: F401 print("[bootstrap] ✓ llama-cpp-python sẵn sàng (không cần cài lại).") return except (ImportError, OSError): pass v = sys.version_info py_tag = f"cp{v.major}{v.minor}" print(f"[bootstrap] llama-cpp-python chưa OK. Bắt đầu cài ({py_tag})...") # ── Bước 2: Query GitHub API tìm manylinux wheel ────────────────────────── print("[bootstrap] Đang tìm manylinux wheel qua GitHub API...") url = _find_manylinux_wheel(py_tag, version="0.3.19") if url and _install_wheel(url): _clear_llama_cache() print("[bootstrap] ✓ Cài thành công (GitHub API).") return # ── Bước 3: Thử fallback URLs hardcode ─────────────────────────────────── # Các version đã biết có manylinux wheel từ eswarthammana FALLBACK_VERSIONS = ["0.3.14", "0.3.10", "0.3.4"] FALLBACK_REPOS = [ "eswarthammana/llama-cpp-wheels", "mrzeeshanahmed/llama-cpp-python", # repo manylinux cộng đồng ] for ver in FALLBACK_VERSIONS: for repo_slug in FALLBACK_REPOS: fallback_url = ( f"https://github.com/{repo_slug}/releases/download/" f"v{ver}/" f"llama_cpp_python-{ver}-{py_tag}-{py_tag}-manylinux_x86_64.whl" ) print(f"[bootstrap] Thử: v{ver} từ {repo_slug}...") if _install_wheel(fallback_url): _clear_llama_cache() print(f"[bootstrap] ✓ Cài thành công (fallback v{ver}).") return # ── Bước 4: Build từ source ─────────────────────────────────────────────── print("[bootstrap] Tất cả pre-built wheel thất bại.") print("[bootstrap] Build từ source (có thể mất 5-10 phút)...") env = { **__import__("os").environ, "CMAKE_ARGS": "-DGGML_BLAS=OFF -DGGML_OPENMP=OFF", "FORCE_CMAKE": "1", } try: subprocess.run( [ sys.executable, "-m", "pip", "install", "--quiet", "llama-cpp-python==0.3.19", "--no-binary", "llama-cpp-python", ], check=True, env=env, ) _clear_llama_cache() print("[bootstrap] ✓ Build từ source thành công.") except subprocess.CalledProcessError as e: raise RuntimeError( "[bootstrap] FATAL: Không thể cài llama-cpp-python bằng bất kỳ cách nào. " "Kiểm tra log phía trên để biết chi tiết." ) from e # ─── GỌI BOOTSTRAP ─────────────────────────────────────────────────────────── _ensure_llama_cpp() # ══════════════════════════════════════════════════════════════════════════════ # IMPORTS CHÍNH (sau khi bootstrap đã xong) # ══════════════════════════════════════════════════════════════════════════════ import os import pickle import gradio as gr import torch from typing import Tuple from huggingface_hub import hf_hub_download, list_repo_files from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.core.schema import QueryBundle from utils.graph_builder import ( load_healthcare_documents, build_llamaindex_property_graph, build_pyg_graph, ) from utils.gnn_model import get_structural_embeddings from utils.retriever import GNNHybridRetriever # ══════════════════════════════════════════════════════════════════════════════ # CONSTANTS # ══════════════════════════════════════════════════════════════════════════════ EMBED_MODEL_NAME = "BAAI/bge-small-en-v1.5" NUM_DOCS = 250 CACHE_PATH = "rag_cache.pkl" GNN_HIDDEN_DIM = 256 DEFAULT_ALPHA = 0.6 DEFAULT_TOP_K = 5 MODEL_REPO = "Jackrong/Qwen3.5-4B-Neo-GGUF" # MODEL_REPO = "Jackrong/Qwen3.5-4B-Claude-4.6-Opus-Reasoning-Distilledv2-GGUF" # ══════════════════════════════════════════════════════════════════════════════ # INITIALIZATION # ══════════════════════════════════════════════════════════════════════════════ print("\n" + "=" * 60) print(" Initializing GNN-based Medical RAG System") print("=" * 60) # ── [1/5] Embedding Model ──────────────────────────────────────────────────── print("\n[1/5] Loading embedding model...") embed_model = HuggingFaceEmbedding( model_name=EMBED_MODEL_NAME, max_length=512, ) print(f" ✓ {EMBED_MODEL_NAME} (384-dim)") # ── [2/5] Build / Load Graph ───────────────────────────────────────────────── if os.path.exists(CACHE_PATH): print(f"\n[2/5] Loading cached graph từ {CACHE_PATH}...") with open(CACHE_PATH, "rb") as f: cache = pickle.load(f) pyg_data = cache["pyg_data"] nodes = cache["nodes"] raw_embs = cache["raw_embs"] print(f" ✓ {len(nodes)} nodes, {pyg_data.edge_index.shape[1]} edges") else: print(f"\n[2/5] Building knowledge graph từ MedQuad ({NUM_DOCS} docs)...") documents = load_healthcare_documents(num_samples=NUM_DOCS) index, text_nodes = build_llamaindex_property_graph(documents, embed_model) pyg_data, nodes, raw_embs = build_pyg_graph( text_nodes, embed_model, similarity_threshold=0.70, ) with open(CACHE_PATH, "wb") as f: pickle.dump({"pyg_data": pyg_data, "nodes": nodes, "raw_embs": raw_embs}, f) print(" ✓ Graph cached to disk") # ── [3/5] GNN Message Passing ──────────────────────────────────────────────── print("\n[3/5] Running GNN message passing...") in_dim = pyg_data.x.shape[1] struct_embs = get_structural_embeddings( pyg_data, in_channels=in_dim, hidden_channels=GNN_HIDDEN_DIM, ) # ── [4/5] Hybrid Retriever ──────────────────────────────────────────────────── print("\n[4/5] Initializing GNNHybridRetriever...") retriever = GNNHybridRetriever( nodes=nodes, raw_embeddings=raw_embs, structural_embeddings=struct_embs, embed_model=embed_model, alpha=DEFAULT_ALPHA, top_k=DEFAULT_TOP_K, ) print(f" ✓ Retriever ready (α={DEFAULT_ALPHA}, top_k={DEFAULT_TOP_K})") # ── [5/5] Load GGUF LLM ─────────────────────────────────────────────────────── print(f"\n[5/5] Loading GGUF model từ {MODEL_REPO}...") try: all_files = list(list_repo_files(MODEL_REPO)) gguf_files = [f for f in all_files if f.endswith(".gguf")] preferred = [f for f in gguf_files if "Q4_K_M" in f.upper()] model_filename = (preferred or gguf_files)[0] print(f" Downloading: {model_filename}") except Exception as e: print(f" Không list được files ({e}), dùng tên mặc định...") model_filename = "qwen3.5-4b-neo-Q4_K_M.gguf" model_path = hf_hub_download( repo_id=MODEL_REPO, filename=model_filename, local_dir="./models", ) # Import lazy — sau khi bootstrap đảm bảo đúng wheel đã được cài from llama_cpp import Llama # noqa: E402 llm = Llama( model_path=model_path, n_ctx=2048, n_threads=os.cpu_count() or 4, n_gpu_layers=0, verbose=False, ) print(f" ✓ LLM loaded: {model_filename}") print("\n" + "=" * 60) print(" System ready!") print("=" * 60 + "\n") # ══════════════════════════════════════════════════════════════════════════════ # RAG PIPELINE # ══════════════════════════════════════════════════════════════════════════════ def rag_pipeline(question: str, alpha: float, top_k: int) -> Tuple[str, str]: """Full RAG pipeline: Retrieve → Build Context → Generate.""" if not question.strip(): return "Vui lòng nhập câu hỏi.", "" retriever.alpha = float(alpha) retriever.top_k = int(top_k) query = QueryBundle(query_str=question) results = retriever._retrieve(query) context_parts, source_lines = [], [] for i, node_score in enumerate(results): context_parts.append(f"[{i+1}] {node_score.node.text}") q_meta = node_score.node.metadata.get("question", "N/A") source_lines.append( f"{i+1}. Hybrid Score={node_score.score:.4f} | " f"Related Q: {q_meta[:90]}" ) context = "\n\n".join(context_parts) sources = "\n".join(source_lines) prompt = ( "You are a knowledgeable and accurate medical assistant. " "Answer the question using ONLY the provided context. " "Be concise, clear, and factual. " "If the context does not contain enough information, " "say 'I don't have enough information to answer this accurately.'\n\n" f"Context:\n{context}\n\n" f"Question: {question}\n\n" "Answer:" ) try: output = llm( prompt, max_tokens=400, temperature=0.1, top_p=0.9, repeat_penalty=1.1, stop=["Question:", "\n\nContext:", "\n\n\n"], echo=False, ) answer = output["choices"][0]["text"].strip() or "Không thể tạo câu trả lời." except Exception as e: answer = f"Lỗi khi generate: {str(e)}" return answer, sources def get_score_analysis(question: str, alpha: float) -> str: """Phân tích chi tiết hybrid scores.""" if not question.strip(): return "Nhập câu hỏi để phân tích." retriever.alpha = float(alpha) breakdown = retriever.get_score_breakdown(question, top_k=5) lines = ["📊 Score Breakdown (top 5 nodes):\n"] for item in breakdown: lines.append( f"Rank {item['rank']}: Hybrid={item['hybrid_score']:.4f} " f"| Semantic={item['semantic_score']:.4f} " f"| Structural={item['structural_score']:.4f}\n" f" Preview: {item['text_preview'][:100]}\n" ) return "\n".join(lines) # ══════════════════════════════════════════════════════════════════════════════ # GRADIO UI # ══════════════════════════════════════════════════════════════════════════════ with gr.Blocks(title="GNN-based Medical RAG", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🏥 GNN-based Medical RAG System Hệ thống hỏi đáp y tế sử dụng **Graph Neural Network-enhanced RAG**. | Component | Chi tiết | |-----------|----------| | **Dataset** | MedQuad Medical QA (HuggingFace) | | **Embeddings** | BAAI/bge-small-en-v1.5 (384-dim) | | **Graph** | Document cosine similarity graph | | **GNN** | 2-layer GCN – message passing enrichment | | **Retrieval** | α × semantic + (1-α) × structural | | **Generator** | Qwen3.5-4B-Neo GGUF via llama-cpp | """) with gr.Row(): with gr.Column(scale=2): question_input = gr.Textbox( label="Medical Question", placeholder="e.g. What are the symptoms of Type 2 diabetes?", lines=3, ) with gr.Accordion("⚙️ Retrieval Settings", open=False): alpha_slider = gr.Slider( minimum=0.0, maximum=1.0, value=DEFAULT_ALPHA, step=0.05, label="α — Semantic weight (1-α = Structural weight)", info="Tăng α → semantic | Giảm α → GNN structural", ) topk_slider = gr.Slider( minimum=1, maximum=10, value=DEFAULT_TOP_K, step=1, label="Top-K retrieved contexts", ) with gr.Row(): submit_btn = gr.Button("🔍 Get Answer", variant="primary") analyze_btn = gr.Button("📊 Analyze Scores") with gr.Column(scale=3): answer_output = gr.Textbox(label="Answer", lines=8) sources_output = gr.Textbox(label="Retrieved Sources & Scores", lines=6) analysis_output = gr.Textbox(label="Score Analysis (optional)", lines=8) gr.Examples( examples=[ ["What are the symptoms of Type 2 diabetes?"], ["How is hypertension treated?"], ["What causes Alzheimer's disease?"], ["What is the recommended treatment for asthma?"], ["How does the immune system respond to infection?"], ["What are the risk factors for heart disease?"], ], inputs=question_input, label="Example Questions", ) submit_btn.click( fn=rag_pipeline, inputs=[question_input, alpha_slider, topk_slider], outputs=[answer_output, sources_output], ) analyze_btn.click( fn=get_score_analysis, inputs=[question_input, alpha_slider], outputs=analysis_output, ) if __name__ == "__main__": demo.launch()