[MLIR][TORCH] Add decomposition for aten.randn_like op

This commit decomposes aten.randn_like op into aten.randn.generator op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/1776/head snapshot-20230118.722
Vivek Khandelwal 2023-01-16 17:10:21 +05:30
parent 999fd9036b
commit f9d59eb500
9 changed files with 119 additions and 0 deletions

View File

@ -787,4 +787,6 @@ LTC_XFAIL_SET = {
"ElementwisePreluModule_basic",
"VarMeanBiasedModule_basic",
"VarMeanUnbiasedModule_basic",
"RandnLikeModule_basic",
"RandnLikeDtypeModule_basic",
}

View File

@ -3771,6 +3771,34 @@ def Torch_AtenRandnGeneratorOp : Torch_Op<"aten.randn.generator", [
}];
}
def Torch_AtenRandnLikeOp : Torch_Op<"aten.randn_like", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::randn_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalIntType:$dtype,
AnyTorchOptionalIntType:$layout,
AnyTorchOptionalDeviceType:$device,
AnyTorchOptionalBoolType:$pin_memory,
AnyTorchOptionalIntType:$memory_format
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRandnLikeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenRandnLikeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenTriuOp : Torch_Op<"aten.triu", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -6639,6 +6639,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.randn_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.randint.low\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg2 : !torch.list<int>\n"
" }\n"

View File

@ -3535,6 +3535,40 @@ public:
};
} // namespace
namespace {
// Decompose `aten.randn_like` op into `aten.randn.generator` op.
class DecomposeAtenRandnLikeOp : public OpRewritePattern<AtenRandnLikeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRandnLikeOp op,
PatternRewriter &rewriter) const override {
// Only `none`, `contiguous` and `preserve` memory_format is supported.
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
int64_t memoryFormat;
if (!matchPattern(op.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)))
return rewriter.notifyMatchFailure(
op, "unimplemented: the memory format should be specified in "
"an integer constant");
if (memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
memoryFormat != torch_upstream::MemoryFormat::Preserve)
return rewriter.notifyMatchFailure(
op, "unimplemented: only none, contiguous and preserve "
"memory_format is supported");
}
Value none = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
auto sizeListType =
Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value sizeList =
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getSelf());
rewriter.replaceOpWithNewOp<AtenRandnGeneratorOp>(
op, op.getType(), sizeList, /*generator=*/none, op.getDtype(),
op.getLayout(), op.getDevice(), op.getPinMemory());
return success();
}
};
} // namespace
namespace {
class DecomposeAtenVarMeanOp : public OpRewritePattern<AtenVarMeanOp> {
public:
@ -3704,6 +3738,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposePrimsSqrtOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);

View File

@ -442,6 +442,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<PrimsSqrtOp>();
target.addIllegalOp<AtenRandnOp>();
target.addIllegalOp<AtenRandnGeneratorOp>();
target.addIllegalOp<AtenRandnLikeOp>();
target.addIllegalOp<AtenVarMeanOp>();
for (std::string opName : backendLegalOps) {
target.addLegalOp(OperationName(opName, context));

View File

@ -1039,6 +1039,9 @@ void TypeAnalysis::visitOperation(Operation *op,
} else if (auto randLike = dyn_cast<AtenRandLikeOp>(op)) {
visitConstantTensorAllocLikeOp<AtenRandLikeOp>(randLike, operands);
return;
} else if (auto randLike = dyn_cast<AtenRandnLikeOp>(op)) {
visitConstantTensorAllocLikeOp<AtenRandnLikeOp>(randLike, operands);
return;
} else if (auto toCopy = dyn_cast<Aten_ToCopyOp>(op)) {
visitConstantTensorAllocLikeOp<Aten_ToCopyOp>(toCopy, operands);
return;

View File

@ -616,6 +616,9 @@ def atencumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None
def atenrand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
return self
def atenrandn_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
return self
def atenrandintlow〡shape(low: int, high: int, size: List[int], dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size

View File

@ -335,6 +335,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit_with_mutating_variants("aten::bernoulli.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)")
emit("aten::randn : (int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::randn.generator : (int[], Generator?, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::randn_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)

View File

@ -364,3 +364,46 @@ class RandnGeneratorModule(torch.nn.Module):
@register_test_case(module_factory=lambda: RandnGeneratorModule())
def RandnGeneratorModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class RandnLikeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, x):
a = torch.ops.aten.randn_like(x)
std = torch.std(a)
return std
@register_test_case(module_factory=lambda: RandnLikeModule())
def RandnLikeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 512, 1024).double())
# ==============================================================================
class RandnLikeDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float64, True),
])
def forward(self, x):
a = torch.ops.aten.randn_like(x, dtype=torch.float32)
std = torch.std(a)
return std
@register_test_case(module_factory=lambda: RandnLikeDtypeModule())
def RandnLikeDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(256, 1024).double())