Skip to content

Commit

Permalink
[python] Remove unused MLIR components
Browse files Browse the repository at this point in the history
We don't need to take everything from MLIR for our python bindings. This
change cherry picks the upstream components our compiler depends on.

The commit also cleans up some unnecessary code that ends up registering
dialects more than once, and surfaces the `register_all_dialects`
function to a less obscure location.

Signed-off-by: boschmitt <[email protected]>
  • Loading branch information
boschmitt committed Jan 31, 2025
1 parent d893046 commit e62b6cf
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 66 deletions.
2 changes: 0 additions & 2 deletions python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs):
else:
self.ctx = Context()
register_all_dialects(self.ctx)
quake.register_dialect(self.ctx)
cc.register_dialect(self.ctx)
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)
self.loc = Location.unknown(context=self.ctx)
self.module = Module.create(loc=self.loc)
Expand Down
5 changes: 2 additions & 3 deletions python/cudaq/kernel/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@

# We need static initializers to run in the CAPI `ExecutionEngine`,
# so here we run a simple JIT compile at global scope
with Context():
with Context() as ctx:
register_all_dialects(ctx)
module = Module.parse(r"""
llvm.func @none() {
llvm.return
Expand Down Expand Up @@ -246,8 +247,6 @@ class PyKernel(object):
def __init__(self, argTypeList):
self.ctx = Context()
register_all_dialects(self.ctx)
quake.register_dialect(self.ctx)
cc.register_dialect(self.ctx)
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)

self.metadata = {'conditionalOnMeasure': False}
Expand Down
9 changes: 9 additions & 0 deletions python/cudaq/mlir/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# ============================================================================ #
# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #

from ._mlir_libs._quakeDialects import register_all_dialects
22 changes: 15 additions & 7 deletions python/extension/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,14 @@ add_mlir_python_common_capi_library(CUDAQuantumPythonCAPI
RELATIVE_INSTALL_ROOT "../.."
DECLARED_SOURCES
CUDAQuantumPythonSources
# TODO: Remove this in favor of showing fine grained registration once
# available.
MLIRPythonExtension.RegisterEverything
MLIRPythonSources.Core
MLIRPythonSources.Dialects.arith
MLIRPythonSources.Dialects.builtin
MLIRPythonSources.Dialects.cf
MLIRPythonSources.Dialects.complex
MLIRPythonSources.Dialects.func
MLIRPythonSources.Dialects.math
MLIRPythonSources.ExecutionEngine
)

################################################################################
Expand All @@ -134,10 +138,14 @@ add_mlir_python_modules(CUDAQuantumPythonModules
INSTALL_PREFIX "cudaq/mlir"
DECLARED_SOURCES
CUDAQuantumPythonSources
# TODO: Remove this in favor of showing fine grained registration once
# available.
MLIRPythonExtension.RegisterEverything
MLIRPythonSources
MLIRPythonSources.Core
MLIRPythonSources.Dialects.arith
MLIRPythonSources.Dialects.builtin
MLIRPythonSources.Dialects.cf
MLIRPythonSources.Dialects.complex
MLIRPythonSources.Dialects.func
MLIRPythonSources.Dialects.math
MLIRPythonSources.ExecutionEngine
COMMON_CAPI_LINK_LIBS
CUDAQuantumPythonCAPI
)
Expand Down
75 changes: 22 additions & 53 deletions python/runtime/mlir/py_register_dialects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,16 @@
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

#include "mlir/Bindings/Python/PybindAdaptors.h"

#include "py_register_dialects.h"
#include "cudaq/Optimizer/Builder/Intrinsics.h"
#include "cudaq/Optimizer/CAPI/Dialects.h"
#include "cudaq/Optimizer/CodeGen/Passes.h"
#include "cudaq/Optimizer/CodeGen/Pipelines.h"
#include "cudaq/Optimizer/Dialect/CC/CCDialect.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/CC/CCTypes.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
#include "cudaq/Optimizer/InitAllDialects.h"
#include "cudaq/Optimizer/Transforms/Passes.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/CAPI/IR.h"
#include "mlir/Transforms/Passes.h"
#include <fmt/core.h>
#include <pybind11/complex.h>
#include <pybind11/stl.h>
Expand All @@ -27,33 +24,9 @@ namespace py = pybind11;
using namespace mlir::python::adaptors;
using namespace mlir;

namespace cudaq {
static bool registered = false;

void registerQuakeDialectAndTypes(py::module &m) {
static void registerQuakeTypes(py::module &m) {
auto quakeMod = m.def_submodule("quake");

quakeMod.def(
"register_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle handle = mlirGetDialectHandle__quake__();
mlirDialectHandleRegisterDialect(handle, context);
if (load) {
mlirDialectHandleLoadDialect(handle, context);
}

if (!registered) {
cudaq::opt::registerOptCodeGenPasses();
cudaq::opt::registerOptTransformsPasses();
cudaq::opt::registerAggressiveEarlyInlining();
cudaq::opt::registerUnrollingPipeline();
cudaq::opt::registerTargetPipelines();
cudaq::opt::registerMappingPipeline();
registered = true;
}
},
py::arg("context") = py::none(), py::arg("load") = true);

mlir_type_subclass(quakeMod, "RefType", [](MlirType type) {
return unwrap(type).isa<quake::RefType>();
}).def_classmethod("get", [](py::object cls, MlirContext ctx) {
Expand Down Expand Up @@ -144,21 +117,10 @@ void registerQuakeDialectAndTypes(py::module &m) {
});
}

void registerCCDialectAndTypes(py::module &m) {
static void registerCCTypes(py::module &m) {

auto ccMod = m.def_submodule("cc");

ccMod.def(
"register_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle ccHandle = mlirGetDialectHandle__cc__();
mlirDialectHandleRegisterDialect(ccHandle, context);
if (load) {
mlirDialectHandleLoadDialect(ccHandle, context);
}
},
py::arg("context") = py::none(), py::arg("load") = true);

mlir_type_subclass(ccMod, "CharspanType", [](MlirType type) {
return unwrap(type).isa<cudaq::cc::CharspanType>();
}).def_classmethod("get", [](py::object cls, MlirContext ctx) {
Expand Down Expand Up @@ -298,10 +260,7 @@ void registerCCDialectAndTypes(py::module &m) {
});
}

void bindRegisterDialects(py::module &mod) {
registerQuakeDialectAndTypes(mod);
registerCCDialectAndTypes(mod);

void cudaq::bindRegisterDialects(py::module &mod) {
mod.def("load_intrinsic", [](MlirModule module, std::string name) {
auto unwrapped = unwrap(module);
cudaq::IRBuilder builder = IRBuilder::atBlockEnd(unwrapped.getBody());
Expand All @@ -311,14 +270,25 @@ void bindRegisterDialects(py::module &mod) {

mod.def("register_all_dialects", [](MlirContext context) {
DialectRegistry registry;
registry.insert<quake::QuakeDialect, cudaq::cc::CCDialect>();
cudaq::registerAllDialects(registry);
cudaq::opt::registerCodeGenDialect(registry);
registerAllDialects(registry);
auto *mlirContext = unwrap(context);
MLIRContext *mlirContext = unwrap(context);
mlirContext->appendDialectRegistry(registry);
mlirContext->loadAllAvailableDialects();
});

// Register type and passes once, when the module is loaded.
registerQuakeTypes(mod);
registerCCTypes(mod);

mlir::registerTransformsPasses();
cudaq::opt::registerOptCodeGenPasses();
cudaq::opt::registerOptTransformsPasses();
cudaq::opt::registerAggressiveEarlyInlining();
cudaq::opt::registerUnrollingPipeline();
cudaq::opt::registerTargetPipelines();
cudaq::opt::registerMappingPipeline();

mod.def("gen_vector_of_complex_constant", [](MlirLocation loc,
MlirModule module,
std::string name,
Expand All @@ -330,4 +300,3 @@ void bindRegisterDialects(py::module &mod) {
builder.genVectorOfConstants(unwrap(loc), modOp, name, newValues);
});
}
} // namespace cudaq
3 changes: 2 additions & 1 deletion python/tests/mlir/bare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@

# RUN: PYTHONPATH=../../ python3 %s | FileCheck %s

from cudaq.mlir import register_all_dialects
from cudaq.mlir.ir import *
from cudaq.mlir.dialects import quake
from cudaq.mlir.dialects import builtin, func, arith

with Context() as ctx:
quake.register_dialect()
register_all_dialects(ctx)
m = Module.create(loc=Location.unknown())
with InsertionPoint(m.body), Location.unknown():
f = func.FuncOp('main', ([], []))
Expand Down

0 comments on commit e62b6cf

Please sign in to comment.