File size: 12,852 Bytes
ad7e512 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 |
import os
import shutil
import torch
from transformers import pipeline, AutoTokenizer, BitsAndBytesConfig
import gradio as gr
import logging
from pathlib import Path
import json
from typing import Tuple # generate_response の型ヒントのために追加
# ragsys03.py から RAGSystem クラスと QueryResult をインポート
from ragsys03 import RAGSystem, QueryResult
# ロギング設定
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# グローバル変数
rag_system: RAGSystem = None
llm_pipeline = None
# RAGデータ保存用のベースディレクトリ
RAG_BASE_DIR = Path('rag_data')
def initialize_systems():
"""
RAGシステムとLLMパイプラインを初期化します。
"""
global rag_system, llm_pipeline
if rag_system is None:
logger.info("RAGSystemを初期化中...")
try:
# 最新のインデックスをロードするか、新しいディレクトリを作成する
rag_system = RAGSystem(
model_name='all-MiniLM-L6-v2', # SentenceTransformer モデル
index_type='ivf', # または 'flat'
n_clusters=100, # IVFの場合のクラスタ数
index_base_dir=RAG_BASE_DIR,
load_latest=True # 最新のインデックスを自動的にロード
)
logger.info("RAGSystemの初期化が完了しました。")
# 初期化時にインデックスが存在しない場合、ここで構築を促す
if rag_system.index is None or rag_system.index.ntotal == 0:
logger.warning("RAGシステムにインデックスがありません。文書をアップロードして「インデックスを構築」ボタンを押してください。")
except Exception as e:
logger.critical(f"RAGSystemの初期化中にエラーが発生しました: {e}")
rag_system = None # 初期化失敗時はNoneに設定
# Gradio UIでエラーメッセージを表示するための処理を考慮するか、
# 各アクションでrag_systemがNoneの場合の処理を堅牢にする
# ここでは単にログに出力し、後続の関数でNoneチェックを行う
if llm_pipeline is None:
logger.info("LLMパイプラインを初期化中: rinna/japanese-gpt-neox-3.6b-instruction-sft")
try:
# 量子化を無効にした設定
llm_pipeline = pipeline(
"text-generation",
model="rinna/japanese-gpt-neox-3.6b-instruction-sft",
tokenizer=AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft", use_fast=False),
torch_dtype=torch.bfloat16, # CPUでも動くようにbfloat16のまま
device_map="auto", # autoのままでOK (GPUがなければCPUに自動的に割り当たる)
# model_kwargs={"quantization_config": quantization_config} を削除!
)
logger.info("LLMパイプラインの初期化が完了しました。")
except Exception as e:
logger.critical(f"LLMパイプラインの初期化中にエラーが発生しました: {e}")
llm_pipeline = None # 初期化失敗時はNoneに設定
# 同様に、エラー発生時はllm_pipelineをNoneのままにする
def generate_response(query: str, top_k: int, similarity_threshold: float) -> Tuple[str, str]:
"""
RAGシステムとLLMを使用して応答を生成します。
Args:
query (str): ユーザーからのクエリ。
top_k (int): 検索する文書の数。
similarity_threshold (float): コサイン類似度の閾値。
Returns:
Tuple[str, str]: LLMの回答と、検索された文書のテキスト。
"""
if rag_system is None or llm_pipeline is None:
return "システムが初期化されていません。アプリケーションを再起動してください。", ""
logger.info(f"クエリ処理開始: {query}")
query_result: QueryResult = rag_system.query(query, top_k=top_k, similarity_threshold=similarity_threshold)
llm_prompt = query_result['llm_prompt']
retrieved_docs = query_result['retrieved_documents']
search_time = query_result['search_time']
if not retrieved_docs:
llm_answer = "関連する情報が見つかりませんでした。別の表現で質問していただくか、より具体的な内容で質問してください。"
return llm_answer, "関連文書なし"
try:
outputs = llm_pipeline(
llm_prompt,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
num_return_sequences=1,
pad_token_id=llm_pipeline.tokenizer.eos_token_id # pad_token_idを指定
)
llm_answer = outputs[0]['generated_text']
# プロンプト部分を除去して回答のみを抽出 (モデルの出力形式による)
if llm_answer.startswith(llm_prompt):
llm_answer = llm_answer[len(llm_prompt):].strip()
logger.info(f"LLM応答生成完了 (検索時間: {search_time:.4f}秒)")
retrieved_docs_text_formatted = "\n\n".join([f"**文書{i+1}:** {doc}" for i, doc in enumerate(retrieved_docs)])
return llm_answer, retrieved_docs_text_formatted
except Exception as e:
logger.error(f"LLMによる応答生成中にエラーが発生しました: {e}")
return f"応答生成中にエラーが発生しました: {str(e)}", ""
def upload_documents(file):
"""
アップロードされたJSONファイルから文書をロードし、インデックスを再構築します。
"""
global rag_system
if rag_system is None:
# RAGSystemが未初期化の場合、ここで初期化を試みる
# initialize_systems() は、Gradioの起動時に一度しか呼ばれないため、
# ここでrag_systemがNoneの場合、初期化に失敗している可能性がある。
# 代わりに、エラーメッセージを返す
return "エラー: RAGシステムが初期化されていません。アプリケーションを再起動してください。", ""
if file is None:
return "エラー: ファイルがアップロードされていません。", ""
uploaded_file_path = Path(file.name)
logger.info(f"アップロードされたファイル: {uploaded_file_path}")
try:
# ragsys03.py の load_documents メソッドを使用
rag_system.load_documents(uploaded_file_path)
# インデックス構築のメッセージを表示し、ユーザーにボタンを押してもらう
return (
f"'{uploaded_file_path.name}' から {len(rag_system.documents)} 件の文書を正常にロードしました。\n"
"続けて「インデックスを構築」ボタンを押してください。",
"文書がロードされました。インデックス構築が必要です。"
)
except Exception as e:
logger.error(f"文書のロード中にエラーが発生しました: {e}")
return f"文書のロード中にエラーが発生しました: {str(e)}", ""
def build_index_action():
"""
RAGインデックスを構築します。
"""
global rag_system
if rag_system is None:
return "エラー: RAGシステムが初期化されていません。", ""
if not rag_system.documents:
return "エラー: インデックス作成用の文書がロードされていません。", ""
try:
# force_rebuild=True で常に再構築
rag_system.build_index(force_rebuild=True)
stats = rag_system.get_stats()
return (
f"インデックスが正常に構築されました。\n"
f"総文書数: {stats.get('documents_count', 'N/A')}\n"
f"インデックスタイプ: {stats.get('index_type', 'N/A')}\n"
f"インデックスディレクトリ: {stats.get('current_index_dir', 'N/A')}\n"
f"ファイルサイズ: {stats.get('index_file_size_mb', 'N/A')} MB"
), "インデックス構築完了"
except Exception as e:
logger.error(f"インデックスの構築中にエラーが発生しました: {e}")
return f"エラー: インデックスの構築中に問題が発生しました - {str(e)}", ""
def get_rag_stats_action():
"""RAGシステムの統計情報を表示します。"""
global rag_system
if rag_system is None:
return "RAGシステムはまだ初期化されていません。"
try:
stats = rag_system.get_stats()
stats_str = "\n".join([f"{k}: {v}" for k, v in stats.items()])
return f"RAGシステムの状態:\n{stats_str}"
except Exception as e:
logger.error(f"RAG統計情報の取得中にエラーが発生しました: {e}")
return f"統計情報の取得中にエラーが発生しました: {str(e)}"
# Gradio UI の構築
with gr.Blocks() as demo:
gr.Markdown(
"""
# RAG (Retrieval Augmented Generation) デモ
rinna/japanese-gpt-neox-3.6b-instruction-sft と FAISS を利用したRAGシステムです。
JSONファイルをアップロードして独自の知識ベースを構築し、質問をすることができます。
## 使用方法
1. 「文書をアップロード」セクションで、`documents` キーに文字列のリストを持つJSONファイルをアップロードします。
2. 「インデックスを構築」ボタンをクリックして、アップロードされた文書から検索インデックスを作成します。
3. 「RAG質問」セクションで質問を入力し、「質問を送信」ボタンをクリックします。
"""
)
with gr.Tab("RAG質問"):
gr.Markdown("### RAG質問")
with gr.Row():
query_input = gr.Textbox(label="質問を入力してください", placeholder="RAGシステムの主な利点は何ですか?", lines=2)
submit_button = gr.Button("質問を送信")
with gr.Row():
top_k_slider = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="取得文書数 (top_k)")
similarity_threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.05, label="類似度閾値 (0.0-1.0)")
llm_output = gr.Textbox(label="LLMの回答", lines=10, interactive=False)
retrieved_docs_output = gr.Markdown("検索された関連文書", visible=True)
with gr.Tab("文書管理"):
gr.Markdown("### 文書とインデックス管理")
with gr.Row():
file_upload_button = gr.File(label="文書JSONファイルをアップロード", file_types=[".json"])
upload_status = gr.Textbox(label="アップロードステータス", interactive=False)
with gr.Row():
build_index_button = gr.Button("インデックスを構築")
build_index_status = gr.Textbox(label="インデックス構築ステータス", interactive=False)
gr.Markdown("### RAGシステム情報")
get_stats_button = gr.Button("RAGシステムの状態を表示")
rag_stats_output = gr.Textbox(label="RAGシステム情報", interactive=False, lines=5)
# イベントハンドラの登録
submit_button.click(
fn=generate_response,
inputs=[query_input, top_k_slider, similarity_threshold_slider],
outputs=[llm_output, retrieved_docs_output]
)
file_upload_button.upload(
fn=upload_documents,
inputs=file_upload_button,
outputs=[upload_status, retrieved_docs_output] # アップロード結果と関連文書表示エリアを更新
)
build_index_button.click(
fn=build_index_action,
inputs=[],
outputs=[build_index_status, retrieved_docs_output] # 構築結果と関連文書表示エリアを更新
)
get_stats_button.click(
fn=get_rag_stats_action,
inputs=[],
outputs=rag_stats_output
)
# システムの初期化をアプリケーション起動時に実行
# initialize_systems() の中でエラーが発生した場合、rag_system や llm_pipeline が None になる
# これにより、後続の関数呼び出しでこれらの変数がNoneであることのチェックが機能する
initialize_systems()
if __name__ == "__main__":
demo.launch(debug=True, share=True)
|