mirror of https://github.com/llvm/torch-mlir
[stablehlo] support aten.view.dtype lowering (#3778)
parent
94f5410913
commit
d0041dc310
|
@ -161,12 +161,70 @@ public:
|
||||||
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
||||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
|
||||||
|
unsigned getBitWidth(Type type) const {
|
||||||
|
if (auto complexTy = dyn_cast<ComplexType>(type))
|
||||||
|
return 2 * getBitWidth(complexTy.getElementType());
|
||||||
|
return type.getIntOrFloatBitWidth();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto rankType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
auto rankType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||||
if (!rankType)
|
if (!rankType)
|
||||||
return op.emitError("Only ranked tensor types are currently supported");
|
return op.emitError("Only ranked tensor types are currently supported.");
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
// support AtenViewDtypeOp
|
||||||
|
if (isa<AtenViewDtypeOp>(op)) {
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
auto baseResultTy = dyn_cast<BaseTensorType>(op.getType());
|
||||||
|
|
||||||
|
// infer the result shape
|
||||||
|
auto operandElt = rankType.getElementType();
|
||||||
|
auto targetElt = baseResultTy.getDtype();
|
||||||
|
auto operandEltBitWidth = getBitWidth(operandElt);
|
||||||
|
auto targetEltBitWidth = getBitWidth(targetElt);
|
||||||
|
auto operandSizes = rankType.getShape();
|
||||||
|
SmallVector<int64_t> castShape(operandSizes);
|
||||||
|
if (operandEltBitWidth > targetEltBitWidth) {
|
||||||
|
int64_t last_size = operandEltBitWidth / targetEltBitWidth;
|
||||||
|
castShape.push_back(last_size);
|
||||||
|
} else if (operandEltBitWidth < targetEltBitWidth) {
|
||||||
|
int64_t last_size = targetEltBitWidth / operandEltBitWidth;
|
||||||
|
if (!ShapedType::isDynamic(castShape.back()) and
|
||||||
|
last_size != castShape.back()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "The last dim size is not equal to targetEltBitWidth / "
|
||||||
|
"operandEltBitWidth.");
|
||||||
|
} else {
|
||||||
|
castShape.pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resultType =
|
||||||
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
|
baseResultTy);
|
||||||
|
if (!dyn_cast<ShapedType>(resultType).hasStaticShape()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Currently only support static output shape.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto castType =
|
||||||
|
baseResultTy.getWithSizesAndDtype(castShape, baseResultTy.getDtype());
|
||||||
|
auto cast = rewriter.create<stablehlo::BitcastConvertOp>(
|
||||||
|
loc,
|
||||||
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
|
castType),
|
||||||
|
self);
|
||||||
|
|
||||||
|
auto reshape =
|
||||||
|
rewriter.create<stablehlo::ReshapeOp>(loc, resultType, cast);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, reshape);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// collect Value of dims
|
// collect Value of dims
|
||||||
SmallVector<Value, 4> dimSizes;
|
SmallVector<Value, 4> dimSizes;
|
||||||
|
@ -174,7 +232,6 @@ public:
|
||||||
return op.emitError("Dims size must be a list of Scalar");
|
return op.emitError("Dims size must be a list of Scalar");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
if (dimSizes.size() == 0 || rankType.getRank() == 0) {
|
if (dimSizes.size() == 0 || rankType.getRank() == 0) {
|
||||||
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||||
op,
|
op,
|
||||||
|
@ -236,6 +293,13 @@ public:
|
||||||
SmallVector<Value, 4> &dimSizes) const;
|
SmallVector<Value, 4> &dimSizes) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
bool ConvertAtenViewOp<AtenViewDtypeOp>::getAtenViewOpSizes(
|
||||||
|
AtenViewDtypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
|
||||||
|
SmallVector<Value, 4> &dimSizes) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
|
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
|
||||||
AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
|
AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
|
||||||
|
@ -496,6 +560,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
||||||
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context, options)
|
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context, options)
|
||||||
|
INSERT_VIEW_OP_PATTERN(AtenViewDtypeOp);
|
||||||
INSERT_VIEW_OP_PATTERN(AtenViewOp);
|
INSERT_VIEW_OP_PATTERN(AtenViewOp);
|
||||||
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
|
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
|
||||||
#undef INSERT_VIEW_OP_PATTERN
|
#undef INSERT_VIEW_OP_PATTERN
|
||||||
|
|
|
@ -506,6 +506,7 @@ FX_IMPORTER_XFAIL_SET = {
|
||||||
"UpSampleNearest2dDynamicFactor_basic",
|
"UpSampleNearest2dDynamicFactor_basic",
|
||||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
"ViewSizeFromOtherTensor_basic",
|
"ViewSizeFromOtherTensor_basic",
|
||||||
|
"ViewDtypeStaticModule_basic",
|
||||||
"WeightNormInterfaceModule_basic",
|
"WeightNormInterfaceModule_basic",
|
||||||
# Error: `aten.as_strided` op is not supported
|
# Error: `aten.as_strided` op is not supported
|
||||||
"ChunkListUnpackDynamic_Module_basic",
|
"ChunkListUnpackDynamic_Module_basic",
|
||||||
|
@ -3169,6 +3170,7 @@ ONNX_XFAIL_SET = {
|
||||||
"Unfold_Module_Rank_Zero_basic",
|
"Unfold_Module_Rank_Zero_basic",
|
||||||
"Unfold_Module_Rank_Zero_Size_Zero_basic",
|
"Unfold_Module_Rank_Zero_Size_Zero_basic",
|
||||||
"Unfold_Module_Dynamic_basic",
|
"Unfold_Module_Dynamic_basic",
|
||||||
|
"ViewDtypeStaticModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
|
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
|
||||||
|
|
|
@ -1174,6 +1174,30 @@ def ReshapeDynamicModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ViewDtypeStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([12, 1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a):
|
||||||
|
res = a.view(torch.int8)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ViewDtypeStaticModule())
|
||||||
|
def ViewDtypeStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(12, 1))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ReshapeAliasCollapseModule(torch.nn.Module):
|
class ReshapeAliasCollapseModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue