[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op (#2340)

[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op and configure crashing e2e sets for stablehlo backend.
pull/2356/head
Jiawei Wu 2023-07-29 21:55:49 +08:00 committed by GitHub
parent 0109bf705b
commit 16923fdbd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 95 additions and 1 deletions

View File

@ -31,6 +31,7 @@ from .xfail_sets import (
LINALG_XFAIL_SET,
MAKE_FX_TOSA_PASS_SET,
STABLEHLO_PASS_SET,
STABLEHLO_CRASHING_SET,
TOSA_PASS_SET,
LTC_XFAIL_SET,
TORCHDYNAMO_XFAIL_SET,
@ -101,7 +102,7 @@ def main():
elif args.config == "stablehlo":
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
crashing_set = set()
crashing_set = STABLEHLO_CRASHING_SET
elif args.config == "native_torch":
config = NativeTorchTestConfig()
xfail_set = set()

View File

@ -310,6 +310,32 @@ TORCHDYNAMO_CRASHING_SET = {
}
STABLEHLO_PASS_SET = {
"AddIntModule_basic",
"AtenIntBoolOpModule_basic",
"AtenIntTensorByteDtypeModule_basic",
"AtenIntTensorCharDtypeModule_basic",
"BoolFloatFalseModule_basic",
"BoolFloatTrueModule_basic",
"BoolIntFalseModule_basic",
"BoolIntTrueModule_basic",
"CeilFloatModule_basic",
"DivFloatModule_basic",
"DivIntModule_basic",
"EqIntModule_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
"GeIntModule_basic",
"GtFloatIntModule_basic",
"GtIntModule_basic",
"MulIntModule_basic",
"NeFloatIntModule_basic",
"NeIntModule_basic",
"SqrtIntModule_basic",
"SubFloatModule_basic",
"SubIntModule_basic",
"TensorToBoolZeroRank_basic",
"TensorToIntZeroRank_basic",
"TensorToFloatZeroRank_basic",
"AliasModule_basic",
"TensorIntModule_basic",
"AllBoolFalseModule_basic",
@ -826,6 +852,18 @@ STABLEHLO_PASS_SET = {
"TupleModule_basic",
}
STABLEHLO_CRASHING_SET = {
# These e2e tests crash because currently mlir-hlo's shape-component-analysis
# only support exact one index in tensor::ExtractOp when it's related with
# some tensors' shape. REF:
# https://github.com/tensorflow/mlir-hlo/blob/master/mhlo/analysis/shape_component_analysis.cc#L586
# FIXME if upstream mlir-hlo fix this.
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"Aten_EmbeddingBagExample_basic"
}
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {

View File

@ -232,6 +232,51 @@ public:
} // namespace
namespace {
// Casts a tensor of exactly one element to an elemental type.
// Many codes borrowed from
// `lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp`
template <typename AtenOpT>
class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto inputType =
adaptor.getA().getType().template dyn_cast<RankedTensorType>();
if (!inputType)
op.emitError("only Tensor types supported in StableHLO");
auto outType =
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType());
Location loc = op.getLoc();
Value input = adaptor.getA();
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
int64_t inputRank = inputSizes.size();
Type inputDtype =
op.getA().getType().template cast<BaseTensorType>().getDtype();
Value constantOne =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(1));
for (int64_t i = 0; i < inputRank; i++)
checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne);
Value constantZero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
SmallVector<Value> indices(inputRank, constantZero);
Value result = rewriter.create<tensor::ExtractOp>(loc, input, indices);
Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result,
resultType, inputDtype));
return success();
}
};
} // namespace
// The binary broadcast patterns
namespace {
template <typename AtenOpT, typename ChloOpT>
@ -1662,6 +1707,16 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
#undef INSERT_CONSTANT_FILL_PATTERN
#define INSERT_TENSOR_TO_SCALAR_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenOp>>(typeConverter, \
context)
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenIntTensorOp);
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenFloatTensorOp);
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenBoolTensorOp);
#undef INSERT_TENSOR_TO_SCALAR_PATTERN
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAddSubOp<AtenOp, ChloOp>>(typeConverter, context)