diff --git a/python/cudaq/kernel/ast_bridge.py b/python/cudaq/kernel/ast_bridge.py index ba5e664ea85..6162470c7f8 100644 --- a/python/cudaq/kernel/ast_bridge.py +++ b/python/cudaq/kernel/ast_bridge.py @@ -123,8 +123,6 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs): else: self.ctx = Context() register_all_dialects(self.ctx) - quake.register_dialect(self.ctx) - cc.register_dialect(self.ctx) cudaq_runtime.registerLLVMDialectTranslation(self.ctx) self.loc = Location.unknown(context=self.ctx) self.module = Module.create(loc=self.loc) diff --git a/python/cudaq/kernel/kernel_builder.py b/python/cudaq/kernel/kernel_builder.py index 3e83f161470..74be3b12727 100644 --- a/python/cudaq/kernel/kernel_builder.py +++ b/python/cudaq/kernel/kernel_builder.py @@ -37,7 +37,8 @@ # We need static initializers to run in the CAPI `ExecutionEngine`, # so here we run a simple JIT compile at global scope -with Context(): +with Context() as ctx: + register_all_dialects(ctx) module = Module.parse(r""" llvm.func @none() { llvm.return @@ -246,8 +247,6 @@ class PyKernel(object): def __init__(self, argTypeList): self.ctx = Context() register_all_dialects(self.ctx) - quake.register_dialect(self.ctx) - cc.register_dialect(self.ctx) cudaq_runtime.registerLLVMDialectTranslation(self.ctx) self.metadata = {'conditionalOnMeasure': False} diff --git a/python/cudaq/mlir/__init__.py b/python/cudaq/mlir/__init__.py new file mode 100644 index 00000000000..eda2e6614f5 --- /dev/null +++ b/python/cudaq/mlir/__init__.py @@ -0,0 +1,9 @@ +# ============================================================================ # +# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # + +from ._mlir_libs._quakeDialects import register_all_dialects diff --git a/python/extension/CMakeLists.txt b/python/extension/CMakeLists.txt index fe8431828e0..91023373cd3 100644 --- a/python/extension/CMakeLists.txt +++ b/python/extension/CMakeLists.txt @@ -119,10 +119,14 @@ add_mlir_python_common_capi_library(CUDAQuantumPythonCAPI RELATIVE_INSTALL_ROOT "../.." DECLARED_SOURCES CUDAQuantumPythonSources - # TODO: Remove this in favor of showing fine grained registration once - # available. - MLIRPythonExtension.RegisterEverything MLIRPythonSources.Core + MLIRPythonSources.Dialects.arith + MLIRPythonSources.Dialects.builtin + MLIRPythonSources.Dialects.cf + MLIRPythonSources.Dialects.complex + MLIRPythonSources.Dialects.func + MLIRPythonSources.Dialects.math + MLIRPythonSources.ExecutionEngine ) ################################################################################ @@ -134,10 +138,14 @@ add_mlir_python_modules(CUDAQuantumPythonModules INSTALL_PREFIX "cudaq/mlir" DECLARED_SOURCES CUDAQuantumPythonSources - # TODO: Remove this in favor of showing fine grained registration once - # available. - MLIRPythonExtension.RegisterEverything - MLIRPythonSources + MLIRPythonSources.Core + MLIRPythonSources.Dialects.arith + MLIRPythonSources.Dialects.builtin + MLIRPythonSources.Dialects.cf + MLIRPythonSources.Dialects.complex + MLIRPythonSources.Dialects.func + MLIRPythonSources.Dialects.math + MLIRPythonSources.ExecutionEngine COMMON_CAPI_LINK_LIBS CUDAQuantumPythonCAPI ) diff --git a/python/runtime/mlir/py_register_dialects.cpp b/python/runtime/mlir/py_register_dialects.cpp index 157a91d9211..2c831229d6a 100644 --- a/python/runtime/mlir/py_register_dialects.cpp +++ b/python/runtime/mlir/py_register_dialects.cpp @@ -6,19 +6,15 @@ * the terms of the Apache License 2.0 which accompanies this distribution. * ******************************************************************************/ -#include "mlir/Bindings/Python/PybindAdaptors.h" - +#include "py_register_dialects.h" #include "cudaq/Optimizer/Builder/Intrinsics.h" -#include "cudaq/Optimizer/CAPI/Dialects.h" #include "cudaq/Optimizer/CodeGen/Passes.h" -#include "cudaq/Optimizer/CodeGen/Pipelines.h" -#include "cudaq/Optimizer/Dialect/CC/CCDialect.h" -#include "cudaq/Optimizer/Dialect/CC/CCOps.h" #include "cudaq/Optimizer/Dialect/CC/CCTypes.h" -#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h" #include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h" -#include "cudaq/Optimizer/Transforms/Passes.h" -#include "mlir/InitAllDialects.h" +#include "cudaq/Optimizer/InitAllDialects.h" +#include "cudaq/Optimizer/InitAllPasses.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/CAPI/IR.h" #include #include #include @@ -27,33 +23,9 @@ namespace py = pybind11; using namespace mlir::python::adaptors; using namespace mlir; -namespace cudaq { -static bool registered = false; - -void registerQuakeDialectAndTypes(py::module &m) { +static void registerQuakeTypes(py::module &m) { auto quakeMod = m.def_submodule("quake"); - quakeMod.def( - "register_dialect", - [](MlirContext context, bool load) { - MlirDialectHandle handle = mlirGetDialectHandle__quake__(); - mlirDialectHandleRegisterDialect(handle, context); - if (load) { - mlirDialectHandleLoadDialect(handle, context); - } - - if (!registered) { - cudaq::opt::registerOptCodeGenPasses(); - cudaq::opt::registerOptTransformsPasses(); - cudaq::opt::registerAggressiveEarlyInlining(); - cudaq::opt::registerUnrollingPipeline(); - cudaq::opt::registerTargetPipelines(); - cudaq::opt::registerMappingPipeline(); - registered = true; - } - }, - py::arg("context") = py::none(), py::arg("load") = true); - mlir_type_subclass(quakeMod, "RefType", [](MlirType type) { return unwrap(type).isa(); }).def_classmethod("get", [](py::object cls, MlirContext ctx) { @@ -144,21 +116,10 @@ void registerQuakeDialectAndTypes(py::module &m) { }); } -void registerCCDialectAndTypes(py::module &m) { +static void registerCCTypes(py::module &m) { auto ccMod = m.def_submodule("cc"); - ccMod.def( - "register_dialect", - [](MlirContext context, bool load) { - MlirDialectHandle ccHandle = mlirGetDialectHandle__cc__(); - mlirDialectHandleRegisterDialect(ccHandle, context); - if (load) { - mlirDialectHandleLoadDialect(ccHandle, context); - } - }, - py::arg("context") = py::none(), py::arg("load") = true); - mlir_type_subclass(ccMod, "CharspanType", [](MlirType type) { return unwrap(type).isa(); }).def_classmethod("get", [](py::object cls, MlirContext ctx) { @@ -298,10 +259,7 @@ void registerCCDialectAndTypes(py::module &m) { }); } -void bindRegisterDialects(py::module &mod) { - registerQuakeDialectAndTypes(mod); - registerCCDialectAndTypes(mod); - +void cudaq::bindRegisterDialects(py::module &mod) { mod.def("load_intrinsic", [](MlirModule module, std::string name) { auto unwrapped = unwrap(module); cudaq::IRBuilder builder = IRBuilder::atBlockEnd(unwrapped.getBody()); @@ -311,14 +269,18 @@ void bindRegisterDialects(py::module &mod) { mod.def("register_all_dialects", [](MlirContext context) { DialectRegistry registry; - registry.insert(); + cudaq::registerAllDialects(registry); cudaq::opt::registerCodeGenDialect(registry); - registerAllDialects(registry); - auto *mlirContext = unwrap(context); + MLIRContext *mlirContext = unwrap(context); mlirContext->appendDialectRegistry(registry); mlirContext->loadAllAvailableDialects(); }); + // Register type as passes once, when the module is loaded. + registerQuakeTypes(mod); + registerCCTypes(mod); + cudaq::registerAllPasses(); + mod.def("gen_vector_of_complex_constant", [](MlirLocation loc, MlirModule module, std::string name, @@ -330,4 +292,3 @@ void bindRegisterDialects(py::module &mod) { builder.genVectorOfConstants(unwrap(loc), modOp, name, newValues); }); } -} // namespace cudaq diff --git a/python/tests/mlir/bare.py b/python/tests/mlir/bare.py index 2b007b5a100..e7a2d71049c 100644 --- a/python/tests/mlir/bare.py +++ b/python/tests/mlir/bare.py @@ -8,12 +8,13 @@ # RUN: PYTHONPATH=../../ python3 %s | FileCheck %s +from cudaq.mlir import register_all_dialects from cudaq.mlir.ir import * from cudaq.mlir.dialects import quake from cudaq.mlir.dialects import builtin, func, arith with Context() as ctx: - quake.register_dialect() + register_all_dialects(ctx) m = Module.create(loc=Location.unknown()) with InsertionPoint(m.body), Location.unknown(): f = func.FuncOp('main', ([], []))