helion-v1-embeddings / prepare_embeddings_data.py
Specific-Cognito's picture
Create prepare_embeddings_data.py
3d8d9e4 verified
"""
Helion-V1-Embeddings Training Data Generator
Generate sentence pairs for training embeddings model
Optimized for semantic similarity and retrieval tasks
"""
import json
import logging
import random
from typing import List, Dict, Tuple
from pathlib import Path
from datetime import datetime
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class EmbeddingsDataGenerator:
"""Generate training data for embeddings model."""
def __init__(self, output_dir: str = "./embeddings_training_data"):
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
def generate_paraphrase_pairs(self) -> List[Dict]:
"""
Generate paraphrase pairs (high similarity).
Score: 0.85-1.0
"""
pairs = [
# Technical questions
{
"sentence1": "How do I install Python on Windows?",
"sentence2": "What's the process to set up Python on a Windows computer?",
"score": 0.95
},
{
"sentence1": "What is machine learning?",
"sentence2": "Can you explain machine learning to me?",
"score": 0.92
},
{
"sentence1": "How to fix a bug in my code?",
"sentence2": "What's the best way to debug my program?",
"score": 0.88
},
{
"sentence1": "Reset password instructions",
"sentence2": "How do I reset my password?",
"score": 0.93
},
{
"sentence1": "Database connection error",
"sentence2": "Can't connect to the database",
"score": 0.90
},
# General knowledge
{
"sentence1": "What is the capital of France?",
"sentence2": "Tell me the capital city of France",
"score": 0.96
},
{
"sentence1": "Best restaurants in New York",
"sentence2": "Where to eat in New York City",
"score": 0.89
},
{
"sentence1": "Weather forecast for tomorrow",
"sentence2": "What will the weather be like tomorrow?",
"score": 0.91
},
{
"sentence1": "How to learn a new language",
"sentence2": "Tips for learning foreign languages",
"score": 0.87
},
{
"sentence1": "Symptoms of the flu",
"sentence2": "What are flu symptoms?",
"score": 0.94
},
# Product/service queries
{
"sentence1": "How to cancel my subscription",
"sentence2": "Steps to unsubscribe from the service",
"score": 0.90
},
{
"sentence1": "Return policy for products",
"sentence2": "How do I return an item?",
"score": 0.86
},
{
"sentence1": "Customer support contact",
"sentence2": "How to reach customer service",
"score": 0.92
},
{
"sentence1": "Shipping tracking information",
"sentence2": "Where is my order?",
"score": 0.85
},
{
"sentence1": "Update payment method",
"sentence2": "Change my credit card information",
"score": 0.88
},
]
logger.info(f"Generated {len(pairs)} paraphrase pairs")
return pairs
def generate_similar_pairs(self) -> List[Dict]:
"""
Generate semantically similar pairs (medium-high similarity).
Score: 0.60-0.85
"""
pairs = [
# Related concepts
{
"sentence1": "Machine learning algorithms",
"sentence2": "Neural network architectures",
"score": 0.75
},
{
"sentence1": "Python programming language",
"sentence2": "JavaScript coding tutorial",
"score": 0.68
},
{
"sentence1": "Data science career path",
"sentence2": "Becoming a data analyst",
"score": 0.72
},
{
"sentence1": "Cloud computing services",
"sentence2": "AWS infrastructure guide",
"score": 0.70
},
{
"sentence1": "Web development frameworks",
"sentence2": "React and Vue.js comparison",
"score": 0.74
},
# Related questions
{
"sentence1": "How to lose weight?",
"sentence2": "Healthy eating habits",
"score": 0.65
},
{
"sentence1": "Best laptops for programming",
"sentence2": "Computer hardware for developers",
"score": 0.71
},
{
"sentence1": "Learning guitar for beginners",
"sentence2": "Music theory basics",
"score": 0.62
},
{
"sentence1": "Travel tips for Europe",
"sentence2": "Budget travel guide",
"score": 0.67
},
{
"sentence1": "Home workout routines",
"sentence2": "Fitness exercises without equipment",
"score": 0.73
},
# Professional context
{
"sentence1": "Project management best practices",
"sentence2": "Agile methodology guide",
"score": 0.69
},
{
"sentence1": "Resume writing tips",
"sentence2": "Job interview preparation",
"score": 0.64
},
{
"sentence1": "Team collaboration tools",
"sentence2": "Remote work software solutions",
"score": 0.72
},
{
"sentence1": "Time management techniques",
"sentence2": "Productivity improvement strategies",
"score": 0.76
},
{
"sentence1": "Business communication skills",
"sentence2": "Professional email etiquette",
"score": 0.66
},
]
logger.info(f"Generated {len(pairs)} similar pairs")
return pairs
def generate_dissimilar_pairs(self) -> List[Dict]:
"""
Generate unrelated pairs (low similarity).
Score: 0.0-0.30
"""
pairs = [
# Completely unrelated
{
"sentence1": "How to bake chocolate cake",
"sentence2": "Installing Linux operating system",
"score": 0.05
},
{
"sentence1": "Football match results",
"sentence2": "Quantum physics equations",
"score": 0.02
},
{
"sentence1": "Dog training tips",
"sentence2": "Stock market analysis",
"score": 0.08
},
{
"sentence1": "Car repair manual",
"sentence2": "Ancient Roman history",
"score": 0.04
},
{
"sentence1": "Gardening for beginners",
"sentence2": "Cryptocurrency trading strategies",
"score": 0.06
},
# Different domains
{
"sentence1": "Piano lessons online",
"sentence2": "Chemical engineering degree",
"score": 0.10
},
{
"sentence1": "Knitting patterns",
"sentence2": "Cybersecurity threats",
"score": 0.03
},
{
"sentence1": "Mediterranean diet recipes",
"sentence2": "Smartphone app development",
"score": 0.07
},
{
"sentence1": "Yoga poses for flexibility",
"sentence2": "Legal contract templates",
"score": 0.05
},
{
"sentence1": "Movie reviews 2024",
"sentence2": "Database optimization techniques",
"score": 0.09
},
# Topic mismatch
{
"sentence1": "Wedding planning checklist",
"sentence2": "Machine learning deployment",
"score": 0.04
},
{
"sentence1": "Child development stages",
"sentence2": "Network security protocols",
"score": 0.06
},
{
"sentence1": "Photography lighting techniques",
"sentence2": "Tax filing requirements",
"score": 0.08
},
{
"sentence1": "Fashion trends 2024",
"sentence2": "Docker container orchestration",
"score": 0.02
},
{
"sentence1": "Scuba diving certification",
"sentence2": "Financial portfolio management",
"score": 0.07
},
]
logger.info(f"Generated {len(pairs)} dissimilar pairs")
return pairs
def generate_question_answer_pairs(self) -> List[Dict]:
"""
Generate question-answer pairs for retrieval training.
Score: 0.80-0.95
"""
pairs = [
{
"sentence1": "What is Python?",
"sentence2": "Python is a high-level programming language known for its simplicity and versatility.",
"score": 0.88
},
{
"sentence1": "How does HTTP work?",
"sentence2": "HTTP is a protocol that enables communication between web browsers and servers.",
"score": 0.85
},
{
"sentence1": "What is artificial intelligence?",
"sentence2": "AI is the simulation of human intelligence by machines and computer systems.",
"score": 0.90
},
{
"sentence1": "Define cloud computing",
"sentence2": "Cloud computing delivers computing services over the internet including storage and processing.",
"score": 0.87
},
{
"sentence1": "What is a database?",
"sentence2": "A database is an organized collection of structured information stored electronically.",
"score": 0.89
},
{
"sentence1": "Explain REST API",
"sentence2": "REST API is an architectural style for building web services using HTTP requests.",
"score": 0.84
},
{
"sentence1": "What is version control?",
"sentence2": "Version control is a system that tracks changes to files over time.",
"score": 0.86
},
{
"sentence1": "Define responsive design",
"sentence2": "Responsive design ensures websites work well on all devices and screen sizes.",
"score": 0.88
},
{
"sentence1": "What is encryption?",
"sentence2": "Encryption is the process of encoding information to prevent unauthorized access.",
"score": 0.91
},
{
"sentence1": "Explain agile methodology",
"sentence2": "Agile is an iterative approach to project management focused on flexibility.",
"score": 0.83
},
]
logger.info(f"Generated {len(pairs)} question-answer pairs")
return pairs
def generate_domain_specific_pairs(self) -> List[Dict]:
"""
Generate domain-specific sentence pairs.
Score: Various
"""
pairs = [
# Programming
{
"sentence1": "Python list comprehension",
"sentence2": "Creating lists in Python efficiently",
"score": 0.86
},
{
"sentence1": "Git merge conflicts",
"sentence2": "Resolving version control conflicts",
"score": 0.84
},
{
"sentence1": "React component lifecycle",
"sentence2": "Understanding React hooks",
"score": 0.72
},
# Healthcare
{
"sentence1": "Blood pressure medication",
"sentence2": "Treating hypertension",
"score": 0.78
},
{
"sentence1": "Physical therapy exercises",
"sentence2": "Rehabilitation program",
"score": 0.80
},
# Finance
{
"sentence1": "Investment portfolio diversification",
"sentence2": "Managing financial risk",
"score": 0.75
},
{
"sentence1": "Mortgage interest rates",
"sentence2": "Home loan options",
"score": 0.82
},
# Education
{
"sentence1": "Online course platforms",
"sentence2": "E-learning systems",
"score": 0.88
},
{
"sentence1": "Study techniques for exams",
"sentence2": "Test preparation strategies",
"score": 0.85
},
# E-commerce
{
"sentence1": "Product recommendation system",
"sentence2": "Personalized shopping suggestions",
"score": 0.83
},
{
"sentence1": "Shopping cart abandonment",
"sentence2": "Incomplete purchase behavior",
"score": 0.86
},
]
logger.info(f"Generated {len(pairs)} domain-specific pairs")
return pairs
def format_for_training(self, pairs: List[Dict]) -> List[Dict]:
"""
Format sentence pairs for training.
Args:
pairs: List of sentence pair dictionaries
Returns:
Formatted training examples
"""
formatted = []
for pair in pairs:
formatted.append({
"sentence1": pair["sentence1"],
"sentence2": pair["sentence2"],
"score": pair["score"]
})
return formatted
def create_contrastive_examples(self, pairs: List[Dict]) -> List[Dict]:
"""
Create contrastive examples (anchor, positive, negative).
Args:
pairs: Sentence pairs with scores
Returns:
Triplet examples
"""
contrastive = []
high_sim = [p for p in pairs if p["score"] >= 0.80]
low_sim = [p for p in pairs if p["score"] <= 0.30]
for positive_pair in high_sim[:20]: # Take first 20
# Select random negative
if low_sim:
negative_pair = random.choice(low_sim)
contrastive.append({
"anchor": positive_pair["sentence1"],
"positive": positive_pair["sentence2"],
"negative": negative_pair["sentence2"]
})
logger.info(f"Created {len(contrastive)} contrastive examples")
return contrastive
def save_data(self, data: List[Dict], filename: str, format: str = "json"):
"""Save training data to file."""
filepath = self.output_dir / filename
if format == "json":
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
elif format == "jsonl":
with open(filepath, 'w', encoding='utf-8') as f:
for item in data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
logger.info(f"Saved {len(data)} examples to {filepath}")
def generate_full_dataset(self, format: str = "json") -> str:
"""
Generate complete embeddings training dataset.
Args:
format: Output format ('json' or 'jsonl')
Returns:
Output directory path
"""
logger.info("Generating embeddings training dataset...")
# Collect all pairs
all_pairs = []
paraphrase_pairs = self.generate_paraphrase_pairs()
all_pairs.extend(paraphrase_pairs)
similar_pairs = self.generate_similar_pairs()
all_pairs.extend(similar_pairs)
dissimilar_pairs = self.generate_dissimilar_pairs()
all_pairs.extend(dissimilar_pairs)
qa_pairs = self.generate_question_answer_pairs()
all_pairs.extend(qa_pairs)
domain_pairs = self.generate_domain_specific_pairs()
all_pairs.extend(domain_pairs)
# Shuffle
random.shuffle(all_pairs)
# Split train/validation
split_idx = int(len(all_pairs) * 0.9)
train_pairs = all_pairs[:split_idx]
val_pairs = all_pairs[split_idx:]
logger.info(f"Train: {len(train_pairs)} pairs")
logger.info(f"Validation: {len(val_pairs)} pairs")
# Format data
train_data = self.format_for_training(train_pairs)
val_data = self.format_for_training(val_pairs)
# Save sentence pair format
self.save_data(train_data, f"train_pairs.{format}", format)
self.save_data(val_data, f"validation_pairs.{format}", format)
# Create contrastive examples
contrastive_data = self.create_contrastive_examples(all_pairs)
self.save_data(contrastive_data, f"contrastive_triplets.{format}", format)
# Generate statistics
stats = {
"total_pairs": len(all_pairs),
"train_size": len(train_pairs),
"validation_size": len(val_pairs),
"contrastive_triplets": len(contrastive_data),
"paraphrase_pairs": len(paraphrase_pairs),
"similar_pairs": len(similar_pairs),
"dissimilar_pairs": len(dissimilar_pairs),
"qa_pairs": len(qa_pairs),
"domain_pairs": len(domain_pairs),
"score_distribution": {
"high (0.8-1.0)": len([p for p in all_pairs if p["score"] >= 0.8]),
"medium (0.5-0.8)": len([p for p in all_pairs if 0.5 <= p["score"] < 0.8]),
"low (0.0-0.5)": len([p for p in all_pairs if p["score"] < 0.5])
},
"generated_at": datetime.now().isoformat(),
"format": format
}
self.save_data(stats, "embeddings_dataset_stats.json", "json")
logger.info("="*60)
logger.info("βœ… Embeddings dataset generation complete!")
logger.info(f"Total pairs: {len(all_pairs)}")
logger.info(f"Output directory: {self.output_dir}")
logger.info("="*60)
return str(self.output_dir)
def main():
"""Main function for data generation."""
import argparse
parser = argparse.ArgumentParser(
description="Generate training data for Helion-V1-Embeddings"
)
parser.add_argument(
"--output-dir",
default="./embeddings_training_data",
help="Output directory for training data"
)
parser.add_argument(
"--format",
choices=["json", "jsonl"],
default="json",
help="Output format"
)
args = parser.parse_args()
# Generate dataset
generator = EmbeddingsDataGenerator(output_dir=args.output_dir)
output_path = generator.generate_full_dataset(format=args.format)
print("\n" + "="*60)
print("🎯 Embeddings Training Data Ready!")
print("="*60)
print(f"πŸ“ Location: {output_path}")
print(f"πŸ“Š Format: {args.format}")
print("\nπŸ“„ Files created:")
print(f" β€’ train_pairs.{args.format} - Training sentence pairs")
print(f" β€’ validation_pairs.{args.format} - Validation pairs")
print(f" β€’ contrastive_triplets.{args.format} - Triplet examples")
print(" β€’ embeddings_dataset_stats.json - Dataset statistics")
print("\nπŸ’‘ Training data includes:")
print(" β€’ Paraphrase pairs (high similarity)")
print(" β€’ Similar concept pairs (medium similarity)")
print(" β€’ Dissimilar pairs (low similarity)")
print(" β€’ Question-answer pairs")
print(" β€’ Domain-specific examples")
print("\nπŸš€ Next step:")
print(f" python train_embeddings.py --data-file {output_path}/train_pairs.{args.format}")
print("="*60)
if __name__ == "__main__":
main()