[MLIR][TORCH] Add value tensor variant to aten::copy_ op

This commit adds the op `ValsemVariantAtenCopyOp` that represents
`AtenCopy_Op` without the underscore. This is needed to make sure
that the `ReduceOpVariants` pass turns the in-place op into an op
that takes value tensors as inputs, otherwise the
`MaximizeValueSemantics` pass will not be able to add value
semantics correctly.

This commit also adds the lowering of `ValsemVariantAtenCopyOp`.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/682/head
Vivek Khandelwal 2022-03-15 22:09:58 +05:30
parent 4c0cd5c23d
commit 13383b03b8
8 changed files with 291 additions and 75 deletions

View File

@ -1491,3 +1491,80 @@ class ExpandAsIntModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ExpandAsIntModule())
def ExpandAsIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (1, 1, 1)), torch.randint(200, (4, 5, 6)))
# ==============================================================================
class CopyModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, x, y):
return torch.ops.aten.copy_(x, y)
@register_test_case(module_factory=lambda: CopyModule())
def CopyModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4), tu.rand(3, 2, 4))
class CopyWithDifferentSizesModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, 4], torch.float32, True),
([-1, -1, 1], torch.float32, True),
])
def forward(self, x, y):
return torch.ops.aten.copy_(x, y)
@register_test_case(module_factory=lambda: CopyWithDifferentSizesModule())
def CopyWithDifferentSizesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4), tu.rand(3, 2, 1))
class CopyWithDifferentDTypesModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.int64, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, x, y):
return torch.ops.aten.copy_(x, y)
@register_test_case(module_factory=lambda: CopyWithDifferentDTypesModule())
def CopyWithDifferentDTypesModule_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (3, 2, 4)), tu.rand(3, 2, 4))
class CopyWithDifferentDTypesAndSizesModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, 4], torch.float32, True),
([-1, -1, 1], torch.int64, True),
])
def forward(self, x, y):
return torch.ops.aten.copy_(x, y)
@register_test_case(module_factory=lambda: CopyWithDifferentDTypesAndSizesModule())
def CopyWithDifferentDTypesAndSizesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4), torch.randint(1000, (3, 2, 1)))

View File

@ -1023,7 +1023,7 @@ def Torch_ValsemVariantAtenIndexPutImplOp: Torch_Op<"valsem.aten.index_put_impl"
HasValueSemantics,
ReadOnly
]> {
let summary = "`index_put_impl op : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`";
let summary = "`index_put_impl op : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalTensorListType:$indices,
@ -1031,12 +1031,31 @@ let summary = "`index_put_impl op : (Tensor, Tensor?[], Tensor, bool, bool) -> (
Torch_BoolType:$accumulate,
Torch_BoolType:$unsafe
);
let results = (outs
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate `,` $unsafe attr-dict `:` qualified(type($self)) `,` qualified(type($indices)) `,` qualified(type($values)) `,` qualified(type($accumulate)) `,` qualified(type($unsafe)) `->` qualified(type($result))";
}
// The corresponding without underscore variant for `torch.aten.copy_`
// doesn't exist in the pytorch ops registry. Add it here.
def Torch_ValsemVariantAtenCopyOp : Torch_Op<"valsem.aten.copy", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "copy op : (Tensor, Tensor, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$src,
Torch_BoolType:$non_blocking
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $src `,` $non_blocking attr-dict `:` qualified(type($self)) `,` qualified(type($src)) `,` qualified(type($non_blocking)) `->` qualified(type($result))";
}
// To handle runtime assertions, torchscript provides us `torch._assert` operation.
// But TS compiler introduces control flow for `torch._assert` operation. The
// `torch._assert` would introduce control flow like:

View File

@ -878,6 +878,90 @@ public:
};
} // namespace
// Broadcasts input tensor based on the broadcastToShape.
static LogicalResult broadcastToGivenShape(Operation *op,
ConversionPatternRewriter &rewriter,
Value input,
SmallVector<Value> broadcastToShape,
Value &result) {
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
ArrayRef<int64_t> inputShape = inputType.getShape();
if (broadcastToShape.size() < inputShape.size()) {
return rewriter.notifyMatchFailure(
op, "invalid shape: broadcastToShape size must not be smaller than the "
"size of the input shape");
}
Type elementType = inputType.getElementType();
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
SmallVector<Value> outShape;
// Create affine map and shapes for tensor initialization.
SmallVector<AffineExpr> outExpr;
Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
size_t diff = broadcastToShape.size() - inputShape.size();
for (size_t i = 0; i < broadcastToShape.size(); i++) {
Value shapeValue = broadcastToShape[i];
size_t j = i - diff;
if (i < diff) {
Value isValid = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, shapeValue, zero);
rewriter.create<cf::AssertOp>(
loc, isValid,
rewriter.getStringAttr(
"negative values not allowed in new dimensions"));
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
continue;
}
if (inputShape[j] == 1) {
// Broadcast singleton dimension
Value one =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Value isNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shapeValue, zero);
Value select = rewriter.create<arith::SelectOp>(
loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue));
outShape.push_back(select);
outExpr.push_back(mlir::getAffineConstantExpr(0, context));
continue;
}
// Non-broadcast case
Value dim = getDimOp(rewriter, loc, input, j);
Value isNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shapeValue, zero);
Value isEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, castIndexToInt(rewriter, loc, dim),
shapeValue);
Value isValid = rewriter.create<arith::OrIOp>(loc, isNegative, isEqual);
rewriter.create<cf::AssertOp>(
loc, isValid,
rewriter.getStringAttr(
"only broadcasting singleton dimensions supported"));
outShape.push_back(dim);
outExpr.push_back(mlir::getAffineDimExpr(i, context));
}
Value outTensor =
rewriter.create<linalg::InitTensorOp>(loc, outShape, elementType);
SmallVector<AffineMap> indexingMaps = {
AffineMap::get(broadcastToShape.size(), 0, outExpr, context),
rewriter.getMultiDimIdentityMap(broadcastToShape.size())};
SmallVector<StringRef> iteratorTypes(broadcastToShape.size(), "parallel");
result = rewriter
.create<linalg::GenericOp>(
loc, outTensor.getType(), input, outTensor, indexingMaps,
iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
return success();
}
namespace {
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
public:
@ -889,88 +973,24 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Value self = adaptor.self();
auto selfType = self.getType().cast<RankedTensorType>();
ArrayRef<int64_t> selfShape = selfType.getShape();
Type elementType = selfType.getElementType();
Location loc = op.getLoc();
MLIRContext *context = op->getContext();
SmallVector<Value> inShape, outShape;
SmallVector<Value> inShape;
if (!getListConstructElements(adaptor.size(), inShape)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: the size list is not from list construct");
}
SmallVector<Value> inShapeConverted =
getTypeConvertedValues(rewriter, loc, getTypeConverter(), inShape);
if (inShape.size() < selfShape.size())
SmallVector<Value> inShapeConverted = getTypeConvertedValues(
rewriter, op.getLoc(), getTypeConverter(), inShape);
Value result;
if (failed(broadcastToGivenShape(op, rewriter, self, inShapeConverted,
result))) {
return rewriter.notifyMatchFailure(
op, "invalid shape: must not be smaller than rank of tensor");
size_t diff = inShape.size() - selfShape.size();
// Create affine map and shapes for tensor initialization.
SmallVector<AffineExpr> outExpr;
Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
for (size_t i = 0; i < inShape.size(); i++) {
Value shapeValue = inShapeConverted[i];
size_t j = i - diff;
if (i < diff) {
Value isValid = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, shapeValue, zero);
rewriter.create<cf::AssertOp>(
loc, isValid,
rewriter.getStringAttr(
"negative values not allowed in new dimensions"));
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
continue;
}
if (selfShape[j] == 1) {
// Broadcast singleton dimension
Value one =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Value isNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shapeValue, zero);
Value select = rewriter.create<arith::SelectOp>(
loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue));
outShape.push_back(select);
outExpr.push_back(mlir::getAffineConstantExpr(0, context));
continue;
}
// Non-broadcast case
Value dim = getDimOp(rewriter, loc, self, j);
Value isNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shapeValue, zero);
Value isEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, castIndexToInt(rewriter, loc, dim),
shapeValue);
Value isValid = rewriter.create<arith::OrIOp>(loc, isNegative, isEqual);
rewriter.create<cf::AssertOp>(
loc, isValid,
rewriter.getStringAttr(
"only broadcasting singleton dimensions supported"));
outShape.push_back(dim);
outExpr.push_back(mlir::getAffineDimExpr(i, context));
op, "unable to perform broadcast operation");
}
Value outTensor =
rewriter.create<linalg::InitTensorOp>(loc, outShape, elementType);
SmallVector<AffineMap> indexingMaps = {
AffineMap::get(inShape.size(), 0, outExpr, context),
rewriter.getMultiDimIdentityMap(inShape.size())};
SmallVector<StringRef> iteratorTypes(inShape.size(), "parallel");
Value result = rewriter
.create<linalg::GenericOp>(
loc, outTensor.getType(), self, outTensor,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, result);
return success();
}
};
@ -992,6 +1012,74 @@ public:
};
} // namespace
namespace {
class ConvertValsemVariantAtenCopyOp
: public OpConversionPattern<ValsemVariantAtenCopyOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ValsemVariantAtenCopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Value self = adaptor.self();
Value src = adaptor.src();
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
// The non_blocking should be a constant `False`.
bool nonBlocking;
if (!matchPattern(op.non_blocking(), m_TorchConstantBool(&nonBlocking))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non_blocking must be a constant");
} else if (nonBlocking) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non_blocking is expected to be false");
}
// The size of the src tensor can be different from the self but should be
// broadcastable. Therefore, broadcasting the src tensor to match the size
// of the self tensor.
SmallVector<Value> selfSizes = getTensorSizes(rewriter, loc, self);
for (unsigned i = 0; i < selfSizes.size(); i++)
selfSizes[i] = castIndexToInt(rewriter, loc, selfSizes[i]);
Value broadcastedSrc;
if (failed(broadcastToGivenShape(op, rewriter, src, selfSizes,
broadcastedSrc))) {
return rewriter.notifyMatchFailure(
op, "unable to perform broadcast operation");
}
AffineMap id = AffineMap::getMultiDimIdentityMap(selfType.getRank(),
rewriter.getContext());
SmallVector<StringRef> iteratorTypes(selfType.getRank(),
getParallelIteratorTypeName());
Value result = rewriter
.create<linalg::GenericOp>(
loc,
/*resultType=*/selfType,
/*inputs=*/broadcastedSrc,
/*outputs=*/self,
/*indexingMaps=*/llvm::makeArrayRef({id, id}),
/*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
Value result = args[0];
if (args[0].getType() != args[1].getType()) {
result = convertScalarToDtype(b, loc, args[0],
args[1].getType());
}
b.create<linalg::YieldOp>(loc, result);
})
->getResult(0);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
@ -1018,4 +1106,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
target.addIllegalOp<AtenContiguousOp>();
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
target.addIllegalOp<ValsemVariantAtenCopyOp>();
patterns.add<ConvertValsemVariantAtenCopyOp>(typeConverter, context);
}

View File

@ -161,6 +161,9 @@ public:
} else if (isa<Aten_IndexPutImpl_Op>(op)) {
newOp = rewriter.create<ValsemVariantAtenIndexPutImplOp>(
loc, op->getResultTypes(), op->getOperands());
} else if (isa<AtenCopy_Op>(op)) {
newOp = rewriter.create<ValsemVariantAtenCopyOp>(
loc, op->getResultTypes(), op->getOperands());
} else {
return failure();
}
@ -241,6 +244,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
target.addIllegalOp<AtenBernoulli_TensorOp>();
target.addIllegalOp<AtenFill_ScalarOp>();
target.addIllegalOp<Aten_IndexPutImpl_Op>();
target.addIllegalOp<AtenCopy_Op>();
target.markUnknownOpDynamicallyLegal([](Operation *op) {
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
auto hasValueSemantics = [](Type t) {

View File

@ -517,7 +517,8 @@ ChangeResult TypeAnalyzer::visitOperation(
AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp,
AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp,
AtenConstantPadNdOp, AtenIndexTensorOp,
ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp>(op)) {
ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
ValsemVariantAtenCopyOp>(op)) {
ValueKnowledge knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = operands[0]->getValue().dtype;

View File

@ -1796,6 +1796,10 @@ module {
func @"__torch_mlir_shape_fn.aten.fill.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
return %arg0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.copy"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.uniform"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list<int> {
return %arg0 : !torch.list<int>
}

View File

@ -597,6 +597,10 @@ def atennew_ones(self: List[int], size: List[int], dtype: Optional[int] = Non
def atenfillScalar(self: List[int], value: float) -> List[int]:
return self
@not_present_in_registry
def atencopy(self: List[int], src: List[int], non_blocking: bool = False) -> List[int]:
return upstream_shape_helpers.unary(self)
@not_present_in_registry
def atenuniform(self: List[int], from_: float = 0., to: float = 1., generator: Any = None) -> List[int]:
return self

View File

@ -198,3 +198,20 @@ func @torch.aten._index_put_impl_(%self: !torch.tensor, %index: !torch.tensor, %
%ret = torch.aten._index_put_impl_ %self, %indicesList, %values, %true, %false : !torch.tensor, !torch.list<optional<tensor>>, !torch.tensor, !torch.bool, !torch.bool -> !torch.tensor
return %ret : !torch.tensor
}
// CHECK-LABEL: func @torch.aten.copy_(
// CHECK-SAME: %[[DST:.*]]: !torch.tensor,
// CHECK-SAME: %[[SRC:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[DST_VTENSOR:.*]] = torch.copy.to_vtensor %[[DST]] : !torch.vtensor
// CHECK: %[[SRC_VTENSOR:.*]] = torch.copy.to_vtensor %[[SRC]] : !torch.vtensor
// CHECK: %[[VRET:.*]] = torch.valsem.aten.copy %[[DST_VTENSOR]], %[[SRC_VTENSOR]], %[[FALSE]] : !torch.vtensor, !torch.vtensor, !torch.bool -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[DST]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[DST]] : !torch.tensor
func @torch.aten.copy_(%dst: !torch.tensor, %src : !torch.tensor) -> !torch.tensor {
%false = torch.constant.bool false
%ret = torch.aten.copy_ %dst, %src, %false : !torch.tensor, !torch.tensor, !torch.bool -> !torch.tensor
return %ret : !torch.tensor
}