diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 12907c9a6..0b1a8b257 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8370,6 +8370,31 @@ def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [ }]; } +def Torch_Aten_SafeSoftmaxOp : Torch_Op<"aten._safe_softmax", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_safe_softmax : (Tensor, int, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_SafeSoftmaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void Aten_SafeSoftmaxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenMeanOp : Torch_Op<"aten.mean", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 27a2f1e2c..59cf69393 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6772,6 +6772,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten._safe_softmax\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.softmax.int\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -15367,6 +15371,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._safe_softmax\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" torch.prim.If.yield %2#1 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._log_softmax\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %none = torch.constant.none\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index b60eda351..ed0ef9e5b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2148,6 +2148,62 @@ public: }; } // namespace +// Ref: +// https://github.com/pytorch/pytorch/blob/5314ae2660a778b87987030182f787bb6cb092c0/aten/src/ATen/native/transformers/attention.cpp#L663-L673 +namespace { +class DecomposeAten_SafeSoftmaxOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten_SafeSoftmaxOp op, + PatternRewriter &rewriter) const override { + BaseTensorType resultTensorType = cast(op.getType()); + if (!resultTensorType.hasDtype() || !resultTensorType.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "expected result type to have sizes and dtype"); + } + SmallVector sizes(resultTensorType.getSizes()); + + int64_t dimInt; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) + return rewriter.notifyMatchFailure(op, "Unsupported: non-constant dim"); + + dimInt = toPositiveDim(dimInt, sizes.size()); + if (!isValidDim(dimInt, sizes.size())) + return rewriter.notifyMatchFailure(op, "dim int is not valid"); + + Location loc = op.getLoc(); + Value softmax = rewriter.create( + loc, op.getType(), op.getSelf(), op.getDim(), op.getDtype()); + + Type resultTensorDtype = resultTensorType.getDtype(); + + Value negInfinity = getConstantWithGivenDtypeAndValue( + rewriter, loc, -std::numeric_limits::infinity(), + resultTensorDtype); + + auto boolDtype = rewriter.getI1Type(); + auto boolTensorType = + resultTensorType.getWithSizesAndDtype(sizes, boolDtype); + Value masked = rewriter.create(loc, boolTensorType, + op.getSelf(), negInfinity); + + sizes[dimInt] = 1; + auto maskedRowsType = + resultTensorType.getWithSizesAndDtype(sizes, boolDtype); + Value cstTrue = + rewriter.create(loc, rewriter.getBoolAttr(true)); + Value maskedRows = rewriter.create( + loc, maskedRowsType, masked, op.getDim(), cstTrue); + Value cstZero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0.0, + resultTensorDtype); + rewriter.replaceOpWithNewOp( + op, resultTensorType, maskedRows, cstZero, softmax); + return success(); + } +}; +} // namespace + // Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) => // newGrad = gradOutput * output // result = newGrad - output * sum(newGrad, dim)) @@ -9608,6 +9664,7 @@ public: patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index aa81a68ca..ebc43faa5 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -371,6 +371,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, llvm::StringSet<> backendLegalOpsSet) { target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0430ba9d5..918cbae63 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -504,14 +504,6 @@ FX_IMPORTER_XFAIL_SET = { "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", "WeightNormInterfaceModule_basic", - # REMOVE WHEN ENABLE_GQA IS ADDED - "ScaledDotProductAttentionBoolMaskModule_basic", - "ScaledDotProductAttentionDifferentCausalModule_basic", - "ScaledDotProductAttentionDifferentModule_basic", - "ScaledDotProductAttentionMaskModule_basic", - "ScaledDotProductAttentionSameCausalModule_basic", - "ScaledDotProductAttentionSameDynamicModule_basic", - "ScaledDotProductAttentionSameModule_basic", } FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { @@ -826,6 +818,9 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "ReplicationPad2dModule_top0", "RsubInt0d_NumToTensor_Module_basic", "ScalarImplicitFloatModule_basic", + # need aten.all.dim lowering to stablehlo + "SafeSoftmaxModule_basic", + "SafeSoftmaxNonNoneDtypeModule_basic", # REMOVE WHEN ENABLE_GQA IS ADDED "ScaledDotProductAttentionBoolMaskModule_basic", "ScaledDotProductAttentionDifferentCausalModule_basic", @@ -2770,6 +2765,8 @@ ONNX_XFAIL_SET = { "ReshapeAliasExpandModule_basic", "ReshapeExpandModule_basic", "Rot90DynamicDimsModule_basic", + "SafeSoftmaxModule_basic", + "SafeSoftmaxNonNoneDtypeModule_basic", "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", "ScalarImplicitIntModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 1870c5829..bc49757ee 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -348,6 +348,9 @@ def aten〇glu〡shape(self: List[int], dim: int = -1) -> List[int]: def aten〇_softmax〡shape(self: List[int], dim: int, half_to_float: bool) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇_safe_softmax〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇softmax〇int〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.unary(self) @@ -5426,6 +5429,12 @@ def aten〇_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, half_to_ return torch.float32 return self_dtype +def aten〇_safe_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: + if dtype is not None: + return dtype + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function( # _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, half_to_float=False) + _check_tensors_with_the_same_dtype( diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 2421fda24..5f53e17b9 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -692,6 +692,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::__lshift__.Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::__rshift__.Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)") + emit("aten::_safe_softmax : (Tensor, int, int?) -> (Tensor)") emit("aten::mean : (Tensor, int?) -> (Tensor)") emit("aten::std : (Tensor, bool) -> (Tensor)") emit("aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 03e16ab2c..ce9a254f6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1907,6 +1907,52 @@ def _LogSoftmaxModuleStable_basic(module, tu: TestUtils): # ============================================================================== +class SafeSoftmaxModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, tensor): + return torch.ops.aten._safe_softmax(tensor, dim=0) + + +@register_test_case(module_factory=lambda: SafeSoftmaxModule()) +def SafeSoftmaxModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 2, 4)) + + +# ============================================================================== + + +class SafeSoftmaxNonNoneDtypeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, tensor): + return torch.ops.aten._safe_softmax(tensor, dim=2, dtype=torch.float64) + + +@register_test_case(module_factory=lambda: SafeSoftmaxNonNoneDtypeModule()) +def SafeSoftmaxNonNoneDtypeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 2, 4)) + + +# ============================================================================== + + class SoftplusModule(torch.nn.Module): def __init__(self): super().__init__()