Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Remove unused MLIR components #2443

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs):
else:
self.ctx = Context()
register_all_dialects(self.ctx)
quake.register_dialect(self.ctx)
cc.register_dialect(self.ctx)
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)
self.loc = Location.unknown(context=self.ctx)
self.module = Module.create(loc=self.loc)
Expand Down
5 changes: 2 additions & 3 deletions python/cudaq/kernel/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@

# We need static initializers to run in the CAPI `ExecutionEngine`,
# so here we run a simple JIT compile at global scope
with Context():
with Context() as ctx:
register_all_dialects(ctx)
module = Module.parse(r"""
llvm.func @none() {
llvm.return
Expand Down Expand Up @@ -246,8 +247,6 @@ class PyKernel(object):
def __init__(self, argTypeList):
self.ctx = Context()
register_all_dialects(self.ctx)
quake.register_dialect(self.ctx)
cc.register_dialect(self.ctx)
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)

self.metadata = {'conditionalOnMeasure': False}
Expand Down
9 changes: 9 additions & 0 deletions python/cudaq/mlir/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# ============================================================================ #
# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #

from ._mlir_libs._quakeDialects import register_all_dialects
22 changes: 15 additions & 7 deletions python/extension/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,14 @@ add_mlir_python_common_capi_library(CUDAQuantumPythonCAPI
RELATIVE_INSTALL_ROOT "../.."
DECLARED_SOURCES
CUDAQuantumPythonSources
# TODO: Remove this in favor of showing fine grained registration once
# available.
MLIRPythonExtension.RegisterEverything
MLIRPythonSources.Core
MLIRPythonSources.Dialects.arith
MLIRPythonSources.Dialects.builtin
MLIRPythonSources.Dialects.cf
MLIRPythonSources.Dialects.complex
MLIRPythonSources.Dialects.func
MLIRPythonSources.Dialects.math
MLIRPythonSources.ExecutionEngine
)

################################################################################
Expand All @@ -134,10 +138,14 @@ add_mlir_python_modules(CUDAQuantumPythonModules
INSTALL_PREFIX "cudaq/mlir"
DECLARED_SOURCES
CUDAQuantumPythonSources
# TODO: Remove this in favor of showing fine grained registration once
# available.
MLIRPythonExtension.RegisterEverything
MLIRPythonSources
MLIRPythonSources.Core
MLIRPythonSources.Dialects.arith
MLIRPythonSources.Dialects.builtin
MLIRPythonSources.Dialects.cf
MLIRPythonSources.Dialects.complex
MLIRPythonSources.Dialects.func
MLIRPythonSources.Dialects.math
MLIRPythonSources.ExecutionEngine
COMMON_CAPI_LINK_LIBS
CUDAQuantumPythonCAPI
)
Expand Down
63 changes: 18 additions & 45 deletions python/runtime/mlir/py_register_dialects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,15 @@
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

#include "py_register_dialects.h"
#include "cudaq/Optimizer/Builder/Intrinsics.h"
#include "cudaq/Optimizer/CAPI/Dialects.h"
#include "cudaq/Optimizer/CodeGen/Passes.h"
#include "cudaq/Optimizer/CodeGen/Pipelines.h"
#include "cudaq/Optimizer/Dialect/CC/CCDialect.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/CC/CCTypes.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
#include "cudaq/Optimizer/InitAllDialects.h"
#include "cudaq/Optimizer/InitAllPasses.h"
#include "cudaq/Optimizer/Transforms/Passes.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/InitAllDialects.h"
#include "mlir/CAPI/IR.h"
#include <fmt/core.h>
#include <pybind11/complex.h>
#include <pybind11/stl.h>
Expand All @@ -27,27 +23,9 @@ namespace py = pybind11;
using namespace mlir::python::adaptors;
using namespace mlir;

namespace cudaq {
static bool registered = false;

void registerQuakeDialectAndTypes(py::module &m) {
static void registerQuakeTypes(py::module &m) {
auto quakeMod = m.def_submodule("quake");

quakeMod.def(
"register_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle handle = mlirGetDialectHandle__quake__();
mlirDialectHandleRegisterDialect(handle, context);
if (load)
mlirDialectHandleLoadDialect(handle, context);

if (!registered) {
cudaq::registerCudaqPassesAndPipelines();
registered = true;
}
},
py::arg("context") = py::none(), py::arg("load") = true);

mlir_type_subclass(quakeMod, "RefType", [](MlirType type) {
return unwrap(type).isa<quake::RefType>();
}).def_classmethod("get", [](py::object cls, MlirContext ctx) {
Expand Down Expand Up @@ -138,21 +116,10 @@ void registerQuakeDialectAndTypes(py::module &m) {
});
}

void registerCCDialectAndTypes(py::module &m) {
static void registerCCTypes(py::module &m) {

auto ccMod = m.def_submodule("cc");

ccMod.def(
"register_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle ccHandle = mlirGetDialectHandle__cc__();
mlirDialectHandleRegisterDialect(ccHandle, context);
if (load) {
mlirDialectHandleLoadDialect(ccHandle, context);
}
},
py::arg("context") = py::none(), py::arg("load") = true);

mlir_type_subclass(ccMod, "CharspanType", [](MlirType type) {
return unwrap(type).isa<cudaq::cc::CharspanType>();
}).def_classmethod("get", [](py::object cls, MlirContext ctx) {
Expand Down Expand Up @@ -292,10 +259,9 @@ void registerCCDialectAndTypes(py::module &m) {
});
}

void bindRegisterDialects(py::module &mod) {
registerQuakeDialectAndTypes(mod);
registerCCDialectAndTypes(mod);
static bool registered = false;

void cudaq::bindRegisterDialects(py::module &mod) {
mod.def("load_intrinsic", [](MlirModule module, std::string name) {
auto unwrapped = unwrap(module);
cudaq::IRBuilder builder = IRBuilder::atBlockEnd(unwrapped.getBody());
Expand All @@ -305,14 +271,22 @@ void bindRegisterDialects(py::module &mod) {

mod.def("register_all_dialects", [](MlirContext context) {
DialectRegistry registry;
registry.insert<quake::QuakeDialect, cudaq::cc::CCDialect>();
cudaq::registerAllDialects(registry);
cudaq::opt::registerCodeGenDialect(registry);
registerAllDialects(registry);
auto *mlirContext = unwrap(context);
MLIRContext *mlirContext = unwrap(context);
mlirContext->appendDialectRegistry(registry);
mlirContext->loadAllAvailableDialects();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does loadAllAvailableDialects actually load Quake, CC, and CodeGen?

Do we need an explicit test to verify that all the dialects we expect to be loaded are loaded?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loadAllAvailableDialects() will load all dialects registered in the context. So, Quake and CC will be loaded. The CodeGen dialect is not registered in cudaq::registerAllDialects. I'm not sure why, but its being registered in other places and in different ways:

registry.insert<cudaq::codegen::CodeGenDialect>();

cudaq::opt::registerCodeGenDialect(registry);

I can make it uniform in a follow-up PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The codegen dialect was intended to be sort of second-class in that it would only be applicable at codegen time. The point being we don't want folks to start using it in ad hoc ways in random passes, etc. Adding a few landmines was sort of the idea.

OTOH, we could clean it up a bit and remove some of those hurdles. We still don't want that dialect used "in the wild" though...

});

// Register type and passes once, when the module is loaded.
registerQuakeTypes(mod);
registerCCTypes(mod);

if (!registered) {
cudaq::registerAllPasses();
registered = true;
}

mod.def("gen_vector_of_complex_constant", [](MlirLocation loc,
MlirModule module,
std::string name,
Expand All @@ -324,4 +298,3 @@ void bindRegisterDialects(py::module &mod) {
builder.genVectorOfConstants(unwrap(loc), modOp, name, newValues);
});
}
} // namespace cudaq
3 changes: 2 additions & 1 deletion python/tests/mlir/bare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@

# RUN: PYTHONPATH=../../ python3 %s | FileCheck %s

from cudaq.mlir import register_all_dialects
from cudaq.mlir.ir import *
from cudaq.mlir.dialects import quake
from cudaq.mlir.dialects import builtin, func, arith

with Context() as ctx:
quake.register_dialect()
register_all_dialects(ctx)
m = Module.create(loc=Location.unknown())
with InsertionPoint(m.body), Location.unknown():
f = func.FuncOp('main', ([], []))
Expand Down
Loading