mirror of https://github.com/llvm/torch-mlir
[Stablehlo] support dynamic shape when convert aten.fill.Scalar (#2349)
parent
991eba2b51
commit
c7c59b540e
|
@ -474,8 +474,16 @@ STABLEHLO_PASS_SET = {
|
|||
"EmbeddingModuleI32_basic",
|
||||
"EmbeddingModuleI64_basic",
|
||||
"EmbeddingModuleF16_basic",
|
||||
"EmptyLikeMemoryFormatModule_basic",
|
||||
"EmptyLikeModule_defaultDtype",
|
||||
"EmptyLikeModule_falsePinMemory",
|
||||
"EmptyLikeModule_float",
|
||||
"EmptyLikeModule_int",
|
||||
"ExpandAsIntModule_basic",
|
||||
"ExpandModule_basic",
|
||||
"Fill_TensorFloat64WithFloat32_basic",
|
||||
"Fill_TensorFloat64WithFloat64_basic",
|
||||
"Fill_TensorFloat64WithInt64_basic",
|
||||
"Fill_TensorFloat64WithFloat32Static_basic",
|
||||
"Fill_TensorFloat64WithInt64Static_basic",
|
||||
"FlipModuleStaticShape_basic",
|
||||
|
@ -672,6 +680,9 @@ STABLEHLO_PASS_SET = {
|
|||
"PermuteModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
||||
"ZeroFloat32Module_basic",
|
||||
"ZeroInt32Module_basic",
|
||||
"ZeroInt64Module_basic",
|
||||
"ZerosLikeModule_defaultDtype",
|
||||
"ZerosLikeModule_falsePinMemory",
|
||||
"ZerosLikeModule_float",
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "PopulatePatterns.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
|
@ -1583,8 +1584,11 @@ LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(
|
|||
auto dtype = outType.getElementType();
|
||||
Value scalarTensor =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getValue(), dtype);
|
||||
Value bcastScalar = rewriter.create<stablehlo::BroadcastInDimOp>(
|
||||
op->getLoc(), outType, scalarTensor, rewriter.getI64TensorAttr({}));
|
||||
Value shapeTensor =
|
||||
rewriter.create<shape::ShapeOfOp>(op->getLoc(), adaptor.getSelf());
|
||||
Value bcastScalar = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
|
||||
op->getLoc(), outType, scalarTensor, shapeTensor,
|
||||
rewriter.getI64TensorAttr({}));
|
||||
rewriter.replaceOp(op, bcastScalar);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@ public:
|
|||
registry.insert<chlo::ChloDialect>();
|
||||
registry.insert<stablehlo::StablehloDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<shape::ShapeDialect>();
|
||||
registry.insert<arith::ArithDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
}
|
||||
|
@ -51,7 +52,8 @@ public:
|
|||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<chlo::ChloDialect, stablehlo::StablehloDialect,
|
||||
tensor::TensorDialect, arith::ArithDialect>();
|
||||
tensor::TensorDialect, arith::ArithDialect,
|
||||
shape::ShapeDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
|
|
Loading…
Reference in New Issue