Spaces:
Sleeping
Sleeping
| """ | |
| GNN-based Medical RAG System - Main Application | |
| Entry point cho HuggingFace Spaces deployment. | |
| Pipeline đầy đủ: | |
| MedQuad Dataset | |
| → LlamaIndex PropertyGraphIndex | |
| → PyTorch Geometric graph | |
| → 2-layer GCN (message passing) | |
| → GNNHybridRetriever (α × semantic + (1-α) × structural) | |
| → Qwen3.5-4B GGUF via llama-cpp | |
| → Gradio UI | |
| """ | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # BOOTSTRAP — phải chạy TRƯỚC mọi import khác | |
| # | |
| # Vấn đề: KHÔNG có pre-built wheel nào từ PyPI / abetlen/releases hoạt động | |
| # trên HuggingFace Spaces vì tất cả đều dùng musl libc (Alpine build), | |
| # trong khi HF Spaces chạy Debian (glibc). | |
| # | |
| # Giải pháp: Query GitHub API để tìm manylinux wheel đúng version + python tag, | |
| # sau đó cài trực tiếp. Fallback về build từ source nếu không tìm được. | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| import subprocess | |
| import sys | |
| import json | |
| import urllib.request | |
| def _find_manylinux_wheel(py_tag: str, version: str = "0.3.19") -> str | None: | |
| """ | |
| Query GitHub API để tìm manylinux_x86_64 wheel cho đúng python tag. | |
| Thử repo eswarthammana/llama-cpp-wheels trước (chuyên build manylinux), | |
| sau đó fallback về abetlen/llama-cpp-python. | |
| """ | |
| repos = [ | |
| f"eswarthammana/llama-cpp-wheels", | |
| f"abetlen/llama-cpp-python", | |
| ] | |
| target_suffix = f"{py_tag}-{py_tag}-manylinux" | |
| for repo in repos: | |
| try: | |
| api_url = f"https://api.github.com/repos/{repo}/releases" | |
| req = urllib.request.Request( | |
| api_url, | |
| headers={"User-Agent": "python-bootstrap/1.0"}, | |
| ) | |
| with urllib.request.urlopen(req, timeout=15) as resp: | |
| releases = json.loads(resp.read()) | |
| for release in releases[:10]: # check 10 releases gần nhất | |
| tag = release.get("tag_name", "") | |
| # Tìm release có version khớp (hoặc gần nhất) | |
| if version not in tag and version.replace(".", "") not in tag.replace(".", ""): | |
| continue | |
| assets = release.get("assets", []) | |
| for asset in assets: | |
| name = asset.get("name", "") | |
| if target_suffix in name and name.endswith(".whl"): | |
| url = asset.get("browser_download_url", "") | |
| print(f"[bootstrap] Tìm thấy: {name} ({repo})") | |
| return url | |
| except Exception as e: | |
| print(f"[bootstrap] Không query được {repo}: {e}") | |
| continue | |
| return None | |
| def _install_wheel(url: str) -> bool: | |
| """Cài wheel từ URL, trả về True nếu thành công.""" | |
| result = subprocess.run( | |
| [ | |
| sys.executable, "-m", "pip", "install", | |
| "--quiet", | |
| "--force-reinstall", | |
| "--no-deps", | |
| url, | |
| ], | |
| capture_output=True, | |
| text=True, | |
| ) | |
| if result.returncode == 0: | |
| return True | |
| print(f"[bootstrap] pip lỗi: {result.stderr.strip()[:400]}") | |
| return False | |
| def _clear_llama_cache(): | |
| """Xoá module cache để import lại sạch sau khi cài.""" | |
| for mod in list(sys.modules.keys()): | |
| if mod.startswith("llama_cpp"): | |
| del sys.modules[mod] | |
| def _ensure_llama_cpp(): | |
| """ | |
| Đảm bảo llama-cpp-python hoạt động trên Debian/glibc (HF Spaces). | |
| Chiến lược: | |
| 1. Nếu đã import được → done | |
| 2. Query GitHub API tìm manylinux wheel → cài | |
| 3. Thử fallback URLs hardcode cho các version phổ biến | |
| 4. Build từ source (last resort) | |
| """ | |
| # ── Bước 1: Thử import ──────────────────────────────────────────────────── | |
| try: | |
| from llama_cpp import Llama # noqa: F401 | |
| print("[bootstrap] ✓ llama-cpp-python sẵn sàng (không cần cài lại).") | |
| return | |
| except (ImportError, OSError): | |
| pass | |
| v = sys.version_info | |
| py_tag = f"cp{v.major}{v.minor}" | |
| print(f"[bootstrap] llama-cpp-python chưa OK. Bắt đầu cài ({py_tag})...") | |
| # ── Bước 2: Query GitHub API tìm manylinux wheel ────────────────────────── | |
| print("[bootstrap] Đang tìm manylinux wheel qua GitHub API...") | |
| url = _find_manylinux_wheel(py_tag, version="0.3.19") | |
| if url and _install_wheel(url): | |
| _clear_llama_cache() | |
| print("[bootstrap] ✓ Cài thành công (GitHub API).") | |
| return | |
| # ── Bước 3: Thử fallback URLs hardcode ─────────────────────────────────── | |
| # Các version đã biết có manylinux wheel từ eswarthammana | |
| FALLBACK_VERSIONS = ["0.3.14", "0.3.10", "0.3.4"] | |
| FALLBACK_REPOS = [ | |
| "eswarthammana/llama-cpp-wheels", | |
| "mrzeeshanahmed/llama-cpp-python", # repo manylinux cộng đồng | |
| ] | |
| for ver in FALLBACK_VERSIONS: | |
| for repo_slug in FALLBACK_REPOS: | |
| fallback_url = ( | |
| f"https://github.com/{repo_slug}/releases/download/" | |
| f"v{ver}/" | |
| f"llama_cpp_python-{ver}-{py_tag}-{py_tag}-manylinux_x86_64.whl" | |
| ) | |
| print(f"[bootstrap] Thử: v{ver} từ {repo_slug}...") | |
| if _install_wheel(fallback_url): | |
| _clear_llama_cache() | |
| print(f"[bootstrap] ✓ Cài thành công (fallback v{ver}).") | |
| return | |
| # ── Bước 4: Build từ source ─────────────────────────────────────────────── | |
| print("[bootstrap] Tất cả pre-built wheel thất bại.") | |
| print("[bootstrap] Build từ source (có thể mất 5-10 phút)...") | |
| env = { | |
| **__import__("os").environ, | |
| "CMAKE_ARGS": "-DGGML_BLAS=OFF -DGGML_OPENMP=OFF", | |
| "FORCE_CMAKE": "1", | |
| } | |
| try: | |
| subprocess.run( | |
| [ | |
| sys.executable, "-m", "pip", "install", | |
| "--quiet", | |
| "llama-cpp-python==0.3.19", | |
| "--no-binary", "llama-cpp-python", | |
| ], | |
| check=True, | |
| env=env, | |
| ) | |
| _clear_llama_cache() | |
| print("[bootstrap] ✓ Build từ source thành công.") | |
| except subprocess.CalledProcessError as e: | |
| raise RuntimeError( | |
| "[bootstrap] FATAL: Không thể cài llama-cpp-python bằng bất kỳ cách nào. " | |
| "Kiểm tra log phía trên để biết chi tiết." | |
| ) from e | |
| # ─── GỌI BOOTSTRAP ─────────────────────────────────────────────────────────── | |
| _ensure_llama_cpp() | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # IMPORTS CHÍNH (sau khi bootstrap đã xong) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| import os | |
| import pickle | |
| import gradio as gr | |
| import torch | |
| from typing import Tuple | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from llama_index.core.schema import QueryBundle | |
| from utils.graph_builder import ( | |
| load_healthcare_documents, | |
| build_llamaindex_property_graph, | |
| build_pyg_graph, | |
| ) | |
| from utils.gnn_model import get_structural_embeddings | |
| from utils.retriever import GNNHybridRetriever | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # CONSTANTS | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| EMBED_MODEL_NAME = "BAAI/bge-small-en-v1.5" | |
| NUM_DOCS = 250 | |
| CACHE_PATH = "rag_cache.pkl" | |
| GNN_HIDDEN_DIM = 256 | |
| DEFAULT_ALPHA = 0.6 | |
| DEFAULT_TOP_K = 5 | |
| MODEL_REPO = "Jackrong/Qwen3.5-4B-Neo-GGUF" | |
| # MODEL_REPO = "Jackrong/Qwen3.5-4B-Claude-4.6-Opus-Reasoning-Distilledv2-GGUF" | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # INITIALIZATION | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| print("\n" + "=" * 60) | |
| print(" Initializing GNN-based Medical RAG System") | |
| print("=" * 60) | |
| # ── [1/5] Embedding Model ──────────────────────────────────────────────────── | |
| print("\n[1/5] Loading embedding model...") | |
| embed_model = HuggingFaceEmbedding( | |
| model_name=EMBED_MODEL_NAME, | |
| max_length=512, | |
| ) | |
| print(f" ✓ {EMBED_MODEL_NAME} (384-dim)") | |
| # ── [2/5] Build / Load Graph ───────────────────────────────────────────────── | |
| if os.path.exists(CACHE_PATH): | |
| print(f"\n[2/5] Loading cached graph từ {CACHE_PATH}...") | |
| with open(CACHE_PATH, "rb") as f: | |
| cache = pickle.load(f) | |
| pyg_data = cache["pyg_data"] | |
| nodes = cache["nodes"] | |
| raw_embs = cache["raw_embs"] | |
| print(f" ✓ {len(nodes)} nodes, {pyg_data.edge_index.shape[1]} edges") | |
| else: | |
| print(f"\n[2/5] Building knowledge graph từ MedQuad ({NUM_DOCS} docs)...") | |
| documents = load_healthcare_documents(num_samples=NUM_DOCS) | |
| index, text_nodes = build_llamaindex_property_graph(documents, embed_model) | |
| pyg_data, nodes, raw_embs = build_pyg_graph( | |
| text_nodes, | |
| embed_model, | |
| similarity_threshold=0.70, | |
| ) | |
| with open(CACHE_PATH, "wb") as f: | |
| pickle.dump({"pyg_data": pyg_data, "nodes": nodes, "raw_embs": raw_embs}, f) | |
| print(" ✓ Graph cached to disk") | |
| # ── [3/5] GNN Message Passing ──────────────────────────────────────────────── | |
| print("\n[3/5] Running GNN message passing...") | |
| in_dim = pyg_data.x.shape[1] | |
| struct_embs = get_structural_embeddings( | |
| pyg_data, | |
| in_channels=in_dim, | |
| hidden_channels=GNN_HIDDEN_DIM, | |
| ) | |
| # ── [4/5] Hybrid Retriever ──────────────────────────────────────────────────── | |
| print("\n[4/5] Initializing GNNHybridRetriever...") | |
| retriever = GNNHybridRetriever( | |
| nodes=nodes, | |
| raw_embeddings=raw_embs, | |
| structural_embeddings=struct_embs, | |
| embed_model=embed_model, | |
| alpha=DEFAULT_ALPHA, | |
| top_k=DEFAULT_TOP_K, | |
| ) | |
| print(f" ✓ Retriever ready (α={DEFAULT_ALPHA}, top_k={DEFAULT_TOP_K})") | |
| # ── [5/5] Load GGUF LLM ─────────────────────────────────────────────────────── | |
| print(f"\n[5/5] Loading GGUF model từ {MODEL_REPO}...") | |
| try: | |
| all_files = list(list_repo_files(MODEL_REPO)) | |
| gguf_files = [f for f in all_files if f.endswith(".gguf")] | |
| preferred = [f for f in gguf_files if "Q4_K_M" in f.upper()] | |
| model_filename = (preferred or gguf_files)[0] | |
| print(f" Downloading: {model_filename}") | |
| except Exception as e: | |
| print(f" Không list được files ({e}), dùng tên mặc định...") | |
| model_filename = "qwen3.5-4b-neo-Q4_K_M.gguf" | |
| model_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=model_filename, | |
| local_dir="./models", | |
| ) | |
| # Import lazy — sau khi bootstrap đảm bảo đúng wheel đã được cài | |
| from llama_cpp import Llama # noqa: E402 | |
| llm = Llama( | |
| model_path=model_path, | |
| n_ctx=2048, | |
| n_threads=os.cpu_count() or 4, | |
| n_gpu_layers=0, | |
| verbose=False, | |
| ) | |
| print(f" ✓ LLM loaded: {model_filename}") | |
| print("\n" + "=" * 60) | |
| print(" System ready!") | |
| print("=" * 60 + "\n") | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # RAG PIPELINE | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def rag_pipeline(question: str, alpha: float, top_k: int) -> Tuple[str, str]: | |
| """Full RAG pipeline: Retrieve → Build Context → Generate.""" | |
| if not question.strip(): | |
| return "Vui lòng nhập câu hỏi.", "" | |
| retriever.alpha = float(alpha) | |
| retriever.top_k = int(top_k) | |
| query = QueryBundle(query_str=question) | |
| results = retriever._retrieve(query) | |
| context_parts, source_lines = [], [] | |
| for i, node_score in enumerate(results): | |
| context_parts.append(f"[{i+1}] {node_score.node.text}") | |
| q_meta = node_score.node.metadata.get("question", "N/A") | |
| source_lines.append( | |
| f"{i+1}. Hybrid Score={node_score.score:.4f} | " | |
| f"Related Q: {q_meta[:90]}" | |
| ) | |
| context = "\n\n".join(context_parts) | |
| sources = "\n".join(source_lines) | |
| prompt = ( | |
| "You are a knowledgeable and accurate medical assistant. " | |
| "Answer the question using ONLY the provided context. " | |
| "Be concise, clear, and factual. " | |
| "If the context does not contain enough information, " | |
| "say 'I don't have enough information to answer this accurately.'\n\n" | |
| f"Context:\n{context}\n\n" | |
| f"Question: {question}\n\n" | |
| "Answer:" | |
| ) | |
| try: | |
| output = llm( | |
| prompt, | |
| max_tokens=400, | |
| temperature=0.1, | |
| top_p=0.9, | |
| repeat_penalty=1.1, | |
| stop=["Question:", "\n\nContext:", "\n\n\n"], | |
| echo=False, | |
| ) | |
| answer = output["choices"][0]["text"].strip() or "Không thể tạo câu trả lời." | |
| except Exception as e: | |
| answer = f"Lỗi khi generate: {str(e)}" | |
| return answer, sources | |
| def get_score_analysis(question: str, alpha: float) -> str: | |
| """Phân tích chi tiết hybrid scores.""" | |
| if not question.strip(): | |
| return "Nhập câu hỏi để phân tích." | |
| retriever.alpha = float(alpha) | |
| breakdown = retriever.get_score_breakdown(question, top_k=5) | |
| lines = ["📊 Score Breakdown (top 5 nodes):\n"] | |
| for item in breakdown: | |
| lines.append( | |
| f"Rank {item['rank']}: Hybrid={item['hybrid_score']:.4f} " | |
| f"| Semantic={item['semantic_score']:.4f} " | |
| f"| Structural={item['structural_score']:.4f}\n" | |
| f" Preview: {item['text_preview'][:100]}\n" | |
| ) | |
| return "\n".join(lines) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # GRADIO UI | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| with gr.Blocks(title="GNN-based Medical RAG", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🏥 GNN-based Medical RAG System | |
| Hệ thống hỏi đáp y tế sử dụng **Graph Neural Network-enhanced RAG**. | |
| | Component | Chi tiết | | |
| |-----------|----------| | |
| | **Dataset** | MedQuad Medical QA (HuggingFace) | | |
| | **Embeddings** | BAAI/bge-small-en-v1.5 (384-dim) | | |
| | **Graph** | Document cosine similarity graph | | |
| | **GNN** | 2-layer GCN – message passing enrichment | | |
| | **Retrieval** | α × semantic + (1-α) × structural | | |
| | **Generator** | Qwen3.5-4B-Neo GGUF via llama-cpp | | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| question_input = gr.Textbox( | |
| label="Medical Question", | |
| placeholder="e.g. What are the symptoms of Type 2 diabetes?", | |
| lines=3, | |
| ) | |
| with gr.Accordion("⚙️ Retrieval Settings", open=False): | |
| alpha_slider = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=DEFAULT_ALPHA, step=0.05, | |
| label="α — Semantic weight (1-α = Structural weight)", | |
| info="Tăng α → semantic | Giảm α → GNN structural", | |
| ) | |
| topk_slider = gr.Slider( | |
| minimum=1, maximum=10, value=DEFAULT_TOP_K, step=1, | |
| label="Top-K retrieved contexts", | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("🔍 Get Answer", variant="primary") | |
| analyze_btn = gr.Button("📊 Analyze Scores") | |
| with gr.Column(scale=3): | |
| answer_output = gr.Textbox(label="Answer", lines=8) | |
| sources_output = gr.Textbox(label="Retrieved Sources & Scores", lines=6) | |
| analysis_output = gr.Textbox(label="Score Analysis (optional)", lines=8) | |
| gr.Examples( | |
| examples=[ | |
| ["What are the symptoms of Type 2 diabetes?"], | |
| ["How is hypertension treated?"], | |
| ["What causes Alzheimer's disease?"], | |
| ["What is the recommended treatment for asthma?"], | |
| ["How does the immune system respond to infection?"], | |
| ["What are the risk factors for heart disease?"], | |
| ], | |
| inputs=question_input, | |
| label="Example Questions", | |
| ) | |
| submit_btn.click( | |
| fn=rag_pipeline, | |
| inputs=[question_input, alpha_slider, topk_slider], | |
| outputs=[answer_output, sources_output], | |
| ) | |
| analyze_btn.click( | |
| fn=get_score_analysis, | |
| inputs=[question_input, alpha_slider], | |
| outputs=analysis_output, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |