//===- VerifyInvariantsBeforeBackendLowering.cpp -----------------*- C++-*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::TorchConversion; using namespace mlir::torch; static LogicalResult checkValueInvariants(Operation *errorReportOp, Value v) { // TODO: Make this an allowlist instead of a denylist. // TODO: Make this stricter. auto type = v.getType(); if (auto valueTensorType = type.dyn_cast()) { if (!valueTensorType.hasDtype() || !valueTensorType.hasSizes()) return errorReportOp->emitError() .append("unsupported by backend lowering: tensor with unknown rank " "or dtype") .attachNote() .append("this is likely due to a missing shape transfer function in " "shape_lib_gen.py or missing case in RefineTypes"); } return success(); } namespace { class VerifyInvariantsBeforeBackendLoweringPass : public VerifyInvariantsBeforeBackendLoweringBase< VerifyInvariantsBeforeBackendLoweringPass> { void runOnOperation() override { auto walkResult = getOperation().walk([&](Block *block) { // Check invariants on all the Value's in the program. // That is, check all BlockArgument's and OpResult's. for (BlockArgument arg : block->getArguments()) if (failed(checkValueInvariants(block->getParentOp(), arg))) return WalkResult::interrupt(); for (Operation &op : *block) { if (isa(op)) { op.emitError() .append("unsupported by backend lowering: `torch.operator` op") .attachNote() .append("this is likely due to a missing op that needs to be " "generated by torch_ods_gen.py"); return WalkResult::interrupt(); } for (OpResult result : op.getResults()) if (failed(checkValueInvariants(&op, result))) return WalkResult::interrupt(); } return WalkResult::advance(); }); if (walkResult.wasInterrupted()) return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::torch::TorchConversion:: createVerifyInvariantsBeforeBackendLoweringPass() { return std::make_unique(); }