mirror of https://github.com/llvm/torch-mlir
[Stablehlo] support aten.all.dim (#3746)
parent
eb4e59e189
commit
5f74de5ba0
|
@ -110,7 +110,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<AtenAllOp>(op)) {
|
if (isa<AtenAllOp, AtenAllDimOp>(op)) {
|
||||||
auto constAttr =
|
auto constAttr =
|
||||||
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)});
|
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)});
|
||||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
|
@ -166,7 +166,7 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
|
||||||
AtenLinalgVectorNormOp>(op)) {
|
AtenLinalgVectorNormOp>(op)) {
|
||||||
result = rewriter.create<stablehlo::AddOp>(
|
result = rewriter.create<stablehlo::AddOp>(
|
||||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||||
} else if (isa<AtenAllOp>(op)) {
|
} else if (isa<AtenAllOp, AtenAllDimOp>(op)) {
|
||||||
result = rewriter.create<stablehlo::AndOp>(
|
result = rewriter.create<stablehlo::AndOp>(
|
||||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||||
} else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
|
} else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
|
||||||
|
@ -887,6 +887,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
||||||
patterns.add<ConvertAtenReduceOneDimOp<AtenOp>>(typeConverter, context, \
|
patterns.add<ConvertAtenReduceOneDimOp<AtenOp>>(typeConverter, context, \
|
||||||
options)
|
options)
|
||||||
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAnyDimOp);
|
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAnyDimOp);
|
||||||
|
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAllDimOp);
|
||||||
#undef INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN
|
#undef INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN
|
||||||
|
|
||||||
#define INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenOp) \
|
#define INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenOp) \
|
||||||
|
|
|
@ -815,10 +815,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"RandnLikeDtypeModule_basic",
|
"RandnLikeDtypeModule_basic",
|
||||||
"RandnLikeModule_basic",
|
"RandnLikeModule_basic",
|
||||||
"RandnModule_basic",
|
"RandnModule_basic",
|
||||||
"ReduceAllDimBool_basic",
|
|
||||||
"ReduceAllDimEmpty_basic",
|
|
||||||
"ReduceAllDimFloat_basic",
|
|
||||||
"ReduceAllDimInt_basic",
|
|
||||||
"ReduceProdDimIntFloatModule_basic",
|
"ReduceProdDimIntFloatModule_basic",
|
||||||
"ReflectionPad1dModule2dInput_Right",
|
"ReflectionPad1dModule2dInput_Right",
|
||||||
"ReflectionPad1dModule2dInput_basic",
|
"ReflectionPad1dModule2dInput_basic",
|
||||||
|
@ -836,18 +832,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"ReplicationPad2dModule_top0",
|
"ReplicationPad2dModule_top0",
|
||||||
"RsubInt0d_NumToTensor_Module_basic",
|
"RsubInt0d_NumToTensor_Module_basic",
|
||||||
"ScalarImplicitFloatModule_basic",
|
"ScalarImplicitFloatModule_basic",
|
||||||
# need aten.all.dim lowering to stablehlo
|
|
||||||
"SafeSoftmaxModule_basic",
|
|
||||||
"SafeSoftmaxNonNoneDtypeModule_basic",
|
|
||||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
|
||||||
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
|
||||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
|
||||||
"ScaledDotProductAttentionDifferentModule_basic",
|
|
||||||
"ScaledDotProductAttentionMaskModule_basic",
|
|
||||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
|
||||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
|
||||||
"ScaledDotProductAttentionSameModule_basic",
|
|
||||||
"ScatterReduceFloatMaxModule",
|
"ScatterReduceFloatMaxModule",
|
||||||
"ScatterReduceFloatMaxModuleIncludeSelf",
|
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||||
"ScatterReduceFloatMeanModule",
|
"ScatterReduceFloatMeanModule",
|
||||||
|
|
|
@ -170,6 +170,26 @@ def ReduceAllFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4, 5))
|
module.forward(tu.rand(3, 4, 5))
|
||||||
|
|
||||||
|
|
||||||
|
class ReduceAllDimFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.all(a, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReduceAllDimFloatModule())
|
||||||
|
def ReduceAllDimFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue