Lab3 / app.py
NamPhoenix's picture
Update app.py
da91b61 verified
"""
GNN-based Medical RAG System - Main Application
Entry point cho HuggingFace Spaces deployment.
Pipeline đầy đủ:
MedQuad Dataset
→ LlamaIndex PropertyGraphIndex
→ PyTorch Geometric graph
→ 2-layer GCN (message passing)
→ GNNHybridRetriever (α × semantic + (1-α) × structural)
→ Qwen3.5-4B GGUF via llama-cpp
→ Gradio UI
"""
# ══════════════════════════════════════════════════════════════════════════════
# BOOTSTRAP — phải chạy TRƯỚC mọi import khác
#
# Vấn đề: KHÔNG có pre-built wheel nào từ PyPI / abetlen/releases hoạt động
# trên HuggingFace Spaces vì tất cả đều dùng musl libc (Alpine build),
# trong khi HF Spaces chạy Debian (glibc).
#
# Giải pháp: Query GitHub API để tìm manylinux wheel đúng version + python tag,
# sau đó cài trực tiếp. Fallback về build từ source nếu không tìm được.
# ══════════════════════════════════════════════════════════════════════════════
import subprocess
import sys
import json
import urllib.request
def _find_manylinux_wheel(py_tag: str, version: str = "0.3.19") -> str | None:
"""
Query GitHub API để tìm manylinux_x86_64 wheel cho đúng python tag.
Thử repo eswarthammana/llama-cpp-wheels trước (chuyên build manylinux),
sau đó fallback về abetlen/llama-cpp-python.
"""
repos = [
f"eswarthammana/llama-cpp-wheels",
f"abetlen/llama-cpp-python",
]
target_suffix = f"{py_tag}-{py_tag}-manylinux"
for repo in repos:
try:
api_url = f"https://api.github.com/repos/{repo}/releases"
req = urllib.request.Request(
api_url,
headers={"User-Agent": "python-bootstrap/1.0"},
)
with urllib.request.urlopen(req, timeout=15) as resp:
releases = json.loads(resp.read())
for release in releases[:10]: # check 10 releases gần nhất
tag = release.get("tag_name", "")
# Tìm release có version khớp (hoặc gần nhất)
if version not in tag and version.replace(".", "") not in tag.replace(".", ""):
continue
assets = release.get("assets", [])
for asset in assets:
name = asset.get("name", "")
if target_suffix in name and name.endswith(".whl"):
url = asset.get("browser_download_url", "")
print(f"[bootstrap] Tìm thấy: {name} ({repo})")
return url
except Exception as e:
print(f"[bootstrap] Không query được {repo}: {e}")
continue
return None
def _install_wheel(url: str) -> bool:
"""Cài wheel từ URL, trả về True nếu thành công."""
result = subprocess.run(
[
sys.executable, "-m", "pip", "install",
"--quiet",
"--force-reinstall",
"--no-deps",
url,
],
capture_output=True,
text=True,
)
if result.returncode == 0:
return True
print(f"[bootstrap] pip lỗi: {result.stderr.strip()[:400]}")
return False
def _clear_llama_cache():
"""Xoá module cache để import lại sạch sau khi cài."""
for mod in list(sys.modules.keys()):
if mod.startswith("llama_cpp"):
del sys.modules[mod]
def _ensure_llama_cpp():
"""
Đảm bảo llama-cpp-python hoạt động trên Debian/glibc (HF Spaces).
Chiến lược:
1. Nếu đã import được → done
2. Query GitHub API tìm manylinux wheel → cài
3. Thử fallback URLs hardcode cho các version phổ biến
4. Build từ source (last resort)
"""
# ── Bước 1: Thử import ────────────────────────────────────────────────────
try:
from llama_cpp import Llama # noqa: F401
print("[bootstrap] ✓ llama-cpp-python sẵn sàng (không cần cài lại).")
return
except (ImportError, OSError):
pass
v = sys.version_info
py_tag = f"cp{v.major}{v.minor}"
print(f"[bootstrap] llama-cpp-python chưa OK. Bắt đầu cài ({py_tag})...")
# ── Bước 2: Query GitHub API tìm manylinux wheel ──────────────────────────
print("[bootstrap] Đang tìm manylinux wheel qua GitHub API...")
url = _find_manylinux_wheel(py_tag, version="0.3.19")
if url and _install_wheel(url):
_clear_llama_cache()
print("[bootstrap] ✓ Cài thành công (GitHub API).")
return
# ── Bước 3: Thử fallback URLs hardcode ───────────────────────────────────
# Các version đã biết có manylinux wheel từ eswarthammana
FALLBACK_VERSIONS = ["0.3.14", "0.3.10", "0.3.4"]
FALLBACK_REPOS = [
"eswarthammana/llama-cpp-wheels",
"mrzeeshanahmed/llama-cpp-python", # repo manylinux cộng đồng
]
for ver in FALLBACK_VERSIONS:
for repo_slug in FALLBACK_REPOS:
fallback_url = (
f"https://github.com/{repo_slug}/releases/download/"
f"v{ver}/"
f"llama_cpp_python-{ver}-{py_tag}-{py_tag}-manylinux_x86_64.whl"
)
print(f"[bootstrap] Thử: v{ver} từ {repo_slug}...")
if _install_wheel(fallback_url):
_clear_llama_cache()
print(f"[bootstrap] ✓ Cài thành công (fallback v{ver}).")
return
# ── Bước 4: Build từ source ───────────────────────────────────────────────
print("[bootstrap] Tất cả pre-built wheel thất bại.")
print("[bootstrap] Build từ source (có thể mất 5-10 phút)...")
env = {
**__import__("os").environ,
"CMAKE_ARGS": "-DGGML_BLAS=OFF -DGGML_OPENMP=OFF",
"FORCE_CMAKE": "1",
}
try:
subprocess.run(
[
sys.executable, "-m", "pip", "install",
"--quiet",
"llama-cpp-python==0.3.19",
"--no-binary", "llama-cpp-python",
],
check=True,
env=env,
)
_clear_llama_cache()
print("[bootstrap] ✓ Build từ source thành công.")
except subprocess.CalledProcessError as e:
raise RuntimeError(
"[bootstrap] FATAL: Không thể cài llama-cpp-python bằng bất kỳ cách nào. "
"Kiểm tra log phía trên để biết chi tiết."
) from e
# ─── GỌI BOOTSTRAP ───────────────────────────────────────────────────────────
_ensure_llama_cpp()
# ══════════════════════════════════════════════════════════════════════════════
# IMPORTS CHÍNH (sau khi bootstrap đã xong)
# ══════════════════════════════════════════════════════════════════════════════
import os
import pickle
import gradio as gr
import torch
from typing import Tuple
from huggingface_hub import hf_hub_download, list_repo_files
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.schema import QueryBundle
from utils.graph_builder import (
load_healthcare_documents,
build_llamaindex_property_graph,
build_pyg_graph,
)
from utils.gnn_model import get_structural_embeddings
from utils.retriever import GNNHybridRetriever
# ══════════════════════════════════════════════════════════════════════════════
# CONSTANTS
# ══════════════════════════════════════════════════════════════════════════════
EMBED_MODEL_NAME = "BAAI/bge-small-en-v1.5"
NUM_DOCS = 250
CACHE_PATH = "rag_cache.pkl"
GNN_HIDDEN_DIM = 256
DEFAULT_ALPHA = 0.6
DEFAULT_TOP_K = 5
MODEL_REPO = "Jackrong/Qwen3.5-4B-Neo-GGUF"
# MODEL_REPO = "Jackrong/Qwen3.5-4B-Claude-4.6-Opus-Reasoning-Distilledv2-GGUF"
# ══════════════════════════════════════════════════════════════════════════════
# INITIALIZATION
# ══════════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 60)
print(" Initializing GNN-based Medical RAG System")
print("=" * 60)
# ── [1/5] Embedding Model ────────────────────────────────────────────────────
print("\n[1/5] Loading embedding model...")
embed_model = HuggingFaceEmbedding(
model_name=EMBED_MODEL_NAME,
max_length=512,
)
print(f" ✓ {EMBED_MODEL_NAME} (384-dim)")
# ── [2/5] Build / Load Graph ─────────────────────────────────────────────────
if os.path.exists(CACHE_PATH):
print(f"\n[2/5] Loading cached graph từ {CACHE_PATH}...")
with open(CACHE_PATH, "rb") as f:
cache = pickle.load(f)
pyg_data = cache["pyg_data"]
nodes = cache["nodes"]
raw_embs = cache["raw_embs"]
print(f" ✓ {len(nodes)} nodes, {pyg_data.edge_index.shape[1]} edges")
else:
print(f"\n[2/5] Building knowledge graph từ MedQuad ({NUM_DOCS} docs)...")
documents = load_healthcare_documents(num_samples=NUM_DOCS)
index, text_nodes = build_llamaindex_property_graph(documents, embed_model)
pyg_data, nodes, raw_embs = build_pyg_graph(
text_nodes,
embed_model,
similarity_threshold=0.70,
)
with open(CACHE_PATH, "wb") as f:
pickle.dump({"pyg_data": pyg_data, "nodes": nodes, "raw_embs": raw_embs}, f)
print(" ✓ Graph cached to disk")
# ── [3/5] GNN Message Passing ────────────────────────────────────────────────
print("\n[3/5] Running GNN message passing...")
in_dim = pyg_data.x.shape[1]
struct_embs = get_structural_embeddings(
pyg_data,
in_channels=in_dim,
hidden_channels=GNN_HIDDEN_DIM,
)
# ── [4/5] Hybrid Retriever ────────────────────────────────────────────────────
print("\n[4/5] Initializing GNNHybridRetriever...")
retriever = GNNHybridRetriever(
nodes=nodes,
raw_embeddings=raw_embs,
structural_embeddings=struct_embs,
embed_model=embed_model,
alpha=DEFAULT_ALPHA,
top_k=DEFAULT_TOP_K,
)
print(f" ✓ Retriever ready (α={DEFAULT_ALPHA}, top_k={DEFAULT_TOP_K})")
# ── [5/5] Load GGUF LLM ───────────────────────────────────────────────────────
print(f"\n[5/5] Loading GGUF model từ {MODEL_REPO}...")
try:
all_files = list(list_repo_files(MODEL_REPO))
gguf_files = [f for f in all_files if f.endswith(".gguf")]
preferred = [f for f in gguf_files if "Q4_K_M" in f.upper()]
model_filename = (preferred or gguf_files)[0]
print(f" Downloading: {model_filename}")
except Exception as e:
print(f" Không list được files ({e}), dùng tên mặc định...")
model_filename = "qwen3.5-4b-neo-Q4_K_M.gguf"
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=model_filename,
local_dir="./models",
)
# Import lazy — sau khi bootstrap đảm bảo đúng wheel đã được cài
from llama_cpp import Llama # noqa: E402
llm = Llama(
model_path=model_path,
n_ctx=2048,
n_threads=os.cpu_count() or 4,
n_gpu_layers=0,
verbose=False,
)
print(f" ✓ LLM loaded: {model_filename}")
print("\n" + "=" * 60)
print(" System ready!")
print("=" * 60 + "\n")
# ══════════════════════════════════════════════════════════════════════════════
# RAG PIPELINE
# ══════════════════════════════════════════════════════════════════════════════
def rag_pipeline(question: str, alpha: float, top_k: int) -> Tuple[str, str]:
"""Full RAG pipeline: Retrieve → Build Context → Generate."""
if not question.strip():
return "Vui lòng nhập câu hỏi.", ""
retriever.alpha = float(alpha)
retriever.top_k = int(top_k)
query = QueryBundle(query_str=question)
results = retriever._retrieve(query)
context_parts, source_lines = [], []
for i, node_score in enumerate(results):
context_parts.append(f"[{i+1}] {node_score.node.text}")
q_meta = node_score.node.metadata.get("question", "N/A")
source_lines.append(
f"{i+1}. Hybrid Score={node_score.score:.4f} | "
f"Related Q: {q_meta[:90]}"
)
context = "\n\n".join(context_parts)
sources = "\n".join(source_lines)
prompt = (
"You are a knowledgeable and accurate medical assistant. "
"Answer the question using ONLY the provided context. "
"Be concise, clear, and factual. "
"If the context does not contain enough information, "
"say 'I don't have enough information to answer this accurately.'\n\n"
f"Context:\n{context}\n\n"
f"Question: {question}\n\n"
"Answer:"
)
try:
output = llm(
prompt,
max_tokens=400,
temperature=0.1,
top_p=0.9,
repeat_penalty=1.1,
stop=["Question:", "\n\nContext:", "\n\n\n"],
echo=False,
)
answer = output["choices"][0]["text"].strip() or "Không thể tạo câu trả lời."
except Exception as e:
answer = f"Lỗi khi generate: {str(e)}"
return answer, sources
def get_score_analysis(question: str, alpha: float) -> str:
"""Phân tích chi tiết hybrid scores."""
if not question.strip():
return "Nhập câu hỏi để phân tích."
retriever.alpha = float(alpha)
breakdown = retriever.get_score_breakdown(question, top_k=5)
lines = ["📊 Score Breakdown (top 5 nodes):\n"]
for item in breakdown:
lines.append(
f"Rank {item['rank']}: Hybrid={item['hybrid_score']:.4f} "
f"| Semantic={item['semantic_score']:.4f} "
f"| Structural={item['structural_score']:.4f}\n"
f" Preview: {item['text_preview'][:100]}\n"
)
return "\n".join(lines)
# ══════════════════════════════════════════════════════════════════════════════
# GRADIO UI
# ══════════════════════════════════════════════════════════════════════════════
with gr.Blocks(title="GNN-based Medical RAG", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🏥 GNN-based Medical RAG System
Hệ thống hỏi đáp y tế sử dụng **Graph Neural Network-enhanced RAG**.
| Component | Chi tiết |
|-----------|----------|
| **Dataset** | MedQuad Medical QA (HuggingFace) |
| **Embeddings** | BAAI/bge-small-en-v1.5 (384-dim) |
| **Graph** | Document cosine similarity graph |
| **GNN** | 2-layer GCN – message passing enrichment |
| **Retrieval** | α × semantic + (1-α) × structural |
| **Generator** | Qwen3.5-4B-Neo GGUF via llama-cpp |
""")
with gr.Row():
with gr.Column(scale=2):
question_input = gr.Textbox(
label="Medical Question",
placeholder="e.g. What are the symptoms of Type 2 diabetes?",
lines=3,
)
with gr.Accordion("⚙️ Retrieval Settings", open=False):
alpha_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=DEFAULT_ALPHA, step=0.05,
label="α — Semantic weight (1-α = Structural weight)",
info="Tăng α → semantic | Giảm α → GNN structural",
)
topk_slider = gr.Slider(
minimum=1, maximum=10, value=DEFAULT_TOP_K, step=1,
label="Top-K retrieved contexts",
)
with gr.Row():
submit_btn = gr.Button("🔍 Get Answer", variant="primary")
analyze_btn = gr.Button("📊 Analyze Scores")
with gr.Column(scale=3):
answer_output = gr.Textbox(label="Answer", lines=8)
sources_output = gr.Textbox(label="Retrieved Sources & Scores", lines=6)
analysis_output = gr.Textbox(label="Score Analysis (optional)", lines=8)
gr.Examples(
examples=[
["What are the symptoms of Type 2 diabetes?"],
["How is hypertension treated?"],
["What causes Alzheimer's disease?"],
["What is the recommended treatment for asthma?"],
["How does the immune system respond to infection?"],
["What are the risk factors for heart disease?"],
],
inputs=question_input,
label="Example Questions",
)
submit_btn.click(
fn=rag_pipeline,
inputs=[question_input, alpha_slider, topk_slider],
outputs=[answer_output, sources_output],
)
analyze_btn.click(
fn=get_score_analysis,
inputs=[question_input, alpha_slider],
outputs=analysis_output,
)
if __name__ == "__main__":
demo.launch()