mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add value tensor variant to aten::copy_ op
This commit adds the op `ValsemVariantAtenCopyOp` that represents `AtenCopy_Op` without the underscore. This is needed to make sure that the `ReduceOpVariants` pass turns the in-place op into an op that takes value tensors as inputs, otherwise the `MaximizeValueSemantics` pass will not be able to add value semantics correctly. This commit also adds the lowering of `ValsemVariantAtenCopyOp`. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/682/head
parent
4c0cd5c23d
commit
13383b03b8
|
@ -1491,3 +1491,80 @@ class ExpandAsIntModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ExpandAsIntModule())
|
||||
def ExpandAsIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (1, 1, 1)), torch.randint(200, (4, 5, 6)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class CopyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.copy_(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: CopyModule())
|
||||
def CopyModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 2, 4), tu.rand(3, 2, 4))
|
||||
|
||||
|
||||
class CopyWithDifferentSizesModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, 4], torch.float32, True),
|
||||
([-1, -1, 1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.copy_(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: CopyWithDifferentSizesModule())
|
||||
def CopyWithDifferentSizesModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 2, 4), tu.rand(3, 2, 1))
|
||||
|
||||
|
||||
class CopyWithDifferentDTypesModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.copy_(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: CopyWithDifferentDTypesModule())
|
||||
def CopyWithDifferentDTypesModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 2, 4)), tu.rand(3, 2, 4))
|
||||
|
||||
|
||||
class CopyWithDifferentDTypesAndSizesModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, 4], torch.float32, True),
|
||||
([-1, -1, 1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.copy_(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: CopyWithDifferentDTypesAndSizesModule())
|
||||
def CopyWithDifferentDTypesAndSizesModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 2, 4), torch.randint(1000, (3, 2, 1)))
|
||||
|
|
|
@ -1023,7 +1023,7 @@ def Torch_ValsemVariantAtenIndexPutImplOp: Torch_Op<"valsem.aten.index_put_impl"
|
|||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "`index_put_impl op : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`";
|
||||
let summary = "`index_put_impl op : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalTensorListType:$indices,
|
||||
|
@ -1031,12 +1031,31 @@ let summary = "`index_put_impl op : (Tensor, Tensor?[], Tensor, bool, bool) -> (
|
|||
Torch_BoolType:$accumulate,
|
||||
Torch_BoolType:$unsafe
|
||||
);
|
||||
let results = (outs
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate `,` $unsafe attr-dict `:` qualified(type($self)) `,` qualified(type($indices)) `,` qualified(type($values)) `,` qualified(type($accumulate)) `,` qualified(type($unsafe)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
// The corresponding without underscore variant for `torch.aten.copy_`
|
||||
// doesn't exist in the pytorch ops registry. Add it here.
|
||||
def Torch_ValsemVariantAtenCopyOp : Torch_Op<"valsem.aten.copy", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "copy op : (Tensor, Tensor, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$src,
|
||||
Torch_BoolType:$non_blocking
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $src `,` $non_blocking attr-dict `:` qualified(type($self)) `,` qualified(type($src)) `,` qualified(type($non_blocking)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
// To handle runtime assertions, torchscript provides us `torch._assert` operation.
|
||||
// But TS compiler introduces control flow for `torch._assert` operation. The
|
||||
// `torch._assert` would introduce control flow like:
|
||||
|
|
|
@ -878,6 +878,90 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Broadcasts input tensor based on the broadcastToShape.
|
||||
static LogicalResult broadcastToGivenShape(Operation *op,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value input,
|
||||
SmallVector<Value> broadcastToShape,
|
||||
Value &result) {
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
if (broadcastToShape.size() < inputShape.size()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "invalid shape: broadcastToShape size must not be smaller than the "
|
||||
"size of the input shape");
|
||||
}
|
||||
|
||||
Type elementType = inputType.getElementType();
|
||||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
SmallVector<Value> outShape;
|
||||
|
||||
// Create affine map and shapes for tensor initialization.
|
||||
SmallVector<AffineExpr> outExpr;
|
||||
Value zero =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
size_t diff = broadcastToShape.size() - inputShape.size();
|
||||
for (size_t i = 0; i < broadcastToShape.size(); i++) {
|
||||
Value shapeValue = broadcastToShape[i];
|
||||
size_t j = i - diff;
|
||||
if (i < diff) {
|
||||
Value isValid = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, shapeValue, zero);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, isValid,
|
||||
rewriter.getStringAttr(
|
||||
"negative values not allowed in new dimensions"));
|
||||
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
|
||||
continue;
|
||||
}
|
||||
if (inputShape[j] == 1) {
|
||||
// Broadcast singleton dimension
|
||||
Value one =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
Value isNegative = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, shapeValue, zero);
|
||||
Value select = rewriter.create<arith::SelectOp>(
|
||||
loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue));
|
||||
outShape.push_back(select);
|
||||
outExpr.push_back(mlir::getAffineConstantExpr(0, context));
|
||||
continue;
|
||||
}
|
||||
// Non-broadcast case
|
||||
Value dim = getDimOp(rewriter, loc, input, j);
|
||||
Value isNegative = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, shapeValue, zero);
|
||||
Value isEqual = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, castIndexToInt(rewriter, loc, dim),
|
||||
shapeValue);
|
||||
Value isValid = rewriter.create<arith::OrIOp>(loc, isNegative, isEqual);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, isValid,
|
||||
rewriter.getStringAttr(
|
||||
"only broadcasting singleton dimensions supported"));
|
||||
outShape.push_back(dim);
|
||||
outExpr.push_back(mlir::getAffineDimExpr(i, context));
|
||||
}
|
||||
|
||||
Value outTensor =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, outShape, elementType);
|
||||
|
||||
SmallVector<AffineMap> indexingMaps = {
|
||||
AffineMap::get(broadcastToShape.size(), 0, outExpr, context),
|
||||
rewriter.getMultiDimIdentityMap(broadcastToShape.size())};
|
||||
SmallVector<StringRef> iteratorTypes(broadcastToShape.size(), "parallel");
|
||||
result = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, outTensor.getType(), input, outTensor, indexingMaps,
|
||||
iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
|
||||
public:
|
||||
|
@ -889,88 +973,24 @@ public:
|
|||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Value self = adaptor.self();
|
||||
auto selfType = self.getType().cast<RankedTensorType>();
|
||||
ArrayRef<int64_t> selfShape = selfType.getShape();
|
||||
Type elementType = selfType.getElementType();
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
|
||||
SmallVector<Value> inShape, outShape;
|
||||
SmallVector<Value> inShape;
|
||||
if (!getListConstructElements(adaptor.size(), inShape)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: the size list is not from list construct");
|
||||
}
|
||||
SmallVector<Value> inShapeConverted =
|
||||
getTypeConvertedValues(rewriter, loc, getTypeConverter(), inShape);
|
||||
if (inShape.size() < selfShape.size())
|
||||
SmallVector<Value> inShapeConverted = getTypeConvertedValues(
|
||||
rewriter, op.getLoc(), getTypeConverter(), inShape);
|
||||
|
||||
Value result;
|
||||
if (failed(broadcastToGivenShape(op, rewriter, self, inShapeConverted,
|
||||
result))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "invalid shape: must not be smaller than rank of tensor");
|
||||
size_t diff = inShape.size() - selfShape.size();
|
||||
|
||||
// Create affine map and shapes for tensor initialization.
|
||||
SmallVector<AffineExpr> outExpr;
|
||||
Value zero =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
for (size_t i = 0; i < inShape.size(); i++) {
|
||||
Value shapeValue = inShapeConverted[i];
|
||||
size_t j = i - diff;
|
||||
if (i < diff) {
|
||||
Value isValid = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, shapeValue, zero);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, isValid,
|
||||
rewriter.getStringAttr(
|
||||
"negative values not allowed in new dimensions"));
|
||||
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
|
||||
continue;
|
||||
}
|
||||
if (selfShape[j] == 1) {
|
||||
// Broadcast singleton dimension
|
||||
Value one =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
Value isNegative = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, shapeValue, zero);
|
||||
Value select = rewriter.create<arith::SelectOp>(
|
||||
loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue));
|
||||
outShape.push_back(select);
|
||||
outExpr.push_back(mlir::getAffineConstantExpr(0, context));
|
||||
continue;
|
||||
}
|
||||
// Non-broadcast case
|
||||
Value dim = getDimOp(rewriter, loc, self, j);
|
||||
Value isNegative = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, shapeValue, zero);
|
||||
Value isEqual = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, castIndexToInt(rewriter, loc, dim),
|
||||
shapeValue);
|
||||
Value isValid = rewriter.create<arith::OrIOp>(loc, isNegative, isEqual);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, isValid,
|
||||
rewriter.getStringAttr(
|
||||
"only broadcasting singleton dimensions supported"));
|
||||
outShape.push_back(dim);
|
||||
outExpr.push_back(mlir::getAffineDimExpr(i, context));
|
||||
op, "unable to perform broadcast operation");
|
||||
}
|
||||
|
||||
Value outTensor =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, outShape, elementType);
|
||||
|
||||
SmallVector<AffineMap> indexingMaps = {
|
||||
AffineMap::get(inShape.size(), 0, outExpr, context),
|
||||
rewriter.getMultiDimIdentityMap(inShape.size())};
|
||||
SmallVector<StringRef> iteratorTypes(inShape.size(), "parallel");
|
||||
Value result = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, outTensor.getType(), self, outTensor,
|
||||
indexingMaps, iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, result);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -992,6 +1012,74 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertValsemVariantAtenCopyOp
|
||||
: public OpConversionPattern<ValsemVariantAtenCopyOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(ValsemVariantAtenCopyOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value self = adaptor.self();
|
||||
Value src = adaptor.src();
|
||||
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
|
||||
|
||||
// The non_blocking should be a constant `False`.
|
||||
bool nonBlocking;
|
||||
if (!matchPattern(op.non_blocking(), m_TorchConstantBool(&nonBlocking))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non_blocking must be a constant");
|
||||
} else if (nonBlocking) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non_blocking is expected to be false");
|
||||
}
|
||||
|
||||
// The size of the src tensor can be different from the self but should be
|
||||
// broadcastable. Therefore, broadcasting the src tensor to match the size
|
||||
// of the self tensor.
|
||||
SmallVector<Value> selfSizes = getTensorSizes(rewriter, loc, self);
|
||||
for (unsigned i = 0; i < selfSizes.size(); i++)
|
||||
selfSizes[i] = castIndexToInt(rewriter, loc, selfSizes[i]);
|
||||
Value broadcastedSrc;
|
||||
if (failed(broadcastToGivenShape(op, rewriter, src, selfSizes,
|
||||
broadcastedSrc))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to perform broadcast operation");
|
||||
}
|
||||
|
||||
AffineMap id = AffineMap::getMultiDimIdentityMap(selfType.getRank(),
|
||||
rewriter.getContext());
|
||||
SmallVector<StringRef> iteratorTypes(selfType.getRank(),
|
||||
getParallelIteratorTypeName());
|
||||
Value result = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc,
|
||||
/*resultType=*/selfType,
|
||||
/*inputs=*/broadcastedSrc,
|
||||
/*outputs=*/self,
|
||||
/*indexingMaps=*/llvm::makeArrayRef({id, id}),
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value result = args[0];
|
||||
if (args[0].getType() != args[1].getType()) {
|
||||
result = convertScalarToDtype(b, loc, args[0],
|
||||
args[1].getType());
|
||||
}
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
->getResult(0);
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
|
@ -1018,4 +1106,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
|||
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenContiguousOp>();
|
||||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||
target.addIllegalOp<ValsemVariantAtenCopyOp>();
|
||||
patterns.add<ConvertValsemVariantAtenCopyOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -161,6 +161,9 @@ public:
|
|||
} else if (isa<Aten_IndexPutImpl_Op>(op)) {
|
||||
newOp = rewriter.create<ValsemVariantAtenIndexPutImplOp>(
|
||||
loc, op->getResultTypes(), op->getOperands());
|
||||
} else if (isa<AtenCopy_Op>(op)) {
|
||||
newOp = rewriter.create<ValsemVariantAtenCopyOp>(
|
||||
loc, op->getResultTypes(), op->getOperands());
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
|
@ -241,6 +244,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
|||
target.addIllegalOp<AtenBernoulli_TensorOp>();
|
||||
target.addIllegalOp<AtenFill_ScalarOp>();
|
||||
target.addIllegalOp<Aten_IndexPutImpl_Op>();
|
||||
target.addIllegalOp<AtenCopy_Op>();
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
|
||||
auto hasValueSemantics = [](Type t) {
|
||||
|
|
|
@ -517,7 +517,8 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp,
|
||||
AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp,
|
||||
AtenConstantPadNdOp, AtenIndexTensorOp,
|
||||
ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp>(op)) {
|
||||
ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
||||
ValsemVariantAtenCopyOp>(op)) {
|
||||
ValueKnowledge knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.dtype = operands[0]->getValue().dtype;
|
||||
|
|
|
@ -1796,6 +1796,10 @@ module {
|
|||
func @"__torch_mlir_shape_fn.aten.fill.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
|
||||
return %arg0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.copy"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.uniform"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list<int> {
|
||||
return %arg0 : !torch.list<int>
|
||||
}
|
||||
|
|
|
@ -597,6 +597,10 @@ def aten〇new_ones(self: List[int], size: List[int], dtype: Optional[int] = Non
|
|||
def aten〇fill〇Scalar(self: List[int], value: float) -> List[int]:
|
||||
return self
|
||||
|
||||
@not_present_in_registry
|
||||
def aten〇copy(self: List[int], src: List[int], non_blocking: bool = False) -> List[int]:
|
||||
return upstream_shape_helpers.unary(self)
|
||||
|
||||
@not_present_in_registry
|
||||
def aten〇uniform(self: List[int], from_: float = 0., to: float = 1., generator: Any = None) -> List[int]:
|
||||
return self
|
||||
|
|
|
@ -198,3 +198,20 @@ func @torch.aten._index_put_impl_(%self: !torch.tensor, %index: !torch.tensor, %
|
|||
%ret = torch.aten._index_put_impl_ %self, %indicesList, %values, %true, %false : !torch.tensor, !torch.list<optional<tensor>>, !torch.tensor, !torch.bool, !torch.bool -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.copy_(
|
||||
// CHECK-SAME: %[[DST:.*]]: !torch.tensor,
|
||||
// CHECK-SAME: %[[SRC:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[DST_VTENSOR:.*]] = torch.copy.to_vtensor %[[DST]] : !torch.vtensor
|
||||
// CHECK: %[[SRC_VTENSOR:.*]] = torch.copy.to_vtensor %[[SRC]] : !torch.vtensor
|
||||
// CHECK: %[[VRET:.*]] = torch.valsem.aten.copy %[[DST_VTENSOR]], %[[SRC_VTENSOR]], %[[FALSE]] : !torch.vtensor, !torch.vtensor, !torch.bool -> !torch.vtensor
|
||||
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
|
||||
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
|
||||
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[DST]] : !torch.vtensor, !torch.tensor
|
||||
// CHECK: return %[[DST]] : !torch.tensor
|
||||
func @torch.aten.copy_(%dst: !torch.tensor, %src : !torch.tensor) -> !torch.tensor {
|
||||
%false = torch.constant.bool false
|
||||
%ret = torch.aten.copy_ %dst, %src, %false : !torch.tensor, !torch.tensor, !torch.bool -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue