mirror of https://github.com/llvm/torch-mlir
[stablehlo] support aten.view.dtype lowering (#3778)
parent
94f5410913
commit
d0041dc310
|
@ -161,12 +161,70 @@ public:
|
|||
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
|
||||
unsigned getBitWidth(Type type) const {
|
||||
if (auto complexTy = dyn_cast<ComplexType>(type))
|
||||
return 2 * getBitWidth(complexTy.getElementType());
|
||||
return type.getIntOrFloatBitWidth();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto rankType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||
if (!rankType)
|
||||
return op.emitError("Only ranked tensor types are currently supported");
|
||||
return op.emitError("Only ranked tensor types are currently supported.");
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// support AtenViewDtypeOp
|
||||
if (isa<AtenViewDtypeOp>(op)) {
|
||||
auto self = adaptor.getSelf();
|
||||
auto baseResultTy = dyn_cast<BaseTensorType>(op.getType());
|
||||
|
||||
// infer the result shape
|
||||
auto operandElt = rankType.getElementType();
|
||||
auto targetElt = baseResultTy.getDtype();
|
||||
auto operandEltBitWidth = getBitWidth(operandElt);
|
||||
auto targetEltBitWidth = getBitWidth(targetElt);
|
||||
auto operandSizes = rankType.getShape();
|
||||
SmallVector<int64_t> castShape(operandSizes);
|
||||
if (operandEltBitWidth > targetEltBitWidth) {
|
||||
int64_t last_size = operandEltBitWidth / targetEltBitWidth;
|
||||
castShape.push_back(last_size);
|
||||
} else if (operandEltBitWidth < targetEltBitWidth) {
|
||||
int64_t last_size = targetEltBitWidth / operandEltBitWidth;
|
||||
if (!ShapedType::isDynamic(castShape.back()) and
|
||||
last_size != castShape.back()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The last dim size is not equal to targetEltBitWidth / "
|
||||
"operandEltBitWidth.");
|
||||
} else {
|
||||
castShape.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
auto resultType =
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
baseResultTy);
|
||||
if (!dyn_cast<ShapedType>(resultType).hasStaticShape()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Currently only support static output shape.");
|
||||
}
|
||||
|
||||
auto castType =
|
||||
baseResultTy.getWithSizesAndDtype(castShape, baseResultTy.getDtype());
|
||||
auto cast = rewriter.create<stablehlo::BitcastConvertOp>(
|
||||
loc,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
castType),
|
||||
self);
|
||||
|
||||
auto reshape =
|
||||
rewriter.create<stablehlo::ReshapeOp>(loc, resultType, cast);
|
||||
|
||||
rewriter.replaceOp(op, reshape);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// collect Value of dims
|
||||
SmallVector<Value, 4> dimSizes;
|
||||
|
@ -174,7 +232,6 @@ public:
|
|||
return op.emitError("Dims size must be a list of Scalar");
|
||||
}
|
||||
|
||||
auto loc = op.getLoc();
|
||||
if (dimSizes.size() == 0 || rankType.getRank() == 0) {
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||
op,
|
||||
|
@ -236,6 +293,13 @@ public:
|
|||
SmallVector<Value, 4> &dimSizes) const;
|
||||
};
|
||||
|
||||
template <>
|
||||
bool ConvertAtenViewOp<AtenViewDtypeOp>::getAtenViewOpSizes(
|
||||
AtenViewDtypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
|
||||
SmallVector<Value, 4> &dimSizes) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
|
||||
AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
|
||||
|
@ -496,6 +560,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
|||
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context, options)
|
||||
INSERT_VIEW_OP_PATTERN(AtenViewDtypeOp);
|
||||
INSERT_VIEW_OP_PATTERN(AtenViewOp);
|
||||
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
|
||||
#undef INSERT_VIEW_OP_PATTERN
|
||||
|
|
|
@ -506,6 +506,7 @@ FX_IMPORTER_XFAIL_SET = {
|
|||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
"ViewDtypeStaticModule_basic",
|
||||
"WeightNormInterfaceModule_basic",
|
||||
# Error: `aten.as_strided` op is not supported
|
||||
"ChunkListUnpackDynamic_Module_basic",
|
||||
|
@ -3169,6 +3170,7 @@ ONNX_XFAIL_SET = {
|
|||
"Unfold_Module_Rank_Zero_basic",
|
||||
"Unfold_Module_Rank_Zero_Size_Zero_basic",
|
||||
"Unfold_Module_Dynamic_basic",
|
||||
"ViewDtypeStaticModule_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
|
||||
|
|
|
@ -1174,6 +1174,30 @@ def ReshapeDynamicModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewDtypeStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([12, 1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
res = a.view(torch.int8)
|
||||
return res
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewDtypeStaticModule())
|
||||
def ViewDtypeStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(12, 1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReshapeAliasCollapseModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue