Skip to content

Commit d4f37fb

Browse files
committed
refactor: improve tomlir trait
1 parent 74b581c commit d4f37fb

File tree

5 files changed

+275
-262
lines changed

5 files changed

+275
-262
lines changed

descend_derive/build.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::path::Path;
44
fn main() {
55
// Tell Cargo to re-run this build script if any .desc files change
66
let examples_dir = "examples/core";
7-
7+
88
if Path::new(examples_dir).exists() {
99
// Walk through the directory and tell Cargo to re-run if any .desc files change
1010
if let Ok(entries) = fs::read_dir(examples_dir) {
@@ -17,7 +17,7 @@ fn main() {
1717
}
1818
}
1919
}
20-
20+
2121
// Also watch the entire directory for new files
2222
println!("cargo:rerun-if-changed={}", examples_dir);
2323
}

src/codegen/mlir/mod.rs

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,13 @@ use builder::MlirBuilder;
3636
use error::MlirError;
3737
use melior::{
3838
dialect::DialectRegistry,
39-
ir::{Location, Module, operation::OperationLike},
39+
ir::{operation::OperationLike, Location, Module},
4040
utility::register_all_dialects,
4141
Context,
4242
};
4343

4444
use crate::ast::CompilUnit;
45+
use crate::ast::{DataTyKind, Memory, TyKind};
4546

4647
/// Internal helper function to build MLIR module
4748
fn build_module_internal(comp_unit: &CompilUnit) -> Result<String, MlirError> {
@@ -92,13 +93,10 @@ pub fn gen_checked(comp_unit: &CompilUnit, _idx_checks: bool) -> Result<String,
9293
fn needs_hivm_address_space(comp_unit: &CompilUnit) -> bool {
9394
for item in &comp_unit.items {
9495
if let crate::ast::Item::FunDef(fun) = item {
95-
// Only check the main function or functions that are not HIVM placeholders
96-
if fun.ident.name == "main".into() || !is_hivm_placeholder_function(fun) {
97-
for param in &fun.param_decls {
98-
if let Some(ty) = &param.ty {
99-
if has_gpu_memory(ty) {
100-
return true;
101-
}
96+
for param in &fun.param_decls {
97+
if let Some(ty) = &param.ty {
98+
if has_gpu_memory(ty) {
99+
return true;
102100
}
103101
}
104102
}
@@ -107,31 +105,19 @@ fn needs_hivm_address_space(comp_unit: &CompilUnit) -> bool {
107105
false
108106
}
109107

110-
/// Check if a function is a HIVM placeholder function
111-
fn is_hivm_placeholder_function(fun: &crate::ast::FunDef) -> bool {
112-
fun.ident.name.starts_with("hivm_")
113-
}
114-
115108
/// Check if a type has GPU memory qualifiers
116109
fn has_gpu_memory(ty: &crate::ast::Ty) -> bool {
110+
fn mem_is_gpu(mem: &Memory) -> bool {
111+
matches!(
112+
mem,
113+
Memory::GpuGlobal | Memory::GpuShared | Memory::GpuLocal
114+
)
115+
}
116+
117117
match &ty.ty {
118-
crate::ast::TyKind::Data(data_ty) => match &data_ty.dty {
119-
crate::ast::DataTyKind::At(_, mem) => {
120-
matches!(
121-
mem,
122-
crate::ast::Memory::GpuGlobal
123-
| crate::ast::Memory::GpuShared
124-
| crate::ast::Memory::GpuLocal
125-
)
126-
}
127-
crate::ast::DataTyKind::Ref(ref_dty) => {
128-
matches!(
129-
ref_dty.mem,
130-
crate::ast::Memory::GpuGlobal
131-
| crate::ast::Memory::GpuShared
132-
| crate::ast::Memory::GpuLocal
133-
)
134-
}
118+
TyKind::Data(data_ty) => match &data_ty.dty {
119+
DataTyKind::At(_, mem) => mem_is_gpu(mem),
120+
DataTyKind::Ref(ref_dty) => mem_is_gpu(&ref_dty.mem),
135121
_ => false,
136122
},
137123
_ => false,

0 commit comments

Comments
 (0)