diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 549c96cc6..9a2efc029 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -464,6 +464,9 @@ TOSA_PASS_SET = { "ArangeStartNegativeStepIntModule_basic", "ArangeZeroElementOutputModule_basic", "NumToTensorIntModule_basic", + "ToDtypeBoolLayoutNoneStaticModule_basic", + "ToCopyBoolDTypeStaticModule_basic", + "HardTanhIntModule_basic", } LTC_XFAIL_SET = { diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 44f00eea9..3d8e4bffc 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -53,6 +53,9 @@ template llvm::Optional getConstTensor(PatternRewriter &rewriter, Operation *op, ArrayRef vec, ArrayRef shape); +LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, + Value src, Type destType, Value &result); + // Creates a TOSA operation and performs shape inference on the individual // op. This allows shape inference during the framework to TOSA lowering. template diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index d98dc2eca..323846fe3 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -52,6 +52,9 @@ int getTensorRank(Value tensor); bool isViewLikeOp(Operation *op); +Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc, + float value, Type dtype); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index d26cc88f9..a4fc236e1 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3078,6 +3078,110 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + ValsemVariantAtenCopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto selfType = adaptor.self().getType().dyn_cast(); + auto srcType = adaptor.src().getType().dyn_cast(); + if (!selfType || !selfType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Only tensor types with static shape are supported"); + + if (!srcType || !srcType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Only tensor types with static shape are supported"); + + // The non_blocking should be a constant `False`. + bool nonBlocking; + if (!matchPattern(op.non_blocking(), m_TorchConstantBool(&nonBlocking))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: non_blocking must be a constant"); + } else if (nonBlocking) { + return rewriter.notifyMatchFailure( + op, "unimplemented: non_blocking is expected to be false"); + } + + SmallVector selfShape(selfType.getShape()); + SmallVector srcShape(srcType.getShape()); + + if (llvm::equal(selfShape, srcShape) || selfShape.size() == 0) { + // If we reach here, then it means the given case is handled by implicit + // broadcasting done by tosa. + Value result; + if (failed(tosa::tosaCastTensorToType( + rewriter, op, adaptor.src(), + getTypeConverter()->convertType(op.getType()), result))) + return rewriter.notifyMatchFailure( + op, "unimplemented: cast to result type not supported"); + rewriter.replaceOp(op, result); + return success(); + } + return rewriter.notifyMatchFailure( + op, "unimplemented: valsem.aten.copy op not supported for this case."); +} + +// Legalizes the torch.aten.to.dtype op +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenToDtypeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType || !selfType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Only tensor types with static shape are supported"); + + // The non_blocking arg should be a constant `False`. + bool nonBlocking; + if (!matchPattern(op.non_blocking(), m_TorchConstantBool(&nonBlocking))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: non_blocking arg must be a constant"); + } else if (nonBlocking) { + return rewriter.notifyMatchFailure( + op, "unimplemented: non_blocking arg is expected to be false"); + } + + // The copy arg should be a constant `False`. + bool copy; + if (!matchPattern(op.copy(), m_TorchConstantBool(©))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: copy arg must be a constant"); + } else if (copy) { + return rewriter.notifyMatchFailure( + op, "unimplemented: copy arg is expected to be false"); + } + + // Only `none`, `contiguous` and `preserve` memory_format is supported. + if (!op.memory_format().getType().isa()) { + int64_t memoryFormat; + if (!matchPattern(op.memory_format(), m_TorchConstantInt(&memoryFormat))) + return rewriter.notifyMatchFailure( + op, "unimplemented: the memory format should be specified in " + "an integer constant"); + if (memoryFormat != torch_upstream::MemoryFormat::Contiguous && + memoryFormat != torch_upstream::MemoryFormat::Preserve) + return rewriter.notifyMatchFailure( + op, "unimplemented: only none, contiguous and preserve " + "memory_format is supported"); + } + + auto resultTy = getTypeConverter() + ->convertType(op.getResult().getType()) + .cast(); + + Value result; + if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.self(), resultTy, + result))) + return rewriter.notifyMatchFailure(op, "conversion to result type failed"); + + rewriter.replaceOp(op, result); + return success(); +} + template class ConvertAtenPoolingBaseOp : public OpConversionPattern { public: @@ -3728,6 +3832,8 @@ public: INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); + INSERT_ATENOP_PATTERN(ValsemVariantAtenCopyOp); + INSERT_ATENOP_PATTERN(AtenToDtypeOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index e0c213837..685a6dd86 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -221,6 +221,64 @@ llvm::Optional getConstTensor(PatternRewriter &rewriter, return const_op.getResult(); } +static LogicalResult checkValidityOfCast(Type src, Type dest) { + if ((src.isInteger(64) && dest.isInteger(32)) || + (src.isInteger(32) && dest.isInteger(64)) || + (src.isInteger(64) && dest.isInteger(1)) || + (src.isInteger(32) && dest.isInteger(1)) || + (src.isInteger(8) && dest.isInteger(1)) || + (src.isF32() && dest.isInteger(1))) { + return success(); + } + return failure(); +} + +// Template specialization for float +LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, + Value src, Type destType, Value &result) { + + Type srcElemTy = src.getType().dyn_cast().getElementType(); + Type destElemTy = destType.dyn_cast().getElementType(); + + if (failed(checkValidityOfCast(srcElemTy, destElemTy))) + return rewriter.notifyMatchFailure( + op, "casting to result dtype is invalid or unsupported"); + + if (destElemTy.isInteger(1)) { + auto srcType = src.getType().dyn_cast(); + SmallVector srcShape(srcType.getShape()); + uint64_t num_total_elements = 1; + for (int64_t a : srcShape) + num_total_elements *= a; + + llvm::Optional constOp; + if (srcElemTy.isInteger(64)) { + SmallVector values(num_total_elements, 0); + constOp = + tosa::getConstTensor(rewriter, op, values, srcShape).value(); + } else if (srcElemTy.isInteger(32)) { + SmallVector values(num_total_elements, 0); + constOp = + tosa::getConstTensor(rewriter, op, values, srcShape).value(); + } else if (srcElemTy.isF32()) { + SmallVector values(num_total_elements, 0.0); + constOp = + tosa::getConstTensor(rewriter, op, values, srcShape).value(); + } else if (srcElemTy.isInteger(8)) { + SmallVector values(num_total_elements, 0); + constOp = + tosa::getConstTensor(rewriter, op, values, srcShape).value(); + } + Value equalToZero = rewriter.create(op->getLoc(), destType, + src, constOp.value()); + result = rewriter.create(op->getLoc(), destType, + equalToZero); + } else { + result = rewriter.create(op->getLoc(), destType, src); + } + return success(); +} + // Template instantiation template llvm::Optional getConstTensor(PatternRewriter &, Operation *, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 49ac08b2c..4574b2399 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2200,8 +2200,9 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_ToCopyOp op, PatternRewriter &rewriter) const override { - Value zero = rewriter.create( - op.getLoc(), rewriter.getF64FloatAttr(0.0)); + Type resultDtype = op.getType().cast().getDtype(); + Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0, + resultDtype); Value emptyTensor = rewriter.create( op.getLoc(), op.getType(), op.self(), zero, op.dtype(), op.layout(), op.device(), op.pin_memory(), op.memory_format()); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 1ff3b1608..d729f81ae 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -163,3 +163,18 @@ bool Torch::isViewLikeOp(Operation *op) { TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp, AtenNarrowOp, AtenToDeviceOp>(op); } + +Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, + Location loc, float value, + Type dtype) { + // Creating constants satisfying backend contract. + if (dtype.isInteger(64) || dtype.isInteger(32) || dtype.isInteger(8) || + dtype.isInteger(1)) + return rewriter.create( + loc, rewriter.getI64IntegerAttr((int64_t)value)); + if (dtype.isF64() || dtype.isF32() || dtype.isF16() || dtype.isBF16()) + return rewriter.create(loc, + rewriter.getF64FloatAttr(value)); + llvm::report_fatal_error( + "unhandled type for getConstantWithGivenDtypeAndValue"); +} diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index cdfcbb54f..cf945d0cf 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -2513,6 +2513,25 @@ def ToCopyWithDTypeFalsePinMemoryModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 2, 4)) +class ToCopyBoolDTypeStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 5, 5], torch.uint8, True), + ]) + def forward(self, x): + return torch.ops.aten._to_copy(x, dtype=torch.bool) + + +@register_test_case(module_factory=lambda: ToCopyBoolDTypeStaticModule()) +def ToCopyBoolDTypeStaticModule_basic(module, tu: TestUtils): + module.forward(tu.randint(1, 1, 5, 5).to(dtype=torch.uint8)) + + # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/python/torch_mlir_e2e_test/test_suite/type_conversion.py index 2df66184e..53f2d2e0a 100644 --- a/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -193,13 +193,13 @@ def ToDtypeLayoutStridedModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) -class ToDtypeBoolLayoutNoneModule(torch.nn.Module): +class ToDtypeBoolLayoutNoneStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([None, ([-1, -1], torch.float32, True)]) + @annotate_args([None, ([3, 5], torch.int64, True)]) def forward(self, x): return torch.ops.aten.to(x, dtype=torch.bool, @@ -211,9 +211,9 @@ class ToDtypeBoolLayoutNoneModule(torch.nn.Module): memory_format=None) -@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneModule()) -def ToDtypeBoolLayoutNoneModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 5)) +@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneStaticModule()) +def ToDtypeBoolLayoutNoneStaticModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 5)) class TypeAsSameModule(torch.nn.Module): diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 70fb95f36..48061c15c 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -859,3 +859,57 @@ func.func @torch.prim.NumToTensor.Scalar() -> !torch.vtensor<[],si64> { %0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64> return %0 : !torch.vtensor<[],si64> } + +// ----- +// CHECK-LABEL: func.func @torch.valsem.aten.copy( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> { +// CHECK: %[[INP:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,1,5,5],ui8> -> tensor<1x1x5x5xi8> +// CHECK: %[[CST5:.*]] = torch.constant.int 5 +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[CST11:.*]] = torch.constant.int 11 +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_1:.*]] = "tosa.const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tosa.equal"(%[[VAL_0]], %[[VAL_1]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = "tosa.logical_not"(%[[VAL_2]]) : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<0> : tensor<1x1x5x5xi8>} : () -> tensor<1x1x5x5xi8> +// CHECK: %[[VAL_5:.*]] = "tosa.equal"(%[[INP]], %[[VAL_4]]) : (tensor<1x1x5x5xi8>, tensor<1x1x5x5xi8>) -> tensor<1x1x5x5xi1> +// CHECK: %[[VAL_6:.*]] = "tosa.logical_not"(%[[VAL_5]]) : (tensor<1x1x5x5xi1>) -> tensor<1x1x5x5xi1> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x1x5x5xi1> -> !torch.vtensor<[1,1,5,5],i1> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,1,5,5],i1> +func.func @torch.valsem.aten.copy(%arg0: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> { + %int5 = torch.constant.int 5 + %int1 = torch.constant.int 1 + %int11 = torch.constant.int 11 + %none = torch.constant.none + %false = torch.constant.bool false + %int0 = torch.constant.int 0 + %0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64> + %1 = torch.aten.to.dtype %0, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1> + %2 = torch.prim.ListConstruct %int1, %int1, %int5, %int5 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.broadcast_to %1, %2 : !torch.vtensor<[],i1>, !torch.list -> !torch.vtensor<[1,1,5,5],i1> + %4 = torch.valsem.aten.copy %3, %arg0, %false : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,1,5,5],ui8>, !torch.bool -> !torch.vtensor<[1,1,5,5],i1> + return %4 : !torch.vtensor<[1,1,5,5],i1> +} + +// ----- +// CHECK-LABEL: func.func @torch.aten.to.dtype( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> { +// CHECK: %[[INP:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[3,5],si64> -> tensor<3x5xi64> +// CHECK: %[[CST11:.*]] = torch.constant.int 11 +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<0> : tensor<3x5xi64>} : () -> tensor<3x5xi64> +// CHECK: %[[VAL_1:.*]] = "tosa.equal"(%[[INP]], %[[VAL_0]]) : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1> +// CHECK: %[[VAL_2:.*]] = "tosa.logical_not"(%[[VAL_1]]) : (tensor<3x5xi1>) -> tensor<3x5xi1> +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,5],i1> +func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> { + %int11 = torch.constant.int 11 + %none = torch.constant.none + %false = torch.constant.bool false + %0 = torch.aten.to.dtype %arg0, %int11, %false, %false, %none : !torch.vtensor<[3,5],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],i1> + return %0 : !torch.vtensor<[3,5],i1> +}