suno_prompt_gen / glm_huggingface.py
caveman1's picture
Upload folder using huggingface_hub
f96eb1f verified
raw
history blame
4.33 kB
from __future__ import annotations
import os
import sys
import html
import re
from typing import Any, List, Optional
from paths import get_project_root
from gradio_client import Client # type: ignore
def run_glm(prompt: str, sys_prompt: str) -> str:
"""
Run LLM using the GLM Gradio Space via gradio_client.
Environment variables (loaded from .env at the repo root if available):
- GLM_GRADIO_SPACE (default: "zai-org/GLM-4.5-Space")
- GLM_SYS_PROMPT (default mirrors the system prompt used in test_gradio.py)
- GLM_TEMPERATURE (float; default: 1)
- GLM_THINKING_ENABLED (bool; "1"/"true"/"yes" => True; default: False)
- GLM_API_NAME (default: "/chat_wrapper")
"""
text = (prompt or "").strip()
if not text:
return prompt
_load_env()
space = os.getenv("GLM_GRADIO_SPACE", "zai-org/GLM-4.5-Space").strip() or "zai-org/GLM-4.5-Space"
sys_prompt = sys_prompt or ""
api_name = os.getenv("GLM_API_NAME", "/chat_wrapper").strip() or "/chat_wrapper"
temperature = _coerce_float(os.getenv("GLM_TEMPERATURE", "1"), default=1.0)
thinking_enabled = _coerce_bool(os.getenv("GLM_THINKING_ENABLED", "0"))
try:
client = Client(space)
result = client.predict(
msg=text,
sys_prompt=sys_prompt,
thinking_enabled=thinking_enabled,
temperature=temperature,
api_name=api_name,
)
content = _extract_clean_assistant_text(result).strip()
return content if content else prompt
except Exception as e:
_warn(f"GLM infer filed: {type(e).__name__}")
return prompt
# --- Internal helpers ---------------------------------------------------------
def _warn(msg: str) -> None:
print(f"[llm_refiner_glm] {msg}", file=sys.stderr)
def _load_env() -> None:
"""
Attempt to load .env from the project root (best-effort).
"""
env_path = get_project_root() / ".env"
try:
from dotenv import load_dotenv # type: ignore
load_dotenv(dotenv_path=env_path)
except Exception:
# If python-dotenv isn't installed, silently rely on existing process env
pass
def _coerce_bool(val: Any) -> bool:
s = str(val).strip().lower()
return s in {"1", "true", "yes", "on"}
def _coerce_float(val: Any, default: float) -> float:
try:
return float(val)
except Exception:
return default
def _clean_html_to_text(html_content: Any) -> str:
"""
Convert HTML (possibly with entities) to plain text with minimal structure preserved via newlines.
Mirrors the robust cleaning approach used in test_gradio.py.
"""
if not isinstance(html_content, str):
html_content = str(html_content)
# Unescape entities first to normalize both raw and escaped HTML
text = html.unescape(html_content)
# Convert common block/line-break tags to newlines
text = re.sub(r"(?is)<br\s*/?>|</p>|</div>", "\n", text)
text = re.sub(r"(?is)<(script|style)\b.*?>.*?</\1>", "", text)
# Remove all remaining HTML tags (raw)
text = re.sub(r"(?s)<[^>]+>", "", text)
# Normalize whitespace: CRLF->LF, but don't remove empty lines or collapse whitespace
text = text.replace("\r\n", "\n").replace("\r", "\n")
# Preserve all lines including empty ones and whitespace
lines = text.splitlines()
# Join lines with single newlines, but keep existing whitespace
return "\n".join(lines)
def _extract_clean_assistant_text(result: Any) -> str:
"""
Extract 'assistant' messages from a Gradio chat result and return plain text without HTML.
Supports result formats:
- (messages, *rest)
- messages
Each message is a dict with 'role' and 'content'.
"""
# Normalize to a list of messages
if isinstance(result, tuple) and result:
messages = result[0]
else:
messages = result
clean_segments: List[str] = []
if isinstance(messages, list):
for m in messages:
if isinstance(m, dict) and m.get("role") == "assistant":
content = m.get("content", "")
clean_segments.append(_clean_html_to_text(content))
if not clean_segments:
return ""
return "\n\n".join(s for s in clean_segments if s)