dragonllm-finance-models / tests /test_inference.py
jeanbaptdzd's picture
feat: Clean deployment to HuggingFace Space with model config test endpoint
8c0b652
"""
Inference functionality tests
"""
import pytest
import requests
from typing import Dict, Any
class TestInference:
"""Test inference endpoints"""
@pytest.fixture
def base_url(self):
return "http://localhost:8000"
def test_basic_inference(self, base_url):
"""Test basic inference endpoint"""
payload = {
"prompt": "What is EBITDA?",
"max_new_tokens": 50,
"temperature": 0.6
}
response = requests.post(f"{base_url}/inference", json=payload)
assert response.status_code == 200
data = response.json()
assert "response" in data
assert "model_used" in data
assert len(data["response"]) > 0
def test_inference_with_different_models(self, base_url):
"""Test inference with different model parameters"""
payload = {
"prompt": "Explain financial risk management",
"max_new_tokens": 100,
"temperature": 0.3
}
response = requests.post(f"{base_url}/inference", json=payload)
assert response.status_code == 200
data = response.json()
assert "response" in data
assert len(data["response"]) > 50 # Should be substantial response
def test_inference_error_handling(self, base_url):
"""Test inference error handling"""
# Test with invalid parameters
payload = {
"prompt": "", # Empty prompt
"max_new_tokens": 50
}
response = requests.post(f"{base_url}/inference", json=payload)
# Should handle gracefully (either 400 or 200 with error message)
assert response.status_code in [200, 400]
def test_inference_performance(self, base_url):
"""Test inference performance (basic timing)"""
import time
payload = {
"prompt": "What is the current ratio?",
"max_new_tokens": 30,
"temperature": 0.5
}
start_time = time.time()
response = requests.post(f"{base_url}/inference", json=payload)
end_time = time.time()
assert response.status_code == 200
response_time = end_time - start_time
# Should respond within reasonable time (adjust based on your setup)
assert response_time < 30 # 30 seconds max for simple query