mirror of https://github.com/llvm/torch-mlir
[tosa] Add lowering for aten.to.dtype and aten._to_copy op
This commit adds the TorchToTosa lowering for `aten.to.dtype` and `aten._to_copy` op. Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>pull/1467/head snapshot-20221006.618
parent
56f9a9b5de
commit
d3cc3f1aff
|
@ -464,6 +464,9 @@ TOSA_PASS_SET = {
|
|||
"ArangeStartNegativeStepIntModule_basic",
|
||||
"ArangeZeroElementOutputModule_basic",
|
||||
"NumToTensorIntModule_basic",
|
||||
"ToDtypeBoolLayoutNoneStaticModule_basic",
|
||||
"ToCopyBoolDTypeStaticModule_basic",
|
||||
"HardTanhIntModule_basic",
|
||||
}
|
||||
|
||||
LTC_XFAIL_SET = {
|
||||
|
|
|
@ -53,6 +53,9 @@ template <typename T>
|
|||
llvm::Optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<T> vec, ArrayRef<int64_t> shape);
|
||||
|
||||
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
|
||||
Value src, Type destType, Value &result);
|
||||
|
||||
// Creates a TOSA operation and performs shape inference on the individual
|
||||
// op. This allows shape inference during the framework to TOSA lowering.
|
||||
template <typename TosaOp, typename... Args>
|
||||
|
|
|
@ -52,6 +52,9 @@ int getTensorRank(Value tensor);
|
|||
|
||||
bool isViewLikeOp(Operation *op);
|
||||
|
||||
Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc,
|
||||
float value, Type dtype);
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -3078,6 +3078,110 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<ValsemVariantAtenCopyOp>::matchAndRewrite(
|
||||
ValsemVariantAtenCopyOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||
auto srcType = adaptor.src().getType().dyn_cast<TensorType>();
|
||||
if (!selfType || !selfType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types with static shape are supported");
|
||||
|
||||
if (!srcType || !srcType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types with static shape are supported");
|
||||
|
||||
// The non_blocking should be a constant `False`.
|
||||
bool nonBlocking;
|
||||
if (!matchPattern(op.non_blocking(), m_TorchConstantBool(&nonBlocking))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non_blocking must be a constant");
|
||||
} else if (nonBlocking) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non_blocking is expected to be false");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> selfShape(selfType.getShape());
|
||||
SmallVector<int64_t> srcShape(srcType.getShape());
|
||||
|
||||
if (llvm::equal(selfShape, srcShape) || selfShape.size() == 0) {
|
||||
// If we reach here, then it means the given case is handled by implicit
|
||||
// broadcasting done by tosa.
|
||||
Value result;
|
||||
if (failed(tosa::tosaCastTensorToType(
|
||||
rewriter, op, adaptor.src(),
|
||||
getTypeConverter()->convertType(op.getType()), result)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: cast to result type not supported");
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: valsem.aten.copy op not supported for this case.");
|
||||
}
|
||||
|
||||
// Legalizes the torch.aten.to.dtype op
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
||||
AtenToDtypeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||
if (!selfType || !selfType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types with static shape are supported");
|
||||
|
||||
// The non_blocking arg should be a constant `False`.
|
||||
bool nonBlocking;
|
||||
if (!matchPattern(op.non_blocking(), m_TorchConstantBool(&nonBlocking))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non_blocking arg must be a constant");
|
||||
} else if (nonBlocking) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non_blocking arg is expected to be false");
|
||||
}
|
||||
|
||||
// The copy arg should be a constant `False`.
|
||||
bool copy;
|
||||
if (!matchPattern(op.copy(), m_TorchConstantBool(©))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: copy arg must be a constant");
|
||||
} else if (copy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: copy arg is expected to be false");
|
||||
}
|
||||
|
||||
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
||||
if (!op.memory_format().getType().isa<Torch::NoneType>()) {
|
||||
int64_t memoryFormat;
|
||||
if (!matchPattern(op.memory_format(), m_TorchConstantInt(&memoryFormat)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: the memory format should be specified in "
|
||||
"an integer constant");
|
||||
if (memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
|
||||
memoryFormat != torch_upstream::MemoryFormat::Preserve)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only none, contiguous and preserve "
|
||||
"memory_format is supported");
|
||||
}
|
||||
|
||||
auto resultTy = getTypeConverter()
|
||||
->convertType(op.getResult().getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
Value result;
|
||||
if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.self(), resultTy,
|
||||
result)))
|
||||
return rewriter.notifyMatchFailure(op, "conversion to result type failed");
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
|
@ -3728,6 +3832,8 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
||||
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
|
||||
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
||||
INSERT_ATENOP_PATTERN(ValsemVariantAtenCopyOp);
|
||||
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -221,6 +221,64 @@ llvm::Optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
|||
return const_op.getResult();
|
||||
}
|
||||
|
||||
static LogicalResult checkValidityOfCast(Type src, Type dest) {
|
||||
if ((src.isInteger(64) && dest.isInteger(32)) ||
|
||||
(src.isInteger(32) && dest.isInteger(64)) ||
|
||||
(src.isInteger(64) && dest.isInteger(1)) ||
|
||||
(src.isInteger(32) && dest.isInteger(1)) ||
|
||||
(src.isInteger(8) && dest.isInteger(1)) ||
|
||||
(src.isF32() && dest.isInteger(1))) {
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Template specialization for float
|
||||
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
|
||||
Value src, Type destType, Value &result) {
|
||||
|
||||
Type srcElemTy = src.getType().dyn_cast<TensorType>().getElementType();
|
||||
Type destElemTy = destType.dyn_cast<TensorType>().getElementType();
|
||||
|
||||
if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "casting to result dtype is invalid or unsupported");
|
||||
|
||||
if (destElemTy.isInteger(1)) {
|
||||
auto srcType = src.getType().dyn_cast<TensorType>();
|
||||
SmallVector<int64_t> srcShape(srcType.getShape());
|
||||
uint64_t num_total_elements = 1;
|
||||
for (int64_t a : srcShape)
|
||||
num_total_elements *= a;
|
||||
|
||||
llvm::Optional<Value> constOp;
|
||||
if (srcElemTy.isInteger(64)) {
|
||||
SmallVector<int64_t> values(num_total_elements, 0);
|
||||
constOp =
|
||||
tosa::getConstTensor<int64_t>(rewriter, op, values, srcShape).value();
|
||||
} else if (srcElemTy.isInteger(32)) {
|
||||
SmallVector<int32_t> values(num_total_elements, 0);
|
||||
constOp =
|
||||
tosa::getConstTensor<int32_t>(rewriter, op, values, srcShape).value();
|
||||
} else if (srcElemTy.isF32()) {
|
||||
SmallVector<float> values(num_total_elements, 0.0);
|
||||
constOp =
|
||||
tosa::getConstTensor<float>(rewriter, op, values, srcShape).value();
|
||||
} else if (srcElemTy.isInteger(8)) {
|
||||
SmallVector<int8_t> values(num_total_elements, 0);
|
||||
constOp =
|
||||
tosa::getConstTensor<int8_t>(rewriter, op, values, srcShape).value();
|
||||
}
|
||||
Value equalToZero = rewriter.create<tosa::EqualOp>(op->getLoc(), destType,
|
||||
src, constOp.value());
|
||||
result = rewriter.create<tosa::LogicalNotOp>(op->getLoc(), destType,
|
||||
equalToZero);
|
||||
} else {
|
||||
result = rewriter.create<tosa::CastOp>(op->getLoc(), destType, src);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// Template instantiation
|
||||
template llvm::Optional<Value> getConstTensor<int32_t>(PatternRewriter &,
|
||||
Operation *,
|
||||
|
|
|
@ -2200,8 +2200,9 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten_ToCopyOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value zero = rewriter.create<ConstantFloatOp>(
|
||||
op.getLoc(), rewriter.getF64FloatAttr(0.0));
|
||||
Type resultDtype = op.getType().cast<BaseTensorType>().getDtype();
|
||||
Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0,
|
||||
resultDtype);
|
||||
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
|
||||
op.getLoc(), op.getType(), op.self(), zero, op.dtype(), op.layout(),
|
||||
op.device(), op.pin_memory(), op.memory_format());
|
||||
|
|
|
@ -163,3 +163,18 @@ bool Torch::isViewLikeOp(Operation *op) {
|
|||
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
|
||||
AtenNarrowOp, AtenToDeviceOp>(op);
|
||||
}
|
||||
|
||||
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
|
||||
Location loc, float value,
|
||||
Type dtype) {
|
||||
// Creating constants satisfying backend contract.
|
||||
if (dtype.isInteger(64) || dtype.isInteger(32) || dtype.isInteger(8) ||
|
||||
dtype.isInteger(1))
|
||||
return rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr((int64_t)value));
|
||||
if (dtype.isF64() || dtype.isF32() || dtype.isF16() || dtype.isBF16())
|
||||
return rewriter.create<ConstantFloatOp>(loc,
|
||||
rewriter.getF64FloatAttr(value));
|
||||
llvm::report_fatal_error(
|
||||
"unhandled type for getConstantWithGivenDtypeAndValue");
|
||||
}
|
||||
|
|
|
@ -2513,6 +2513,25 @@ def ToCopyWithDTypeFalsePinMemoryModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(3, 2, 4))
|
||||
|
||||
|
||||
class ToCopyBoolDTypeStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 1, 5, 5], torch.uint8, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten._to_copy(x, dtype=torch.bool)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ToCopyBoolDTypeStaticModule())
|
||||
def ToCopyBoolDTypeStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(1, 1, 5, 5).to(dtype=torch.uint8))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
|
|
@ -193,13 +193,13 @@ def ToDtypeLayoutStridedModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
class ToDtypeBoolLayoutNoneModule(torch.nn.Module):
|
||||
class ToDtypeBoolLayoutNoneStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||
@annotate_args([None, ([3, 5], torch.int64, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.to(x,
|
||||
dtype=torch.bool,
|
||||
|
@ -211,9 +211,9 @@ class ToDtypeBoolLayoutNoneModule(torch.nn.Module):
|
|||
memory_format=None)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneModule())
|
||||
def ToDtypeBoolLayoutNoneModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneStaticModule())
|
||||
def ToDtypeBoolLayoutNoneStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 5))
|
||||
|
||||
|
||||
class TypeAsSameModule(torch.nn.Module):
|
||||
|
|
|
@ -859,3 +859,57 @@ func.func @torch.prim.NumToTensor.Scalar() -> !torch.vtensor<[],si64> {
|
|||
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
|
||||
return %0 : !torch.vtensor<[],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.valsem.aten.copy(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> {
|
||||
// CHECK: %[[INP:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,1,5,5],ui8> -> tensor<1x1x5x5xi8>
|
||||
// CHECK: %[[CST5:.*]] = torch.constant.int 5
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[CST11:.*]] = torch.constant.int 11
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK: %[[VAL_1:.*]] = "tosa.const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.equal"(%[[VAL_0]], %[[VAL_1]]) : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.logical_not"(%[[VAL_2]]) : (tensor<i1>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<0> : tensor<1x1x5x5xi8>} : () -> tensor<1x1x5x5xi8>
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.equal"(%[[INP]], %[[VAL_4]]) : (tensor<1x1x5x5xi8>, tensor<1x1x5x5xi8>) -> tensor<1x1x5x5xi1>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.logical_not"(%[[VAL_5]]) : (tensor<1x1x5x5xi1>) -> tensor<1x1x5x5xi1>
|
||||
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x1x5x5xi1> -> !torch.vtensor<[1,1,5,5],i1>
|
||||
// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,1,5,5],i1>
|
||||
func.func @torch.valsem.aten.copy(%arg0: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> {
|
||||
%int5 = torch.constant.int 5
|
||||
%int1 = torch.constant.int 1
|
||||
%int11 = torch.constant.int 11
|
||||
%none = torch.constant.none
|
||||
%false = torch.constant.bool false
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
|
||||
%1 = torch.aten.to.dtype %0, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
|
||||
%2 = torch.prim.ListConstruct %int1, %int1, %int5, %int5 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.broadcast_to %1, %2 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,5,5],i1>
|
||||
%4 = torch.valsem.aten.copy %3, %arg0, %false : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,1,5,5],ui8>, !torch.bool -> !torch.vtensor<[1,1,5,5],i1>
|
||||
return %4 : !torch.vtensor<[1,1,5,5],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.to.dtype(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> {
|
||||
// CHECK: %[[INP:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[3,5],si64> -> tensor<3x5xi64>
|
||||
// CHECK: %[[CST11:.*]] = torch.constant.int 11
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<0> : tensor<3x5xi64>} : () -> tensor<3x5xi64>
|
||||
// CHECK: %[[VAL_1:.*]] = "tosa.equal"(%[[INP]], %[[VAL_0]]) : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.logical_not"(%[[VAL_1]]) : (tensor<3x5xi1>) -> tensor<3x5xi1>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1>
|
||||
// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,5],i1>
|
||||
func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> {
|
||||
%int11 = torch.constant.int 11
|
||||
%none = torch.constant.none
|
||||
%false = torch.constant.bool false
|
||||
%0 = torch.aten.to.dtype %arg0, %int11, %false, %false, %none : !torch.vtensor<[3,5],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],i1>
|
||||
return %0 : !torch.vtensor<[3,5],i1>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue