mirror of https://github.com/llvm/torch-mlir
parent
396ab35c9d
commit
0cd95b5c68
|
@ -38,4 +38,7 @@ TOSA_PASS_SET = {
|
|||
"BoolTensorReturnTrueModule_basic",
|
||||
"BoolTensorReturnMixedModule_basic",
|
||||
"ElementwiseRsqrtModule_basic",
|
||||
"SqueezeModule_static",
|
||||
"SqueezeModule_noUnitDim",
|
||||
"SqueezeModule_allUnitDim",
|
||||
}
|
||||
|
|
|
@ -526,6 +526,41 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
|
||||
AtenSqueezeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
Value self = adaptor.self();
|
||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||
|
||||
if (!selfTy)
|
||||
return op.emitError("Only ranked tensor types supported in TOSA argmax");
|
||||
|
||||
auto selfShape = selfTy.getShape();
|
||||
|
||||
SmallVector<int64_t> newOutputShape;
|
||||
for (auto &dim : selfShape) {
|
||||
if (dim != 1)
|
||||
newOutputShape.push_back(dim);
|
||||
}
|
||||
|
||||
auto resultTy = getTypeConverter()
|
||||
->convertType(op.getResult().getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto resultElemTy = resultTy.getElementType();
|
||||
|
||||
auto newOutputTy = RankedTensorType::get(newOutputShape, resultElemTy);
|
||||
|
||||
auto reshapeOp = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(), getTypeConverter()->convertType(newOutputTy), self,
|
||||
rewriter.getI64ArrayAttr(newOutputShape));
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
op, getTypeConverter()->convertType(newOutputTy), reshapeOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -624,6 +659,7 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenMulTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDivTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSqueezeOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
|
|
Loading…
Reference in New Issue