jruffle commited on
Commit
9d5fcb3
·
verified ·
1 Parent(s): 4c93b7d

Upload run_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_inference.py +113 -0
run_inference.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Inference script for transcriptome autoencoder model
4
+ Generated automatically during training
5
+ """
6
+
7
+ import torch
8
+ import pandas as pd
9
+ import numpy as np
10
+ import json
11
+ import argparse
12
+ import os
13
+
14
+ def load_model_and_config(model_dir):
15
+ """Load the trained model and its configuration"""
16
+ config_path = os.path.join(model_dir, 'model_config.json')
17
+ with open(config_path, 'r') as f:
18
+ config = json.load(f)
19
+
20
+ # Load model
21
+ model_file = config['model_info']['saved_model_file']
22
+ model_path = os.path.join(model_dir, model_file)
23
+
24
+ # Reconstruct model architecture based on model type
25
+ from compress_data_unified import SimpleAE, AE
26
+
27
+ latent_dims = config['model_info']['latent_dims']
28
+ input_dim = config['model_info']['input_dim']
29
+ layer_sizes = config['model_info']['layer_sizes']
30
+ model_type = config['model_info']['model_type']
31
+
32
+ if model_type == 'SimpleAE':
33
+ if isinstance(layer_sizes, list) and len(layer_sizes) > 1:
34
+ # If wrapped in AE class
35
+ model = AE(layer_sizes, use_simple=True)
36
+ else:
37
+ # Direct SimpleAE
38
+ model = SimpleAE(input_dim, latent_dims)
39
+ else:
40
+ # Standard AE
41
+ model = AE(layer_sizes, use_simple=False)
42
+
43
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
44
+ model.eval()
45
+
46
+ return model, config
47
+
48
+ def preprocess_data(data, config):
49
+ """Apply same preprocessing as training"""
50
+ # Normalize to [-1, 1] range exactly as done in training
51
+ eps = 1e-8
52
+ min_val = np.nanmin(data)
53
+ max_val = np.nanmax(data)
54
+ if max_val - min_val < eps:
55
+ return data
56
+ normalized = 2 * (data - min_val) / (max_val - min_val + eps) - 1
57
+ return normalized
58
+
59
+ def run_inference(model_dir, input_data_path, output_path=None):
60
+ """Run inference on new data"""
61
+ model, config = load_model_and_config(model_dir)
62
+
63
+ # Load and preprocess data
64
+ data = pd.read_csv(input_data_path, index_col=0)
65
+ data_processed = preprocess_data(data, config)
66
+
67
+ # Convert to tensor
68
+ data_tensor = torch.FloatTensor(data_processed.values)
69
+
70
+ # Run inference
71
+ with torch.no_grad():
72
+ # Encode to latent space
73
+ latent = model.encode(data_tensor)
74
+ # Decode back to original space
75
+ reconstructed = model.decode(latent)
76
+
77
+ # Convert back to dataframes
78
+ latent_df = pd.DataFrame(latent.numpy(),
79
+ index=data.index,
80
+ columns=[f'latent_{i+1}' for i in range(config['model_info']['latent_dims'])])
81
+
82
+ reconstructed_df = pd.DataFrame(reconstructed.numpy(),
83
+ index=data.index,
84
+ columns=data.columns)
85
+
86
+ # Save results
87
+ if output_path is None:
88
+ output_path = 'inference_results'
89
+
90
+ os.makedirs(output_path, exist_ok=True)
91
+ latent_df.to_csv(os.path.join(output_path, 'latent_representation.csv'))
92
+ reconstructed_df.to_csv(os.path.join(output_path, 'reconstructed_data.csv'))
93
+
94
+ print(f"Inference completed:")
95
+ print(f" Latent representation saved: {os.path.join(output_path, 'latent_representation.csv')}")
96
+ print(f" Reconstructed data saved: {os.path.join(output_path, 'reconstructed_data.csv')}")
97
+
98
+ return latent_df, reconstructed_df
99
+
100
+ if __name__ == "__main__":
101
+ parser = argparse.ArgumentParser(description='Run inference with trained autoencoder')
102
+ parser.add_argument('--model_dir', type=str, required=True,
103
+ help='Directory containing trained model and config')
104
+ parser.add_argument('--input_data', type=str, required=True,
105
+ help='Path to input data CSV file')
106
+ parser.add_argument('--output_dir', type=str, default='inference_results',
107
+ help='Output directory for results')
108
+
109
+ args = parser.parse_args()
110
+
111
+ latent, reconstructed = run_inference(args.model_dir, args.input_data, args.output_dir)
112
+ print(f"Latent dimensions: {latent.shape}")
113
+ print(f"Reconstructed dimensions: {reconstructed.shape}")