mirror of https://github.com/llvm/torch-mlir
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op (#2340)
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op and configure crashing e2e sets for stablehlo backend.pull/2356/head
parent
0109bf705b
commit
16923fdbd2
|
@ -31,6 +31,7 @@ from .xfail_sets import (
|
||||||
LINALG_XFAIL_SET,
|
LINALG_XFAIL_SET,
|
||||||
MAKE_FX_TOSA_PASS_SET,
|
MAKE_FX_TOSA_PASS_SET,
|
||||||
STABLEHLO_PASS_SET,
|
STABLEHLO_PASS_SET,
|
||||||
|
STABLEHLO_CRASHING_SET,
|
||||||
TOSA_PASS_SET,
|
TOSA_PASS_SET,
|
||||||
LTC_XFAIL_SET,
|
LTC_XFAIL_SET,
|
||||||
TORCHDYNAMO_XFAIL_SET,
|
TORCHDYNAMO_XFAIL_SET,
|
||||||
|
@ -101,7 +102,7 @@ def main():
|
||||||
elif args.config == "stablehlo":
|
elif args.config == "stablehlo":
|
||||||
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
|
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
|
||||||
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
|
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
|
||||||
crashing_set = set()
|
crashing_set = STABLEHLO_CRASHING_SET
|
||||||
elif args.config == "native_torch":
|
elif args.config == "native_torch":
|
||||||
config = NativeTorchTestConfig()
|
config = NativeTorchTestConfig()
|
||||||
xfail_set = set()
|
xfail_set = set()
|
||||||
|
|
|
@ -310,6 +310,32 @@ TORCHDYNAMO_CRASHING_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
STABLEHLO_PASS_SET = {
|
STABLEHLO_PASS_SET = {
|
||||||
|
"AddIntModule_basic",
|
||||||
|
"AtenIntBoolOpModule_basic",
|
||||||
|
"AtenIntTensorByteDtypeModule_basic",
|
||||||
|
"AtenIntTensorCharDtypeModule_basic",
|
||||||
|
"BoolFloatFalseModule_basic",
|
||||||
|
"BoolFloatTrueModule_basic",
|
||||||
|
"BoolIntFalseModule_basic",
|
||||||
|
"BoolIntTrueModule_basic",
|
||||||
|
"CeilFloatModule_basic",
|
||||||
|
"DivFloatModule_basic",
|
||||||
|
"DivIntModule_basic",
|
||||||
|
"EqIntModule_basic",
|
||||||
|
"GeFloatIntModule_basic",
|
||||||
|
"GeFloatModule_basic",
|
||||||
|
"GeIntModule_basic",
|
||||||
|
"GtFloatIntModule_basic",
|
||||||
|
"GtIntModule_basic",
|
||||||
|
"MulIntModule_basic",
|
||||||
|
"NeFloatIntModule_basic",
|
||||||
|
"NeIntModule_basic",
|
||||||
|
"SqrtIntModule_basic",
|
||||||
|
"SubFloatModule_basic",
|
||||||
|
"SubIntModule_basic",
|
||||||
|
"TensorToBoolZeroRank_basic",
|
||||||
|
"TensorToIntZeroRank_basic",
|
||||||
|
"TensorToFloatZeroRank_basic",
|
||||||
"AliasModule_basic",
|
"AliasModule_basic",
|
||||||
"TensorIntModule_basic",
|
"TensorIntModule_basic",
|
||||||
"AllBoolFalseModule_basic",
|
"AllBoolFalseModule_basic",
|
||||||
|
@ -826,6 +852,18 @@ STABLEHLO_PASS_SET = {
|
||||||
"TupleModule_basic",
|
"TupleModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
STABLEHLO_CRASHING_SET = {
|
||||||
|
# These e2e tests crash because currently mlir-hlo's shape-component-analysis
|
||||||
|
# only support exact one index in tensor::ExtractOp when it's related with
|
||||||
|
# some tensors' shape. REF:
|
||||||
|
# https://github.com/tensorflow/mlir-hlo/blob/master/mhlo/analysis/shape_component_analysis.cc#L586
|
||||||
|
# FIXME if upstream mlir-hlo fix this.
|
||||||
|
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
|
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
|
|
||||||
|
"Aten_EmbeddingBagExample_basic"
|
||||||
|
}
|
||||||
|
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
TOSA_PASS_SET = {
|
TOSA_PASS_SET = {
|
||||||
|
|
|
@ -232,6 +232,51 @@ public:
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Casts a tensor of exactly one element to an elemental type.
|
||||||
|
// Many codes borrowed from
|
||||||
|
// `lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp`
|
||||||
|
template <typename AtenOpT>
|
||||||
|
class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern<AtenOpT> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto inputType =
|
||||||
|
adaptor.getA().getType().template dyn_cast<RankedTensorType>();
|
||||||
|
if (!inputType)
|
||||||
|
|
||||||
|
op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
auto outType =
|
||||||
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
|
op.getType());
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value input = adaptor.getA();
|
||||||
|
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
|
||||||
|
int64_t inputRank = inputSizes.size();
|
||||||
|
Type inputDtype =
|
||||||
|
op.getA().getType().template cast<BaseTensorType>().getDtype();
|
||||||
|
|
||||||
|
Value constantOne =
|
||||||
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
for (int64_t i = 0; i < inputRank; i++)
|
||||||
|
checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne);
|
||||||
|
|
||||||
|
Value constantZero =
|
||||||
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
||||||
|
SmallVector<Value> indices(inputRank, constantZero);
|
||||||
|
Value result = rewriter.create<tensor::ExtractOp>(loc, input, indices);
|
||||||
|
Type resultType =
|
||||||
|
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||||
|
rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result,
|
||||||
|
resultType, inputDtype));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// The binary broadcast patterns
|
// The binary broadcast patterns
|
||||||
namespace {
|
namespace {
|
||||||
template <typename AtenOpT, typename ChloOpT>
|
template <typename AtenOpT, typename ChloOpT>
|
||||||
|
@ -1662,6 +1707,16 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||||
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
||||||
#undef INSERT_CONSTANT_FILL_PATTERN
|
#undef INSERT_CONSTANT_FILL_PATTERN
|
||||||
|
|
||||||
|
#define INSERT_TENSOR_TO_SCALAR_PATTERN(AtenOp) \
|
||||||
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenOp>>(typeConverter, \
|
||||||
|
context)
|
||||||
|
|
||||||
|
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenIntTensorOp);
|
||||||
|
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenFloatTensorOp);
|
||||||
|
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenBoolTensorOp);
|
||||||
|
#undef INSERT_TENSOR_TO_SCALAR_PATTERN
|
||||||
|
|
||||||
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \
|
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenAddSubOp<AtenOp, ChloOp>>(typeConverter, context)
|
patterns.add<ConvertAtenAddSubOp<AtenOp, ChloOp>>(typeConverter, context)
|
||||||
|
|
Loading…
Reference in New Issue