[MLIR][TORCH] add E2E support for aten.new_full (#2425)

* implement aten.new_full

* remove extraneous tests
pull/2458/head
Arham Khan 2023-09-12 09:29:08 -05:00 committed by GitHub
parent 23b72244b1
commit 82456eefed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 221 additions and 0 deletions

View File

@ -561,6 +561,14 @@ STABLEHLO_PASS_SET = {
"FullModuleFloat3D_basic",
"FullModuleInt2D_basic",
"FullModuleInt3D_basic",
"NewFullModuleDefaultDtype_basic",
"NewFullModuleFalsePinMemory_basic",
"NewFullModuleFloat2D_basic",
"NewFullModuleFloat3DStatic_basic",
"NewFullModuleFloat3D_basic",
"NewFullModuleInt2DStatic_basic",
"NewFullModuleInt2D_basic",
"NewFullModuleInt3D_basic",
"GatherStaticModule_basic",
"GatherModule_basic",
"Gather2DInputModdule_basic",
@ -1149,6 +1157,12 @@ TOSA_PASS_SET = {
"FullLikeModuleFloat3DStatic_basic",
"FullModuleDefaultDtype_basic",
"FullModuleFloat3D_basic",
"NewFullModuleDefaultDtype_basic",
"NewFullModuleFalsePinMemory_basic",
"NewFullModuleFloat2D_basic",
"NewFullModuleFloat3DStatic_basic",
"NewFullModuleFloat3D_basic",
"NewFullModuleInt2DStatic_basic",
"MaskedFillScalarDefaultModule_basic",
"NumToTensorFloatModule_basic",
"LiftFreshCopyModule_basic",

View File

@ -9853,6 +9853,35 @@ def Torch_AtenFullLikeOp : Torch_Op<"aten.full_like", [
}];
}
def Torch_AtenNewFullOp : Torch_Op<"aten.new_full", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$size,
AnyTorchScalarType:$fill_value,
AnyTorchOptionalIntType:$dtype,
AnyTorchOptionalIntType:$layout,
AnyTorchOptionalDeviceType:$device,
AnyTorchOptionalBoolType:$pin_memory
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNewFullOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenNewFullOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}
def Torch_AtenBaddbmmOp : Torch_Op<"aten.baddbmm", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -7226,6 +7226,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.full_like\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.new_full\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.zeros_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -10542,6 +10545,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.new_full\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.number, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
" torch.prim.If.yield %0#1 : !torch.int\n"
" } else {\n"
" %3 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %3 : !torch.int\n"
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.new_zeros\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

View File

@ -3166,6 +3166,33 @@ public:
};
} // namespace
namespace {
// Decompose `aten.new_full` op into `aten.full` op.
class DecomposeAtenNewFullOp : public OpRewritePattern<AtenNewFullOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNewFullOp op,
PatternRewriter &rewriter) const override {
Value dtype = op.getDtype();
if (dtype.getType().isa<Torch::NoneType>()) {
BaseTensorType tensorType = op.getSelf().getType().cast<BaseTensorType>();
if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected input tensor to have a dtype");
}
dtype =
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
}
rewriter.replaceOpWithNewOp<AtenFullOp>(
op, op.getType(), op.getSize(), op.getFillValue(), dtype, op.getLayout(), op.getDevice(),
op.getPinMemory());
return success();
}
};
} // namespace
namespace {
// Decompose `aten.indexPut` op into `valsem.aten.indexPutImpl` op.
class DecomposeAtenIndexPutOp : public OpRewritePattern<AtenIndexPutOp> {
@ -5177,6 +5204,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenLinearOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewFullOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns);

View File

@ -437,6 +437,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenLinearOp>();
target.addIllegalOp<AtenMishOp>();
target.addIllegalOp<AtenFullLikeOp>();
target.addIllegalOp<AtenNewFullOp>();
target.addIllegalOp<AtenIndexPutOp>();
target.addIllegalOp<AtenExpandAsOp>();
target.addIllegalOp<Aten_ToCopyOp>();

View File

@ -650,6 +650,9 @@ def atenfull〡shape(size: List[int], fill_value: float, dtype: Optional[int]
def atenfull_like〡shape(self: List[int], fill_value: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
return self
def atennew_full〡shape(self: List[int], size: List[int], fill_value: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size
def atenzeros_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
return upstream_shape_functions.unary(self)
@ -3244,6 +3247,16 @@ def atenfull_like〡dtype(self_rank_dtype: Tuple[int, int], fill_value: Union
self_rank, self_dtype = self_rank_dtype
return self_dtype if dtype is None else dtype
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0.0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0.0, dtype=torch.float16) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0.0, dtype=torch.int32) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0.0, dtype=torch.complex64))
def atennew_full〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], fill_value: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype if dtype is None else dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.float16) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.int32) +

View File

@ -599,6 +599,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::numpy_T : (Tensor) -> (Tensor)")
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)")
emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)")

View File

@ -1093,6 +1093,126 @@ class FullLikeModuleFalsePinMemory(torch.nn.Module):
def FullLikeModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward(tu.randint(10, 4, high=100))
# ==============================================================================
class NewFullModuleDefaultDtype(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.new_full(a, (3,4), 5)
@register_test_case(module_factory=lambda: NewFullModuleDefaultDtype())
def NewFullModuleDefaultDtype_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3))
class NewFullModuleInt2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.new_full(a, (3,4), 10.5)
@register_test_case(module_factory=lambda: NewFullModuleInt2D())
def NewFullModuleInt2D_basic(module, tu: TestUtils):
module.forward(tu.randint(4, 5, high=10))
class NewFullModuleInt3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.int32, True),
])
def forward(self, a):
return torch.ops.aten.new_full(a, (3,4), 5.0, dtype=torch.int64)
@register_test_case(module_factory=lambda: NewFullModuleInt3D())
def NewFullModuleInt3D_basic(module, tu: TestUtils):
module.forward(tu.randint(10, 4, 5, high=100).to(torch.int32))
class NewFullModuleFloat3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, a):
return torch.ops.aten.new_full(a, (3,4), 15, dtype=torch.float32)
@register_test_case(module_factory=lambda: NewFullModuleFloat3D())
def NewFullModuleFloat3D_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
class NewFullModuleFloat3DStatic(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 5], torch.float64, True),
])
def forward(self, a):
return torch.ops.aten.new_full(a, (3,4), 15.3, dtype=torch.float32)
@register_test_case(module_factory=lambda: NewFullModuleFloat3DStatic())
def NewFullModuleFloat3DStatic_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
class NewFullModuleFalsePinMemory(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.new_full(a,
(3,4),
5,
dtype=torch.int64,
pin_memory=False)
@register_test_case(module_factory=lambda: NewFullModuleFalsePinMemory())
def NewFullModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward(tu.randint(10, 4, high=100))
# ==============================================================================