-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprofiling_tools.rs
74 lines (65 loc) · 2.46 KB
/
profiling_tools.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use std::time::Instant;
use tch::{nn, Device, Tensor};
// Struct to store profiling information
struct ProfilingInfo {
name: String,
duration: std::time::Duration,
}
// Function to generate audio samples using the pre-trained model
fn generate_audio_samples(model: &tch::CModule, num_samples: usize) -> Vec<f32> {
// Generate random noise as input to the model
let noise = Tensor::rand(&[1, 100], kind::FLOAT_CPU);
// Use the loaded model to generate audio samples
let output = model
.forward_ts(&[noise])
.expect("Model forward pass failed") // Possible Error: Forward pass might fail
.squeeze()
.to(Device::Cpu)
.data()
.as_slice::<f32>()
.expect("Failed to convert tensor data to slice") // Possible Error: Conversion might fail
.to_vec();
// Truncate or pad the output to the desired number of samples
if output.len() >= num_samples {
output[..num_samples].to_vec()
} else {
let mut samples = output;
samples.resize(num_samples, 0.0); // Pad with zeros if not enough samples
samples
}
}
// Python binding function to generate audio samples with profiling
#[pyfunction]
fn generate_audio_with_profiling(
model_path: &str,
num_samples: usize,
) -> PyResult<(Vec<f32>, Vec<ProfilingInfo>)> {
let mut profiling_info = Vec::new();
// Load the pre-trained model
let start_time = Instant::now();
let model = tch::CModule::load(model_path).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Failed to load model: {}", e))
// Possible Error: Model loading might fail due to incorrect path or file format
// Solution: Ensure the model path is correct and the file is a valid PyTorch model
})?;
profiling_info.push(ProfilingInfo {
name: "Model Loading".to_string(),
duration: start_time.elapsed(),
});
// Generate audio samples
let start_time = Instant::now();
let samples = generate_audio_samples(&model, num_samples);
profiling_info.push(ProfilingInfo {
name: "Audio Generation".to_string(),
duration: start_time.elapsed(),
});
Ok((samples, profiling_info))
}
// Python module to expose the audio generation function with profiling
#[pymodule]
fn profiling_tools(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(generate_audio_with_profiling, m)?)?;
Ok(())
}