mirror of https://github.com/llvm/torch-mlir
[tosa] Support for Aten[Zeros|Ones|Fill_Scalar] ops (#604)
Signed-off-by: Anup Gangwar <anup.gangwar@arm.com> Co-authored-by: Anup Gangwar <anup.gangwar@arm.com>pull/590/head
parent
126dac3ded
commit
c60468f141
|
@ -106,4 +106,14 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseSubScalarIntModule_basic",
|
||||
"ElementwiseAddScalarIntModule_basic",
|
||||
"ElementwiseMulScalarModule_basic",
|
||||
"ZerosModuleDefaultDtype_basic",
|
||||
"ZerosModuleInt2D_basic",
|
||||
"ZerosModuleInt3D_basic",
|
||||
"ZerosModuleFloat2D_basic",
|
||||
"ZerosModuleFloat3D_basic",
|
||||
"ZerosModuleFalsePinMemory_basic",
|
||||
"OnesModuleDefaultDtype_basic",
|
||||
"OnesModuleInt_basic",
|
||||
"OnesModuleFloat_basic",
|
||||
"OnesModuleFalsePinMemory_basic",
|
||||
}
|
||||
|
|
|
@ -222,11 +222,9 @@ public:
|
|||
"Integers with widths greater than 32 are not supported");
|
||||
}
|
||||
|
||||
auto outType =
|
||||
static_cast<Type>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()))
|
||||
.cast<TensorType>();
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat()) {
|
||||
|
@ -347,11 +345,9 @@ public:
|
|||
if (!lhsType)
|
||||
return op.emitError("Only Tensor types supported in TOSA");
|
||||
|
||||
auto outType =
|
||||
static_cast<Type>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()))
|
||||
.cast<TensorType>();
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat())
|
||||
|
@ -2604,6 +2600,92 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// Ref: Error checking based on the Torch to LinAlg lowering
|
||||
template <typename AtenOpT, int fillVal>
|
||||
class ConvertAtenConstPatternOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<TensorType>();
|
||||
|
||||
if (!outType)
|
||||
return op.emitError("Only Tensor types supported in TOSA");
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat())
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
|
||||
// FIXME: Handle layout, device and pin_memory. Assume dtype has been
|
||||
// processed to set output type correctly?
|
||||
if (!op.layout().getType().template isa<Torch::NoneType>())
|
||||
return op.emitError("Only default layout is supported");
|
||||
|
||||
bool pinMemory;
|
||||
if (!op.pin_memory().getType().template isa<Torch::NoneType>() &&
|
||||
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
|
||||
pinMemory)) {
|
||||
return op.emitError(
|
||||
"Unsupported pin_memory, should be either None or false");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> shape;
|
||||
if (!matchPattern(op.size(), m_TorchConstantIntList(shape))) {
|
||||
return op.emitError("Shape must be a list of Scalar constants");
|
||||
}
|
||||
|
||||
int64_t size = 1;
|
||||
for (auto s : shape)
|
||||
size *= s;
|
||||
|
||||
SmallVector<int32_t> values(size, fillVal);
|
||||
auto constOp =
|
||||
tosa::getConstTensor<int32_t>(rewriter, op, values, shape).getValue();
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenFillScalarOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<TensorType>();
|
||||
|
||||
if (!outType || !outType.hasStaticShape())
|
||||
return op.emitError(
|
||||
"Only Tensor types with static shapes are currently supported");
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat()) {
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
Value constOp;
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.value(), constOp,
|
||||
outElemTy, outType.getShape())))
|
||||
return op.emitError("Supplied value must be a Scalar constant");
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -2765,6 +2847,20 @@ public:
|
|||
tosa::AvgPool2dOp);
|
||||
#undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
|
||||
context);
|
||||
INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1);
|
||||
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
||||
#undef INSERT_CONSTANT_FILL_PATTERN
|
||||
|
||||
#define INSERT_FILL_SCALAR_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenFillScalarOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp);
|
||||
#undef INSERT_FILL_SCALAR_PATTERN
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||
|
|
|
@ -665,3 +665,45 @@ func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor
|
|||
%0 = torch.aten.log2 %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
|
||||
// CHECK: %[[VAL_0:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<0> : tensor<3x4xi32>} : () -> tensor<3x4xi32>
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_4]]) : (tensor<3x4xi32>) -> tensor<3x4xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
|
||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32>
|
||||
// CHECK: }
|
||||
func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
|
||||
%int4 = torch.constant.int 4
|
||||
%int3 = torch.constant.int 3
|
||||
%none = torch.constant.none
|
||||
%0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%1 = torch.aten.zeros %0, %none, %none, %none, %none : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],f32>
|
||||
return %1 : !torch.vtensor<[3,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
|
||||
// CHECK: %[[VAL_0:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<1> : tensor<3x4xi32>} : () -> tensor<3x4xi32>
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_4]]) : (tensor<3x4xi32>) -> tensor<3x4xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
|
||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32>
|
||||
// CHECK: }
|
||||
func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
|
||||
%int4 = torch.constant.int 4
|
||||
%int3 = torch.constant.int 3
|
||||
%none = torch.constant.none
|
||||
%0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%1 = torch.aten.ones %0, %none, %none, %none, %none : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],f32>
|
||||
return %1 : !torch.vtensor<[3,4],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue