From 8a06419980cecda344f092b0ecb1a594ba437d40 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 2 May 2022 16:50:54 +0530 Subject: [PATCH] [MLIR][TORCH] Add E2E support for aten.masked_fill.Scalar op This commit adds lowering of `aten.masked_fill.Scalar` op. This commit also fixes the formatting of the file constant_alloc.py. Signed-Off By: Vivek Khandelwal --- .../TorchToLinalg/Uncategorized.cpp | 16 +- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 2 +- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 6 +- .../jit_ir/build_tools/shape_lib_gen.py | 3 + .../test_suite/constant_alloc.py | 265 +++++++++++++++++- 5 files changed, 277 insertions(+), 15 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 0e7ee4f44..801d18600 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -894,6 +894,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp( threshold); return b.create(loc, predicate, constantZero, grad); } + if (auto maskedFill = dyn_cast(op)) { + AtenMaskedFillScalarOp::Adaptor adaptor(operands); + Type dtype = converter->convertType(maskedFill.getType()) + .cast() + .getElementType(); + + Value input = payloadArgs[0]; + Value mask = payloadArgs[1]; + Value fillValue = convertScalarToDtype(b, loc, adaptor.value(), dtype); + + return b.create(loc, mask, fillValue, input); + } op->emitError("unimplemented lowering in " "createLinalgPayloadCalculationForElementwiseOp"); @@ -939,7 +951,7 @@ public: AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, - AtenNeScalarOp, AtenNegOp>(op)) + AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1657,7 +1669,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, - AtenSinOp, AtenCosOp, AtenNeScalarOp>(); + AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 34b1955df..8fe9f37f9 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -499,7 +499,7 @@ ChangeResult TypeAnalyzer::visitOperation( AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp, ValsemVariantAtenCopyOp, ValsemVariantAtenZeroOp, - AtenIndexPutHackedTwinOp>(op)) { + AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp>(op)) { ValueKnowledge knowledge = ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); knowledge.dtype = operands[0]->getValue().dtype; diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index eae7687e0..84e8d3e96 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -1981,6 +1981,10 @@ module { %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list } + func @"__torch_mlir_shape_fn.aten.masked_fill.Scalar"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list { + %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func @"__torch_mlir_shape_fn.aten.zero"(%arg0: !torch.list) -> !torch.list { return %arg0 : !torch.list } @@ -2047,7 +2051,7 @@ module { torch.prim.If.yield } %2 = torch.aten.sub.float %arg1, %arg0 : !torch.float, !torch.float -> !torch.float - %3 = torch.operator "aten.div.float"(%2, %arg2) : (!torch.float, !torch.float) -> !torch.float + %3 = torch.aten.div.float %2, %arg2 : !torch.float, !torch.float -> !torch.float %4 = torch.operator "aten.ceil.float"(%3) : (!torch.float) -> !torch.int %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list return %5 : !torch.list diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index eeb1c38e4..66dceccf8 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -625,6 +625,9 @@ def aten〇new_empty(self: List[int], size: List[int], dtype: Optional[int] = No def aten〇_to_copy(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> List[int]: return upstream_shape_helpers.unary(self) +def aten〇masked_fill〇Scalar(self: List[int], mask: List[int], value: float) -> List[int]: + return upstream_shape_helpers.unary(self) + @not_present_in_registry def aten〇zero(self: List[int]) -> List[int]: return self diff --git a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index 61444167a..5b3905d5f 100644 --- a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -11,7 +11,9 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export # ============================================================================== + class ZerosModuleDefaultDtype(torch.nn.Module): + def __init__(self): super().__init__() @@ -22,12 +24,14 @@ class ZerosModuleDefaultDtype(torch.nn.Module): def forward(self): return torch.zeros(3, 4) + @register_test_case(module_factory=lambda: ZerosModuleDefaultDtype()) def ZerosModuleDefaultDtype_basic(module, tu: TestUtils): module.forward() class ZerosModuleInt2D(torch.nn.Module): + def __init__(self): super().__init__() @@ -38,12 +42,14 @@ class ZerosModuleInt2D(torch.nn.Module): def forward(self): return torch.zeros(3, 4, dtype=torch.int64) + @register_test_case(module_factory=lambda: ZerosModuleInt2D()) def ZerosModuleInt2D_basic(module, tu: TestUtils): module.forward() class ZerosModuleInt3D(torch.nn.Module): + def __init__(self): super().__init__() @@ -54,12 +60,14 @@ class ZerosModuleInt3D(torch.nn.Module): def forward(self): return torch.zeros(3, 4, 5, dtype=torch.int64) + @register_test_case(module_factory=lambda: ZerosModuleInt3D()) def ZerosModuleInt3D_basic(module, tu: TestUtils): module.forward() class ZerosModuleFloat2D(torch.nn.Module): + def __init__(self): super().__init__() @@ -70,12 +78,14 @@ class ZerosModuleFloat2D(torch.nn.Module): def forward(self): return torch.zeros(3, 4, dtype=torch.float32) + @register_test_case(module_factory=lambda: ZerosModuleFloat2D()) def ZerosModuleFloat2D_basic(module, tu: TestUtils): module.forward() class ZerosModuleFloat3D(torch.nn.Module): + def __init__(self): super().__init__() @@ -86,12 +96,14 @@ class ZerosModuleFloat3D(torch.nn.Module): def forward(self): return torch.zeros(3, 4, 5, dtype=torch.float32) + @register_test_case(module_factory=lambda: ZerosModuleFloat3D()) def ZerosModuleFloat3D_basic(module, tu: TestUtils): module.forward() class ZerosModuleFalsePinMemory(torch.nn.Module): + def __init__(self): super().__init__() @@ -102,13 +114,17 @@ class ZerosModuleFalsePinMemory(torch.nn.Module): def forward(self): return torch.zeros(3, 4, dtype=torch.float32, pin_memory=False) + @register_test_case(module_factory=lambda: ZerosModuleFalsePinMemory()) def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils): module.forward() + # ============================================================================== + class OnesModuleDefaultDtype(torch.nn.Module): + def __init__(self): super().__init__() @@ -119,12 +135,14 @@ class OnesModuleDefaultDtype(torch.nn.Module): def forward(self): return torch.ones(3, 4) + @register_test_case(module_factory=lambda: OnesModuleDefaultDtype()) def OnesModuleDefaultDtype_basic(module, tu: TestUtils): module.forward() class OnesModuleInt(torch.nn.Module): + def __init__(self): super().__init__() @@ -135,12 +153,14 @@ class OnesModuleInt(torch.nn.Module): def forward(self): return torch.ones(3, 4, dtype=torch.int64) + @register_test_case(module_factory=lambda: OnesModuleInt()) def OnesModuleInt_basic(module, tu: TestUtils): module.forward() class OnesModuleFloat(torch.nn.Module): + def __init__(self): super().__init__() @@ -151,12 +171,14 @@ class OnesModuleFloat(torch.nn.Module): def forward(self): return torch.ones(3, 4, dtype=torch.float32) + @register_test_case(module_factory=lambda: OnesModuleFloat()) def OnesModuleFloat_basic(module, tu: TestUtils): module.forward() class OnesModuleFalsePinMemory(torch.nn.Module): + def __init__(self): super().__init__() @@ -167,13 +189,17 @@ class OnesModuleFalsePinMemory(torch.nn.Module): def forward(self): return torch.ones(3, 4, dtype=torch.float32, pin_memory=False) + @register_test_case(module_factory=lambda: OnesModuleFalsePinMemory()) def OnesModuleFalsePinMemory_basic(module, tu: TestUtils): module.forward() + # ============================================================================== + class EmptyContiguousModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -185,12 +211,14 @@ class EmptyContiguousModule(torch.nn.Module): return torch.empty((3, 4), memory_format=torch.contiguous_format).fill_(0) + @register_test_case(module_factory=lambda: EmptyContiguousModule()) def EmptyModule_contiguous(module, tu: TestUtils): module.forward() class EmptyDefaultDtypeModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -201,12 +229,14 @@ class EmptyDefaultDtypeModule(torch.nn.Module): def forward(self): return torch.empty((3, 4)).fill_(0) + @register_test_case(module_factory=lambda: EmptyDefaultDtypeModule()) def EmptyModule_defaultDtype(module, tu: TestUtils): module.forward() class EmptyIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -217,12 +247,14 @@ class EmptyIntModule(torch.nn.Module): def forward(self): return torch.empty((3, 4), dtype=torch.int64).fill_(0) + @register_test_case(module_factory=lambda: EmptyIntModule()) def EmptyModule_int(module, tu: TestUtils): module.forward() class EmptyFloatModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -233,12 +265,14 @@ class EmptyFloatModule(torch.nn.Module): def forward(self): return torch.empty((3, 4), dtype=torch.float32).fill_(0) + @register_test_case(module_factory=lambda: EmptyFloatModule()) def EmptyModule_float(module, tu: TestUtils): module.forward() class EmptyFalsePinMemoryModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -247,16 +281,20 @@ class EmptyFalsePinMemoryModule(torch.nn.Module): None, ]) def forward(self): - return torch.empty((3, 4), dtype=torch.float32, + return torch.empty((3, 4), dtype=torch.float32, pin_memory=False).fill_(0) + @register_test_case(module_factory=lambda: EmptyFalsePinMemoryModule()) def EmptyModule_falsePinMemory(module, tu: TestUtils): module.forward() + # ============================================================================== + class EmptyLikeDefaultDtypeModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -268,12 +306,14 @@ class EmptyLikeDefaultDtypeModule(torch.nn.Module): def forward(self, a): return torch.empty_like(a).fill_(0) + @register_test_case(module_factory=lambda: EmptyLikeDefaultDtypeModule()) def EmptyLikeModule_defaultDtype(module, tu: TestUtils): module.forward(tu.rand(3, 5)) class EmptyLikeIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -285,12 +325,14 @@ class EmptyLikeIntModule(torch.nn.Module): def forward(self, a): return torch.empty_like(a, dtype=torch.int32).fill_(0) + @register_test_case(module_factory=lambda: EmptyLikeIntModule()) def EmptyLikeModule_int(module, tu: TestUtils): module.forward(torch.randint(10, (3, 5))) class EmptyLikeFloatModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -302,12 +344,14 @@ class EmptyLikeFloatModule(torch.nn.Module): def forward(self, a): return torch.empty_like(a, dtype=torch.float32).fill_(0) + @register_test_case(module_factory=lambda: EmptyLikeFloatModule()) def EmptyLikeModule_float(module, tu: TestUtils): module.forward(tu.rand(4, 5)) class EmptyLikeFalsePinMemoryModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -320,13 +364,17 @@ class EmptyLikeFalsePinMemoryModule(torch.nn.Module): return torch.empty_like(a, dtype=torch.float64, pin_memory=False).fill_(0) + @register_test_case(module_factory=lambda: EmptyLikeFalsePinMemoryModule()) def EmptyLikeModule_falsePinMemory(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + # ============================================================================== + class ZerosLikeDefaultDtypeModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -338,12 +386,14 @@ class ZerosLikeDefaultDtypeModule(torch.nn.Module): def forward(self, a): return torch.zeros_like(a) + @register_test_case(module_factory=lambda: ZerosLikeDefaultDtypeModule()) def ZerosLikeModule_defaultDtype(module, tu: TestUtils): module.forward(tu.rand(3, 5)) class ZerosLikeIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -355,12 +405,14 @@ class ZerosLikeIntModule(torch.nn.Module): def forward(self, a): return torch.zeros_like(a, dtype=torch.int32) + @register_test_case(module_factory=lambda: ZerosLikeIntModule()) def ZerosLikeModule_int(module, tu: TestUtils): module.forward(torch.randint(10, (3, 5))) class ZerosLikeFloatModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -372,12 +424,14 @@ class ZerosLikeFloatModule(torch.nn.Module): def forward(self, a): return torch.zeros_like(a, dtype=torch.float32) + @register_test_case(module_factory=lambda: ZerosLikeFloatModule()) def ZerosLikeModule_float(module, tu: TestUtils): module.forward(tu.rand(4, 5)) class ZerosLikeFalsePinMemoryModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -389,13 +443,17 @@ class ZerosLikeFalsePinMemoryModule(torch.nn.Module): def forward(self, a): return torch.zeros_like(a, dtype=torch.float64, pin_memory=False) + @register_test_case(module_factory=lambda: ZerosLikeFalsePinMemoryModule()) def ZerosLikeModule_falsePinMemory(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + # ============================================================================== + class OnesLikeDefaultDtypeModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -407,12 +465,14 @@ class OnesLikeDefaultDtypeModule(torch.nn.Module): def forward(self, a): return torch.ones_like(a) + @register_test_case(module_factory=lambda: OnesLikeDefaultDtypeModule()) def OnesLikeModule_defaultDtype(module, tu: TestUtils): module.forward(tu.rand(3, 5)) class OnesLikeIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -424,12 +484,14 @@ class OnesLikeIntModule(torch.nn.Module): def forward(self, a): return torch.ones_like(a, dtype=torch.int32) + @register_test_case(module_factory=lambda: OnesLikeIntModule()) def OnesLikeModule_int(module, tu: TestUtils): module.forward(torch.randint(10, (3, 5))) class OnesLikeFloatModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -441,12 +503,14 @@ class OnesLikeFloatModule(torch.nn.Module): def forward(self, a): return torch.ones_like(a, dtype=torch.float32) + @register_test_case(module_factory=lambda: OnesLikeFloatModule()) def OnesLikeModule_float(module, tu: TestUtils): module.forward(tu.rand(4, 5)) class OnesLikeFalsePinMemoryModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -458,13 +522,17 @@ class OnesLikeFalsePinMemoryModule(torch.nn.Module): def forward(self, a): return torch.ones_like(a, dtype=torch.float64, pin_memory=False) + @register_test_case(module_factory=lambda: OnesLikeFalsePinMemoryModule()) def OnesLikeModule_falsePinMemory(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + # ============================================================================== + class Fill_TensorFloat64WithFloat32(torch.nn.Module): + def __init__(self): super().__init__() @@ -476,12 +544,14 @@ class Fill_TensorFloat64WithFloat32(torch.nn.Module): def forward(self, tensor): return torch.ops.aten.fill_(tensor, 3.0) + @register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat32()) def Fill_TensorFloat64WithFloat32_basic(module, tu: TestUtils): module.forward(torch.randn(3, 2, 4)) class Fill_TensorFloat64WithFloat64(torch.nn.Module): + def __init__(self): super().__init__() @@ -493,12 +563,14 @@ class Fill_TensorFloat64WithFloat64(torch.nn.Module): def forward(self, tensor): return torch.ops.aten.fill_(tensor, 3.0) + @register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat64()) def Fill_TensorFloat64WithFloat64_basic(module, tu: TestUtils): module.forward(torch.randn(3, 2, 4).to(torch.float64)) class Fill_TensorFloat64WithInt64(torch.nn.Module): + def __init__(self): super().__init__() @@ -510,6 +582,7 @@ class Fill_TensorFloat64WithInt64(torch.nn.Module): def forward(self, tensor): return torch.ops.aten.fill_(tensor, 3) + @register_test_case(module_factory=lambda: Fill_TensorFloat64WithInt64()) def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils): module.forward(torch.randn(3, 2, 4).to(torch.float64)) @@ -517,7 +590,9 @@ def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils): # ============================================================================== + class NewZerosModuleDefaultDtype(torch.nn.Module): + def __init__(self): super().__init__() @@ -529,12 +604,14 @@ class NewZerosModuleDefaultDtype(torch.nn.Module): def forward(self, a): return torch.ops.aten.new_zeros(a, [3, 4]) + @register_test_case(module_factory=lambda: NewZerosModuleDefaultDtype()) def NewZerosModuleDefaultDtype_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3)) class NewZerosModuleInt2D(torch.nn.Module): + def __init__(self): super().__init__() @@ -546,12 +623,14 @@ class NewZerosModuleInt2D(torch.nn.Module): def forward(self, a): return torch.ops.aten.new_zeros(a, [3, 4], dtype=torch.int64) + @register_test_case(module_factory=lambda: NewZerosModuleInt2D()) def NewZerosModuleInt2D_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) class NewZerosModuleInt3D(torch.nn.Module): + def __init__(self): super().__init__() @@ -563,12 +642,14 @@ class NewZerosModuleInt3D(torch.nn.Module): def forward(self, a): return torch.ops.aten.new_zeros(a, [3, 4, 5], dtype=torch.int64) + @register_test_case(module_factory=lambda: NewZerosModuleInt3D()) def NewZerosModuleInt3D_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3)) class NewZerosModuleFloat2D(torch.nn.Module): + def __init__(self): super().__init__() @@ -580,12 +661,14 @@ class NewZerosModuleFloat2D(torch.nn.Module): def forward(self, a): return torch.ops.aten.new_zeros(a, [3, 4], dtype=torch.float32) + @register_test_case(module_factory=lambda: NewZerosModuleFloat2D()) def NewZerosModuleFloat2D_basic(module, tu: TestUtils): module.forward(torch.randint(10, (2, 3, 4))) class NewZerosModuleFloat3D(torch.nn.Module): + def __init__(self): super().__init__() @@ -597,12 +680,14 @@ class NewZerosModuleFloat3D(torch.nn.Module): def forward(self, a): return torch.ops.aten.new_zeros(a, [3, 4, 5], dtype=torch.float32) + @register_test_case(module_factory=lambda: NewZerosModuleFloat3D()) def NewZerosModuleFloat3D_basic(module, tu: TestUtils): module.forward(torch.randint(10, (2, 3))) class NewZerosModuleFalsePinMemory(torch.nn.Module): + def __init__(self): super().__init__() @@ -612,15 +697,21 @@ class NewZerosModuleFalsePinMemory(torch.nn.Module): ([-1, -1], torch.int64, True), ]) def forward(self, a): - return torch.ops.aten.new_zeros(a, [3, 4], dtype=torch.float32, pin_memory=False) + return torch.ops.aten.new_zeros(a, [3, 4], + dtype=torch.float32, + pin_memory=False) + @register_test_case(module_factory=lambda: NewZerosModuleFalsePinMemory()) def NewZerosModuleFalsePinMemory_basic(module, tu: TestUtils): module.forward(torch.randint(10, (2, 3))) + # ============================================================================== + class NewOnesModuleDefaultDtype(torch.nn.Module): + def __init__(self): super().__init__() @@ -632,12 +723,14 @@ class NewOnesModuleDefaultDtype(torch.nn.Module): def forward(self, a): return torch.ops.aten.new_ones(a, [3, 4]) + @register_test_case(module_factory=lambda: NewOnesModuleDefaultDtype()) def NewOnesModuleDefaultDtype_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3)) class NewOnesModuleInt2D(torch.nn.Module): + def __init__(self): super().__init__() @@ -649,12 +742,14 @@ class NewOnesModuleInt2D(torch.nn.Module): def forward(self, a): return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.int64) + @register_test_case(module_factory=lambda: NewOnesModuleInt2D()) def NewOnesModuleInt2D_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) class NewOnesModuleInt3D(torch.nn.Module): + def __init__(self): super().__init__() @@ -666,12 +761,14 @@ class NewOnesModuleInt3D(torch.nn.Module): def forward(self, a): return torch.ops.aten.new_ones(a, [3, 4, 5], dtype=torch.int64) + @register_test_case(module_factory=lambda: NewOnesModuleInt3D()) def NewOnesModuleInt3D_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3)) class NewOnesModuleFloat2D(torch.nn.Module): + def __init__(self): super().__init__() @@ -683,12 +780,14 @@ class NewOnesModuleFloat2D(torch.nn.Module): def forward(self, a): return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.float32) + @register_test_case(module_factory=lambda: NewOnesModuleFloat2D()) def NewOnesModuleFloat2D_basic(module, tu: TestUtils): module.forward(torch.randint(10, (2, 3, 4))) class NewOnesModuleFloat3D(torch.nn.Module): + def __init__(self): super().__init__() @@ -700,12 +799,14 @@ class NewOnesModuleFloat3D(torch.nn.Module): def forward(self, a): return torch.ops.aten.new_ones(a, [3, 4, 5], dtype=torch.float32) + @register_test_case(module_factory=lambda: NewOnesModuleFloat3D()) def NewOnesModuleFloat3D_basic(module, tu: TestUtils): module.forward(torch.randint(10, (2, 3))) class NewOnesModuleFalsePinMemory(torch.nn.Module): + def __init__(self): super().__init__() @@ -715,15 +816,21 @@ class NewOnesModuleFalsePinMemory(torch.nn.Module): ([-1, -1], torch.int64, True), ]) def forward(self, a): - return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.float32, pin_memory=False) + return torch.ops.aten.new_ones(a, [3, 4], + dtype=torch.float32, + pin_memory=False) + @register_test_case(module_factory=lambda: NewOnesModuleFalsePinMemory()) def NewOnesModuleFalsePinMemory_basic(module, tu: TestUtils): module.forward(torch.randint(10, (2, 3))) + # ============================================================================== + class FullModuleDefaultDtype(torch.nn.Module): + def __init__(self): super().__init__() @@ -734,12 +841,14 @@ class FullModuleDefaultDtype(torch.nn.Module): def forward(self): return torch.ops.aten.full([2, 3], 5.0) + @register_test_case(module_factory=lambda: FullModuleDefaultDtype()) def FullModuleDefaultDtype_basic(module, tu: TestUtils): module.forward() class FullModuleInt2D(torch.nn.Module): + def __init__(self): super().__init__() @@ -750,12 +859,14 @@ class FullModuleInt2D(torch.nn.Module): def forward(self): return torch.ops.aten.full([10, 5], 10.5, dtype=torch.int64) + @register_test_case(module_factory=lambda: FullModuleInt2D()) def FullModuleInt2D_basic(module, tu: TestUtils): module.forward() class FullModuleInt3D(torch.nn.Module): + def __init__(self): super().__init__() @@ -766,12 +877,14 @@ class FullModuleInt3D(torch.nn.Module): def forward(self): return torch.ops.aten.full([2, 3, 4], 5) + @register_test_case(module_factory=lambda: FullModuleInt3D()) def FullModuleInt3D_basic(module, tu: TestUtils): module.forward() class FullModuleFloat2D(torch.nn.Module): + def __init__(self): super().__init__() @@ -782,12 +895,14 @@ class FullModuleFloat2D(torch.nn.Module): def forward(self): return torch.ops.aten.full([10, 5], 10, dtype=torch.float32) + @register_test_case(module_factory=lambda: FullModuleFloat2D()) def FullModuleFloat2D_basic(module, tu: TestUtils): module.forward() class FullModuleFloat3D(torch.nn.Module): + def __init__(self): super().__init__() @@ -798,12 +913,14 @@ class FullModuleFloat3D(torch.nn.Module): def forward(self): return torch.ops.aten.full([2, 3, 4], 5.0) + @register_test_case(module_factory=lambda: FullModuleFloat3D()) def FullModuleFloat3D_basic(module, tu: TestUtils): module.forward() class FullModuleFalsePinMemory(torch.nn.Module): + def __init__(self): super().__init__() @@ -812,15 +929,22 @@ class FullModuleFalsePinMemory(torch.nn.Module): None, ]) def forward(self): - return torch.ops.aten.full([2, 3], 5.0, dtype=torch.int64, pin_memory=False) + return torch.ops.aten.full([2, 3], + 5.0, + dtype=torch.int64, + pin_memory=False) + @register_test_case(module_factory=lambda: FullModuleFalsePinMemory()) def FullModuleFalsePinMemory_basic(module, tu: TestUtils): module.forward() + # ============================================================================== + class FullLikeModuleDefaultDtype(torch.nn.Module): + def __init__(self): super().__init__() @@ -832,12 +956,14 @@ class FullLikeModuleDefaultDtype(torch.nn.Module): def forward(self, a): return torch.ops.aten.full_like(a, 5) + @register_test_case(module_factory=lambda: FullLikeModuleDefaultDtype()) def FullLikeModuleDefaultDtype_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3)) class FullLikeModuleInt2D(torch.nn.Module): + def __init__(self): super().__init__() @@ -849,12 +975,14 @@ class FullLikeModuleInt2D(torch.nn.Module): def forward(self, a): return torch.ops.aten.full_like(a, 10.5) + @register_test_case(module_factory=lambda: FullLikeModuleInt2D()) def FullLikeModuleInt2D_basic(module, tu: TestUtils): module.forward(torch.randint(10, (4, 5))) class FullLikeModuleInt3D(torch.nn.Module): + def __init__(self): super().__init__() @@ -866,12 +994,14 @@ class FullLikeModuleInt3D(torch.nn.Module): def forward(self, a): return torch.ops.aten.full_like(a, 5.0, dtype=torch.int64) + @register_test_case(module_factory=lambda: FullLikeModuleInt3D()) def FullLikeModuleInt3D_basic(module, tu: TestUtils): module.forward(torch.randint(100, (10, 4, 5)).to(torch.int32)) class FullLikeModuleInt2DStatic(torch.nn.Module): + def __init__(self): super().__init__() @@ -883,12 +1013,14 @@ class FullLikeModuleInt2DStatic(torch.nn.Module): def forward(self, a): return torch.ops.aten.full_like(a, 10) + @register_test_case(module_factory=lambda: FullLikeModuleInt2DStatic()) def FullLikeModuleInt2DStatic_basic(module, tu: TestUtils): module.forward(torch.randint(10, (4, 5))) class FullLikeModuleFloat2D(torch.nn.Module): + def __init__(self): super().__init__() @@ -900,12 +1032,14 @@ class FullLikeModuleFloat2D(torch.nn.Module): def forward(self, a): return torch.ops.aten.full_like(a, 10) + @register_test_case(module_factory=lambda: FullLikeModuleFloat2D()) def FullLikeModuleFloat2D_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) class FullLikeModuleFloat3D(torch.nn.Module): + def __init__(self): super().__init__() @@ -917,12 +1051,14 @@ class FullLikeModuleFloat3D(torch.nn.Module): def forward(self, a): return torch.ops.aten.full_like(a, 15, dtype=torch.float32) + @register_test_case(module_factory=lambda: FullLikeModuleFloat3D()) def FullLikeModuleFloat3D_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.float64)) class FullLikeModuleFloat3DStatic(torch.nn.Module): + def __init__(self): super().__init__() @@ -934,12 +1070,14 @@ class FullLikeModuleFloat3DStatic(torch.nn.Module): def forward(self, a): return torch.ops.aten.full_like(a, 15.3, dtype=torch.float32) + @register_test_case(module_factory=lambda: FullLikeModuleFloat3DStatic()) def FullLikeModuleFloat3DStatic_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).to(torch.float64)) class FullLikeModuleFalsePinMemory(torch.nn.Module): + def __init__(self): super().__init__() @@ -949,15 +1087,22 @@ class FullLikeModuleFalsePinMemory(torch.nn.Module): ([-1, -1], torch.int64, True), ]) def forward(self, a): - return torch.ops.aten.full_like(a, 5, dtype=torch.int64, pin_memory=False) + return torch.ops.aten.full_like(a, + 5, + dtype=torch.int64, + pin_memory=False) + @register_test_case(module_factory=lambda: FullLikeModuleFalsePinMemory()) def FullLikeModuleFalsePinMemory_basic(module, tu: TestUtils): module.forward(torch.randint(100, (10, 4))) + # ============================================================================== + class ZeroFloat32Module(torch.nn.Module): + def __init__(self): super().__init__() @@ -969,12 +1114,14 @@ class ZeroFloat32Module(torch.nn.Module): def forward(self, tensor): return torch.ops.aten.zero_(tensor) + @register_test_case(module_factory=lambda: ZeroFloat32Module()) def ZeroFloat32Module_basic(module, tu: TestUtils): module.forward(torch.rand(3, 2)) class ZeroInt32Module(torch.nn.Module): + def __init__(self): super().__init__() @@ -986,12 +1133,14 @@ class ZeroInt32Module(torch.nn.Module): def forward(self, tensor): return torch.ops.aten.zero_(tensor) + @register_test_case(module_factory=lambda: ZeroInt32Module()) def ZeroInt32Module_basic(module, tu: TestUtils): module.forward(torch.randint(100, (10, 4), dtype=torch.int32)) class ZeroInt64Module(torch.nn.Module): + def __init__(self): super().__init__() @@ -1003,13 +1152,17 @@ class ZeroInt64Module(torch.nn.Module): def forward(self, tensor): return torch.ops.aten.zero_(tensor) + @register_test_case(module_factory=lambda: ZeroInt64Module()) def ZeroInt64Module_basic(module, tu: TestUtils): module.forward(torch.randint(100, (10, 4))) + # ============================================================================== + class NewEmptyModuleDefaultDtype(torch.nn.Module): + def __init__(self): super().__init__() @@ -1021,12 +1174,14 @@ class NewEmptyModuleDefaultDtype(torch.nn.Module): def forward(self, a): return torch.ops.aten.new_empty(a, [3, 4]).fill_(0) + @register_test_case(module_factory=lambda: NewEmptyModuleDefaultDtype()) def NewEmptyModuleDefaultDtype_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3)) class NewEmptyModuleInt2D(torch.nn.Module): + def __init__(self): super().__init__() @@ -1038,12 +1193,14 @@ class NewEmptyModuleInt2D(torch.nn.Module): def forward(self, a): return torch.ops.aten.new_empty(a, [3, 4], dtype=torch.int64).fill_(0) + @register_test_case(module_factory=lambda: NewEmptyModuleInt2D()) def NewEmptyModuleInt2D_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) class NewEmptyModuleInt3D(torch.nn.Module): + def __init__(self): super().__init__() @@ -1053,7 +1210,9 @@ class NewEmptyModuleInt3D(torch.nn.Module): ([-1, -1], torch.float32, True), ]) def forward(self, a): - return torch.ops.aten.new_empty(a, [3, 4, 5], dtype=torch.int64).fill_(0) + return torch.ops.aten.new_empty(a, [3, 4, 5], + dtype=torch.int64).fill_(0) + @register_test_case(module_factory=lambda: NewEmptyModuleInt3D()) def NewEmptyModuleInt3D_basic(module, tu: TestUtils): @@ -1061,6 +1220,7 @@ def NewEmptyModuleInt3D_basic(module, tu: TestUtils): class NewEmptyModuleFloat2D(torch.nn.Module): + def __init__(self): super().__init__() @@ -1070,7 +1230,9 @@ class NewEmptyModuleFloat2D(torch.nn.Module): ([-1, -1, -1], torch.int64, True), ]) def forward(self, a): - return torch.ops.aten.new_empty(a, [3, 4], dtype=torch.float32).fill_(0) + return torch.ops.aten.new_empty(a, [3, 4], + dtype=torch.float32).fill_(0) + @register_test_case(module_factory=lambda: NewEmptyModuleFloat2D()) def NewEmptyModuleFloat2D_basic(module, tu: TestUtils): @@ -1078,6 +1240,7 @@ def NewEmptyModuleFloat2D_basic(module, tu: TestUtils): class NewEmptyModuleFloat3D(torch.nn.Module): + def __init__(self): super().__init__() @@ -1087,7 +1250,9 @@ class NewEmptyModuleFloat3D(torch.nn.Module): ([-1, -1], torch.int64, True), ]) def forward(self, a): - return torch.ops.aten.new_empty(a, [3, 4, 5], dtype=torch.float32).fill_(0) + return torch.ops.aten.new_empty(a, [3, 4, 5], + dtype=torch.float32).fill_(0) + @register_test_case(module_factory=lambda: NewEmptyModuleFloat3D()) def NewEmptyModuleFloat3D_basic(module, tu: TestUtils): @@ -1095,6 +1260,7 @@ def NewEmptyModuleFloat3D_basic(module, tu: TestUtils): class NewEmptyModuleFalsePinMemory(torch.nn.Module): + def __init__(self): super().__init__() @@ -1104,7 +1270,10 @@ class NewEmptyModuleFalsePinMemory(torch.nn.Module): ([-1, -1], torch.int64, True), ]) def forward(self, a): - return torch.ops.aten.new_empty(a, [3, 4], dtype=torch.float32, pin_memory=False).fill_(0) + return torch.ops.aten.new_empty(a, [3, 4], + dtype=torch.float32, + pin_memory=False).fill_(0) + @register_test_case(module_factory=lambda: NewEmptyModuleFalsePinMemory()) def NewEmptyModuleFalsePinMemory_basic(module, tu: TestUtils): @@ -1112,6 +1281,7 @@ def NewEmptyModuleFalsePinMemory_basic(module, tu: TestUtils): class NewEmptyModuleNonDefaultFloatDtype(torch.nn.Module): + def __init__(self): super().__init__() @@ -1123,12 +1293,15 @@ class NewEmptyModuleNonDefaultFloatDtype(torch.nn.Module): def forward(self, a): return torch.ops.aten.new_empty(a, [3, 4]).fill_(0) -@register_test_case(module_factory=lambda: NewEmptyModuleNonDefaultFloatDtype()) + +@register_test_case( + module_factory=lambda: NewEmptyModuleNonDefaultFloatDtype()) def NewEmptyModuleNonDefaultFloatDtype_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3).to(torch.float64)) class NewEmptyModuleNonDefaultIntDtype(torch.nn.Module): + def __init__(self): super().__init__() @@ -1140,11 +1313,14 @@ class NewEmptyModuleNonDefaultIntDtype(torch.nn.Module): def forward(self, a): return torch.ops.aten.new_empty(a, [3, 4]).fill_(0) + @register_test_case(module_factory=lambda: NewEmptyModuleNonDefaultIntDtype()) def NewEmptyModuleNonDefaultIntDtype_basic(module, tu: TestUtils): module.forward(torch.randint(10, (2, 3)).to(torch.int32)) + class NewEmptyModuleLayoutIntDtype(torch.nn.Module): + def __init__(self): super().__init__() @@ -1154,8 +1330,75 @@ class NewEmptyModuleLayoutIntDtype(torch.nn.Module): ([-1, -1], torch.int32, True), ]) def forward(self, a): - return torch.ops.aten.new_empty(a, [3, 4], layout = 0).fill_(0) + return torch.ops.aten.new_empty(a, [3, 4], layout=0).fill_(0) + @register_test_case(module_factory=lambda: NewEmptyModuleLayoutIntDtype()) def NewEmptyModuleLayoutIntDtype_basic(module, tu: TestUtils): module.forward(torch.randint(10, (2, 3)).to(torch.int32)) + + +# ============================================================================== + + +class MaskedFillScalarDefaultModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.bool, True), + ]) + def forward(self, x, mask): + return torch.ops.aten.masked_fill(x, mask, value=0.5) + + +@register_test_case(module_factory=lambda: MaskedFillScalarDefaultModule()) +def MaskedFillScalarDefaultModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), + torch.randint(0, 2, (2, 3)).to(dtype=torch.bool)) + + +class MaskedFillScalarIntValueModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.bool, True), + ]) + def forward(self, x, mask): + return torch.ops.aten.masked_fill(x, mask, value=5) + + +@register_test_case(module_factory=lambda: MaskedFillScalarIntValueModule()) +def MaskedFillScalarIntValueModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), + torch.randint(0, 2, (2, 3)).to(dtype=torch.bool)) + + +class MaskedFillScalarFloatValueModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.bool, True), + ]) + def forward(self, x, mask): + return torch.ops.aten.masked_fill(x, mask, value=-0.01) + + +@register_test_case(module_factory=lambda: MaskedFillScalarFloatValueModule()) +def MaskedFillScalarFloatValueModule_basic(module, tu: TestUtils): + module.forward(torch.randint(-10, 10, (2, 3)), + torch.randint(0, 2, (2, 3)).to(dtype=torch.bool))