From 405f884522e769d4979e4ddb262dd48708d59629 Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Thu, 16 May 2024 11:03:43 +0800 Subject: [PATCH] [stablehlo] verify stablehlo backend contract (#3338) --- .../Transforms/VerifyStablehloBackendContract.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp index 0c8cdf2fc..3ff6e4732 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp @@ -11,10 +11,9 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/OpDefinition.h" #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" @@ -47,13 +46,21 @@ class VerifyStablehloBackendContractPass // Structural operations. target.addDynamicallyLegalOp( opHasLegalTypes); - // Shape operations. - target.addDynamicallyLegalOp(opHasLegalTypes); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + auto moduleOp = getOperation(); + RewritePatternSet patterns(context); + if (failed(applyFullConversion(moduleOp, target, std::move(patterns)))) { + emitError(moduleOp.getLoc()) + << "Module does not conform to the Stablehlo backend contract. " + "See dialect conversion legality information above."; + return signalPassFailure(); + } } }; } // namespace