mirror of https://github.com/llvm/torch-mlir
parent
396ab35c9d
commit
0cd95b5c68
|
@ -38,4 +38,7 @@ TOSA_PASS_SET = {
|
||||||
"BoolTensorReturnTrueModule_basic",
|
"BoolTensorReturnTrueModule_basic",
|
||||||
"BoolTensorReturnMixedModule_basic",
|
"BoolTensorReturnMixedModule_basic",
|
||||||
"ElementwiseRsqrtModule_basic",
|
"ElementwiseRsqrtModule_basic",
|
||||||
|
"SqueezeModule_static",
|
||||||
|
"SqueezeModule_noUnitDim",
|
||||||
|
"SqueezeModule_allUnitDim",
|
||||||
}
|
}
|
||||||
|
|
|
@ -526,6 +526,41 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
|
||||||
|
AtenSqueezeOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
|
Value self = adaptor.self();
|
||||||
|
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||||
|
|
||||||
|
if (!selfTy)
|
||||||
|
return op.emitError("Only ranked tensor types supported in TOSA argmax");
|
||||||
|
|
||||||
|
auto selfShape = selfTy.getShape();
|
||||||
|
|
||||||
|
SmallVector<int64_t> newOutputShape;
|
||||||
|
for (auto &dim : selfShape) {
|
||||||
|
if (dim != 1)
|
||||||
|
newOutputShape.push_back(dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resultTy = getTypeConverter()
|
||||||
|
->convertType(op.getResult().getType())
|
||||||
|
.cast<RankedTensorType>();
|
||||||
|
auto resultElemTy = resultTy.getElementType();
|
||||||
|
|
||||||
|
auto newOutputTy = RankedTensorType::get(newOutputShape, resultElemTy);
|
||||||
|
|
||||||
|
auto reshapeOp = rewriter.create<tosa::ReshapeOp>(
|
||||||
|
op->getLoc(), getTypeConverter()->convertType(newOutputTy), self,
|
||||||
|
rewriter.getI64ArrayAttr(newOutputShape));
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||||
|
op, getTypeConverter()->convertType(newOutputTy), reshapeOp);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -624,6 +659,7 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenMulTensorOp);
|
INSERT_ATENOP_PATTERN(AtenMulTensorOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenDivTensorOp);
|
INSERT_ATENOP_PATTERN(AtenDivTensorOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
|
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenSqueezeOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
|
|
Loading…
Reference in New Issue