mirror of https://github.com/llvm/torch-mlir
[Stablehlo] support dynamic shape when convert aten.fill.Scalar (#2349)
parent
991eba2b51
commit
c7c59b540e
|
@ -474,8 +474,16 @@ STABLEHLO_PASS_SET = {
|
||||||
"EmbeddingModuleI32_basic",
|
"EmbeddingModuleI32_basic",
|
||||||
"EmbeddingModuleI64_basic",
|
"EmbeddingModuleI64_basic",
|
||||||
"EmbeddingModuleF16_basic",
|
"EmbeddingModuleF16_basic",
|
||||||
|
"EmptyLikeMemoryFormatModule_basic",
|
||||||
|
"EmptyLikeModule_defaultDtype",
|
||||||
|
"EmptyLikeModule_falsePinMemory",
|
||||||
|
"EmptyLikeModule_float",
|
||||||
|
"EmptyLikeModule_int",
|
||||||
"ExpandAsIntModule_basic",
|
"ExpandAsIntModule_basic",
|
||||||
"ExpandModule_basic",
|
"ExpandModule_basic",
|
||||||
|
"Fill_TensorFloat64WithFloat32_basic",
|
||||||
|
"Fill_TensorFloat64WithFloat64_basic",
|
||||||
|
"Fill_TensorFloat64WithInt64_basic",
|
||||||
"Fill_TensorFloat64WithFloat32Static_basic",
|
"Fill_TensorFloat64WithFloat32Static_basic",
|
||||||
"Fill_TensorFloat64WithInt64Static_basic",
|
"Fill_TensorFloat64WithInt64Static_basic",
|
||||||
"FlipModuleStaticShape_basic",
|
"FlipModuleStaticShape_basic",
|
||||||
|
@ -672,6 +680,9 @@ STABLEHLO_PASS_SET = {
|
||||||
"PermuteModule_basic",
|
"PermuteModule_basic",
|
||||||
"PermuteNegativeIndexModule_basic",
|
"PermuteNegativeIndexModule_basic",
|
||||||
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
||||||
|
"ZeroFloat32Module_basic",
|
||||||
|
"ZeroInt32Module_basic",
|
||||||
|
"ZeroInt64Module_basic",
|
||||||
"ZerosLikeModule_defaultDtype",
|
"ZerosLikeModule_defaultDtype",
|
||||||
"ZerosLikeModule_falsePinMemory",
|
"ZerosLikeModule_falsePinMemory",
|
||||||
"ZerosLikeModule_float",
|
"ZerosLikeModule_float",
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "stablehlo/dialect/ChloOps.h"
|
#include "stablehlo/dialect/ChloOps.h"
|
||||||
#include "stablehlo/dialect/StablehloOps.h"
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
|
@ -1583,8 +1584,11 @@ LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(
|
||||||
auto dtype = outType.getElementType();
|
auto dtype = outType.getElementType();
|
||||||
Value scalarTensor =
|
Value scalarTensor =
|
||||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getValue(), dtype);
|
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getValue(), dtype);
|
||||||
Value bcastScalar = rewriter.create<stablehlo::BroadcastInDimOp>(
|
Value shapeTensor =
|
||||||
op->getLoc(), outType, scalarTensor, rewriter.getI64TensorAttr({}));
|
rewriter.create<shape::ShapeOfOp>(op->getLoc(), adaptor.getSelf());
|
||||||
|
Value bcastScalar = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
|
||||||
|
op->getLoc(), outType, scalarTensor, shapeTensor,
|
||||||
|
rewriter.getI64TensorAttr({}));
|
||||||
rewriter.replaceOp(op, bcastScalar);
|
rewriter.replaceOp(op, bcastScalar);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,6 +44,7 @@ public:
|
||||||
registry.insert<chlo::ChloDialect>();
|
registry.insert<chlo::ChloDialect>();
|
||||||
registry.insert<stablehlo::StablehloDialect>();
|
registry.insert<stablehlo::StablehloDialect>();
|
||||||
registry.insert<tensor::TensorDialect>();
|
registry.insert<tensor::TensorDialect>();
|
||||||
|
registry.insert<shape::ShapeDialect>();
|
||||||
registry.insert<arith::ArithDialect>();
|
registry.insert<arith::ArithDialect>();
|
||||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||||
}
|
}
|
||||||
|
@ -51,7 +52,8 @@ public:
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addLegalDialect<chlo::ChloDialect, stablehlo::StablehloDialect,
|
target.addLegalDialect<chlo::ChloDialect, stablehlo::StablehloDialect,
|
||||||
tensor::TensorDialect, arith::ArithDialect>();
|
tensor::TensorDialect, arith::ArithDialect,
|
||||||
|
shape::ShapeDialect>();
|
||||||
|
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
|
Loading…
Reference in New Issue