Spaces:
Paused
Paused
| import asyncio | |
| import httpx | |
| import json | |
| import requests | |
| import math | |
| import os | |
| client = httpx.AsyncClient() | |
| # 请求URL | |
| recommand_base_url = "https://" + os.getenv("recommand_base_url") | |
| chat_url = "https://" + os.getenv("chat_url") | |
| model_url = "https://" + os.getenv("model_url") | |
| character_url = "https://" + os.getenv("character_url") | |
| avatar_url = "https://" + os.getenv("avatar_url") | |
| image_url = "https://" + os.getenv("image_url") | |
| auth = os.getenv("auth") | |
| moment_url = os.getenv("moment_url") | |
| #headers | |
| def create_headers(language): | |
| # 映射 | |
| language_mapping = { | |
| 'Chinese': 'zh', | |
| 'English': 'en', | |
| 'Japanese': 'ja', | |
| 'Korean': 'ko' | |
| } | |
| # 获取对应的语言代码,如果不存在则默认为 'zh' | |
| language_code = language_mapping.get(language, 'zh') | |
| return { | |
| 'X-Refresh-Token': '', | |
| 'X-Language': language_code, | |
| 'accept-language': '', | |
| 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)', | |
| 'Authorization': auth, | |
| 'Accept': '*/*', | |
| 'Connection': 'keep-alive', | |
| 'X-nsfw': '1' | |
| } | |
| def recommand_character(language): | |
| response = requests.get(character_url, headers=create_headers(language)) | |
| json_data = response.json() | |
| characters = [{ | |
| "name": item["name"], | |
| "_id": item["_id"], | |
| "avatar_url": str(avatar_url + item['_id'] + "_avatar.webp") | |
| } for item in json_data['data']] | |
| return characters | |
| def id_to_avatar(char_id): | |
| return str(avatar_url + char_id + "_avatar.webp") | |
| #GET模型列表 | |
| def get_models(): | |
| class ModelStorage: | |
| def __init__(self): | |
| self.models = [] | |
| def add_models(self, models): | |
| for model_info in models: | |
| # 过滤掉 'gpt-4o' 和 'gpt-4o-mini' | |
| if model_info['model'] not in ['mythomax-13b']: | |
| if model_info['model'] in ['gemma-2-9b', 'llama-3.1-8b']: | |
| weight = 12 # Assign a low weight to reduce their frequency | |
| else: | |
| weight = int(math.ceil(25 / (model_info['price'] + 0.5))) | |
| self.models.extend([model_info['model']] * weight) | |
| model_storage = ModelStorage() | |
| # 从指定的 URL 获取 JSON 数据 | |
| response = requests.get(model_url) | |
| if response.status_code == 200: | |
| data = response.json() | |
| # 添加模型到 self.models | |
| model_storage.add_models(data['data']) | |
| return model_storage.models | |
| #解析推荐json | |
| async def extract_recommand(data, language): | |
| result = [] | |
| for item in data["data"]: | |
| opening = await get_moment_opening(item["_id"], language) | |
| result.append({ | |
| "character_id": item["character_id"], | |
| "avatar_url": str(avatar_url + item["character_id"] + "_avatar.webp"), | |
| "_id": item["_id"], | |
| "image_url": str(image_url + item["_id"] + "_large.webp"), | |
| "description": item["description"], | |
| "name": item["title"], | |
| "opening": opening | |
| }) | |
| return result | |
| async def get_moment_opening(moment_id, language): | |
| url = f"{moment_url}{moment_id}" | |
| async with httpx.AsyncClient() as client: | |
| response = await client.get(url, headers=create_headers(language)) | |
| if response.status_code == 200: | |
| data = response.json() | |
| return data['data']['opening'] | |
| return None | |
| #请求推荐API | |
| async def recommand(char_id, language): | |
| recommand_url = str(recommand_base_url + char_id + "?num=20&offset=0") | |
| async with httpx.AsyncClient() as client: | |
| response = await client.get(recommand_url, headers=create_headers(language)) | |
| json_data = response.json() | |
| return await extract_recommand(json_data, language) | |
| async def fetch_stream(query, model, moment_id, session_id, bio, request_name, queue, language): | |
| payload = {"query": query, "model": model, "bio": bio, "moment_id": moment_id} | |
| if session_id: | |
| payload["session_id"] = session_id | |
| async with client.stream( | |
| "POST", chat_url, json=payload, headers=create_headers(language) | |
| ) as response: | |
| # 获取并返回 header | |
| if response.status_code != 200: | |
| await queue.put((request_name, "content", "Error Occur!")) | |
| await queue.put((request_name, "end", None)) | |
| return | |
| response_headers = dict(response.headers) | |
| session_id = response_headers.get("x-session-id") | |
| await queue.put((request_name, "header", response_headers)) | |
| # 流式处理响应内容 | |
| async for chunk in response.aiter_bytes(): | |
| await queue.put((request_name, "content", chunk.decode())) | |
| # 标记流结束 | |
| await queue.put((request_name, "end", None)) | |
| return session_id | |
| async def combine_streams( | |
| query_a, | |
| query_b, | |
| model_a, | |
| model_b, | |
| moment_id_a, | |
| moment_id_b, | |
| session_id_a, | |
| session_id_b, | |
| bio_a, | |
| bio_b, | |
| language | |
| ): | |
| queue = asyncio.Queue() | |
| task_a = asyncio.create_task( | |
| fetch_stream( | |
| query_a, model_a, moment_id_a, session_id_a, bio_a, "requestA", queue, language | |
| ) | |
| ) | |
| task_b = asyncio.create_task( | |
| fetch_stream( | |
| query_b, model_b, moment_id_b, session_id_b, bio_b, "requestB", queue, language | |
| ) | |
| ) | |
| headers = {} | |
| content = {"requestA": "", "requestB": ""} | |
| active_streams = 2 | |
| while active_streams > 0: | |
| request_name, data_type, data = await queue.get() | |
| if data_type == "header": | |
| headers[f"{request_name}_header"] = data | |
| if len(headers) == 2: | |
| yield headers | |
| elif data_type == "content": | |
| content[request_name] = data.strip() | |
| if content["requestA"] or content["requestB"]: | |
| yield content | |
| content = {"requestA": "", "requestB": ""} | |
| elif data_type == "end": | |
| active_streams -= 1 | |
| session_id_a = await task_a | |
| session_id_b = await task_b | |