|
|
""" |
|
|
Helion-V1-Embeddings Training Data Generator |
|
|
Generate sentence pairs for training embeddings model |
|
|
Optimized for semantic similarity and retrieval tasks |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
import random |
|
|
from typing import List, Dict, Tuple |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class EmbeddingsDataGenerator: |
|
|
"""Generate training data for embeddings model.""" |
|
|
|
|
|
def __init__(self, output_dir: str = "./embeddings_training_data"): |
|
|
self.output_dir = Path(output_dir) |
|
|
self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def generate_paraphrase_pairs(self) -> List[Dict]: |
|
|
""" |
|
|
Generate paraphrase pairs (high similarity). |
|
|
Score: 0.85-1.0 |
|
|
""" |
|
|
pairs = [ |
|
|
|
|
|
{ |
|
|
"sentence1": "How do I install Python on Windows?", |
|
|
"sentence2": "What's the process to set up Python on a Windows computer?", |
|
|
"score": 0.95 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "What is machine learning?", |
|
|
"sentence2": "Can you explain machine learning to me?", |
|
|
"score": 0.92 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "How to fix a bug in my code?", |
|
|
"sentence2": "What's the best way to debug my program?", |
|
|
"score": 0.88 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Reset password instructions", |
|
|
"sentence2": "How do I reset my password?", |
|
|
"score": 0.93 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Database connection error", |
|
|
"sentence2": "Can't connect to the database", |
|
|
"score": 0.90 |
|
|
}, |
|
|
|
|
|
|
|
|
{ |
|
|
"sentence1": "What is the capital of France?", |
|
|
"sentence2": "Tell me the capital city of France", |
|
|
"score": 0.96 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Best restaurants in New York", |
|
|
"sentence2": "Where to eat in New York City", |
|
|
"score": 0.89 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Weather forecast for tomorrow", |
|
|
"sentence2": "What will the weather be like tomorrow?", |
|
|
"score": 0.91 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "How to learn a new language", |
|
|
"sentence2": "Tips for learning foreign languages", |
|
|
"score": 0.87 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Symptoms of the flu", |
|
|
"sentence2": "What are flu symptoms?", |
|
|
"score": 0.94 |
|
|
}, |
|
|
|
|
|
|
|
|
{ |
|
|
"sentence1": "How to cancel my subscription", |
|
|
"sentence2": "Steps to unsubscribe from the service", |
|
|
"score": 0.90 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Return policy for products", |
|
|
"sentence2": "How do I return an item?", |
|
|
"score": 0.86 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Customer support contact", |
|
|
"sentence2": "How to reach customer service", |
|
|
"score": 0.92 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Shipping tracking information", |
|
|
"sentence2": "Where is my order?", |
|
|
"score": 0.85 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Update payment method", |
|
|
"sentence2": "Change my credit card information", |
|
|
"score": 0.88 |
|
|
}, |
|
|
] |
|
|
|
|
|
logger.info(f"Generated {len(pairs)} paraphrase pairs") |
|
|
return pairs |
|
|
|
|
|
def generate_similar_pairs(self) -> List[Dict]: |
|
|
""" |
|
|
Generate semantically similar pairs (medium-high similarity). |
|
|
Score: 0.60-0.85 |
|
|
""" |
|
|
pairs = [ |
|
|
|
|
|
{ |
|
|
"sentence1": "Machine learning algorithms", |
|
|
"sentence2": "Neural network architectures", |
|
|
"score": 0.75 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Python programming language", |
|
|
"sentence2": "JavaScript coding tutorial", |
|
|
"score": 0.68 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Data science career path", |
|
|
"sentence2": "Becoming a data analyst", |
|
|
"score": 0.72 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Cloud computing services", |
|
|
"sentence2": "AWS infrastructure guide", |
|
|
"score": 0.70 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Web development frameworks", |
|
|
"sentence2": "React and Vue.js comparison", |
|
|
"score": 0.74 |
|
|
}, |
|
|
|
|
|
|
|
|
{ |
|
|
"sentence1": "How to lose weight?", |
|
|
"sentence2": "Healthy eating habits", |
|
|
"score": 0.65 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Best laptops for programming", |
|
|
"sentence2": "Computer hardware for developers", |
|
|
"score": 0.71 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Learning guitar for beginners", |
|
|
"sentence2": "Music theory basics", |
|
|
"score": 0.62 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Travel tips for Europe", |
|
|
"sentence2": "Budget travel guide", |
|
|
"score": 0.67 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Home workout routines", |
|
|
"sentence2": "Fitness exercises without equipment", |
|
|
"score": 0.73 |
|
|
}, |
|
|
|
|
|
|
|
|
{ |
|
|
"sentence1": "Project management best practices", |
|
|
"sentence2": "Agile methodology guide", |
|
|
"score": 0.69 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Resume writing tips", |
|
|
"sentence2": "Job interview preparation", |
|
|
"score": 0.64 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Team collaboration tools", |
|
|
"sentence2": "Remote work software solutions", |
|
|
"score": 0.72 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Time management techniques", |
|
|
"sentence2": "Productivity improvement strategies", |
|
|
"score": 0.76 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Business communication skills", |
|
|
"sentence2": "Professional email etiquette", |
|
|
"score": 0.66 |
|
|
}, |
|
|
] |
|
|
|
|
|
logger.info(f"Generated {len(pairs)} similar pairs") |
|
|
return pairs |
|
|
|
|
|
def generate_dissimilar_pairs(self) -> List[Dict]: |
|
|
""" |
|
|
Generate unrelated pairs (low similarity). |
|
|
Score: 0.0-0.30 |
|
|
""" |
|
|
pairs = [ |
|
|
|
|
|
{ |
|
|
"sentence1": "How to bake chocolate cake", |
|
|
"sentence2": "Installing Linux operating system", |
|
|
"score": 0.05 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Football match results", |
|
|
"sentence2": "Quantum physics equations", |
|
|
"score": 0.02 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Dog training tips", |
|
|
"sentence2": "Stock market analysis", |
|
|
"score": 0.08 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Car repair manual", |
|
|
"sentence2": "Ancient Roman history", |
|
|
"score": 0.04 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Gardening for beginners", |
|
|
"sentence2": "Cryptocurrency trading strategies", |
|
|
"score": 0.06 |
|
|
}, |
|
|
|
|
|
|
|
|
{ |
|
|
"sentence1": "Piano lessons online", |
|
|
"sentence2": "Chemical engineering degree", |
|
|
"score": 0.10 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Knitting patterns", |
|
|
"sentence2": "Cybersecurity threats", |
|
|
"score": 0.03 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Mediterranean diet recipes", |
|
|
"sentence2": "Smartphone app development", |
|
|
"score": 0.07 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Yoga poses for flexibility", |
|
|
"sentence2": "Legal contract templates", |
|
|
"score": 0.05 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Movie reviews 2024", |
|
|
"sentence2": "Database optimization techniques", |
|
|
"score": 0.09 |
|
|
}, |
|
|
|
|
|
|
|
|
{ |
|
|
"sentence1": "Wedding planning checklist", |
|
|
"sentence2": "Machine learning deployment", |
|
|
"score": 0.04 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Child development stages", |
|
|
"sentence2": "Network security protocols", |
|
|
"score": 0.06 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Photography lighting techniques", |
|
|
"sentence2": "Tax filing requirements", |
|
|
"score": 0.08 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Fashion trends 2024", |
|
|
"sentence2": "Docker container orchestration", |
|
|
"score": 0.02 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Scuba diving certification", |
|
|
"sentence2": "Financial portfolio management", |
|
|
"score": 0.07 |
|
|
}, |
|
|
] |
|
|
|
|
|
logger.info(f"Generated {len(pairs)} dissimilar pairs") |
|
|
return pairs |
|
|
|
|
|
def generate_question_answer_pairs(self) -> List[Dict]: |
|
|
""" |
|
|
Generate question-answer pairs for retrieval training. |
|
|
Score: 0.80-0.95 |
|
|
""" |
|
|
pairs = [ |
|
|
{ |
|
|
"sentence1": "What is Python?", |
|
|
"sentence2": "Python is a high-level programming language known for its simplicity and versatility.", |
|
|
"score": 0.88 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "How does HTTP work?", |
|
|
"sentence2": "HTTP is a protocol that enables communication between web browsers and servers.", |
|
|
"score": 0.85 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "What is artificial intelligence?", |
|
|
"sentence2": "AI is the simulation of human intelligence by machines and computer systems.", |
|
|
"score": 0.90 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Define cloud computing", |
|
|
"sentence2": "Cloud computing delivers computing services over the internet including storage and processing.", |
|
|
"score": 0.87 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "What is a database?", |
|
|
"sentence2": "A database is an organized collection of structured information stored electronically.", |
|
|
"score": 0.89 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Explain REST API", |
|
|
"sentence2": "REST API is an architectural style for building web services using HTTP requests.", |
|
|
"score": 0.84 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "What is version control?", |
|
|
"sentence2": "Version control is a system that tracks changes to files over time.", |
|
|
"score": 0.86 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Define responsive design", |
|
|
"sentence2": "Responsive design ensures websites work well on all devices and screen sizes.", |
|
|
"score": 0.88 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "What is encryption?", |
|
|
"sentence2": "Encryption is the process of encoding information to prevent unauthorized access.", |
|
|
"score": 0.91 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Explain agile methodology", |
|
|
"sentence2": "Agile is an iterative approach to project management focused on flexibility.", |
|
|
"score": 0.83 |
|
|
}, |
|
|
] |
|
|
|
|
|
logger.info(f"Generated {len(pairs)} question-answer pairs") |
|
|
return pairs |
|
|
|
|
|
def generate_domain_specific_pairs(self) -> List[Dict]: |
|
|
""" |
|
|
Generate domain-specific sentence pairs. |
|
|
Score: Various |
|
|
""" |
|
|
pairs = [ |
|
|
|
|
|
{ |
|
|
"sentence1": "Python list comprehension", |
|
|
"sentence2": "Creating lists in Python efficiently", |
|
|
"score": 0.86 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Git merge conflicts", |
|
|
"sentence2": "Resolving version control conflicts", |
|
|
"score": 0.84 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "React component lifecycle", |
|
|
"sentence2": "Understanding React hooks", |
|
|
"score": 0.72 |
|
|
}, |
|
|
|
|
|
|
|
|
{ |
|
|
"sentence1": "Blood pressure medication", |
|
|
"sentence2": "Treating hypertension", |
|
|
"score": 0.78 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Physical therapy exercises", |
|
|
"sentence2": "Rehabilitation program", |
|
|
"score": 0.80 |
|
|
}, |
|
|
|
|
|
|
|
|
{ |
|
|
"sentence1": "Investment portfolio diversification", |
|
|
"sentence2": "Managing financial risk", |
|
|
"score": 0.75 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Mortgage interest rates", |
|
|
"sentence2": "Home loan options", |
|
|
"score": 0.82 |
|
|
}, |
|
|
|
|
|
|
|
|
{ |
|
|
"sentence1": "Online course platforms", |
|
|
"sentence2": "E-learning systems", |
|
|
"score": 0.88 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Study techniques for exams", |
|
|
"sentence2": "Test preparation strategies", |
|
|
"score": 0.85 |
|
|
}, |
|
|
|
|
|
|
|
|
{ |
|
|
"sentence1": "Product recommendation system", |
|
|
"sentence2": "Personalized shopping suggestions", |
|
|
"score": 0.83 |
|
|
}, |
|
|
{ |
|
|
"sentence1": "Shopping cart abandonment", |
|
|
"sentence2": "Incomplete purchase behavior", |
|
|
"score": 0.86 |
|
|
}, |
|
|
] |
|
|
|
|
|
logger.info(f"Generated {len(pairs)} domain-specific pairs") |
|
|
return pairs |
|
|
|
|
|
def format_for_training(self, pairs: List[Dict]) -> List[Dict]: |
|
|
""" |
|
|
Format sentence pairs for training. |
|
|
|
|
|
Args: |
|
|
pairs: List of sentence pair dictionaries |
|
|
|
|
|
Returns: |
|
|
Formatted training examples |
|
|
""" |
|
|
formatted = [] |
|
|
|
|
|
for pair in pairs: |
|
|
formatted.append({ |
|
|
"sentence1": pair["sentence1"], |
|
|
"sentence2": pair["sentence2"], |
|
|
"score": pair["score"] |
|
|
}) |
|
|
|
|
|
return formatted |
|
|
|
|
|
def create_contrastive_examples(self, pairs: List[Dict]) -> List[Dict]: |
|
|
""" |
|
|
Create contrastive examples (anchor, positive, negative). |
|
|
|
|
|
Args: |
|
|
pairs: Sentence pairs with scores |
|
|
|
|
|
Returns: |
|
|
Triplet examples |
|
|
""" |
|
|
contrastive = [] |
|
|
|
|
|
high_sim = [p for p in pairs if p["score"] >= 0.80] |
|
|
low_sim = [p for p in pairs if p["score"] <= 0.30] |
|
|
|
|
|
for positive_pair in high_sim[:20]: |
|
|
|
|
|
if low_sim: |
|
|
negative_pair = random.choice(low_sim) |
|
|
contrastive.append({ |
|
|
"anchor": positive_pair["sentence1"], |
|
|
"positive": positive_pair["sentence2"], |
|
|
"negative": negative_pair["sentence2"] |
|
|
}) |
|
|
|
|
|
logger.info(f"Created {len(contrastive)} contrastive examples") |
|
|
return contrastive |
|
|
|
|
|
def save_data(self, data: List[Dict], filename: str, format: str = "json"): |
|
|
"""Save training data to file.""" |
|
|
filepath = self.output_dir / filename |
|
|
|
|
|
if format == "json": |
|
|
with open(filepath, 'w', encoding='utf-8') as f: |
|
|
json.dump(data, f, indent=2, ensure_ascii=False) |
|
|
elif format == "jsonl": |
|
|
with open(filepath, 'w', encoding='utf-8') as f: |
|
|
for item in data: |
|
|
f.write(json.dumps(item, ensure_ascii=False) + '\n') |
|
|
|
|
|
logger.info(f"Saved {len(data)} examples to {filepath}") |
|
|
|
|
|
def generate_full_dataset(self, format: str = "json") -> str: |
|
|
""" |
|
|
Generate complete embeddings training dataset. |
|
|
|
|
|
Args: |
|
|
format: Output format ('json' or 'jsonl') |
|
|
|
|
|
Returns: |
|
|
Output directory path |
|
|
""" |
|
|
logger.info("Generating embeddings training dataset...") |
|
|
|
|
|
|
|
|
all_pairs = [] |
|
|
|
|
|
paraphrase_pairs = self.generate_paraphrase_pairs() |
|
|
all_pairs.extend(paraphrase_pairs) |
|
|
|
|
|
similar_pairs = self.generate_similar_pairs() |
|
|
all_pairs.extend(similar_pairs) |
|
|
|
|
|
dissimilar_pairs = self.generate_dissimilar_pairs() |
|
|
all_pairs.extend(dissimilar_pairs) |
|
|
|
|
|
qa_pairs = self.generate_question_answer_pairs() |
|
|
all_pairs.extend(qa_pairs) |
|
|
|
|
|
domain_pairs = self.generate_domain_specific_pairs() |
|
|
all_pairs.extend(domain_pairs) |
|
|
|
|
|
|
|
|
random.shuffle(all_pairs) |
|
|
|
|
|
|
|
|
split_idx = int(len(all_pairs) * 0.9) |
|
|
train_pairs = all_pairs[:split_idx] |
|
|
val_pairs = all_pairs[split_idx:] |
|
|
|
|
|
logger.info(f"Train: {len(train_pairs)} pairs") |
|
|
logger.info(f"Validation: {len(val_pairs)} pairs") |
|
|
|
|
|
|
|
|
train_data = self.format_for_training(train_pairs) |
|
|
val_data = self.format_for_training(val_pairs) |
|
|
|
|
|
|
|
|
self.save_data(train_data, f"train_pairs.{format}", format) |
|
|
self.save_data(val_data, f"validation_pairs.{format}", format) |
|
|
|
|
|
|
|
|
contrastive_data = self.create_contrastive_examples(all_pairs) |
|
|
self.save_data(contrastive_data, f"contrastive_triplets.{format}", format) |
|
|
|
|
|
|
|
|
stats = { |
|
|
"total_pairs": len(all_pairs), |
|
|
"train_size": len(train_pairs), |
|
|
"validation_size": len(val_pairs), |
|
|
"contrastive_triplets": len(contrastive_data), |
|
|
"paraphrase_pairs": len(paraphrase_pairs), |
|
|
"similar_pairs": len(similar_pairs), |
|
|
"dissimilar_pairs": len(dissimilar_pairs), |
|
|
"qa_pairs": len(qa_pairs), |
|
|
"domain_pairs": len(domain_pairs), |
|
|
"score_distribution": { |
|
|
"high (0.8-1.0)": len([p for p in all_pairs if p["score"] >= 0.8]), |
|
|
"medium (0.5-0.8)": len([p for p in all_pairs if 0.5 <= p["score"] < 0.8]), |
|
|
"low (0.0-0.5)": len([p for p in all_pairs if p["score"] < 0.5]) |
|
|
}, |
|
|
"generated_at": datetime.now().isoformat(), |
|
|
"format": format |
|
|
} |
|
|
|
|
|
self.save_data(stats, "embeddings_dataset_stats.json", "json") |
|
|
|
|
|
logger.info("="*60) |
|
|
logger.info("β
Embeddings dataset generation complete!") |
|
|
logger.info(f"Total pairs: {len(all_pairs)}") |
|
|
logger.info(f"Output directory: {self.output_dir}") |
|
|
logger.info("="*60) |
|
|
|
|
|
return str(self.output_dir) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function for data generation.""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
|
description="Generate training data for Helion-V1-Embeddings" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
default="./embeddings_training_data", |
|
|
help="Output directory for training data" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--format", |
|
|
choices=["json", "jsonl"], |
|
|
default="json", |
|
|
help="Output format" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
generator = EmbeddingsDataGenerator(output_dir=args.output_dir) |
|
|
output_path = generator.generate_full_dataset(format=args.format) |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("π― Embeddings Training Data Ready!") |
|
|
print("="*60) |
|
|
print(f"π Location: {output_path}") |
|
|
print(f"π Format: {args.format}") |
|
|
print("\nπ Files created:") |
|
|
print(f" β’ train_pairs.{args.format} - Training sentence pairs") |
|
|
print(f" β’ validation_pairs.{args.format} - Validation pairs") |
|
|
print(f" β’ contrastive_triplets.{args.format} - Triplet examples") |
|
|
print(" β’ embeddings_dataset_stats.json - Dataset statistics") |
|
|
print("\nπ‘ Training data includes:") |
|
|
print(" β’ Paraphrase pairs (high similarity)") |
|
|
print(" β’ Similar concept pairs (medium similarity)") |
|
|
print(" β’ Dissimilar pairs (low similarity)") |
|
|
print(" β’ Question-answer pairs") |
|
|
print(" β’ Domain-specific examples") |
|
|
print("\nπ Next step:") |
|
|
print(f" python train_embeddings.py --data-file {output_path}/train_pairs.{args.format}") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |