//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #ifdef TORCH_MLIR_ENABLE_MHLO #include "PassDetail.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/ChloOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::TorchConversion; namespace { class VerifyMhloBackendContractPass : public VerifyMhloBackendContractBase { void runOnOperation() override { MLIRContext *context = &getContext(); auto module = getOperation(); TypeConverter converter; converter.addConversion([](Type type) -> Type { auto elemTy = type; if (isa(type)) { elemTy = type.cast().getElementType(); } if (BaseMemRefType::isValidElementType(elemTy)) return type; return nullptr; }); auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); }; ConversionTarget target(*context); // Structural operations. target.addDynamicallyLegalOp(opHasLegalTypes); // Shape operations. target.addDynamicallyLegalOp(opHasLegalTypes); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); RewritePatternSet patterns(context); if (failed(applyFullConversion(module, target, std::move(patterns)))) { // We avoid `module.emitError()` so that mlir-print-op-on-diagnostics // doesn't unnecessarily spew out the entire module. emitError(module.getLoc()) << "Module does not conform to the MHLO backend contract. " "See dialect conversion legality information above."; return signalPassFailure(); } } }; } // namespace std::unique_ptr> mlir::torch::TorchConversion::createVerifyMhloBackendContractPass() { return std::make_unique(); } #endif // TORCH_MLIR_ENABLE_MHLO