eri64 commited on
Commit
c9f8489
·
verified ·
1 Parent(s): 97ff569

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -22
app.py CHANGED
@@ -7,31 +7,38 @@ import faiss
7
  from sentence_transformers import SentenceTransformer
8
  from huggingface_hub import InferenceClient
9
 
10
- # ===================== 📚 RAG 模块 =====================
 
 
11
 
 
12
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
13
- EMBEDDING_PATH = "embeddings.index"
14
- CHUNKS_PATH = "text_chunks.json"
 
 
 
 
 
 
 
 
15
 
16
- def load_or_create_embeddings(novel_files):
17
- """上传txt文件 -> 切分 -> 生成embedding -> 存储"""
18
  model = SentenceTransformer(EMBED_MODEL_NAME)
19
  texts = []
20
 
21
- for file_path in novel_files:
22
  with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
23
  text = f.read()
24
-
25
- # 分段(每段约500字)
26
  chunk_size = 500
27
  chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
28
  texts.extend(chunks)
29
 
30
- # 生成向量
31
  embeddings = model.encode(texts, show_progress_bar=True, normalize_embeddings=True)
32
  embeddings = np.array(embeddings, dtype=np.float32)
33
 
34
- # 保存索引
35
  index = faiss.IndexFlatIP(embeddings.shape[1])
36
  index.add(embeddings)
37
  faiss.write_index(index, EMBEDDING_PATH)
@@ -41,14 +48,22 @@ def load_or_create_embeddings(novel_files):
41
 
42
  return f"✅ 已建立知识库,共 {len(texts)} 段文本。"
43
 
 
 
 
 
 
 
 
 
 
44
  def search_similar(query, top_k=5):
45
- """根据用户输入检索最相关的小说片段"""
46
- if not os.path.exists(EMBEDDING_PATH) or not os.path.exists(CHUNKS_PATH):
47
  return ""
48
 
49
  model = SentenceTransformer(EMBED_MODEL_NAME)
50
  index = faiss.read_index(EMBEDDING_PATH)
51
-
52
  with open(CHUNKS_PATH, "r", encoding="utf-8") as f:
53
  texts = json.load(f)
54
 
@@ -58,8 +73,9 @@ def search_similar(query, top_k=5):
58
 
59
  return "\n".join(results)
60
 
61
-
62
- # ===================== 💬 聊天逻辑 =====================
 
63
 
64
  def respond(
65
  message,
@@ -79,7 +95,7 @@ def respond(
79
  messages = [{"role": "system", "content": system_message}]
80
  messages.extend(history)
81
 
82
- # === 🔍 检索相关小说文本 ===
83
  related_text = search_similar(message, top_k=5)
84
  if related_text:
85
  message = f"以下是参考小说片段:\n{related_text}\n\n请模仿上文风格续写或回答:{message}"
@@ -87,7 +103,6 @@ def respond(
87
  messages.append({"role": "user", "content": message})
88
 
89
  response = ""
90
-
91
  for chunk in client.chat_completion(
92
  messages,
93
  max_tokens=max_tokens,
@@ -99,16 +114,41 @@ def respond(
99
  token = ""
100
  if len(choices) and choices[0].delta.content:
101
  token = choices[0].delta.content
102
-
103
- # ⚙️ 清除异常标记
104
  token = re.sub(r"<\|im_(start|end)\|>", "", token)
105
  token = re.sub(r"(?i)<\|assistant\|>", "", token)
106
-
107
  response += token
108
  yield re.sub(r"<\|im_(start|end)\|>", "", response).strip()
109
 
110
 
111
- # ===================== 🧱 Gradio UI =====================
 
 
112
 
113
  chatbot = gr.ChatInterface(
114
- res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from sentence_transformers import SentenceTransformer
8
  from huggingface_hub import InferenceClient
9
 
10
+ # ======================================================
11
+ # 📚 RAG 知识库模块
12
+ # ======================================================
13
 
14
+ KNOWLEDGE_DIR = "knowledge" # 存放小说txt的文件夹
15
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
16
+ EMBEDDING_PATH = os.path.join(KNOWLEDGE_DIR, "embeddings.index")
17
+ CHUNKS_PATH = os.path.join(KNOWLEDGE_DIR, "text_chunks.json")
18
+
19
+ def build_knowledge_base():
20
+ """从 space/knowledge/*.txt 构建知识库(仅首次执行)"""
21
+ os.makedirs(KNOWLEDGE_DIR, exist_ok=True)
22
+ txt_files = [os.path.join(KNOWLEDGE_DIR, f) for f in os.listdir(KNOWLEDGE_DIR) if f.endswith(".txt")]
23
+ if not txt_files:
24
+ print("⚠️ 未检测到小说文件,请将小说txt放入 space/knowledge/ 文件夹。")
25
+ return "❌ 未检测到小说文件,请先上传至 knowledge/。"
26
 
 
 
27
  model = SentenceTransformer(EMBED_MODEL_NAME)
28
  texts = []
29
 
30
+ for file_path in txt_files:
31
  with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
32
  text = f.read()
33
+ # 分段
 
34
  chunk_size = 500
35
  chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
36
  texts.extend(chunks)
37
 
38
+ # 生成embedding
39
  embeddings = model.encode(texts, show_progress_bar=True, normalize_embeddings=True)
40
  embeddings = np.array(embeddings, dtype=np.float32)
41
 
 
42
  index = faiss.IndexFlatIP(embeddings.shape[1])
43
  index.add(embeddings)
44
  faiss.write_index(index, EMBEDDING_PATH)
 
48
 
49
  return f"✅ 已建立知识库,共 {len(texts)} 段文本。"
50
 
51
+ def ensure_knowledge_base():
52
+ """确保知识库存在(Space启动时调用)"""
53
+ if os.path.exists(EMBEDDING_PATH) and os.path.exists(CHUNKS_PATH):
54
+ print("✅ 已检测到现有知识库,直接加载。")
55
+ return "知识库已加载。"
56
+ else:
57
+ print("📘 正在首次构建知识库...")
58
+ return build_knowledge_base()
59
+
60
  def search_similar(query, top_k=5):
61
+ """检索相关小说片段"""
62
+ if not (os.path.exists(EMBEDDING_PATH) and os.path.exists(CHUNKS_PATH)):
63
  return ""
64
 
65
  model = SentenceTransformer(EMBED_MODEL_NAME)
66
  index = faiss.read_index(EMBEDDING_PATH)
 
67
  with open(CHUNKS_PATH, "r", encoding="utf-8") as f:
68
  texts = json.load(f)
69
 
 
73
 
74
  return "\n".join(results)
75
 
76
+ # ======================================================
77
+ # 💬 聊天逻辑
78
+ # ======================================================
79
 
80
  def respond(
81
  message,
 
95
  messages = [{"role": "system", "content": system_message}]
96
  messages.extend(history)
97
 
98
+ # 🔍 RAG 检索
99
  related_text = search_similar(message, top_k=5)
100
  if related_text:
101
  message = f"以下是参考小说片段:\n{related_text}\n\n请模仿上文风格续写或回答:{message}"
 
103
  messages.append({"role": "user", "content": message})
104
 
105
  response = ""
 
106
  for chunk in client.chat_completion(
107
  messages,
108
  max_tokens=max_tokens,
 
114
  token = ""
115
  if len(choices) and choices[0].delta.content:
116
  token = choices[0].delta.content
117
+ # 过滤异常标记
 
118
  token = re.sub(r"<\|im_(start|end)\|>", "", token)
119
  token = re.sub(r"(?i)<\|assistant\|>", "", token)
 
120
  response += token
121
  yield re.sub(r"<\|im_(start|end)\|>", "", response).strip()
122
 
123
 
124
+ # ======================================================
125
+ # 🧱 Gradio 界面
126
+ # ======================================================
127
 
128
  chatbot = gr.ChatInterface(
129
+ respond,
130
+ type="messages",
131
+ additional_inputs=[
132
+ gr.Textbox(value="你是一位文风细腻、善于模仿小说风格的AI写作助手。", label="System message"),
133
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
134
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
135
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
136
+ ],
137
+ )
138
+
139
+ with gr.Blocks() as demo:
140
+ gr.Markdown("# 📖 DeepSeek 小说创作助手(内置知识库版)")
141
+
142
+ # 启动时自动加载知识库
143
+ status = ensure_knowledge_base()
144
+ gr.Markdown(f"### 🧠 {status}")
145
+
146
+ with gr.Sidebar():
147
+ gr.LoginButton()
148
+ gr.Markdown("小说知识库路径:`space/knowledge/`")
149
+ gr.Markdown("如需更新内容,请替换该目录内的txt文件并重新部署Space。")
150
+
151
+ chatbot.render()
152
+
153
+ if __name__ == "__main__":
154
+ demo.launch()