[Torch Dialect] support decomposition of aten.linspace (#3006)

pull/3026/head
Yuanqiang Liu 2024-03-14 08:28:33 +08:00 committed by GitHub
parent 43c6996a31
commit 870e63bc3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 193 additions and 0 deletions

View File

@ -8334,6 +8334,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %5 = call @__torch__.torch.jit._shape_functions.arange_end(%0, %1, %2, %3, %4) : (!torch.union<float, int>, !torch.any, !torch.any, !torch.any, !torch.any) -> !torch.list<int>\n" " %5 = call @__torch__.torch.jit._shape_functions.arange_end(%0, %1, %2, %3, %4) : (!torch.union<float, int>, !torch.any, !torch.any, !torch.any, !torch.any) -> !torch.list<int>\n"
" return %5 : !torch.list<int>\n" " return %5 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.linspace\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" %0 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.add.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.add.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -12568,6 +12572,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n" " }\n"
" return %1 : !torch.int\n" " return %1 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.linspace\"(%arg0: !torch.number, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" } else {\n"
" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.normal_functional\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.normal_functional\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n" " %str = torch.constant.str \"AssertionError: \"\n"

View File

@ -6331,6 +6331,78 @@ public:
}; };
} // namespace } // namespace
namespace {
class DecomposeAtenLinspaceOp : public OpRewritePattern<AtenLinspaceOp> {
public:
using OpRewritePattern<AtenLinspaceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenLinspaceOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *context = getContext();
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
Value none = rewriter.create<ConstantNoneOp>(loc);
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value addStart;
int64_t steps;
if (matchPattern(op.getSteps(), m_TorchConstantInt(&steps)) && steps == 1) {
// specically handle steps == 1
Value arange = rewriter.create<AtenArangeStartOp>(
loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(),
op.getDevice(), op.getPinMemory());
addStart = rewriter.create<AtenAddScalarOp>(loc, baseType, arange,
op.getStart(), one);
} else {
// handle steps != 1 or dynamic steps
Value neOrNot = rewriter.create<AtenNeIntOp>(loc, op.getSteps(), one);
rewriter.create<RuntimeAssertOp>(
loc, neOrNot,
rewriter.getStringAttr("linspace's dynamic steps must not be 1"));
// create arange: [0, ..., steps - 1]
Value arange = rewriter.create<AtenArangeStartOp>(
loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(),
op.getDevice(), op.getPinMemory());
// calculate (end - start) / (steps - 1)
Value sub;
if (op.getEnd().getType().isa<Torch::FloatType>() ||
op.getStart().getType().isa<Torch::FloatType>()) {
sub = rewriter.create<AtenSubOp>(loc, Torch::FloatType::get(context),
op.getEnd(), op.getStart());
} else {
sub = rewriter.create<AtenSubIntOp>(loc, op.getEnd(), op.getStart());
}
Value div = rewriter.create<AtenDivOp>(
loc, sub, rewriter.create<AtenSubIntOp>(loc, op.getSteps(), one));
// calculate [0, ..., steps - 1] * ((end - start) / (steps - 1)) + start
Value mulScalar =
rewriter.create<AtenMulScalarOp>(loc, baseType, arange, div);
addStart = rewriter.create<AtenAddScalarOp>(loc, baseType, mulScalar,
op.getStart(), one);
}
// to dtype
Value result;
if (!op.getDtype().getType().isa<Torch::NoneType>()) {
result = rewriter.create<AtenToDtypeOp>(
loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal,
/*copy=*/falseVal, /*memory_format=*/none);
} else {
Value f32Type = rewriter.create<ConstantIntOp>(
loc, (int)torch_upstream::ScalarType::Float);
result = rewriter.create<AtenToDtypeOp>(
loc, op.getType(), addStart, f32Type, /*non_blocking=*/falseVal,
/*copy=*/falseVal, /*memory_format=*/none);
}
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace { namespace {
class DecomposeAtenVarMeanOp : public OpRewritePattern<AtenVarMeanOp> { class DecomposeAtenVarMeanOp : public OpRewritePattern<AtenVarMeanOp> {
public: public:
@ -7216,6 +7288,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinspaceOp>(patterns);
addPatternIfTargetOpIsIllegal< addPatternIfTargetOpIsIllegal<
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns); DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
addPatternIfTargetOpIsIllegal< addPatternIfTargetOpIsIllegal<

View File

@ -424,6 +424,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenConvTranspose2dInputOp>(); target.addIllegalOp<AtenConvTranspose2dInputOp>();
target.addIllegalOp<AtenArangeOp>(); target.addIllegalOp<AtenArangeOp>();
target.addIllegalOp<AtenArangeStartOp>(); target.addIllegalOp<AtenArangeStartOp>();
target.addIllegalOp<AtenLinspaceOp>();
target.addIllegalOp<AtenArgmaxOp>(); target.addIllegalOp<AtenArgmaxOp>();
target.addIllegalOp<AtenArgminOp>(); target.addIllegalOp<AtenArgminOp>();
target.addIllegalOp<AtenSquareOp>(); target.addIllegalOp<AtenSquareOp>();

View File

@ -841,6 +841,11 @@ STABLEHLO_PASS_SET = {
"ZerosModuleFloat3D_basic", "ZerosModuleFloat3D_basic",
"ZerosModuleInt2D_basic", "ZerosModuleInt2D_basic",
"ZerosModuleInt3D_basic", "ZerosModuleInt3D_basic",
"LinspaceDtypeModule_basic",
"LinspaceEmptyModule_basic",
"LinspaceModule_basic",
"LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic",
} }
STABLEHLO_CRASHING_SET = { STABLEHLO_CRASHING_SET = {
@ -1260,6 +1265,9 @@ TOSA_PASS_SET = {
"_LogSoftmaxModuleStable_basic", "_LogSoftmaxModuleStable_basic",
"_LogSoftmaxModule_basic", "_LogSoftmaxModule_basic",
"_SoftmaxModule_basic", "_SoftmaxModule_basic",
"LinspaceModule_basic",
"LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic",
} }
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {

View File

@ -1124,6 +1124,9 @@ def atenarangestart〡shape(start: float, end: float, dtype: Optional[int]
def atenarange〡shape(end: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: def atenarange〡shape(end: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return upstream_shape_functions.arange_end(end, dtype, layout, device, pin_memory) return upstream_shape_functions.arange_end(end, dtype, layout, device, pin_memory)
def atenlinspace〡shape(start: float, end: float, steps: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return [steps]
@check_shape_function([ @check_shape_function([
Invocation(TensorOfShape(2, 3), TensorOfShape(2, 3)), # Basic case. Invocation(TensorOfShape(2, 3), TensorOfShape(2, 3)), # Basic case.
Invocation(TensorOfShape(2, 3), TensorOfShape(3)), # Rank broadcasting. Invocation(TensorOfShape(2, 3), TensorOfShape(3)), # Rank broadcasting.
@ -4248,6 +4251,16 @@ def atenrandn〡dtype(size: List[int], dtype: Optional[int] = None, layout: O
assert not is_integer_dtype(dtype) assert not is_integer_dtype(dtype)
return dtype return dtype
@check_dtype_function([Invocation(start=1, end=10, steps=9),
Invocation(start=1, end=10, steps=9, dtype=torch.int32),
Invocation(start=1, end=10, steps=9, dtype=torch.double),
Invocation(start=1, end=10, steps=9, dtype=torch.complex64),
Invocation(start=1, end=10, steps=9, dtype=torch.complex128)])
def atenlinspace〡dtype(start: Union[int, float, complex], end: Union[int, float, complex], steps: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
if dtype is None:
return torch.float32
return dtype
@check_dtype_function(_check_tensors_with_the_same_dtype( @check_dtype_function(_check_tensors_with_the_same_dtype(
num_of_tensors=1, num_of_tensors=1,
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64})) error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}))

View File

@ -62,6 +62,7 @@ class ArangeZeroElementOutputModule(torch.nn.Module):
def ArangeZeroElementOutputModule_basic(module, tu: TestUtils): def ArangeZeroElementOutputModule_basic(module, tu: TestUtils):
module.forward() module.forward()
# ==============================================================================
class ArangeStartIntModule(torch.nn.Module): class ArangeStartIntModule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -130,6 +131,7 @@ class ArangeNegativeStartFloatModule(torch.nn.Module):
def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils): def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils):
module.forward() module.forward()
# ==============================================================================
class ArangeStartStepIntModule(torch.nn.Module): class ArangeStartStepIntModule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -198,6 +200,7 @@ class ArangeStartNegativeStepFloatModule(torch.nn.Module):
def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils): def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils):
module.forward() module.forward()
# ==============================================================================
class ArangeDtypeFloatModule(torch.nn.Module): class ArangeDtypeFloatModule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -232,6 +235,7 @@ class ArangeDtypeIntModule(torch.nn.Module):
def ArangeDtypeIntModule_basic(module, tu: TestUtils): def ArangeDtypeIntModule_basic(module, tu: TestUtils):
module.forward() module.forward()
# ==============================================================================
class ArangeFalsePinMemoryModule(torch.nn.Module): class ArangeFalsePinMemoryModule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -298,3 +302,81 @@ class ArangeStartOutDtypeModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ArangeStartOutDtypeModule()) @register_test_case(module_factory=lambda: ArangeStartOutDtypeModule())
def ArangeStartOutDtypeModule_basic(module, tu: TestUtils): def ArangeStartOutDtypeModule_basic(module, tu: TestUtils):
module.forward(torch.zeros(12).to(torch.int64)) module.forward(torch.zeros(12).to(torch.int64))
# ==============================================================================
class LinspaceModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.linspace(-10.1, 10.1, 10)
@register_test_case(module_factory=lambda: LinspaceModule())
def LinspaceModule_basic(module, tu: TestUtils):
module.forward()
class LinspaceDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.linspace(-10.1, 10.1, 10, dtype=torch.int64)
@register_test_case(module_factory=lambda: LinspaceDtypeModule())
def LinspaceDtypeModule_basic(module, tu: TestUtils):
module.forward()
class LinspaceEmptyModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.linspace(-10.1, 10.1, 0)
@register_test_case(module_factory=lambda: LinspaceEmptyModule())
def LinspaceEmptyModule_basic(module, tu: TestUtils):
module.forward()
class LinspaceOneSizeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.linspace(-10.1, 10.1, 1)
@register_test_case(module_factory=lambda: LinspaceOneSizeModule())
def LinspaceOneSizeModule_basic(module, tu: TestUtils):
module.forward()
class LinspaceTwoSizeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.linspace(-10.1, 10.1, 2)
@register_test_case(module_factory=lambda: LinspaceTwoSizeModule())
def LinspaceTwoSizeModule_basic(module, tu: TestUtils):
module.forward()