""" OpenAI API compatibility tests """ import pytest import requests import json from typing import Dict, Any class TestOpenAICompatibility: """Test OpenAI-compatible endpoints""" @pytest.fixture def base_url(self): return "http://localhost:8000" def test_chat_completions(self, base_url): """Test OpenAI chat completions endpoint""" payload = { "model": "llama3.1-8b", "messages": [ {"role": "user", "content": "What is EBITDA?"} ], "max_tokens": 100, "temperature": 0.7 } response = requests.post(f"{base_url}/v1/chat/completions", json=payload) assert response.status_code == 200 data = response.json() # Check OpenAI response format assert "choices" in data assert "usage" in data assert "model" in data assert len(data["choices"]) > 0 assert "message" in data["choices"][0] assert "content" in data["choices"][0]["message"] def test_chat_completions_with_system_message(self, base_url): """Test chat completions with system message""" payload = { "model": "llama3.1-8b", "messages": [ {"role": "system", "content": "You are a financial expert."}, {"role": "user", "content": "Explain the difference between revenue and profit."} ], "max_tokens": 150, "temperature": 0.6 } response = requests.post(f"{base_url}/v1/chat/completions", json=payload) assert response.status_code == 200 data = response.json() assert "choices" in data assert len(data["choices"][0]["message"]["content"]) > 0 def test_text_completions(self, base_url): """Test OpenAI text completions endpoint""" payload = { "model": "llama3.1-8b", "prompt": "The key financial ratios for a healthy company include:", "max_tokens": 100, "temperature": 0.5 } response = requests.post(f"{base_url}/v1/completions", json=payload) assert response.status_code == 200 data = response.json() # Check OpenAI response format assert "choices" in data assert "usage" in data assert "model" in data assert len(data["choices"]) > 0 assert "text" in data["choices"][0] def test_json_response_format(self, base_url): """Test structured JSON output""" payload = { "model": "llama3.1-8b", "messages": [ { "role": "user", "content": "Return financial metrics in JSON format: revenue, profit, debt_ratio" } ], "response_format": {"type": "json_object"}, "max_tokens": 150 } response = requests.post(f"{base_url}/v1/chat/completions", json=payload) assert response.status_code == 200 data = response.json() # Check that response is valid JSON content = data["choices"][0]["message"]["content"] try: json_data = json.loads(content) assert isinstance(json_data, dict) except json.JSONDecodeError: pytest.fail("Response is not valid JSON") def test_streaming_response(self, base_url): """Test streaming chat completions""" payload = { "model": "llama3.1-8b", "messages": [ {"role": "user", "content": "Explain financial risk management strategies."} ], "stream": True, "max_tokens": 100 } response = requests.post(f"{base_url}/v1/chat/completions", json=payload, stream=True) assert response.status_code == 200 assert response.headers["content-type"] == "text/event-stream" # Check that we get streaming data chunks = [] for line in response.iter_lines(): if line: chunks.append(line.decode('utf-8')) if len(chunks) >= 3: # Get a few chunks break assert len(chunks) > 0 # Check SSE format assert any("data:" in chunk for chunk in chunks) def test_rag_example(self, base_url): """Test RAG-style document analysis""" payload = { "model": "llama3.1-8b", "messages": [ { "role": "system", "content": "You are a financial analyst. Extract key metrics from financial documents." }, { "role": "user", "content": "Analyze this financial statement and return the data in JSON format with fields: revenue, expenses, net_income.\n\nDocument:\nQ3 2024 Results:\nRevenue: $2.5M\nExpenses: $1.8M\nNet Income: $700K" } ], "response_format": {"type": "json_object"}, "max_tokens": 200 } response = requests.post(f"{base_url}/v1/chat/completions", json=payload) assert response.status_code == 200 data = response.json() content = data["choices"][0]["message"]["content"] json_data = json.loads(content) # Check that key financial metrics are extracted expected_fields = ["revenue", "expenses", "net_income"] for field in expected_fields: assert field in json_data or field.replace("_", " ") in str(json_data).lower()