From a1d3afdba9f3baaa502654c2ab52fb86f358e6ec Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Sun, 6 Nov 2022 18:14:05 +0530 Subject: [PATCH] [MLIR][TORCH] Add E2E support for aten.randint.low op Signed-Off By: Vivek Khandelwal --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 29 ++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 54 ++++++++++++++++++- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 11 ++++ lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 3 ++ .../jit_ir/build_tools/shape_lib_gen.py | 3 ++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + python/torch_mlir_e2e_test/test_suite/rng.py | 40 ++++++++++++++ 7 files changed, 140 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7e7bcd05c..5ec9b7032 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -3334,6 +3334,35 @@ def Torch_AtenBernoulli_FloatOp : Torch_Op<"aten.bernoulli_.float", [ }]; } +def Torch_AtenRandintLowOp : Torch_Op<"aten.randint.low", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + Torch_IntType:$low, + Torch_IntType:$high, + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRandintLowOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenRandintLowOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenBernoulliTensorOp : Torch_Op<"aten.bernoulli.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2e9f1748f..21e149b00 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -241,7 +241,7 @@ public: rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value startPlusLength = rewriter.create(loc, one.getType(), start, length); - + rewriter.replaceOpWithNewOp( op, op.getResult().getType(), op.self(), /*dim=*/dim, /*start=*/start, /*end=*/startPlusLength, /*step=*/one); @@ -3063,6 +3063,56 @@ public: }; } // namespace +namespace { +class DecomposeAtenRandintLowOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRandintLowOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Type resultType = op.getType(); + BaseTensorType resultTensorType = resultType.cast(); + + int64_t cstLow, cstHigh; + if (!matchPattern(op.low(), m_TorchConstantInt(&cstLow))) + return rewriter.notifyMatchFailure( + op, "unimplemented: low must be a constant integer"); + if (!matchPattern(op.high(), m_TorchConstantInt(&cstHigh))) + return rewriter.notifyMatchFailure( + op, "unimplemented: high must be a constant integer"); + + Value none = rewriter.create(loc); + Value cstFalse = rewriter.create(loc, false); + Value low = rewriter.create( + loc, rewriter.getF64FloatAttr((double)cstLow)); + Value high = rewriter.create( + loc, rewriter.getF64FloatAttr((double)cstHigh)); + + BaseTensorType floatResultType = + resultTensorType + .getWithSizesAndDtype(resultTensorType.getSizes(), + rewriter.getF32Type()) + .cast(); + Value emptyTensor = rewriter.create( + loc, floatResultType, op.size(), /*dtype=*/none, /*layout=*/op.layout(), + /*device=*/op.device(), /*pin_memory=*/op.pin_memory(), + /*memory_format=*/none); + + Value result = + rewriter.create(loc, floatResultType, emptyTensor, + /*from=*/low, + /*to=*/high, + /*generator=*/none); + rewriter.replaceOpWithNewOp( + op, resultType, result, + getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()), + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -3264,6 +3314,8 @@ public: target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); for (std::string opName : legalOps) { target.addLegalOp(OperationName(opName, context)); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 333e6fb31..1268caa4d 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -1158,6 +1158,17 @@ void TypeAnalysis::visitOperation(Operation *op, return; } + if (auto randIntLow = dyn_cast(op)) { + auto knowledge = + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); + Type defaultDtype = + IntegerType::get(op->getContext(), 64, IntegerType::Signed); + knowledge.dtype = + getDtypeOrDefault(op->getContext(), randIntLow.dtype(), defaultDtype); + incorporateKnowledge(randIntLow.getResult(), knowledge); + return; + } + // Otherwise, this is an unknown operation, so reset the state. setAllToEntryStates(results); return; diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index ab73458d4..d0455e0bd 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -6327,6 +6327,9 @@ StringRef mlir::torch::Torch::getShapeLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.randint.low\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" +" return %arg2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.arange.start_step\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" " %0 = torch.derefine %arg0 : !torch.float to !torch.union\n" " %1 = torch.derefine %arg1 : !torch.float to !torch.union\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 76e6f3d97..10afa3f39 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -828,6 +828,9 @@ def aten〇cumsum(self: List[int], dim: int, dtype: Optional[int] = None) -> Lis def aten〇rand_like(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]: return self +def aten〇randint〇low(low: int, high: int, size: List[int], dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: + return size + def aten〇arange〇start_step(start: float, end: float, step: float = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return upstream_shape_functions.arange_start_step(start, end, step, dtype, layout, device, pin_memory) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index faa694250..08093839f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -324,6 +324,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)") emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)") + emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)") emit_with_mutating_variants("aten::bernoulli.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)") emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/rng.py b/python/torch_mlir_e2e_test/test_suite/rng.py index dcbea55dd..e9a8898f8 100644 --- a/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/python/torch_mlir_e2e_test/test_suite/rng.py @@ -276,3 +276,43 @@ class RandLikeDtypeModule(torch.nn.Module): @register_test_case(module_factory=lambda: RandLikeDtypeModule()) def RandLikeDtypeModule_basic(module, tu: TestUtils): module.forward(tu.rand(1024, 1024).double()) + +# ============================================================================== + +class RandIntLowModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + a = torch.ops.aten.randint(low=1, high=1000, size=[1024, 1024]) + mean = torch.mean(a.to(torch.float32)) + return mean + + +@register_test_case(module_factory=lambda: RandIntLowModule()) +def RandIntLowModule_basic(module, tu: TestUtils): + module.forward() + +# ============================================================================== + +class RandIntLowDtypeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + a = torch.ops.aten.randint(low=1, high=1000, size=[128, 256, 512], dtype=torch.float64) + mean = torch.mean(a) + return mean + + +@register_test_case(module_factory=lambda: RandIntLowDtypeModule()) +def RandIntLowDtypeModule_basic(module, tu: TestUtils): + module.forward()