[tosa] Support for Torch.squeeze (#487)

pull/483/head snapshot-20211216.147
Suraj Sudhir 2021-12-15 21:40:29 -08:00 committed by GitHub
parent 396ab35c9d
commit 0cd95b5c68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 0 deletions

View File

@ -38,4 +38,7 @@ TOSA_PASS_SET = {
"BoolTensorReturnTrueModule_basic",
"BoolTensorReturnMixedModule_basic",
"ElementwiseRsqrtModule_basic",
"SqueezeModule_static",
"SqueezeModule_noUnitDim",
"SqueezeModule_allUnitDim",
}

View File

@ -526,6 +526,41 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
AtenSqueezeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.self();
auto selfTy = self.getType().template cast<RankedTensorType>();
if (!selfTy)
return op.emitError("Only ranked tensor types supported in TOSA argmax");
auto selfShape = selfTy.getShape();
SmallVector<int64_t> newOutputShape;
for (auto &dim : selfShape) {
if (dim != 1)
newOutputShape.push_back(dim);
}
auto resultTy = getTypeConverter()
->convertType(op.getResult().getType())
.cast<RankedTensorType>();
auto resultElemTy = resultTy.getElementType();
auto newOutputTy = RankedTensorType::get(newOutputShape, resultElemTy);
auto reshapeOp = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), getTypeConverter()->convertType(newOutputTy), self,
rewriter.getI64ArrayAttr(newOutputShape));
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(newOutputTy), reshapeOp);
return success();
}
} // namespace
// -----------------------------------------------------------------------------
@ -624,6 +659,7 @@ public:
INSERT_ATENOP_PATTERN(AtenMulTensorOp);
INSERT_ATENOP_PATTERN(AtenDivTensorOp);
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
INSERT_ATENOP_PATTERN(AtenSqueezeOp);
#undef INSERT_ATENOP_PATTERN
if (failed(applyPartialConversion(getOperation(), target,