mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] add E2E support for aten.new_full (#2425)
* implement aten.new_full * remove extraneous testspull/2458/head
parent
23b72244b1
commit
82456eefed
|
@ -561,6 +561,14 @@ STABLEHLO_PASS_SET = {
|
|||
"FullModuleFloat3D_basic",
|
||||
"FullModuleInt2D_basic",
|
||||
"FullModuleInt3D_basic",
|
||||
"NewFullModuleDefaultDtype_basic",
|
||||
"NewFullModuleFalsePinMemory_basic",
|
||||
"NewFullModuleFloat2D_basic",
|
||||
"NewFullModuleFloat3DStatic_basic",
|
||||
"NewFullModuleFloat3D_basic",
|
||||
"NewFullModuleInt2DStatic_basic",
|
||||
"NewFullModuleInt2D_basic",
|
||||
"NewFullModuleInt3D_basic",
|
||||
"GatherStaticModule_basic",
|
||||
"GatherModule_basic",
|
||||
"Gather2DInputModdule_basic",
|
||||
|
@ -1149,6 +1157,12 @@ TOSA_PASS_SET = {
|
|||
"FullLikeModuleFloat3DStatic_basic",
|
||||
"FullModuleDefaultDtype_basic",
|
||||
"FullModuleFloat3D_basic",
|
||||
"NewFullModuleDefaultDtype_basic",
|
||||
"NewFullModuleFalsePinMemory_basic",
|
||||
"NewFullModuleFloat2D_basic",
|
||||
"NewFullModuleFloat3DStatic_basic",
|
||||
"NewFullModuleFloat3D_basic",
|
||||
"NewFullModuleInt2DStatic_basic",
|
||||
"MaskedFillScalarDefaultModule_basic",
|
||||
"NumToTensorFloatModule_basic",
|
||||
"LiftFreshCopyModule_basic",
|
||||
|
|
|
@ -9853,6 +9853,35 @@ def Torch_AtenFullLikeOp : Torch_Op<"aten.full_like", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenNewFullOp : Torch_Op<"aten.new_full", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$size,
|
||||
AnyTorchScalarType:$fill_value,
|
||||
AnyTorchOptionalIntType:$dtype,
|
||||
AnyTorchOptionalIntType:$layout,
|
||||
AnyTorchOptionalDeviceType:$device,
|
||||
AnyTorchOptionalBoolType:$pin_memory
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenNewFullOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 7, 1);
|
||||
}
|
||||
void AtenNewFullOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 7, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenBaddbmmOp : Torch_Op<"aten.baddbmm", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -7226,6 +7226,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.full_like\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.new_full\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" return %arg1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.zeros_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -10542,6 +10545,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.new_full\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.number, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %0#1 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %3 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
|
||||
" torch.prim.If.yield %3 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.new_zeros\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
|
|
@ -3166,6 +3166,33 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.new_full` op into `aten.full` op.
|
||||
class DecomposeAtenNewFullOp : public OpRewritePattern<AtenNewFullOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenNewFullOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value dtype = op.getDtype();
|
||||
if (dtype.getType().isa<Torch::NoneType>()) {
|
||||
BaseTensorType tensorType = op.getSelf().getType().cast<BaseTensorType>();
|
||||
if (!tensorType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected input tensor to have a dtype");
|
||||
}
|
||||
dtype =
|
||||
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<AtenFullOp>(
|
||||
op, op.getType(), op.getSize(), op.getFillValue(), dtype, op.getLayout(), op.getDevice(),
|
||||
op.getPinMemory());
|
||||
|
||||
return success();
|
||||
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.indexPut` op into `valsem.aten.indexPutImpl` op.
|
||||
class DecomposeAtenIndexPutOp : public OpRewritePattern<AtenIndexPutOp> {
|
||||
|
@ -5177,6 +5204,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenLinearOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNewFullOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandAsOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns);
|
||||
|
|
|
@ -437,6 +437,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenLinearOp>();
|
||||
target.addIllegalOp<AtenMishOp>();
|
||||
target.addIllegalOp<AtenFullLikeOp>();
|
||||
target.addIllegalOp<AtenNewFullOp>();
|
||||
target.addIllegalOp<AtenIndexPutOp>();
|
||||
target.addIllegalOp<AtenExpandAsOp>();
|
||||
target.addIllegalOp<Aten_ToCopyOp>();
|
||||
|
|
|
@ -650,6 +650,9 @@ def aten〇full〡shape(size: List[int], fill_value: float, dtype: Optional[int]
|
|||
def aten〇full_like〡shape(self: List[int], fill_value: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇new_full〡shape(self: List[int], size: List[int], fill_value: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||
return size
|
||||
|
||||
def aten〇zeros_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -3244,6 +3247,16 @@ def aten〇full_like〡dtype(self_rank_dtype: Tuple[int, int], fill_value: Union
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype if dtype is None else dtype
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0.0) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0.0, dtype=torch.float16) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0.0, dtype=torch.int32) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0.0, dtype=torch.complex64))
|
||||
def aten〇new_full〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], fill_value: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype if dtype is None else dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.float16) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.int32) +
|
||||
|
|
|
@ -599,6 +599,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::numpy_T : (Tensor) -> (Tensor)")
|
||||
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
|
||||
emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)")
|
||||
emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||
|
|
|
@ -1093,6 +1093,126 @@ class FullLikeModuleFalsePinMemory(torch.nn.Module):
|
|||
def FullLikeModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(10, 4, high=100))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NewFullModuleDefaultDtype(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_full(a, (3,4), 5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NewFullModuleDefaultDtype())
|
||||
def NewFullModuleDefaultDtype_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3))
|
||||
|
||||
|
||||
class NewFullModuleInt2D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_full(a, (3,4), 10.5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NewFullModuleInt2D())
|
||||
def NewFullModuleInt2D_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 5, high=10))
|
||||
|
||||
|
||||
class NewFullModuleInt3D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_full(a, (3,4), 5.0, dtype=torch.int64)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NewFullModuleInt3D())
|
||||
def NewFullModuleInt3D_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(10, 4, 5, high=100).to(torch.int32))
|
||||
|
||||
|
||||
class NewFullModuleFloat3D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_full(a, (3,4), 15, dtype=torch.float32)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NewFullModuleFloat3D())
|
||||
def NewFullModuleFloat3D_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||
|
||||
|
||||
class NewFullModuleFloat3DStatic(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4, 5], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_full(a, (3,4), 15.3, dtype=torch.float32)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NewFullModuleFloat3DStatic())
|
||||
def NewFullModuleFloat3DStatic_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||
|
||||
|
||||
class NewFullModuleFalsePinMemory(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_full(a,
|
||||
(3,4),
|
||||
5,
|
||||
dtype=torch.int64,
|
||||
pin_memory=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NewFullModuleFalsePinMemory())
|
||||
def NewFullModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(10, 4, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
Loading…
Reference in New Issue