Ignore constants in the legality error (#2328)

pull/2330/head
Matthias Gehre 2023-07-25 10:11:40 +02:00 committed by GitHub
parent 31ef08b63d
commit c56cb531d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 0 deletions

View File

@ -19,6 +19,7 @@
#include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
@ -4775,6 +4776,22 @@ public:
typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter); TorchConversion::setupBackendTypeConversion(target, typeConverter);
// The following ops are never the primary reason why lowering fails.
// The backend contract only allows functions to return tensors thus there
// is always another op using them.
// When we have a chain of torch.constant.int followed by a unsupported
// torch op, we want the pass to mention the unsupported torch op
// in the error message.
target.addLegalOp<ConstantNoneOp>();
target.addLegalOp<ConstantBoolOp>();
target.addLegalOp<ConstantIntOp>();
target.addLegalOp<ConstantFloatOp>();
target.addLegalOp<ConstantStrOp>();
target.addLegalOp<ConstantDeviceOp>();
target.addLegalOp<PrimListConstructOp>();
target.addLegalOp<PrimTupleConstructOp>();
target.addIllegalDialect<Torch::TorchDialect>();
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \ #define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \

View File

@ -124,3 +124,11 @@ func.func @torch.aten.pow.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],f16>) ->
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// -----
func.func @torch.prim.TupleConstruct() {
%int128 = torch.constant.int 128
%0 = torch.prim.TupleConstruct %int128 : !torch.int -> !torch.tuple<int>
// expected-error @below {{failed to legalize operation 'torch.prim.Print' that was explicitly marked illegal}}
torch.prim.Print(%0) : !torch.tuple<int>
return
}