dragonllm-finance-models / tests /test_openai_compatibility.py
jeanbaptdzd's picture
feat: Clean deployment to HuggingFace Space with model config test endpoint
8c0b652
"""
OpenAI API compatibility tests
"""
import pytest
import requests
import json
from typing import Dict, Any
class TestOpenAICompatibility:
"""Test OpenAI-compatible endpoints"""
@pytest.fixture
def base_url(self):
return "http://localhost:8000"
def test_chat_completions(self, base_url):
"""Test OpenAI chat completions endpoint"""
payload = {
"model": "llama3.1-8b",
"messages": [
{"role": "user", "content": "What is EBITDA?"}
],
"max_tokens": 100,
"temperature": 0.7
}
response = requests.post(f"{base_url}/v1/chat/completions", json=payload)
assert response.status_code == 200
data = response.json()
# Check OpenAI response format
assert "choices" in data
assert "usage" in data
assert "model" in data
assert len(data["choices"]) > 0
assert "message" in data["choices"][0]
assert "content" in data["choices"][0]["message"]
def test_chat_completions_with_system_message(self, base_url):
"""Test chat completions with system message"""
payload = {
"model": "llama3.1-8b",
"messages": [
{"role": "system", "content": "You are a financial expert."},
{"role": "user", "content": "Explain the difference between revenue and profit."}
],
"max_tokens": 150,
"temperature": 0.6
}
response = requests.post(f"{base_url}/v1/chat/completions", json=payload)
assert response.status_code == 200
data = response.json()
assert "choices" in data
assert len(data["choices"][0]["message"]["content"]) > 0
def test_text_completions(self, base_url):
"""Test OpenAI text completions endpoint"""
payload = {
"model": "llama3.1-8b",
"prompt": "The key financial ratios for a healthy company include:",
"max_tokens": 100,
"temperature": 0.5
}
response = requests.post(f"{base_url}/v1/completions", json=payload)
assert response.status_code == 200
data = response.json()
# Check OpenAI response format
assert "choices" in data
assert "usage" in data
assert "model" in data
assert len(data["choices"]) > 0
assert "text" in data["choices"][0]
def test_json_response_format(self, base_url):
"""Test structured JSON output"""
payload = {
"model": "llama3.1-8b",
"messages": [
{
"role": "user",
"content": "Return financial metrics in JSON format: revenue, profit, debt_ratio"
}
],
"response_format": {"type": "json_object"},
"max_tokens": 150
}
response = requests.post(f"{base_url}/v1/chat/completions", json=payload)
assert response.status_code == 200
data = response.json()
# Check that response is valid JSON
content = data["choices"][0]["message"]["content"]
try:
json_data = json.loads(content)
assert isinstance(json_data, dict)
except json.JSONDecodeError:
pytest.fail("Response is not valid JSON")
def test_streaming_response(self, base_url):
"""Test streaming chat completions"""
payload = {
"model": "llama3.1-8b",
"messages": [
{"role": "user", "content": "Explain financial risk management strategies."}
],
"stream": True,
"max_tokens": 100
}
response = requests.post(f"{base_url}/v1/chat/completions", json=payload, stream=True)
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream"
# Check that we get streaming data
chunks = []
for line in response.iter_lines():
if line:
chunks.append(line.decode('utf-8'))
if len(chunks) >= 3: # Get a few chunks
break
assert len(chunks) > 0
# Check SSE format
assert any("data:" in chunk for chunk in chunks)
def test_rag_example(self, base_url):
"""Test RAG-style document analysis"""
payload = {
"model": "llama3.1-8b",
"messages": [
{
"role": "system",
"content": "You are a financial analyst. Extract key metrics from financial documents."
},
{
"role": "user",
"content": "Analyze this financial statement and return the data in JSON format with fields: revenue, expenses, net_income.\n\nDocument:\nQ3 2024 Results:\nRevenue: $2.5M\nExpenses: $1.8M\nNet Income: $700K"
}
],
"response_format": {"type": "json_object"},
"max_tokens": 200
}
response = requests.post(f"{base_url}/v1/chat/completions", json=payload)
assert response.status_code == 200
data = response.json()
content = data["choices"][0]["message"]["content"]
json_data = json.loads(content)
# Check that key financial metrics are extracted
expected_fields = ["revenue", "expenses", "net_income"]
for field in expected_fields:
assert field in json_data or field.replace("_", " ") in str(json_data).lower()