[tosa] Support for Aten[Unsqueeze|Contiguous|Dropout|Reshape|View] ops (#700)

pull/702/head snapshot-20220325.346
Anup Gangwar 2022-03-25 16:15:07 -05:00 committed by GitHub
parent 6b637a9fd9
commit 5d7a6c2976
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 175 additions and 3 deletions

View File

@ -144,4 +144,12 @@ TOSA_PASS_SET = {
"SiluModule_basic",
"DropoutEvalIntModule_basic",
"DropoutEvalFloatModule_basic",
"ContiguousModule_basic",
"DropoutModule_basic",
"ViewExpandModule_basic",
"ViewCollapseInferredDimModule_basic",
"ViewExpandInferredDimModule_basic",
"ViewNoChangeStaticModule_basic",
"UnsafeViewExpandModule_basic",
"ReshapeCollapseModule_basic",
}

View File

@ -2329,6 +2329,116 @@ LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
AtenUnsqueezeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
if (!selfType) {
return op.emitError("Only tensor types are currently supported");
}
auto selfRank = selfType.getRank();
auto selfElemTy = selfType.getElementType();
if (!selfElemTy.isIntOrFloat()) {
return op.emitError(
"Only floating-point or integer datatype legalization supported");
}
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return op->emitError("dim must be a Scalar constant");
dim = toPositiveDim(dim, selfRank);
if (!isValidDim(dim, selfRank))
return op.emitError("dim is statically invalid");
SmallVector<int64_t> outShape;
for (auto en : llvm::enumerate(selfType.getShape())) {
if (static_cast<int64_t>(en.index()) == dim)
outShape.push_back(1);
outShape.push_back(en.value());
}
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
rewriter.getI64ArrayAttr(outShape));
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
AtenContiguousOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
if (!selfType)
return op.emitError("Only tensor types are currently supported");
// FIXME: memory_format is not handled.
rewriter.replaceOp(op, adaptor.self());
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenDropoutOp>::matchAndRewrite(
AtenDropoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto selfType = adaptor.input().getType().dyn_cast<TensorType>();
if (!selfType)
return op.emitError("Only tensor types are currently supported");
// FIXME: train and p are not handled.
bool train;
if (!matchPattern(op.train(), m_TorchConstantBool(&train)))
op.emitError("train must be a Scalar constant");
if (train)
op.emitError("train must be false");
rewriter.replaceOpWithNewOp<tosa::CastOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.input());
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite(
AtenViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
if (!selfType)
return op.emitError("Only tensor types are currently supported");
auto selfElemTy = selfType.getElementType();
if (!selfElemTy.isIntOrFloat()) {
return op.emitError(
"Only floating-point or integer datatype legalization supported");
}
SmallVector<int64_t> outShape;
if (!matchPattern(op.size(), m_TorchConstantIntList(outShape)))
return op.emitError("size must consist of Scalar constants");
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
rewriter.getI64ArrayAttr(outShape));
return success();
}
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
public:
@ -2879,6 +2989,10 @@ public:
INSERT_ATENOP_PATTERN(AtenPermuteOp);
INSERT_ATENOP_PATTERN(AtenLog2Op);
INSERT_ATENOP_PATTERN(AtenThresholdOp);
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);
INSERT_ATENOP_PATTERN(AtenDropoutOp);
INSERT_ATENOP_PATTERN(AtenViewOp);
#undef INSERT_ATENOP_PATTERN
if (failed(applyPartialConversion(getOperation(), target,

View File

@ -492,8 +492,8 @@ func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.v
// -----
// CHECK-LABEL: func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,4,3],f32>) -> !torch.vtensor<[10,4,3],f32> {
// CHECK-LABEL: func @torch.aten.native_batch_norm$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,4,3],f32>) -> !torch.vtensor<[10,4,3],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,4,3],f32> -> tensor<10x4x3xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.const"() {value = dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>} : () -> tensor<4xf32>
// CHECK: %[[VAL_3:.*]] = "tosa.const"() {value = dense<[3.000000e+00, 2.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<4xf32>} : () -> tensor<4xf32>
@ -515,7 +515,7 @@ func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.v
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32>
// CHECK: return %[[VAL_19]] : !torch.vtensor<[10,4,3],f32>
// CHECK: }
func @forward(%arg0: !torch.vtensor<[10,4,3],f32> ) -> !torch.vtensor<[10,4,3],f32> {
func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32> ) -> !torch.vtensor<[10,4,3],f32> {
%0 = torch.vtensor.literal(dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%1 = torch.vtensor.literal(dense<[3.000000e+00, 2.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%float1.000000e-01 = torch.constant.float 1.000000e-01
@ -689,6 +689,38 @@ func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
// -----
// CHECK-LABEL: func @torch.aten.unsqueeze$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,1,3],si32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [4, 1, 3]} : (tensor<4x3xi32>) -> tensor<4x1x3xi32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x1x3xi32> -> !torch.vtensor<[4,1,3],si32>
// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,1,3],si32>
// CHECK: }
func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !torch.vtensor<[4,1,3],si32> {
%int1 = torch.constant.int 1
%0 = torch.aten.unsqueeze %arg0, %int1 : !torch.vtensor<[4,3],si32>, !torch.int -> !torch.vtensor<[4,1,3],si32>
return %0 : !torch.vtensor<[4,1,3],si32>
}
// -----
// CHECK-LABEL: func @torch.aten.contiguous$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.contiguous$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
%int0 = torch.constant.int 0
%0 = torch.aten.contiguous %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[VAL_0:.*]] = torch.constant.int 4
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3
@ -707,3 +739,21 @@ func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
%1 = torch.aten.ones %0, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],f32>
return %1 : !torch.vtensor<[3,4],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.dropout$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[VAL_3:.*]] = torch.constant.bool false
// CHECK: %[[VAL_4:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
%float0.000000e00 = torch.constant.float 0.000000e+00
%false = torch.constant.bool false
%0 = torch.aten.dropout %arg0, %float0.000000e00, %false : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}