-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdeployment_and_packaging.rs
103 lines (83 loc) · 3.93 KB
/
deployment_and_packaging.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use std::env;
use std::fs::File;
use std::io::Write;
use std::path::Path;
use tch::{nn, Device, Tensor};
use rand::Rng;
// Function to load a pre-trained PyTorch model
fn load_model(model_path: &str) -> Result<nn::Sequential, Box<dyn std::error::Error>> {
// Determine the device to use (CUDA if available, otherwise CPU)
let device = Device::cuda_if_available();
// Load the pre-trained model from the specified path
let model = nn::Sequential::load(model_path)?;
// Set the model to evaluation mode
model.set_eval();
// Move the model to the selected device
model.to_device(device);
// Return the loaded model
Ok(model)
}
// Function to generate audio samples using a pre-trained PyTorch model
fn generate_audio_samples(model: &nn::Sequential, num_samples: usize) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
// Generate random input data for the model
let mut rng = rand::thread_rng();
let input_data: Vec<f32> = (0..num_samples).map(|_| rng.gen()).collect();
// Convert the input data to a PyTorch tensor
let input_tensor = Tensor::of_slice(&input_data).view([-1, 1, num_samples as i64]);
// Use the pre-trained model to generate audio samples
let output_tensor = model.forward(&input_tensor).squeeze();
// Extract the output data as a vector of f32
let output_data = output_tensor.data().as_slice::<f32>()?.to_vec();
// Return the generated audio samples
Ok(output_data)
}
// Function to package the application and its dependencies
fn package_application(output_dir: &str) -> Result<(), Box<dyn std::error::Error>> {
// Create the output directory if it doesn't exist
if !Path::new(output_dir).exists() {
std::fs::create_dir(output_dir)?;
}
// Copy the pre-trained PyTorch model to the output directory
let model_path = "path/to/pretrained/model.pt";
let model_file_name = Path::new(model_path).file_name().unwrap().to_str().unwrap();
let output_model_path = format!("{}/{}", output_dir, model_file_name);
std::fs::copy(model_path, output_model_path)?;
// Create a requirements file with the necessary dependencies
let requirements_path = format!("{}/requirements.txt", output_dir);
let mut requirements_file = File::create(requirements_path)?;
requirements_file.write_all(b"torch\ntqdm\nnumpy\n")?;
// Copy the Rust executable to the output directory
let executable_path = env::current_exe()?;
let executable_file_name = executable_path.file_name().unwrap().to_str().unwrap();
let output_executable_path = format!("{}/{}", output_dir, executable_file_name);
std::fs::copy(executable_path, output_executable_path)?;
// Return Ok if all operations succeed
Ok(())
}
// Python module to expose the audio generator functionality
#[pymodule]
fn deployment_and_packaging(_py: Python, m: &PyModule) -> PyResult<()> {
// Define a Python function to generate audio samples
#[pyfn(m)]
#[pyo3(name = "generate_audio")]
fn generate_audio_py(model_path: &str, num_samples: usize) -> PyResult<Vec<f32>> {
// Load the pre-trained PyTorch model
let model = load_model(model_path)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Failed to load model: {}", e)))?;
// Generate audio samples using the pre-trained model
let samples = generate_audio_samples(&model, num_samples)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Failed to generate audio: {}", e)))?;
Ok(samples)
}
// Define a Python function to package the application
#[pyfn(m)]
#[pyo3(name = "package_application")]
fn package_application_py(output_dir: &str) -> PyResult<()> {
package_application(output_dir)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Failed to package application: {}", e)))?;
Ok(())
}
Ok(())
}