diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index 21a463246..eeb66f1d2 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -1491,3 +1491,80 @@ class ExpandAsIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ExpandAsIntModule()) def ExpandAsIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(100, (1, 1, 1)), torch.randint(200, (4, 5, 6))) + +# ============================================================================== + +class CopyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x, y): + return torch.ops.aten.copy_(x, y) + + +@register_test_case(module_factory=lambda: CopyModule()) +def CopyModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 2, 4), tu.rand(3, 2, 4)) + + +class CopyWithDifferentSizesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, 4], torch.float32, True), + ([-1, -1, 1], torch.float32, True), + ]) + def forward(self, x, y): + return torch.ops.aten.copy_(x, y) + + +@register_test_case(module_factory=lambda: CopyWithDifferentSizesModule()) +def CopyWithDifferentSizesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 2, 4), tu.rand(3, 2, 1)) + + +class CopyWithDifferentDTypesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int64, True), + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x, y): + return torch.ops.aten.copy_(x, y) + + +@register_test_case(module_factory=lambda: CopyWithDifferentDTypesModule()) +def CopyWithDifferentDTypesModule_basic(module, tu: TestUtils): + module.forward(torch.randint(100, (3, 2, 4)), tu.rand(3, 2, 4)) + + +class CopyWithDifferentDTypesAndSizesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, 4], torch.float32, True), + ([-1, -1, 1], torch.int64, True), + ]) + def forward(self, x, y): + return torch.ops.aten.copy_(x, y) + + +@register_test_case(module_factory=lambda: CopyWithDifferentDTypesAndSizesModule()) +def CopyWithDifferentDTypesAndSizesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 2, 4), torch.randint(1000, (3, 2, 1))) diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 0dd8c85df..e18dbc77d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -1023,7 +1023,7 @@ def Torch_ValsemVariantAtenIndexPutImplOp: Torch_Op<"valsem.aten.index_put_impl" HasValueSemantics, ReadOnly ]> { -let summary = "`index_put_impl op : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`"; + let summary = "`index_put_impl op : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchOptionalTensorListType:$indices, @@ -1031,12 +1031,31 @@ let summary = "`index_put_impl op : (Tensor, Tensor?[], Tensor, bool, bool) -> ( Torch_BoolType:$accumulate, Torch_BoolType:$unsafe ); -let results = (outs + let results = (outs AnyTorchTensorType:$result ); let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate `,` $unsafe attr-dict `:` qualified(type($self)) `,` qualified(type($indices)) `,` qualified(type($values)) `,` qualified(type($accumulate)) `,` qualified(type($unsafe)) `->` qualified(type($result))"; } +// The corresponding without underscore variant for `torch.aten.copy_` +// doesn't exist in the pytorch ops registry. Add it here. +def Torch_ValsemVariantAtenCopyOp : Torch_Op<"valsem.aten.copy", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "copy op : (Tensor, Tensor, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$src, + Torch_BoolType:$non_blocking + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $src `,` $non_blocking attr-dict `:` qualified(type($self)) `,` qualified(type($src)) `,` qualified(type($non_blocking)) `->` qualified(type($result))"; +} + // To handle runtime assertions, torchscript provides us `torch._assert` operation. // But TS compiler introduces control flow for `torch._assert` operation. The // `torch._assert` would introduce control flow like: diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 10550486b..469d82532 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -878,6 +878,90 @@ public: }; } // namespace +// Broadcasts input tensor based on the broadcastToShape. +static LogicalResult broadcastToGivenShape(Operation *op, + ConversionPatternRewriter &rewriter, + Value input, + SmallVector broadcastToShape, + Value &result) { + RankedTensorType inputType = input.getType().cast(); + ArrayRef inputShape = inputType.getShape(); + if (broadcastToShape.size() < inputShape.size()) { + return rewriter.notifyMatchFailure( + op, "invalid shape: broadcastToShape size must not be smaller than the " + "size of the input shape"); + } + + Type elementType = inputType.getElementType(); + Location loc = op->getLoc(); + MLIRContext *context = op->getContext(); + SmallVector outShape; + + // Create affine map and shapes for tensor initialization. + SmallVector outExpr; + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + size_t diff = broadcastToShape.size() - inputShape.size(); + for (size_t i = 0; i < broadcastToShape.size(); i++) { + Value shapeValue = broadcastToShape[i]; + size_t j = i - diff; + if (i < diff) { + Value isValid = rewriter.create( + loc, arith::CmpIPredicate::sge, shapeValue, zero); + rewriter.create( + loc, isValid, + rewriter.getStringAttr( + "negative values not allowed in new dimensions")); + outShape.push_back(castIntToIndex(rewriter, loc, shapeValue)); + continue; + } + if (inputShape[j] == 1) { + // Broadcast singleton dimension + Value one = + rewriter.create(loc, rewriter.getIndexAttr(1)); + Value isNegative = rewriter.create( + loc, arith::CmpIPredicate::slt, shapeValue, zero); + Value select = rewriter.create( + loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue)); + outShape.push_back(select); + outExpr.push_back(mlir::getAffineConstantExpr(0, context)); + continue; + } + // Non-broadcast case + Value dim = getDimOp(rewriter, loc, input, j); + Value isNegative = rewriter.create( + loc, arith::CmpIPredicate::slt, shapeValue, zero); + Value isEqual = rewriter.create( + loc, arith::CmpIPredicate::eq, castIndexToInt(rewriter, loc, dim), + shapeValue); + Value isValid = rewriter.create(loc, isNegative, isEqual); + rewriter.create( + loc, isValid, + rewriter.getStringAttr( + "only broadcasting singleton dimensions supported")); + outShape.push_back(dim); + outExpr.push_back(mlir::getAffineDimExpr(i, context)); + } + + Value outTensor = + rewriter.create(loc, outShape, elementType); + + SmallVector indexingMaps = { + AffineMap::get(broadcastToShape.size(), 0, outExpr, context), + rewriter.getMultiDimIdentityMap(broadcastToShape.size())}; + SmallVector iteratorTypes(broadcastToShape.size(), "parallel"); + result = rewriter + .create( + loc, outTensor.getType(), input, outTensor, indexingMaps, + iteratorTypes, + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }) + .getResult(0); + + return success(); +} + namespace { class ConvertAtenBroadcastToOp : public OpConversionPattern { public: @@ -889,88 +973,24 @@ public: if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Value self = adaptor.self(); - auto selfType = self.getType().cast(); - ArrayRef selfShape = selfType.getShape(); - Type elementType = selfType.getElementType(); - Location loc = op.getLoc(); - MLIRContext *context = op->getContext(); - SmallVector inShape, outShape; + SmallVector inShape; if (!getListConstructElements(adaptor.size(), inShape)) { return rewriter.notifyMatchFailure( op, "unimplemented: the size list is not from list construct"); } - SmallVector inShapeConverted = - getTypeConvertedValues(rewriter, loc, getTypeConverter(), inShape); - if (inShape.size() < selfShape.size()) + SmallVector inShapeConverted = getTypeConvertedValues( + rewriter, op.getLoc(), getTypeConverter(), inShape); + + Value result; + if (failed(broadcastToGivenShape(op, rewriter, self, inShapeConverted, + result))) { return rewriter.notifyMatchFailure( - op, "invalid shape: must not be smaller than rank of tensor"); - size_t diff = inShape.size() - selfShape.size(); - - // Create affine map and shapes for tensor initialization. - SmallVector outExpr; - Value zero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - for (size_t i = 0; i < inShape.size(); i++) { - Value shapeValue = inShapeConverted[i]; - size_t j = i - diff; - if (i < diff) { - Value isValid = rewriter.create( - loc, arith::CmpIPredicate::sge, shapeValue, zero); - rewriter.create( - loc, isValid, - rewriter.getStringAttr( - "negative values not allowed in new dimensions")); - outShape.push_back(castIntToIndex(rewriter, loc, shapeValue)); - continue; - } - if (selfShape[j] == 1) { - // Broadcast singleton dimension - Value one = - rewriter.create(loc, rewriter.getIndexAttr(1)); - Value isNegative = rewriter.create( - loc, arith::CmpIPredicate::slt, shapeValue, zero); - Value select = rewriter.create( - loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue)); - outShape.push_back(select); - outExpr.push_back(mlir::getAffineConstantExpr(0, context)); - continue; - } - // Non-broadcast case - Value dim = getDimOp(rewriter, loc, self, j); - Value isNegative = rewriter.create( - loc, arith::CmpIPredicate::slt, shapeValue, zero); - Value isEqual = rewriter.create( - loc, arith::CmpIPredicate::eq, castIndexToInt(rewriter, loc, dim), - shapeValue); - Value isValid = rewriter.create(loc, isNegative, isEqual); - rewriter.create( - loc, isValid, - rewriter.getStringAttr( - "only broadcasting singleton dimensions supported")); - outShape.push_back(dim); - outExpr.push_back(mlir::getAffineDimExpr(i, context)); + op, "unable to perform broadcast operation"); } - Value outTensor = - rewriter.create(loc, outShape, elementType); - - SmallVector indexingMaps = { - AffineMap::get(inShape.size(), 0, outExpr, context), - rewriter.getMultiDimIdentityMap(inShape.size())}; - SmallVector iteratorTypes(inShape.size(), "parallel"); - Value result = rewriter - .create( - loc, outTensor.getType(), self, outTensor, - indexingMaps, iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }) - .getResult(0); - Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, result); - return success(); } }; @@ -992,6 +1012,74 @@ public: }; } // namespace +namespace { +class ConvertValsemVariantAtenCopyOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ValsemVariantAtenCopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op.getLoc(); + Value self = adaptor.self(); + Value src = adaptor.src(); + RankedTensorType selfType = self.getType().cast(); + + // The non_blocking should be a constant `False`. + bool nonBlocking; + if (!matchPattern(op.non_blocking(), m_TorchConstantBool(&nonBlocking))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: non_blocking must be a constant"); + } else if (nonBlocking) { + return rewriter.notifyMatchFailure( + op, "unimplemented: non_blocking is expected to be false"); + } + + // The size of the src tensor can be different from the self but should be + // broadcastable. Therefore, broadcasting the src tensor to match the size + // of the self tensor. + SmallVector selfSizes = getTensorSizes(rewriter, loc, self); + for (unsigned i = 0; i < selfSizes.size(); i++) + selfSizes[i] = castIndexToInt(rewriter, loc, selfSizes[i]); + Value broadcastedSrc; + if (failed(broadcastToGivenShape(op, rewriter, src, selfSizes, + broadcastedSrc))) { + return rewriter.notifyMatchFailure( + op, "unable to perform broadcast operation"); + } + + AffineMap id = AffineMap::getMultiDimIdentityMap(selfType.getRank(), + rewriter.getContext()); + SmallVector iteratorTypes(selfType.getRank(), + getParallelIteratorTypeName()); + Value result = rewriter + .create( + loc, + /*resultType=*/selfType, + /*inputs=*/broadcastedSrc, + /*outputs=*/self, + /*indexingMaps=*/llvm::makeArrayRef({id, id}), + /*iteratorTypes=*/iteratorTypes, + [](OpBuilder &b, Location loc, ValueRange args) { + Value result = args[0]; + if (args[0].getType() != args[1].getType()) { + result = convertScalarToDtype(b, loc, args[0], + args[1].getType()); + } + b.create(loc, result); + }) + ->getResult(0); + + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1018,4 +1106,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 5f99cc954..310307eb0 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -161,6 +161,9 @@ public: } else if (isa(op)) { newOp = rewriter.create( loc, op->getResultTypes(), op->getOperands()); + } else if (isa(op)) { + newOp = rewriter.create( + loc, op->getResultTypes(), op->getOperands()); } else { return failure(); } @@ -241,6 +244,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *op) { if (op->hasTrait()) { auto hasValueSemantics = [](Type t) { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 79105abe0..3f8730c8c 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -517,7 +517,8 @@ ChangeResult TypeAnalyzer::visitOperation( AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenIndexTensorOp, - ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp>(op)) { + ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp, + ValsemVariantAtenCopyOp>(op)) { ValueKnowledge knowledge = ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); knowledge.dtype = operands[0]->getValue().dtype; diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 165317ac8..17ad74f78 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -1796,6 +1796,10 @@ module { func @"__torch_mlir_shape_fn.aten.fill.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { return %arg0 : !torch.list } + func @"__torch_mlir_shape_fn.aten.copy"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list { + %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func @"__torch_mlir_shape_fn.aten.uniform"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list { return %arg0 : !torch.list } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 7545eb08a..15b3acfc7 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -597,6 +597,10 @@ def aten〇new_ones(self: List[int], size: List[int], dtype: Optional[int] = Non def aten〇fill〇Scalar(self: List[int], value: float) -> List[int]: return self +@not_present_in_registry +def aten〇copy(self: List[int], src: List[int], non_blocking: bool = False) -> List[int]: + return upstream_shape_helpers.unary(self) + @not_present_in_registry def aten〇uniform(self: List[int], from_: float = 0., to: float = 1., generator: Any = None) -> List[int]: return self diff --git a/test/Dialect/Torch/reduce-op-variants.mlir b/test/Dialect/Torch/reduce-op-variants.mlir index febfdd209..a6d5e582b 100644 --- a/test/Dialect/Torch/reduce-op-variants.mlir +++ b/test/Dialect/Torch/reduce-op-variants.mlir @@ -198,3 +198,20 @@ func @torch.aten._index_put_impl_(%self: !torch.tensor, %index: !torch.tensor, % %ret = torch.aten._index_put_impl_ %self, %indicesList, %values, %true, %false : !torch.tensor, !torch.list>, !torch.tensor, !torch.bool, !torch.bool -> !torch.tensor return %ret : !torch.tensor } + +// CHECK-LABEL: func @torch.aten.copy_( +// CHECK-SAME: %[[DST:.*]]: !torch.tensor, +// CHECK-SAME: %[[SRC:.*]]: !torch.tensor) -> !torch.tensor { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[DST_VTENSOR:.*]] = torch.copy.to_vtensor %[[DST]] : !torch.vtensor +// CHECK: %[[SRC_VTENSOR:.*]] = torch.copy.to_vtensor %[[SRC]] : !torch.vtensor +// CHECK: %[[VRET:.*]] = torch.valsem.aten.copy %[[DST_VTENSOR]], %[[SRC_VTENSOR]], %[[FALSE]] : !torch.vtensor, !torch.vtensor, !torch.bool -> !torch.vtensor +// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor +// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor +// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[DST]] : !torch.vtensor, !torch.tensor +// CHECK: return %[[DST]] : !torch.tensor +func @torch.aten.copy_(%dst: !torch.tensor, %src : !torch.tensor) -> !torch.tensor { + %false = torch.constant.bool false + %ret = torch.aten.copy_ %dst, %src, %false : !torch.tensor, !torch.tensor, !torch.bool -> !torch.tensor + return %ret : !torch.tensor +}