From e04b6f8ac7bd53bcf457e1bc12ffb838c707cf2f Mon Sep 17 00:00:00 2001 From: boschmitt <7152025+boschmitt@users.noreply.github.com> Date: Fri, 29 Nov 2024 16:45:51 +0100 Subject: [PATCH] [python] Remove unused MLIR components We don't need to take everything from MLIR for our python bindings. This change cherry picks the upstream components our compiler depends on. The commit also cleans up some unnecessary code that ends up registering dialects more than once, and surfaces the `register_all_dialects` function to a less obscure location. Signed-off-by: boschmitt <7152025+boschmitt@users.noreply.github.com> --- python/cudaq/kernel/ast_bridge.py | 2 - python/cudaq/kernel/kernel_builder.py | 5 +- python/cudaq/mlir/__init__.py | 9 +++ python/extension/CMakeLists.txt | 22 +++++-- python/runtime/mlir/py_register_dialects.cpp | 69 +++++--------------- python/tests/mlir/bare.py | 3 +- 6 files changed, 43 insertions(+), 67 deletions(-) create mode 100644 python/cudaq/mlir/__init__.py diff --git a/python/cudaq/kernel/ast_bridge.py b/python/cudaq/kernel/ast_bridge.py index ba5e664ea85..6162470c7f8 100644 --- a/python/cudaq/kernel/ast_bridge.py +++ b/python/cudaq/kernel/ast_bridge.py @@ -123,8 +123,6 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs): else: self.ctx = Context() register_all_dialects(self.ctx) - quake.register_dialect(self.ctx) - cc.register_dialect(self.ctx) cudaq_runtime.registerLLVMDialectTranslation(self.ctx) self.loc = Location.unknown(context=self.ctx) self.module = Module.create(loc=self.loc) diff --git a/python/cudaq/kernel/kernel_builder.py b/python/cudaq/kernel/kernel_builder.py index 3e83f161470..74be3b12727 100644 --- a/python/cudaq/kernel/kernel_builder.py +++ b/python/cudaq/kernel/kernel_builder.py @@ -37,7 +37,8 @@ # We need static initializers to run in the CAPI `ExecutionEngine`, # so here we run a simple JIT compile at global scope -with Context(): +with Context() as ctx: + register_all_dialects(ctx) module = Module.parse(r""" llvm.func @none() { llvm.return @@ -246,8 +247,6 @@ class PyKernel(object): def __init__(self, argTypeList): self.ctx = Context() register_all_dialects(self.ctx) - quake.register_dialect(self.ctx) - cc.register_dialect(self.ctx) cudaq_runtime.registerLLVMDialectTranslation(self.ctx) self.metadata = {'conditionalOnMeasure': False} diff --git a/python/cudaq/mlir/__init__.py b/python/cudaq/mlir/__init__.py new file mode 100644 index 00000000000..eda2e6614f5 --- /dev/null +++ b/python/cudaq/mlir/__init__.py @@ -0,0 +1,9 @@ +# ============================================================================ # +# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # + +from ._mlir_libs._quakeDialects import register_all_dialects diff --git a/python/extension/CMakeLists.txt b/python/extension/CMakeLists.txt index fe8431828e0..91023373cd3 100644 --- a/python/extension/CMakeLists.txt +++ b/python/extension/CMakeLists.txt @@ -119,10 +119,14 @@ add_mlir_python_common_capi_library(CUDAQuantumPythonCAPI RELATIVE_INSTALL_ROOT "../.." DECLARED_SOURCES CUDAQuantumPythonSources - # TODO: Remove this in favor of showing fine grained registration once - # available. - MLIRPythonExtension.RegisterEverything MLIRPythonSources.Core + MLIRPythonSources.Dialects.arith + MLIRPythonSources.Dialects.builtin + MLIRPythonSources.Dialects.cf + MLIRPythonSources.Dialects.complex + MLIRPythonSources.Dialects.func + MLIRPythonSources.Dialects.math + MLIRPythonSources.ExecutionEngine ) ################################################################################ @@ -134,10 +138,14 @@ add_mlir_python_modules(CUDAQuantumPythonModules INSTALL_PREFIX "cudaq/mlir" DECLARED_SOURCES CUDAQuantumPythonSources - # TODO: Remove this in favor of showing fine grained registration once - # available. - MLIRPythonExtension.RegisterEverything - MLIRPythonSources + MLIRPythonSources.Core + MLIRPythonSources.Dialects.arith + MLIRPythonSources.Dialects.builtin + MLIRPythonSources.Dialects.cf + MLIRPythonSources.Dialects.complex + MLIRPythonSources.Dialects.func + MLIRPythonSources.Dialects.math + MLIRPythonSources.ExecutionEngine COMMON_CAPI_LINK_LIBS CUDAQuantumPythonCAPI ) diff --git a/python/runtime/mlir/py_register_dialects.cpp b/python/runtime/mlir/py_register_dialects.cpp index 157a91d9211..2c831229d6a 100644 --- a/python/runtime/mlir/py_register_dialects.cpp +++ b/python/runtime/mlir/py_register_dialects.cpp @@ -6,19 +6,15 @@ * the terms of the Apache License 2.0 which accompanies this distribution. * ******************************************************************************/ -#include "mlir/Bindings/Python/PybindAdaptors.h" - +#include "py_register_dialects.h" #include "cudaq/Optimizer/Builder/Intrinsics.h" -#include "cudaq/Optimizer/CAPI/Dialects.h" #include "cudaq/Optimizer/CodeGen/Passes.h" -#include "cudaq/Optimizer/CodeGen/Pipelines.h" -#include "cudaq/Optimizer/Dialect/CC/CCDialect.h" -#include "cudaq/Optimizer/Dialect/CC/CCOps.h" #include "cudaq/Optimizer/Dialect/CC/CCTypes.h" -#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h" #include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h" -#include "cudaq/Optimizer/Transforms/Passes.h" -#include "mlir/InitAllDialects.h" +#include "cudaq/Optimizer/InitAllDialects.h" +#include "cudaq/Optimizer/InitAllPasses.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/CAPI/IR.h" #include #include #include @@ -27,33 +23,9 @@ namespace py = pybind11; using namespace mlir::python::adaptors; using namespace mlir; -namespace cudaq { -static bool registered = false; - -void registerQuakeDialectAndTypes(py::module &m) { +static void registerQuakeTypes(py::module &m) { auto quakeMod = m.def_submodule("quake"); - quakeMod.def( - "register_dialect", - [](MlirContext context, bool load) { - MlirDialectHandle handle = mlirGetDialectHandle__quake__(); - mlirDialectHandleRegisterDialect(handle, context); - if (load) { - mlirDialectHandleLoadDialect(handle, context); - } - - if (!registered) { - cudaq::opt::registerOptCodeGenPasses(); - cudaq::opt::registerOptTransformsPasses(); - cudaq::opt::registerAggressiveEarlyInlining(); - cudaq::opt::registerUnrollingPipeline(); - cudaq::opt::registerTargetPipelines(); - cudaq::opt::registerMappingPipeline(); - registered = true; - } - }, - py::arg("context") = py::none(), py::arg("load") = true); - mlir_type_subclass(quakeMod, "RefType", [](MlirType type) { return unwrap(type).isa(); }).def_classmethod("get", [](py::object cls, MlirContext ctx) { @@ -144,21 +116,10 @@ void registerQuakeDialectAndTypes(py::module &m) { }); } -void registerCCDialectAndTypes(py::module &m) { +static void registerCCTypes(py::module &m) { auto ccMod = m.def_submodule("cc"); - ccMod.def( - "register_dialect", - [](MlirContext context, bool load) { - MlirDialectHandle ccHandle = mlirGetDialectHandle__cc__(); - mlirDialectHandleRegisterDialect(ccHandle, context); - if (load) { - mlirDialectHandleLoadDialect(ccHandle, context); - } - }, - py::arg("context") = py::none(), py::arg("load") = true); - mlir_type_subclass(ccMod, "CharspanType", [](MlirType type) { return unwrap(type).isa(); }).def_classmethod("get", [](py::object cls, MlirContext ctx) { @@ -298,10 +259,7 @@ void registerCCDialectAndTypes(py::module &m) { }); } -void bindRegisterDialects(py::module &mod) { - registerQuakeDialectAndTypes(mod); - registerCCDialectAndTypes(mod); - +void cudaq::bindRegisterDialects(py::module &mod) { mod.def("load_intrinsic", [](MlirModule module, std::string name) { auto unwrapped = unwrap(module); cudaq::IRBuilder builder = IRBuilder::atBlockEnd(unwrapped.getBody()); @@ -311,14 +269,18 @@ void bindRegisterDialects(py::module &mod) { mod.def("register_all_dialects", [](MlirContext context) { DialectRegistry registry; - registry.insert(); + cudaq::registerAllDialects(registry); cudaq::opt::registerCodeGenDialect(registry); - registerAllDialects(registry); - auto *mlirContext = unwrap(context); + MLIRContext *mlirContext = unwrap(context); mlirContext->appendDialectRegistry(registry); mlirContext->loadAllAvailableDialects(); }); + // Register type as passes once, when the module is loaded. + registerQuakeTypes(mod); + registerCCTypes(mod); + cudaq::registerAllPasses(); + mod.def("gen_vector_of_complex_constant", [](MlirLocation loc, MlirModule module, std::string name, @@ -330,4 +292,3 @@ void bindRegisterDialects(py::module &mod) { builder.genVectorOfConstants(unwrap(loc), modOp, name, newValues); }); } -} // namespace cudaq diff --git a/python/tests/mlir/bare.py b/python/tests/mlir/bare.py index 2b007b5a100..e7a2d71049c 100644 --- a/python/tests/mlir/bare.py +++ b/python/tests/mlir/bare.py @@ -8,12 +8,13 @@ # RUN: PYTHONPATH=../../ python3 %s | FileCheck %s +from cudaq.mlir import register_all_dialects from cudaq.mlir.ir import * from cudaq.mlir.dialects import quake from cudaq.mlir.dialects import builtin, func, arith with Context() as ctx: - quake.register_dialect() + register_all_dialects(ctx) m = Module.create(loc=Location.unknown()) with InsertionPoint(m.body), Location.unknown(): f = func.FuncOp('main', ([], []))