mirror of https://github.com/llvm/torch-mlir
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op (#2340)
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op and configure crashing e2e sets for stablehlo backend.pull/2356/head
parent
0109bf705b
commit
16923fdbd2
|
@ -31,6 +31,7 @@ from .xfail_sets import (
|
|||
LINALG_XFAIL_SET,
|
||||
MAKE_FX_TOSA_PASS_SET,
|
||||
STABLEHLO_PASS_SET,
|
||||
STABLEHLO_CRASHING_SET,
|
||||
TOSA_PASS_SET,
|
||||
LTC_XFAIL_SET,
|
||||
TORCHDYNAMO_XFAIL_SET,
|
||||
|
@ -101,7 +102,7 @@ def main():
|
|||
elif args.config == "stablehlo":
|
||||
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
|
||||
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
|
||||
crashing_set = set()
|
||||
crashing_set = STABLEHLO_CRASHING_SET
|
||||
elif args.config == "native_torch":
|
||||
config = NativeTorchTestConfig()
|
||||
xfail_set = set()
|
||||
|
|
|
@ -310,6 +310,32 @@ TORCHDYNAMO_CRASHING_SET = {
|
|||
}
|
||||
|
||||
STABLEHLO_PASS_SET = {
|
||||
"AddIntModule_basic",
|
||||
"AtenIntBoolOpModule_basic",
|
||||
"AtenIntTensorByteDtypeModule_basic",
|
||||
"AtenIntTensorCharDtypeModule_basic",
|
||||
"BoolFloatFalseModule_basic",
|
||||
"BoolFloatTrueModule_basic",
|
||||
"BoolIntFalseModule_basic",
|
||||
"BoolIntTrueModule_basic",
|
||||
"CeilFloatModule_basic",
|
||||
"DivFloatModule_basic",
|
||||
"DivIntModule_basic",
|
||||
"EqIntModule_basic",
|
||||
"GeFloatIntModule_basic",
|
||||
"GeFloatModule_basic",
|
||||
"GeIntModule_basic",
|
||||
"GtFloatIntModule_basic",
|
||||
"GtIntModule_basic",
|
||||
"MulIntModule_basic",
|
||||
"NeFloatIntModule_basic",
|
||||
"NeIntModule_basic",
|
||||
"SqrtIntModule_basic",
|
||||
"SubFloatModule_basic",
|
||||
"SubIntModule_basic",
|
||||
"TensorToBoolZeroRank_basic",
|
||||
"TensorToIntZeroRank_basic",
|
||||
"TensorToFloatZeroRank_basic",
|
||||
"AliasModule_basic",
|
||||
"TensorIntModule_basic",
|
||||
"AllBoolFalseModule_basic",
|
||||
|
@ -826,6 +852,18 @@ STABLEHLO_PASS_SET = {
|
|||
"TupleModule_basic",
|
||||
}
|
||||
|
||||
STABLEHLO_CRASHING_SET = {
|
||||
# These e2e tests crash because currently mlir-hlo's shape-component-analysis
|
||||
# only support exact one index in tensor::ExtractOp when it's related with
|
||||
# some tensors' shape. REF:
|
||||
# https://github.com/tensorflow/mlir-hlo/blob/master/mhlo/analysis/shape_component_analysis.cc#L586
|
||||
# FIXME if upstream mlir-hlo fix this.
|
||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
|
||||
"Aten_EmbeddingBagExample_basic"
|
||||
}
|
||||
|
||||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
|
|
|
@ -232,6 +232,51 @@ public:
|
|||
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Casts a tensor of exactly one element to an elemental type.
|
||||
// Many codes borrowed from
|
||||
// `lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp`
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto inputType =
|
||||
adaptor.getA().getType().template dyn_cast<RankedTensorType>();
|
||||
if (!inputType)
|
||||
|
||||
op.emitError("only Tensor types supported in StableHLO");
|
||||
auto outType =
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType());
|
||||
Location loc = op.getLoc();
|
||||
Value input = adaptor.getA();
|
||||
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
|
||||
int64_t inputRank = inputSizes.size();
|
||||
Type inputDtype =
|
||||
op.getA().getType().template cast<BaseTensorType>().getDtype();
|
||||
|
||||
Value constantOne =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
for (int64_t i = 0; i < inputRank; i++)
|
||||
checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne);
|
||||
|
||||
Value constantZero =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
||||
SmallVector<Value> indices(inputRank, constantZero);
|
||||
Value result = rewriter.create<tensor::ExtractOp>(loc, input, indices);
|
||||
Type resultType =
|
||||
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||
rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result,
|
||||
resultType, inputDtype));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// The binary broadcast patterns
|
||||
namespace {
|
||||
template <typename AtenOpT, typename ChloOpT>
|
||||
|
@ -1662,6 +1707,16 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
||||
#undef INSERT_CONSTANT_FILL_PATTERN
|
||||
|
||||
#define INSERT_TENSOR_TO_SCALAR_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenOp>>(typeConverter, \
|
||||
context)
|
||||
|
||||
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenIntTensorOp);
|
||||
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenFloatTensorOp);
|
||||
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenBoolTensorOp);
|
||||
#undef INSERT_TENSOR_TO_SCALAR_PATTERN
|
||||
|
||||
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenAddSubOp<AtenOp, ChloOp>>(typeConverter, context)
|
||||
|
|
Loading…
Reference in New Issue