[Torch Dialect] Support aten.native_dropout (#2259)

* [Torch Dialect] Support aten.native_dropout

* update
pull/2269/head snapshot-20230627.882
Yuanqiang Liu 2023-06-27 14:19:33 +08:00 committed by GitHub
parent 1ea2b57ab7
commit 859885c1d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 168 additions and 0 deletions

View File

@ -699,6 +699,9 @@ STABLEHLO_PASS_SET = {
"NewZerosStaticModuleLayoutStrided_basic",
"DropoutEvalIntModule_basic",
"DropoutEvalFloatModule_basic",
"DropoutTrainStaticShapeModule_basic",
"NativeDropoutEvalFloatModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"ContiguousModule_basic",
"DropoutModule_basic",
"ViewCollapseModule_basic",
@ -1258,6 +1261,9 @@ LTC_XFAIL_SET = {
"BernoulliModule_basic",
"BernoulliPModule_basic",
"DropoutTrainModule_basic",
"DropoutTrainStaticShapeModule_basic",
"NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"StdCorrectionKeepDimModule_basic",
"StdCorrectionNoneModule_basic",
"VarBiasedModule_basic",

View File

@ -6437,6 +6437,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.native_dropout\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<bool>) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %1 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.gelu\"(%arg0: !torch.list<int>, %arg1: !torch.str) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -8244,6 +8249,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.native_dropout\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.optional<bool>) -> !torch.tuple<int, int> {\n"
" %int11 = torch.constant.int 11\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.TupleConstruct %0#1, %int11 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.expand_as\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -2128,6 +2128,58 @@ public:
return success();
}
};
class DeomposeAtenNativeDropoutOp
: public OpRewritePattern<AtenNativeDropoutOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNativeDropoutOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *context = op->getContext();
Value input = op.getInput();
Value prob = op.getP();
bool train = false;
if (!op.getTrain().getType().isa<Torch::NoneType>()) {
if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) {
return rewriter.notifyMatchFailure(
op, "train must be a boolean constant or none");
}
}
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
if (!train) {
Value i1Type =
getDtypeIntValueForType(rewriter, loc, IntegerType::get(context, 1));
Value inputSize = rewriter.create<AtenSizeOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)), input);
Value trueValue = rewriter.create<ConstantIntOp>(loc, 1);
Value trueMask = rewriter.create<AtenFullOp>(
loc, op->getResultTypes()[1], inputSize, trueValue, i1Type,
/*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal);
rewriter.replaceOp(op, ArrayRef<Value>{input, trueMask});
return success();
}
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
op, "only support floating type input for training mode");
}
Value floatOne =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value oneMinusP = rewriter.create<AtenSubFloatOp>(loc, floatOne, prob);
Value boolMask = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
loc, inputType, input, oneMinusP, /*generator=*/noneVal);
Value maskedInput =
rewriter.create<AtenMulTensorOp>(loc, inputType, boolMask, input);
Value output = rewriter.create<AtenDivScalarOp>(
loc, op->getResultTypes()[0], maskedInput, oneMinusP);
rewriter.replaceOp(
op, ArrayRef<Value>{
output, convertTensorToDtype(rewriter, loc, boolMask,
IntegerType::get(context, 1))});
return success();
}
};
} // namespace
// Decompose aten.var into: aten.var.dim op.
@ -4654,6 +4706,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCopyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenDropoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DeomposeAtenNativeDropoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);

View File

@ -440,6 +440,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenExpandAsOp>();
target.addIllegalOp<Aten_ToCopyOp>();
target.addIllegalOp<AtenDropoutOp>();
target.addIllegalOp<AtenNativeDropoutOp>();
target.addIllegalOp<AtenNewEmptyOp>();
target.addIllegalOp<AtenIndexPutHackedTwinOp>();
target.addIllegalOp<AtenPadOp>();

View File

@ -203,6 +203,10 @@ def atentype_as〡shape(self: List[int], other: List[int]) -> List[int]:
def atendropout〡shape(input: List[int], p: float, train: bool) -> List[int]:
return upstream_shape_functions.unary(input)
def atennative_dropout〡shape(input: List[int], p: float, train: Optional[bool]) -> Tuple[List[int], List[int]]:
shape = upstream_shape_functions.unary(input)
return shape, shape
def atengelu〡shape(self: List[int], approximate: str = "none") -> List[int]:
return upstream_shape_functions.unary(self)
@ -1458,6 +1462,11 @@ def atendropout〡dtype(input_rank_dtype: Tuple[int, int], p: float, train: b
input_rank, input_dtype = input_rank_dtype
return input_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, p=0.5, train=False))
def atennative_dropout〡dtype(input_rank_dtype: Tuple[int, int], p: float, train: Optional[bool]) -> Tuple[int, int]:
input_rank, input_dtype = input_rank_dtype
return input_dtype, torch.bool
@check_dtype_function(_check_two_tensor_op())
def atenexpand_as〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -1786,6 +1786,94 @@ class DropoutTrainModule(torch.nn.Module):
def DropoutTrainModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1024, 1536))
# ==============================================================================
class DropoutTrainStaticShapeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1024, 1536], torch.float32, True),
])
def forward(self, x):
res = torch.dropout(x, 0.3, train=True)
return torch.mean(res), torch.std(res)
@register_test_case(module_factory=lambda: DropoutTrainStaticShapeModule())
def DropoutTrainStaticShapeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1024, 1536))
# ==============================================================================
class NativeDropoutEvalFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.native_dropout(x, 0.1, train=False)
@register_test_case(module_factory=lambda: NativeDropoutEvalFloatModule())
def NativeDropoutEvalFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class NativeDropoutTrainModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
res = torch.native_dropout(x, 0.3, train=True)
return torch.mean(res[0]), torch.std(res[0]), torch.mean(res[1].to(torch.float32)), torch.std(res[1].to(torch.float32))
@register_test_case(module_factory=lambda: NativeDropoutTrainModule())
def NativeDropoutTrainModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1024, 1536))
# ==============================================================================
class NativeDropoutTrainStaticShapeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1024, 1536], torch.float32, True),
])
def forward(self, x):
res = torch.native_dropout(x, 0.3, train=True)
return torch.mean(res[0]), torch.std(res[0]), torch.mean(res[1].to(torch.float32)), torch.std(res[1].to(torch.float32))
@register_test_case(module_factory=lambda: NativeDropoutTrainStaticShapeModule())
def NativeDropoutTrainStaticShapeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1024, 1536))
# ==============================================================================