[tosa] Support for Aten[Zeros|Ones|Fill_Scalar] ops (#604)

Signed-off-by: Anup Gangwar <anup.gangwar@arm.com>

Co-authored-by: Anup Gangwar <anup.gangwar@arm.com>
pull/590/head
Anup Gangwar 2022-02-16 11:53:51 -06:00 committed by GitHub
parent 126dac3ded
commit c60468f141
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 158 additions and 10 deletions

View File

@ -106,4 +106,14 @@ TOSA_PASS_SET = {
"ElementwiseSubScalarIntModule_basic",
"ElementwiseAddScalarIntModule_basic",
"ElementwiseMulScalarModule_basic",
"ZerosModuleDefaultDtype_basic",
"ZerosModuleInt2D_basic",
"ZerosModuleInt3D_basic",
"ZerosModuleFloat2D_basic",
"ZerosModuleFloat3D_basic",
"ZerosModuleFalsePinMemory_basic",
"OnesModuleDefaultDtype_basic",
"OnesModuleInt_basic",
"OnesModuleFloat_basic",
"OnesModuleFalsePinMemory_basic",
}

View File

@ -222,11 +222,9 @@ public:
"Integers with widths greater than 32 are not supported");
}
auto outType =
static_cast<Type>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()))
.cast<TensorType>();
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
@ -347,11 +345,9 @@ public:
if (!lhsType)
return op.emitError("Only Tensor types supported in TOSA");
auto outType =
static_cast<Type>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()))
.cast<TensorType>();
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat())
@ -2604,6 +2600,92 @@ public:
}
};
// Ref: Error checking based on the Torch to LinAlg lowering
template <typename AtenOpT, int fillVal>
class ConvertAtenConstPatternOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();
if (!outType)
return op.emitError("Only Tensor types supported in TOSA");
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat())
return op.emitError(
"Only floating-point or integer datatype legalization supported");
// FIXME: Handle layout, device and pin_memory. Assume dtype has been
// processed to set output type correctly?
if (!op.layout().getType().template isa<Torch::NoneType>())
return op.emitError("Only default layout is supported");
bool pinMemory;
if (!op.pin_memory().getType().template isa<Torch::NoneType>() &&
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
pinMemory)) {
return op.emitError(
"Unsupported pin_memory, should be either None or false");
}
SmallVector<int64_t> shape;
if (!matchPattern(op.size(), m_TorchConstantIntList(shape))) {
return op.emitError("Shape must be a list of Scalar constants");
}
int64_t size = 1;
for (auto s : shape)
size *= s;
SmallVector<int32_t> values(size, fillVal);
auto constOp =
tosa::getConstTensor<int32_t>(rewriter, op, values, shape).getValue();
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp);
return success();
}
};
template <typename AtenOpT>
class ConvertAtenFillScalarOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();
if (!outType || !outType.hasStaticShape())
return op.emitError(
"Only Tensor types with static shapes are currently supported");
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
return op.emitError(
"Only floating-point or integer datatype legalization supported");
}
Value constOp;
if (failed(torchScalarToTosaTensor(rewriter, op, op.value(), constOp,
outElemTy, outType.getShape())))
return op.emitError("Supplied value must be a Scalar constant");
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp);
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
@ -2765,6 +2847,20 @@ public:
tosa::AvgPool2dOp);
#undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
context);
INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1);
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
#undef INSERT_CONSTANT_FILL_PATTERN
#define INSERT_FILL_SCALAR_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenFillScalarOp<AtenOp>>(typeConverter, context);
INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp);
#undef INSERT_FILL_SCALAR_PATTERN
#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);

View File

@ -665,3 +665,45 @@ func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor
%0 = torch.aten.log2 %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[VAL_0:.*]] = torch.constant.int 4
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3
// CHECK: %[[VAL_2:.*]] = torch.constant.none
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<0> : tensor<3x4xi32>} : () -> tensor<3x4xi32>
// CHECK: %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_4]]) : (tensor<3x4xi32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32>
// CHECK: }
func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
%int4 = torch.constant.int 4
%int3 = torch.constant.int 3
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.aten.zeros %0, %none, %none, %none, %none : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],f32>
return %1 : !torch.vtensor<[3,4],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[VAL_0:.*]] = torch.constant.int 4
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3
// CHECK: %[[VAL_2:.*]] = torch.constant.none
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<1> : tensor<3x4xi32>} : () -> tensor<3x4xi32>
// CHECK: %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_4]]) : (tensor<3x4xi32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32>
// CHECK: }
func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
%int4 = torch.constant.int 4
%int3 = torch.constant.int 3
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.aten.ones %0, %none, %none, %none, %none : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],f32>
return %1 : !torch.vtensor<[3,4],f32>
}