mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] support decomposition of aten.linspace (#3006)
parent
43c6996a31
commit
870e63bc3c
|
@ -8334,6 +8334,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %5 = call @__torch__.torch.jit._shape_functions.arange_end(%0, %1, %2, %3, %4) : (!torch.union<float, int>, !torch.any, !torch.any, !torch.any, !torch.any) -> !torch.list<int>\n"
|
" %5 = call @__torch__.torch.jit._shape_functions.arange_end(%0, %1, %2, %3, %4) : (!torch.union<float, int>, !torch.any, !torch.any, !torch.any, !torch.any) -> !torch.list<int>\n"
|
||||||
" return %5 : !torch.list<int>\n"
|
" return %5 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.linspace\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||||
|
" %0 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.add.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.add.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -12568,6 +12572,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %1 : !torch.int\n"
|
" return %1 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.linspace\"(%arg0: !torch.number, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.int {\n"
|
||||||
|
" %int6 = torch.constant.int 6\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %0 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||||
|
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
|
||||||
|
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
|
||||||
|
" torch.prim.If.yield %2 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" return %1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.normal_functional\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.normal_functional\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
|
|
@ -6331,6 +6331,78 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenLinspaceOp : public OpRewritePattern<AtenLinspaceOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<AtenLinspaceOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenLinspaceOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
MLIRContext *context = getContext();
|
||||||
|
|
||||||
|
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
||||||
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
|
||||||
|
Value zero =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
Value one =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
|
||||||
|
Value addStart;
|
||||||
|
int64_t steps;
|
||||||
|
if (matchPattern(op.getSteps(), m_TorchConstantInt(&steps)) && steps == 1) {
|
||||||
|
// specically handle steps == 1
|
||||||
|
Value arange = rewriter.create<AtenArangeStartOp>(
|
||||||
|
loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(),
|
||||||
|
op.getDevice(), op.getPinMemory());
|
||||||
|
addStart = rewriter.create<AtenAddScalarOp>(loc, baseType, arange,
|
||||||
|
op.getStart(), one);
|
||||||
|
} else {
|
||||||
|
// handle steps != 1 or dynamic steps
|
||||||
|
Value neOrNot = rewriter.create<AtenNeIntOp>(loc, op.getSteps(), one);
|
||||||
|
rewriter.create<RuntimeAssertOp>(
|
||||||
|
loc, neOrNot,
|
||||||
|
rewriter.getStringAttr("linspace's dynamic steps must not be 1"));
|
||||||
|
// create arange: [0, ..., steps - 1]
|
||||||
|
Value arange = rewriter.create<AtenArangeStartOp>(
|
||||||
|
loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(),
|
||||||
|
op.getDevice(), op.getPinMemory());
|
||||||
|
// calculate (end - start) / (steps - 1)
|
||||||
|
Value sub;
|
||||||
|
if (op.getEnd().getType().isa<Torch::FloatType>() ||
|
||||||
|
op.getStart().getType().isa<Torch::FloatType>()) {
|
||||||
|
sub = rewriter.create<AtenSubOp>(loc, Torch::FloatType::get(context),
|
||||||
|
op.getEnd(), op.getStart());
|
||||||
|
} else {
|
||||||
|
sub = rewriter.create<AtenSubIntOp>(loc, op.getEnd(), op.getStart());
|
||||||
|
}
|
||||||
|
Value div = rewriter.create<AtenDivOp>(
|
||||||
|
loc, sub, rewriter.create<AtenSubIntOp>(loc, op.getSteps(), one));
|
||||||
|
// calculate [0, ..., steps - 1] * ((end - start) / (steps - 1)) + start
|
||||||
|
Value mulScalar =
|
||||||
|
rewriter.create<AtenMulScalarOp>(loc, baseType, arange, div);
|
||||||
|
addStart = rewriter.create<AtenAddScalarOp>(loc, baseType, mulScalar,
|
||||||
|
op.getStart(), one);
|
||||||
|
}
|
||||||
|
// to dtype
|
||||||
|
Value result;
|
||||||
|
if (!op.getDtype().getType().isa<Torch::NoneType>()) {
|
||||||
|
result = rewriter.create<AtenToDtypeOp>(
|
||||||
|
loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal,
|
||||||
|
/*copy=*/falseVal, /*memory_format=*/none);
|
||||||
|
} else {
|
||||||
|
Value f32Type = rewriter.create<ConstantIntOp>(
|
||||||
|
loc, (int)torch_upstream::ScalarType::Float);
|
||||||
|
result = rewriter.create<AtenToDtypeOp>(
|
||||||
|
loc, op.getType(), addStart, f32Type, /*non_blocking=*/falseVal,
|
||||||
|
/*copy=*/falseVal, /*memory_format=*/none);
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenVarMeanOp : public OpRewritePattern<AtenVarMeanOp> {
|
class DecomposeAtenVarMeanOp : public OpRewritePattern<AtenVarMeanOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -7216,6 +7288,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLinspaceOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<
|
addPatternIfTargetOpIsIllegal<
|
||||||
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
|
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<
|
addPatternIfTargetOpIsIllegal<
|
||||||
|
|
|
@ -424,6 +424,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenConvTranspose2dInputOp>();
|
target.addIllegalOp<AtenConvTranspose2dInputOp>();
|
||||||
target.addIllegalOp<AtenArangeOp>();
|
target.addIllegalOp<AtenArangeOp>();
|
||||||
target.addIllegalOp<AtenArangeStartOp>();
|
target.addIllegalOp<AtenArangeStartOp>();
|
||||||
|
target.addIllegalOp<AtenLinspaceOp>();
|
||||||
target.addIllegalOp<AtenArgmaxOp>();
|
target.addIllegalOp<AtenArgmaxOp>();
|
||||||
target.addIllegalOp<AtenArgminOp>();
|
target.addIllegalOp<AtenArgminOp>();
|
||||||
target.addIllegalOp<AtenSquareOp>();
|
target.addIllegalOp<AtenSquareOp>();
|
||||||
|
|
|
@ -841,6 +841,11 @@ STABLEHLO_PASS_SET = {
|
||||||
"ZerosModuleFloat3D_basic",
|
"ZerosModuleFloat3D_basic",
|
||||||
"ZerosModuleInt2D_basic",
|
"ZerosModuleInt2D_basic",
|
||||||
"ZerosModuleInt3D_basic",
|
"ZerosModuleInt3D_basic",
|
||||||
|
"LinspaceDtypeModule_basic",
|
||||||
|
"LinspaceEmptyModule_basic",
|
||||||
|
"LinspaceModule_basic",
|
||||||
|
"LinspaceOneSizeModule_basic",
|
||||||
|
"LinspaceTwoSizeModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
STABLEHLO_CRASHING_SET = {
|
STABLEHLO_CRASHING_SET = {
|
||||||
|
@ -1260,6 +1265,9 @@ TOSA_PASS_SET = {
|
||||||
"_LogSoftmaxModuleStable_basic",
|
"_LogSoftmaxModuleStable_basic",
|
||||||
"_LogSoftmaxModule_basic",
|
"_LogSoftmaxModule_basic",
|
||||||
"_SoftmaxModule_basic",
|
"_SoftmaxModule_basic",
|
||||||
|
"LinspaceModule_basic",
|
||||||
|
"LinspaceOneSizeModule_basic",
|
||||||
|
"LinspaceTwoSizeModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
||||||
|
|
|
@ -1124,6 +1124,9 @@ def aten〇arange〇start〡shape(start: float, end: float, dtype: Optional[int]
|
||||||
def aten〇arange〡shape(end: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
def aten〇arange〡shape(end: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||||
return upstream_shape_functions.arange_end(end, dtype, layout, device, pin_memory)
|
return upstream_shape_functions.arange_end(end, dtype, layout, device, pin_memory)
|
||||||
|
|
||||||
|
def aten〇linspace〡shape(start: float, end: float, steps: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||||
|
return [steps]
|
||||||
|
|
||||||
@check_shape_function([
|
@check_shape_function([
|
||||||
Invocation(TensorOfShape(2, 3), TensorOfShape(2, 3)), # Basic case.
|
Invocation(TensorOfShape(2, 3), TensorOfShape(2, 3)), # Basic case.
|
||||||
Invocation(TensorOfShape(2, 3), TensorOfShape(3)), # Rank broadcasting.
|
Invocation(TensorOfShape(2, 3), TensorOfShape(3)), # Rank broadcasting.
|
||||||
|
@ -4248,6 +4251,16 @@ def aten〇randn〡dtype(size: List[int], dtype: Optional[int] = None, layout: O
|
||||||
assert not is_integer_dtype(dtype)
|
assert not is_integer_dtype(dtype)
|
||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
|
@check_dtype_function([Invocation(start=1, end=10, steps=9),
|
||||||
|
Invocation(start=1, end=10, steps=9, dtype=torch.int32),
|
||||||
|
Invocation(start=1, end=10, steps=9, dtype=torch.double),
|
||||||
|
Invocation(start=1, end=10, steps=9, dtype=torch.complex64),
|
||||||
|
Invocation(start=1, end=10, steps=9, dtype=torch.complex128)])
|
||||||
|
def aten〇linspace〡dtype(start: Union[int, float, complex], end: Union[int, float, complex], steps: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||||
|
if dtype is None:
|
||||||
|
return torch.float32
|
||||||
|
return dtype
|
||||||
|
|
||||||
@check_dtype_function(_check_tensors_with_the_same_dtype(
|
@check_dtype_function(_check_tensors_with_the_same_dtype(
|
||||||
num_of_tensors=1,
|
num_of_tensors=1,
|
||||||
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}))
|
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}))
|
||||||
|
|
|
@ -62,6 +62,7 @@ class ArangeZeroElementOutputModule(torch.nn.Module):
|
||||||
def ArangeZeroElementOutputModule_basic(module, tu: TestUtils):
|
def ArangeZeroElementOutputModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ArangeStartIntModule(torch.nn.Module):
|
class ArangeStartIntModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -130,6 +131,7 @@ class ArangeNegativeStartFloatModule(torch.nn.Module):
|
||||||
def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils):
|
def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ArangeStartStepIntModule(torch.nn.Module):
|
class ArangeStartStepIntModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -198,6 +200,7 @@ class ArangeStartNegativeStepFloatModule(torch.nn.Module):
|
||||||
def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils):
|
def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ArangeDtypeFloatModule(torch.nn.Module):
|
class ArangeDtypeFloatModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -232,6 +235,7 @@ class ArangeDtypeIntModule(torch.nn.Module):
|
||||||
def ArangeDtypeIntModule_basic(module, tu: TestUtils):
|
def ArangeDtypeIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ArangeFalsePinMemoryModule(torch.nn.Module):
|
class ArangeFalsePinMemoryModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -298,3 +302,81 @@ class ArangeStartOutDtypeModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: ArangeStartOutDtypeModule())
|
@register_test_case(module_factory=lambda: ArangeStartOutDtypeModule())
|
||||||
def ArangeStartOutDtypeModule_basic(module, tu: TestUtils):
|
def ArangeStartOutDtypeModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.zeros(12).to(torch.int64))
|
module.forward(torch.zeros(12).to(torch.int64))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class LinspaceModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.linspace(-10.1, 10.1, 10)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: LinspaceModule())
|
||||||
|
def LinspaceModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
class LinspaceDtypeModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.linspace(-10.1, 10.1, 10, dtype=torch.int64)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: LinspaceDtypeModule())
|
||||||
|
def LinspaceDtypeModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
class LinspaceEmptyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.linspace(-10.1, 10.1, 0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: LinspaceEmptyModule())
|
||||||
|
def LinspaceEmptyModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
class LinspaceOneSizeModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.linspace(-10.1, 10.1, 1)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: LinspaceOneSizeModule())
|
||||||
|
def LinspaceOneSizeModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
class LinspaceTwoSizeModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.linspace(-10.1, 10.1, 2)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: LinspaceTwoSizeModule())
|
||||||
|
def LinspaceTwoSizeModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
Loading…
Reference in New Issue