mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] Support aten.native_dropout (#2259)
* [Torch Dialect] Support aten.native_dropout * updatepull/2269/head snapshot-20230627.882
parent
1ea2b57ab7
commit
859885c1d3
|
@ -699,6 +699,9 @@ STABLEHLO_PASS_SET = {
|
|||
"NewZerosStaticModuleLayoutStrided_basic",
|
||||
"DropoutEvalIntModule_basic",
|
||||
"DropoutEvalFloatModule_basic",
|
||||
"DropoutTrainStaticShapeModule_basic",
|
||||
"NativeDropoutEvalFloatModule_basic",
|
||||
"NativeDropoutTrainStaticShapeModule_basic",
|
||||
"ContiguousModule_basic",
|
||||
"DropoutModule_basic",
|
||||
"ViewCollapseModule_basic",
|
||||
|
@ -1258,6 +1261,9 @@ LTC_XFAIL_SET = {
|
|||
"BernoulliModule_basic",
|
||||
"BernoulliPModule_basic",
|
||||
"DropoutTrainModule_basic",
|
||||
"DropoutTrainStaticShapeModule_basic",
|
||||
"NativeDropoutTrainModule_basic",
|
||||
"NativeDropoutTrainStaticShapeModule_basic",
|
||||
"StdCorrectionKeepDimModule_basic",
|
||||
"StdCorrectionNoneModule_basic",
|
||||
"VarBiasedModule_basic",
|
||||
|
|
|
@ -6437,6 +6437,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.native_dropout\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<bool>) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" return %1 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.gelu\"(%arg0: !torch.list<int>, %arg1: !torch.str) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -8244,6 +8249,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.native_dropout\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.optional<bool>) -> !torch.tuple<int, int> {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.prim.TupleConstruct %0#1, %int11 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" return %1 : !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.expand_as\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -2128,6 +2128,58 @@ public:
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class DeomposeAtenNativeDropoutOp
|
||||
: public OpRewritePattern<AtenNativeDropoutOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenNativeDropoutOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
Value input = op.getInput();
|
||||
Value prob = op.getP();
|
||||
bool train = false;
|
||||
if (!op.getTrain().getType().isa<Torch::NoneType>()) {
|
||||
if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "train must be a boolean constant or none");
|
||||
}
|
||||
}
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
||||
if (!train) {
|
||||
Value i1Type =
|
||||
getDtypeIntValueForType(rewriter, loc, IntegerType::get(context, 1));
|
||||
Value inputSize = rewriter.create<AtenSizeOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(context)), input);
|
||||
Value trueValue = rewriter.create<ConstantIntOp>(loc, 1);
|
||||
Value trueMask = rewriter.create<AtenFullOp>(
|
||||
loc, op->getResultTypes()[1], inputSize, trueValue, i1Type,
|
||||
/*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal);
|
||||
rewriter.replaceOp(op, ArrayRef<Value>{input, trueMask});
|
||||
return success();
|
||||
}
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support floating type input for training mode");
|
||||
}
|
||||
Value floatOne =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||
Value oneMinusP = rewriter.create<AtenSubFloatOp>(loc, floatOne, prob);
|
||||
Value boolMask = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
|
||||
loc, inputType, input, oneMinusP, /*generator=*/noneVal);
|
||||
Value maskedInput =
|
||||
rewriter.create<AtenMulTensorOp>(loc, inputType, boolMask, input);
|
||||
Value output = rewriter.create<AtenDivScalarOp>(
|
||||
loc, op->getResultTypes()[0], maskedInput, oneMinusP);
|
||||
rewriter.replaceOp(
|
||||
op, ArrayRef<Value>{
|
||||
output, convertTensorToDtype(rewriter, loc, boolMask,
|
||||
IntegerType::get(context, 1))});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.var into: aten.var.dim op.
|
||||
|
@ -4654,6 +4706,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenCopyOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenDropoutOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DeomposeAtenNativeDropoutOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);
|
||||
|
|
|
@ -440,6 +440,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenExpandAsOp>();
|
||||
target.addIllegalOp<Aten_ToCopyOp>();
|
||||
target.addIllegalOp<AtenDropoutOp>();
|
||||
target.addIllegalOp<AtenNativeDropoutOp>();
|
||||
target.addIllegalOp<AtenNewEmptyOp>();
|
||||
target.addIllegalOp<AtenIndexPutHackedTwinOp>();
|
||||
target.addIllegalOp<AtenPadOp>();
|
||||
|
|
|
@ -203,6 +203,10 @@ def aten〇type_as〡shape(self: List[int], other: List[int]) -> List[int]:
|
|||
def aten〇dropout〡shape(input: List[int], p: float, train: bool) -> List[int]:
|
||||
return upstream_shape_functions.unary(input)
|
||||
|
||||
def aten〇native_dropout〡shape(input: List[int], p: float, train: Optional[bool]) -> Tuple[List[int], List[int]]:
|
||||
shape = upstream_shape_functions.unary(input)
|
||||
return shape, shape
|
||||
|
||||
def aten〇gelu〡shape(self: List[int], approximate: str = "none") -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -1458,6 +1462,11 @@ def aten〇dropout〡dtype(input_rank_dtype: Tuple[int, int], p: float, train: b
|
|||
input_rank, input_dtype = input_rank_dtype
|
||||
return input_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, p=0.5, train=False))
|
||||
def aten〇native_dropout〡dtype(input_rank_dtype: Tuple[int, int], p: float, train: Optional[bool]) -> Tuple[int, int]:
|
||||
input_rank, input_dtype = input_rank_dtype
|
||||
return input_dtype, torch.bool
|
||||
|
||||
@check_dtype_function(_check_two_tensor_op())
|
||||
def aten〇expand_as〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -1786,6 +1786,94 @@ class DropoutTrainModule(torch.nn.Module):
|
|||
def DropoutTrainModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1024, 1536))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class DropoutTrainStaticShapeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1024, 1536], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
res = torch.dropout(x, 0.3, train=True)
|
||||
return torch.mean(res), torch.std(res)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: DropoutTrainStaticShapeModule())
|
||||
def DropoutTrainStaticShapeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1024, 1536))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeDropoutEvalFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.native_dropout(x, 0.1, train=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeDropoutEvalFloatModule())
|
||||
def NativeDropoutEvalFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeDropoutTrainModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
res = torch.native_dropout(x, 0.3, train=True)
|
||||
return torch.mean(res[0]), torch.std(res[0]), torch.mean(res[1].to(torch.float32)), torch.std(res[1].to(torch.float32))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeDropoutTrainModule())
|
||||
def NativeDropoutTrainModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1024, 1536))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeDropoutTrainStaticShapeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1024, 1536], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
res = torch.native_dropout(x, 0.3, train=True)
|
||||
return torch.mean(res[0]), torch.std(res[0]), torch.mean(res[1].to(torch.float32)), torch.std(res[1].to(torch.float32))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeDropoutTrainStaticShapeModule())
|
||||
def NativeDropoutTrainStaticShapeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1024, 1536))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
Loading…
Reference in New Issue