import gradio as gr import json import re from bs4 import BeautifulSoup from transformers import AutoTokenizer from vllm import LLM, SamplingParams # Load model and tokenizer MODEL_NAME = "tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3" SYS_CONTENT = ( "あなたは誠実で優秀な日本人の新聞記者です。質問には正確に具体的に答えることができます。" "入力される記事について,誰(who)が何(what)をいつ(when)どこ(where)でどうした(how)と書いてますか?" "次のJSONの値を埋めて返して下さい.どこ(where)には地図で示せるくらい具体的な地名や施設名を入れてください。" "もしも該当の情報が記事になければJSONの値を空にしてください。" "{ \"who\": \"...\", \"what\": \"...\", \"when\": \"...\", \"where\": \"...\", \"how\": \"...\"} " ) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) llm = LLM( model=MODEL_NAME, tensor_parallel_size=1, ) def preprocess_text(text: str) -> str: # HTMLタグの削除 soup = BeautifulSoup(text, 'html.parser') text = soup.get_text() # 独自タグの削除 (<...> ) text = re.sub(r'<[^>]+>', '', text) # 改行、タブ、余分な空白(半角・全角)の削除 text = re.sub(r'[\n\t]', '', text) text = re.sub(r'[\s ]+', ' ', text) # 連続する空白を1つの半角スペースに置換 text = text.strip() return text def inference(content: str, max_tokens: int, temperature: float, top_p: float): sampling_params = SamplingParams( temperature=temperature, top_p=top_p, max_tokens=max_tokens, stop="<|eot_id|>" ) # 入力テキストの前処理 processed_content = preprocess_text(content) message = [ { "role": "system", "content": SYS_CONTENT }, { "role": "user", "content": processed_content, }, ] try: prompt = tokenizer.apply_chat_template( message, tokenize=False, add_generation_prompt=True ) output = llm.generate(prompt, sampling_params) result_text = output[0].outputs[0].text # JSONを抽出 json_pattern = r'\{[^{}]*\}' match = re.search(json_pattern, result_text) if not match: return "エラー: 生成されたテキストからJSONが見つかりませんでした。" try: json_data = json.loads(match.group()) return json.dumps(json_data, ensure_ascii=False, indent=2) except json.JSONDecodeError as e: return f"JSONパースエラー: {str(e)}" except Exception as e: return f"生成エラー: {str(e)}" # Gradioインターフェースの作成 demo = gr.Interface( fn=inference, inputs=[ gr.Textbox(label="入力テキスト", lines=10), gr.Number(label="最大トークン数", value=512), gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.3, step=0.1), gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.1), ], outputs=gr.Textbox(label="解析結果", lines=10), title="意味解析エンジン", description="テキストを入力すると、5W(Who, What, When, Where, How)の形式で情報を抽出します.テキスト内に混入した改行や空白,独自タグ等を削除する整形処理を入れてますが,きちんとテストしていません.エラーが出る場合は事前に整形してからテキストを入れて下さい.", ) if __name__ == "__main__": demo.launch()