mirror of https://github.com/llvm/torch-mlir
Add Decompostion for `Aten_SafeSoftmaxOp` (#3708)
Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3710/head
parent
edf725ef42
commit
d61986cfcf
|
@ -8370,6 +8370,31 @@ def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_Aten_SafeSoftmaxOp : Torch_Op<"aten._safe_softmax", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::_safe_softmax : (Tensor, int, int?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
Torch_IntType:$dim,
|
||||||
|
AnyTorchOptionalIntType:$dtype
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult Aten_SafeSoftmaxOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||||
|
}
|
||||||
|
void Aten_SafeSoftmaxOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 3, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenMeanOp : Torch_Op<"aten.mean", [
|
def Torch_AtenMeanOp : Torch_Op<"aten.mean", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -6772,6 +6772,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten._safe_softmax\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||||
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.softmax.int\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.softmax.int\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -15367,6 +15371,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %1 : !torch.int\n"
|
" return %1 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten._safe_softmax\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||||
|
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
|
||||||
|
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
|
||||||
|
" torch.prim.If.yield %2 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" torch.prim.If.yield %2#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" return %1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n"
|
||||||
" %int6 = torch.constant.int 6\n"
|
" %int6 = torch.constant.int 6\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
|
|
|
@ -2148,6 +2148,62 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Ref:
|
||||||
|
// https://github.com/pytorch/pytorch/blob/5314ae2660a778b87987030182f787bb6cb092c0/aten/src/ATen/native/transformers/attention.cpp#L663-L673
|
||||||
|
namespace {
|
||||||
|
class DecomposeAten_SafeSoftmaxOp
|
||||||
|
: public OpRewritePattern<Aten_SafeSoftmaxOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(Aten_SafeSoftmaxOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
BaseTensorType resultTensorType = cast<BaseTensorType>(op.getType());
|
||||||
|
if (!resultTensorType.hasDtype() || !resultTensorType.hasSizes()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "expected result type to have sizes and dtype");
|
||||||
|
}
|
||||||
|
SmallVector<int64_t> sizes(resultTensorType.getSizes());
|
||||||
|
|
||||||
|
int64_t dimInt;
|
||||||
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "Unsupported: non-constant dim");
|
||||||
|
|
||||||
|
dimInt = toPositiveDim(dimInt, sizes.size());
|
||||||
|
if (!isValidDim(dimInt, sizes.size()))
|
||||||
|
return rewriter.notifyMatchFailure(op, "dim int is not valid");
|
||||||
|
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value softmax = rewriter.create<AtenSoftmaxIntOp>(
|
||||||
|
loc, op.getType(), op.getSelf(), op.getDim(), op.getDtype());
|
||||||
|
|
||||||
|
Type resultTensorDtype = resultTensorType.getDtype();
|
||||||
|
|
||||||
|
Value negInfinity = getConstantWithGivenDtypeAndValue(
|
||||||
|
rewriter, loc, -std::numeric_limits<double>::infinity(),
|
||||||
|
resultTensorDtype);
|
||||||
|
|
||||||
|
auto boolDtype = rewriter.getI1Type();
|
||||||
|
auto boolTensorType =
|
||||||
|
resultTensorType.getWithSizesAndDtype(sizes, boolDtype);
|
||||||
|
Value masked = rewriter.create<AtenEqScalarOp>(loc, boolTensorType,
|
||||||
|
op.getSelf(), negInfinity);
|
||||||
|
|
||||||
|
sizes[dimInt] = 1;
|
||||||
|
auto maskedRowsType =
|
||||||
|
resultTensorType.getWithSizesAndDtype(sizes, boolDtype);
|
||||||
|
Value cstTrue =
|
||||||
|
rewriter.create<Torch::ConstantBoolOp>(loc, rewriter.getBoolAttr(true));
|
||||||
|
Value maskedRows = rewriter.create<AtenAllDimOp>(
|
||||||
|
loc, maskedRowsType, masked, op.getDim(), cstTrue);
|
||||||
|
Value cstZero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0.0,
|
||||||
|
resultTensorDtype);
|
||||||
|
rewriter.replaceOpWithNewOp<AtenWhereScalarSelfOp>(
|
||||||
|
op, resultTensorType, maskedRows, cstZero, softmax);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) =>
|
// Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) =>
|
||||||
// newGrad = gradOutput * output
|
// newGrad = gradOutput * output
|
||||||
// result = newGrad - output * sum(newGrad, dim))
|
// result = newGrad - output * sum(newGrad, dim))
|
||||||
|
@ -9608,6 +9664,7 @@ public:
|
||||||
patterns);
|
patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAten_SafeSoftmaxOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSoftmaxIntOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSoftmaxIntOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSigmoidOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSigmoidOp>(patterns);
|
||||||
|
|
|
@ -371,6 +371,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
llvm::StringSet<> backendLegalOpsSet) {
|
llvm::StringSet<> backendLegalOpsSet) {
|
||||||
target.addIllegalOp<AtenSoftmaxIntOp>();
|
target.addIllegalOp<AtenSoftmaxIntOp>();
|
||||||
target.addIllegalOp<Aten_SoftmaxOp>();
|
target.addIllegalOp<Aten_SoftmaxOp>();
|
||||||
|
target.addIllegalOp<Aten_SafeSoftmaxOp>();
|
||||||
target.addIllegalOp<Aten_LogSoftmaxOp>();
|
target.addIllegalOp<Aten_LogSoftmaxOp>();
|
||||||
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
||||||
target.addIllegalOp<AtenLogSigmoidOp>();
|
target.addIllegalOp<AtenLogSigmoidOp>();
|
||||||
|
|
|
@ -504,14 +504,6 @@ FX_IMPORTER_XFAIL_SET = {
|
||||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
"ViewSizeFromOtherTensor_basic",
|
"ViewSizeFromOtherTensor_basic",
|
||||||
"WeightNormInterfaceModule_basic",
|
"WeightNormInterfaceModule_basic",
|
||||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
|
||||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
|
||||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
|
||||||
"ScaledDotProductAttentionDifferentModule_basic",
|
|
||||||
"ScaledDotProductAttentionMaskModule_basic",
|
|
||||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
|
||||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
|
||||||
"ScaledDotProductAttentionSameModule_basic",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
|
FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||||
|
@ -826,6 +818,9 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"ReplicationPad2dModule_top0",
|
"ReplicationPad2dModule_top0",
|
||||||
"RsubInt0d_NumToTensor_Module_basic",
|
"RsubInt0d_NumToTensor_Module_basic",
|
||||||
"ScalarImplicitFloatModule_basic",
|
"ScalarImplicitFloatModule_basic",
|
||||||
|
# need aten.all.dim lowering to stablehlo
|
||||||
|
"SafeSoftmaxModule_basic",
|
||||||
|
"SafeSoftmaxNonNoneDtypeModule_basic",
|
||||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||||
|
@ -2770,6 +2765,8 @@ ONNX_XFAIL_SET = {
|
||||||
"ReshapeAliasExpandModule_basic",
|
"ReshapeAliasExpandModule_basic",
|
||||||
"ReshapeExpandModule_basic",
|
"ReshapeExpandModule_basic",
|
||||||
"Rot90DynamicDimsModule_basic",
|
"Rot90DynamicDimsModule_basic",
|
||||||
|
"SafeSoftmaxModule_basic",
|
||||||
|
"SafeSoftmaxNonNoneDtypeModule_basic",
|
||||||
"ScalarConstantTupleModule_basic",
|
"ScalarConstantTupleModule_basic",
|
||||||
"ScalarImplicitFloatModule_basic",
|
"ScalarImplicitFloatModule_basic",
|
||||||
"ScalarImplicitIntModule_basic",
|
"ScalarImplicitIntModule_basic",
|
||||||
|
|
|
@ -348,6 +348,9 @@ def aten〇glu〡shape(self: List[int], dim: int = -1) -> List[int]:
|
||||||
def aten〇_softmax〡shape(self: List[int], dim: int, half_to_float: bool) -> List[int]:
|
def aten〇_softmax〡shape(self: List[int], dim: int, half_to_float: bool) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
def aten〇_safe_softmax〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
|
||||||
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
def aten〇softmax〇int〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
|
def aten〇softmax〇int〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
@ -5426,6 +5429,12 @@ def aten〇_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, half_to_
|
||||||
return torch.float32
|
return torch.float32
|
||||||
return self_dtype
|
return self_dtype
|
||||||
|
|
||||||
|
def aten〇_safe_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int:
|
||||||
|
if dtype is not None:
|
||||||
|
return dtype
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
@check_dtype_function(
|
@check_dtype_function(
|
||||||
# _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, half_to_float=False) +
|
# _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, half_to_float=False) +
|
||||||
_check_tensors_with_the_same_dtype(
|
_check_tensors_with_the_same_dtype(
|
||||||
|
|
|
@ -692,6 +692,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::__lshift__.Scalar : (Tensor, Scalar) -> (Tensor)")
|
emit("aten::__lshift__.Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||||
emit("aten::__rshift__.Scalar : (Tensor, Scalar) -> (Tensor)")
|
emit("aten::__rshift__.Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||||
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
|
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
|
||||||
|
emit("aten::_safe_softmax : (Tensor, int, int?) -> (Tensor)")
|
||||||
emit("aten::mean : (Tensor, int?) -> (Tensor)")
|
emit("aten::mean : (Tensor, int?) -> (Tensor)")
|
||||||
emit("aten::std : (Tensor, bool) -> (Tensor)")
|
emit("aten::std : (Tensor, bool) -> (Tensor)")
|
||||||
emit("aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)")
|
emit("aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)")
|
||||||
|
|
|
@ -1907,6 +1907,52 @@ def _LogSoftmaxModuleStable_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class SafeSoftmaxModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, tensor):
|
||||||
|
return torch.ops.aten._safe_softmax(tensor, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: SafeSoftmaxModule())
|
||||||
|
def SafeSoftmaxModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 2, 4))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class SafeSoftmaxNonNoneDtypeModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, tensor):
|
||||||
|
return torch.ops.aten._safe_softmax(tensor, dim=2, dtype=torch.float64)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: SafeSoftmaxNonNoneDtypeModule())
|
||||||
|
def SafeSoftmaxNonNoneDtypeModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 2, 4))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class SoftplusModule(torch.nn.Module):
|
class SoftplusModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue