From a156bdc8e6156641b9a29699c461fed0b0866cc6 Mon Sep 17 00:00:00 2001 From: boschmitt <7152025+boschmitt@users.noreply.github.com> Date: Fri, 29 Nov 2024 16:45:51 +0100 Subject: [PATCH] [python] Remove unused MLIR components We don't need to take everything from MLIR for our python bindings. This change cherry picks the upstream components our compiler depends on. The commit also cleans up some unnecessary code that ends up registering dialects more than once, and surfaces the `register_all_dialects` function to a less obscure location. Signed-off-by: boschmitt <7152025+boschmitt@users.noreply.github.com> --- python/cudaq/kernel/ast_bridge.py | 2 - python/cudaq/kernel/kernel_builder.py | 5 +- python/cudaq/mlir/__init__.py | 9 +++ python/extension/CMakeLists.txt | 22 ++++-- python/runtime/mlir/py_register_dialects.cpp | 78 +++++++------------- python/tests/mlir/bare.py | 3 +- 6 files changed, 54 insertions(+), 65 deletions(-) create mode 100644 python/cudaq/mlir/__init__.py diff --git a/python/cudaq/kernel/ast_bridge.py b/python/cudaq/kernel/ast_bridge.py index 1e062e42a7..aa8f3d4555 100644 --- a/python/cudaq/kernel/ast_bridge.py +++ b/python/cudaq/kernel/ast_bridge.py @@ -123,8 +123,6 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs): else: self.ctx = Context() register_all_dialects(self.ctx) - quake.register_dialect(self.ctx) - cc.register_dialect(self.ctx) cudaq_runtime.registerLLVMDialectTranslation(self.ctx) self.loc = Location.unknown(context=self.ctx) self.module = Module.create(loc=self.loc) diff --git a/python/cudaq/kernel/kernel_builder.py b/python/cudaq/kernel/kernel_builder.py index 3e83f16147..74be3b1272 100644 --- a/python/cudaq/kernel/kernel_builder.py +++ b/python/cudaq/kernel/kernel_builder.py @@ -37,7 +37,8 @@ # We need static initializers to run in the CAPI `ExecutionEngine`, # so here we run a simple JIT compile at global scope -with Context(): +with Context() as ctx: + register_all_dialects(ctx) module = Module.parse(r""" llvm.func @none() { llvm.return @@ -246,8 +247,6 @@ class PyKernel(object): def __init__(self, argTypeList): self.ctx = Context() register_all_dialects(self.ctx) - quake.register_dialect(self.ctx) - cc.register_dialect(self.ctx) cudaq_runtime.registerLLVMDialectTranslation(self.ctx) self.metadata = {'conditionalOnMeasure': False} diff --git a/python/cudaq/mlir/__init__.py b/python/cudaq/mlir/__init__.py new file mode 100644 index 0000000000..eda2e6614f --- /dev/null +++ b/python/cudaq/mlir/__init__.py @@ -0,0 +1,9 @@ +# ============================================================================ # +# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # + +from ._mlir_libs._quakeDialects import register_all_dialects diff --git a/python/extension/CMakeLists.txt b/python/extension/CMakeLists.txt index fe8431828e..91023373cd 100644 --- a/python/extension/CMakeLists.txt +++ b/python/extension/CMakeLists.txt @@ -119,10 +119,14 @@ add_mlir_python_common_capi_library(CUDAQuantumPythonCAPI RELATIVE_INSTALL_ROOT "../.." DECLARED_SOURCES CUDAQuantumPythonSources - # TODO: Remove this in favor of showing fine grained registration once - # available. - MLIRPythonExtension.RegisterEverything MLIRPythonSources.Core + MLIRPythonSources.Dialects.arith + MLIRPythonSources.Dialects.builtin + MLIRPythonSources.Dialects.cf + MLIRPythonSources.Dialects.complex + MLIRPythonSources.Dialects.func + MLIRPythonSources.Dialects.math + MLIRPythonSources.ExecutionEngine ) ################################################################################ @@ -134,10 +138,14 @@ add_mlir_python_modules(CUDAQuantumPythonModules INSTALL_PREFIX "cudaq/mlir" DECLARED_SOURCES CUDAQuantumPythonSources - # TODO: Remove this in favor of showing fine grained registration once - # available. - MLIRPythonExtension.RegisterEverything - MLIRPythonSources + MLIRPythonSources.Core + MLIRPythonSources.Dialects.arith + MLIRPythonSources.Dialects.builtin + MLIRPythonSources.Dialects.cf + MLIRPythonSources.Dialects.complex + MLIRPythonSources.Dialects.func + MLIRPythonSources.Dialects.math + MLIRPythonSources.ExecutionEngine COMMON_CAPI_LINK_LIBS CUDAQuantumPythonCAPI ) diff --git a/python/runtime/mlir/py_register_dialects.cpp b/python/runtime/mlir/py_register_dialects.cpp index 157a91d921..241c531525 100644 --- a/python/runtime/mlir/py_register_dialects.cpp +++ b/python/runtime/mlir/py_register_dialects.cpp @@ -6,19 +6,16 @@ * the terms of the Apache License 2.0 which accompanies this distribution. * ******************************************************************************/ -#include "mlir/Bindings/Python/PybindAdaptors.h" - +#include "py_register_dialects.h" #include "cudaq/Optimizer/Builder/Intrinsics.h" -#include "cudaq/Optimizer/CAPI/Dialects.h" #include "cudaq/Optimizer/CodeGen/Passes.h" -#include "cudaq/Optimizer/CodeGen/Pipelines.h" -#include "cudaq/Optimizer/Dialect/CC/CCDialect.h" -#include "cudaq/Optimizer/Dialect/CC/CCOps.h" #include "cudaq/Optimizer/Dialect/CC/CCTypes.h" -#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h" #include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h" +#include "cudaq/Optimizer/InitAllDialects.h" #include "cudaq/Optimizer/Transforms/Passes.h" -#include "mlir/InitAllDialects.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/CAPI/IR.h" +#include "mlir/Transforms/Passes.h" #include #include #include @@ -27,33 +24,9 @@ namespace py = pybind11; using namespace mlir::python::adaptors; using namespace mlir; -namespace cudaq { -static bool registered = false; - -void registerQuakeDialectAndTypes(py::module &m) { +static void registerQuakeTypes(py::module &m) { auto quakeMod = m.def_submodule("quake"); - quakeMod.def( - "register_dialect", - [](MlirContext context, bool load) { - MlirDialectHandle handle = mlirGetDialectHandle__quake__(); - mlirDialectHandleRegisterDialect(handle, context); - if (load) { - mlirDialectHandleLoadDialect(handle, context); - } - - if (!registered) { - cudaq::opt::registerOptCodeGenPasses(); - cudaq::opt::registerOptTransformsPasses(); - cudaq::opt::registerAggressiveEarlyInlining(); - cudaq::opt::registerUnrollingPipeline(); - cudaq::opt::registerTargetPipelines(); - cudaq::opt::registerMappingPipeline(); - registered = true; - } - }, - py::arg("context") = py::none(), py::arg("load") = true); - mlir_type_subclass(quakeMod, "RefType", [](MlirType type) { return unwrap(type).isa(); }).def_classmethod("get", [](py::object cls, MlirContext ctx) { @@ -144,21 +117,10 @@ void registerQuakeDialectAndTypes(py::module &m) { }); } -void registerCCDialectAndTypes(py::module &m) { +static void registerCCTypes(py::module &m) { auto ccMod = m.def_submodule("cc"); - ccMod.def( - "register_dialect", - [](MlirContext context, bool load) { - MlirDialectHandle ccHandle = mlirGetDialectHandle__cc__(); - mlirDialectHandleRegisterDialect(ccHandle, context); - if (load) { - mlirDialectHandleLoadDialect(ccHandle, context); - } - }, - py::arg("context") = py::none(), py::arg("load") = true); - mlir_type_subclass(ccMod, "CharspanType", [](MlirType type) { return unwrap(type).isa(); }).def_classmethod("get", [](py::object cls, MlirContext ctx) { @@ -298,10 +260,9 @@ void registerCCDialectAndTypes(py::module &m) { }); } -void bindRegisterDialects(py::module &mod) { - registerQuakeDialectAndTypes(mod); - registerCCDialectAndTypes(mod); +static bool registered = false; +void cudaq::bindRegisterDialects(py::module &mod) { mod.def("load_intrinsic", [](MlirModule module, std::string name) { auto unwrapped = unwrap(module); cudaq::IRBuilder builder = IRBuilder::atBlockEnd(unwrapped.getBody()); @@ -311,14 +272,28 @@ void bindRegisterDialects(py::module &mod) { mod.def("register_all_dialects", [](MlirContext context) { DialectRegistry registry; - registry.insert(); + cudaq::registerAllDialects(registry); cudaq::opt::registerCodeGenDialect(registry); - registerAllDialects(registry); - auto *mlirContext = unwrap(context); + MLIRContext *mlirContext = unwrap(context); mlirContext->appendDialectRegistry(registry); mlirContext->loadAllAvailableDialects(); }); + // Register type and passes once, when the module is loaded. + registerQuakeTypes(mod); + registerCCTypes(mod); + + if (!registered) { + mlir::registerTransformsPasses(); + cudaq::opt::registerOptCodeGenPasses(); + cudaq::opt::registerOptTransformsPasses(); + cudaq::opt::registerAggressiveEarlyInlining(); + cudaq::opt::registerUnrollingPipeline(); + cudaq::opt::registerTargetPipelines(); + cudaq::opt::registerMappingPipeline(); + registered = true; + } + mod.def("gen_vector_of_complex_constant", [](MlirLocation loc, MlirModule module, std::string name, @@ -330,4 +305,3 @@ void bindRegisterDialects(py::module &mod) { builder.genVectorOfConstants(unwrap(loc), modOp, name, newValues); }); } -} // namespace cudaq diff --git a/python/tests/mlir/bare.py b/python/tests/mlir/bare.py index 2b007b5a10..e7a2d71049 100644 --- a/python/tests/mlir/bare.py +++ b/python/tests/mlir/bare.py @@ -8,12 +8,13 @@ # RUN: PYTHONPATH=../../ python3 %s | FileCheck %s +from cudaq.mlir import register_all_dialects from cudaq.mlir.ir import * from cudaq.mlir.dialects import quake from cudaq.mlir.dialects import builtin, func, arith with Context() as ctx: - quake.register_dialect() + register_all_dialects(ctx) m = Module.create(loc=Location.unknown()) with InsertionPoint(m.body), Location.unknown(): f = func.FuncOp('main', ([], []))