[tosa] Add lowering for aten.to.dtype and aten._to_copy op

This commit adds the TorchToTosa lowering for `aten.to.dtype` and
`aten._to_copy` op.

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
pull/1467/head snapshot-20221006.618
Vivek Khandelwal 2022-10-04 18:35:59 +05:30
parent 56f9a9b5de
commit d3cc3f1aff
10 changed files with 269 additions and 7 deletions

View File

@ -464,6 +464,9 @@ TOSA_PASS_SET = {
"ArangeStartNegativeStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"NumToTensorIntModule_basic",
"ToDtypeBoolLayoutNoneStaticModule_basic",
"ToCopyBoolDTypeStaticModule_basic",
"HardTanhIntModule_basic",
}
LTC_XFAIL_SET = {

View File

@ -53,6 +53,9 @@ template <typename T>
llvm::Optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
ArrayRef<T> vec, ArrayRef<int64_t> shape);
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
Value src, Type destType, Value &result);
// Creates a TOSA operation and performs shape inference on the individual
// op. This allows shape inference during the framework to TOSA lowering.
template <typename TosaOp, typename... Args>

View File

@ -52,6 +52,9 @@ int getTensorRank(Value tensor);
bool isViewLikeOp(Operation *op);
Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc,
float value, Type dtype);
} // namespace Torch
} // namespace torch
} // namespace mlir

View File

@ -3078,6 +3078,110 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<ValsemVariantAtenCopyOp>::matchAndRewrite(
ValsemVariantAtenCopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
auto srcType = adaptor.src().getType().dyn_cast<TensorType>();
if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported");
if (!srcType || !srcType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported");
// The non_blocking should be a constant `False`.
bool nonBlocking;
if (!matchPattern(op.non_blocking(), m_TorchConstantBool(&nonBlocking))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non_blocking must be a constant");
} else if (nonBlocking) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non_blocking is expected to be false");
}
SmallVector<int64_t> selfShape(selfType.getShape());
SmallVector<int64_t> srcShape(srcType.getShape());
if (llvm::equal(selfShape, srcShape) || selfShape.size() == 0) {
// If we reach here, then it means the given case is handled by implicit
// broadcasting done by tosa.
Value result;
if (failed(tosa::tosaCastTensorToType(
rewriter, op, adaptor.src(),
getTypeConverter()->convertType(op.getType()), result)))
return rewriter.notifyMatchFailure(
op, "unimplemented: cast to result type not supported");
rewriter.replaceOp(op, result);
return success();
}
return rewriter.notifyMatchFailure(
op, "unimplemented: valsem.aten.copy op not supported for this case.");
}
// Legalizes the torch.aten.to.dtype op
template <>
LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
AtenToDtypeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported");
// The non_blocking arg should be a constant `False`.
bool nonBlocking;
if (!matchPattern(op.non_blocking(), m_TorchConstantBool(&nonBlocking))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non_blocking arg must be a constant");
} else if (nonBlocking) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non_blocking arg is expected to be false");
}
// The copy arg should be a constant `False`.
bool copy;
if (!matchPattern(op.copy(), m_TorchConstantBool(&copy))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: copy arg must be a constant");
} else if (copy) {
return rewriter.notifyMatchFailure(
op, "unimplemented: copy arg is expected to be false");
}
// Only `none`, `contiguous` and `preserve` memory_format is supported.
if (!op.memory_format().getType().isa<Torch::NoneType>()) {
int64_t memoryFormat;
if (!matchPattern(op.memory_format(), m_TorchConstantInt(&memoryFormat)))
return rewriter.notifyMatchFailure(
op, "unimplemented: the memory format should be specified in "
"an integer constant");
if (memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
memoryFormat != torch_upstream::MemoryFormat::Preserve)
return rewriter.notifyMatchFailure(
op, "unimplemented: only none, contiguous and preserve "
"memory_format is supported");
}
auto resultTy = getTypeConverter()
->convertType(op.getResult().getType())
.cast<RankedTensorType>();
Value result;
if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.self(), resultTy,
result)))
return rewriter.notifyMatchFailure(op, "conversion to result type failed");
rewriter.replaceOp(op, result);
return success();
}
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
public:
@ -3728,6 +3832,8 @@ public:
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(ValsemVariantAtenCopyOp);
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

View File

@ -221,6 +221,64 @@ llvm::Optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
return const_op.getResult();
}
static LogicalResult checkValidityOfCast(Type src, Type dest) {
if ((src.isInteger(64) && dest.isInteger(32)) ||
(src.isInteger(32) && dest.isInteger(64)) ||
(src.isInteger(64) && dest.isInteger(1)) ||
(src.isInteger(32) && dest.isInteger(1)) ||
(src.isInteger(8) && dest.isInteger(1)) ||
(src.isF32() && dest.isInteger(1))) {
return success();
}
return failure();
}
// Template specialization for float
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
Value src, Type destType, Value &result) {
Type srcElemTy = src.getType().dyn_cast<TensorType>().getElementType();
Type destElemTy = destType.dyn_cast<TensorType>().getElementType();
if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
return rewriter.notifyMatchFailure(
op, "casting to result dtype is invalid or unsupported");
if (destElemTy.isInteger(1)) {
auto srcType = src.getType().dyn_cast<TensorType>();
SmallVector<int64_t> srcShape(srcType.getShape());
uint64_t num_total_elements = 1;
for (int64_t a : srcShape)
num_total_elements *= a;
llvm::Optional<Value> constOp;
if (srcElemTy.isInteger(64)) {
SmallVector<int64_t> values(num_total_elements, 0);
constOp =
tosa::getConstTensor<int64_t>(rewriter, op, values, srcShape).value();
} else if (srcElemTy.isInteger(32)) {
SmallVector<int32_t> values(num_total_elements, 0);
constOp =
tosa::getConstTensor<int32_t>(rewriter, op, values, srcShape).value();
} else if (srcElemTy.isF32()) {
SmallVector<float> values(num_total_elements, 0.0);
constOp =
tosa::getConstTensor<float>(rewriter, op, values, srcShape).value();
} else if (srcElemTy.isInteger(8)) {
SmallVector<int8_t> values(num_total_elements, 0);
constOp =
tosa::getConstTensor<int8_t>(rewriter, op, values, srcShape).value();
}
Value equalToZero = rewriter.create<tosa::EqualOp>(op->getLoc(), destType,
src, constOp.value());
result = rewriter.create<tosa::LogicalNotOp>(op->getLoc(), destType,
equalToZero);
} else {
result = rewriter.create<tosa::CastOp>(op->getLoc(), destType, src);
}
return success();
}
// Template instantiation
template llvm::Optional<Value> getConstTensor<int32_t>(PatternRewriter &,
Operation *,

View File

@ -2200,8 +2200,9 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_ToCopyOp op,
PatternRewriter &rewriter) const override {
Value zero = rewriter.create<ConstantFloatOp>(
op.getLoc(), rewriter.getF64FloatAttr(0.0));
Type resultDtype = op.getType().cast<BaseTensorType>().getDtype();
Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0,
resultDtype);
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
op.getLoc(), op.getType(), op.self(), zero, op.dtype(), op.layout(),
op.device(), op.pin_memory(), op.memory_format());

View File

@ -163,3 +163,18 @@ bool Torch::isViewLikeOp(Operation *op) {
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
AtenNarrowOp, AtenToDeviceOp>(op);
}
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
Location loc, float value,
Type dtype) {
// Creating constants satisfying backend contract.
if (dtype.isInteger(64) || dtype.isInteger(32) || dtype.isInteger(8) ||
dtype.isInteger(1))
return rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr((int64_t)value));
if (dtype.isF64() || dtype.isF32() || dtype.isF16() || dtype.isBF16())
return rewriter.create<ConstantFloatOp>(loc,
rewriter.getF64FloatAttr(value));
llvm::report_fatal_error(
"unhandled type for getConstantWithGivenDtypeAndValue");
}

View File

@ -2513,6 +2513,25 @@ def ToCopyWithDTypeFalsePinMemoryModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4))
class ToCopyBoolDTypeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, 1, 5, 5], torch.uint8, True),
])
def forward(self, x):
return torch.ops.aten._to_copy(x, dtype=torch.bool)
@register_test_case(module_factory=lambda: ToCopyBoolDTypeStaticModule())
def ToCopyBoolDTypeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(1, 1, 5, 5).to(dtype=torch.uint8))
# ==============================================================================

View File

@ -193,13 +193,13 @@ def ToDtypeLayoutStridedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
class ToDtypeBoolLayoutNoneModule(torch.nn.Module):
class ToDtypeBoolLayoutNoneStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([-1, -1], torch.float32, True)])
@annotate_args([None, ([3, 5], torch.int64, True)])
def forward(self, x):
return torch.ops.aten.to(x,
dtype=torch.bool,
@ -211,9 +211,9 @@ class ToDtypeBoolLayoutNoneModule(torch.nn.Module):
memory_format=None)
@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneModule())
def ToDtypeBoolLayoutNoneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneStaticModule())
def ToDtypeBoolLayoutNoneStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 5))
class TypeAsSameModule(torch.nn.Module):

View File

@ -859,3 +859,57 @@ func.func @torch.prim.NumToTensor.Scalar() -> !torch.vtensor<[],si64> {
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
return %0 : !torch.vtensor<[],si64>
}
// -----
// CHECK-LABEL: func.func @torch.valsem.aten.copy(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> {
// CHECK: %[[INP:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,1,5,5],ui8> -> tensor<1x1x5x5xi8>
// CHECK: %[[CST5:.*]] = torch.constant.int 5
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[CST11:.*]] = torch.constant.int 11
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
// CHECK: %[[VAL_1:.*]] = "tosa.const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
// CHECK: %[[VAL_2:.*]] = "tosa.equal"(%[[VAL_0]], %[[VAL_1]]) : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %[[VAL_3:.*]] = "tosa.logical_not"(%[[VAL_2]]) : (tensor<i1>) -> tensor<i1>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<0> : tensor<1x1x5x5xi8>} : () -> tensor<1x1x5x5xi8>
// CHECK: %[[VAL_5:.*]] = "tosa.equal"(%[[INP]], %[[VAL_4]]) : (tensor<1x1x5x5xi8>, tensor<1x1x5x5xi8>) -> tensor<1x1x5x5xi1>
// CHECK: %[[VAL_6:.*]] = "tosa.logical_not"(%[[VAL_5]]) : (tensor<1x1x5x5xi1>) -> tensor<1x1x5x5xi1>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x1x5x5xi1> -> !torch.vtensor<[1,1,5,5],i1>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,1,5,5],i1>
func.func @torch.valsem.aten.copy(%arg0: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> {
%int5 = torch.constant.int 5
%int1 = torch.constant.int 1
%int11 = torch.constant.int 11
%none = torch.constant.none
%false = torch.constant.bool false
%int0 = torch.constant.int 0
%0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%1 = torch.aten.to.dtype %0, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
%2 = torch.prim.ListConstruct %int1, %int1, %int5, %int5 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.broadcast_to %1, %2 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,5,5],i1>
%4 = torch.valsem.aten.copy %3, %arg0, %false : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,1,5,5],ui8>, !torch.bool -> !torch.vtensor<[1,1,5,5],i1>
return %4 : !torch.vtensor<[1,1,5,5],i1>
}
// -----
// CHECK-LABEL: func.func @torch.aten.to.dtype(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> {
// CHECK: %[[INP:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[3,5],si64> -> tensor<3x5xi64>
// CHECK: %[[CST11:.*]] = torch.constant.int 11
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<0> : tensor<3x5xi64>} : () -> tensor<3x5xi64>
// CHECK: %[[VAL_1:.*]] = "tosa.equal"(%[[INP]], %[[VAL_0]]) : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1>
// CHECK: %[[VAL_2:.*]] = "tosa.logical_not"(%[[VAL_1]]) : (tensor<3x5xi1>) -> tensor<3x5xi1>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,5],i1>
func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> {
%int11 = torch.constant.int 11
%none = torch.constant.none
%false = torch.constant.bool false
%0 = torch.aten.to.dtype %arg0, %int11, %false, %false, %none : !torch.vtensor<[3,5],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],i1>
return %0 : !torch.vtensor<[3,5],i1>
}