Spaces:
Sleeping
Sleeping
[update] adding community GPU support
Browse files
app.py
CHANGED
|
@@ -21,22 +21,10 @@ from pathlib import Path
|
|
| 21 |
from typing import Optional, Generator
|
| 22 |
from queue import Queue
|
| 23 |
from threading import Thread
|
| 24 |
-
|
| 25 |
-
# ZeroGPU support
|
| 26 |
-
try:
|
| 27 |
-
import spaces
|
| 28 |
-
ZEROGPU_AVAILABLE = True
|
| 29 |
-
except ImportError:
|
| 30 |
-
ZEROGPU_AVAILABLE = False
|
| 31 |
-
# Create a no-op decorator for non-ZeroGPU environments
|
| 32 |
-
class spaces:
|
| 33 |
-
@staticmethod
|
| 34 |
-
def GPU(duration=None):
|
| 35 |
-
def decorator(func):
|
| 36 |
-
return func
|
| 37 |
-
return decorator if duration else lambda f: f
|
| 38 |
-
|
| 39 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
|
|
|
|
|
|
|
|
|
| 40 |
from rosetta.utils.evaluate import load_rosetta_model, load_hf_model, set_default_chat_template
|
| 41 |
from rosetta.model.wrapper import RosettaModel
|
| 42 |
from rosetta.baseline.multi_stage import TwoStageInference
|
|
@@ -623,8 +611,7 @@ def main():
|
|
| 623 |
server_name="0.0.0.0",
|
| 624 |
server_port=7860,
|
| 625 |
share=False,
|
| 626 |
-
show_error=True
|
| 627 |
-
ssr_mode=False
|
| 628 |
)
|
| 629 |
|
| 630 |
|
|
|
|
| 21 |
from typing import Optional, Generator
|
| 22 |
from queue import Queue
|
| 23 |
from threading import Thread
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
| 25 |
+
import spaces
|
| 26 |
+
ZEROGPU_AVAILABLE = os.getenv("ZERO_GPU", "").lower() == "true" # ZeroGPU support - HuggingFace Spaces sets ZERO_GPU=true when ZeroGPU is available
|
| 27 |
+
|
| 28 |
from rosetta.utils.evaluate import load_rosetta_model, load_hf_model, set_default_chat_template
|
| 29 |
from rosetta.model.wrapper import RosettaModel
|
| 30 |
from rosetta.baseline.multi_stage import TwoStageInference
|
|
|
|
| 611 |
server_name="0.0.0.0",
|
| 612 |
server_port=7860,
|
| 613 |
share=False,
|
| 614 |
+
show_error=True
|
|
|
|
| 615 |
)
|
| 616 |
|
| 617 |
|