mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] support decomposition of aten.linspace (#3006)
parent
43c6996a31
commit
870e63bc3c
|
@ -8334,6 +8334,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %5 = call @__torch__.torch.jit._shape_functions.arange_end(%0, %1, %2, %3, %4) : (!torch.union<float, int>, !torch.any, !torch.any, !torch.any, !torch.any) -> !torch.list<int>\n"
|
||||
" return %5 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.linspace\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" %0 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.add.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -12568,6 +12572,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.linspace\"(%arg0: !torch.number, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
|
||||
" torch.prim.If.yield %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.normal_functional\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
|
|
|
@ -6331,6 +6331,78 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenLinspaceOp : public OpRewritePattern<AtenLinspaceOp> {
|
||||
public:
|
||||
using OpRewritePattern<AtenLinspaceOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenLinspaceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = getContext();
|
||||
|
||||
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
Value zero =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value one =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
|
||||
Value addStart;
|
||||
int64_t steps;
|
||||
if (matchPattern(op.getSteps(), m_TorchConstantInt(&steps)) && steps == 1) {
|
||||
// specically handle steps == 1
|
||||
Value arange = rewriter.create<AtenArangeStartOp>(
|
||||
loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(),
|
||||
op.getDevice(), op.getPinMemory());
|
||||
addStart = rewriter.create<AtenAddScalarOp>(loc, baseType, arange,
|
||||
op.getStart(), one);
|
||||
} else {
|
||||
// handle steps != 1 or dynamic steps
|
||||
Value neOrNot = rewriter.create<AtenNeIntOp>(loc, op.getSteps(), one);
|
||||
rewriter.create<RuntimeAssertOp>(
|
||||
loc, neOrNot,
|
||||
rewriter.getStringAttr("linspace's dynamic steps must not be 1"));
|
||||
// create arange: [0, ..., steps - 1]
|
||||
Value arange = rewriter.create<AtenArangeStartOp>(
|
||||
loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(),
|
||||
op.getDevice(), op.getPinMemory());
|
||||
// calculate (end - start) / (steps - 1)
|
||||
Value sub;
|
||||
if (op.getEnd().getType().isa<Torch::FloatType>() ||
|
||||
op.getStart().getType().isa<Torch::FloatType>()) {
|
||||
sub = rewriter.create<AtenSubOp>(loc, Torch::FloatType::get(context),
|
||||
op.getEnd(), op.getStart());
|
||||
} else {
|
||||
sub = rewriter.create<AtenSubIntOp>(loc, op.getEnd(), op.getStart());
|
||||
}
|
||||
Value div = rewriter.create<AtenDivOp>(
|
||||
loc, sub, rewriter.create<AtenSubIntOp>(loc, op.getSteps(), one));
|
||||
// calculate [0, ..., steps - 1] * ((end - start) / (steps - 1)) + start
|
||||
Value mulScalar =
|
||||
rewriter.create<AtenMulScalarOp>(loc, baseType, arange, div);
|
||||
addStart = rewriter.create<AtenAddScalarOp>(loc, baseType, mulScalar,
|
||||
op.getStart(), one);
|
||||
}
|
||||
// to dtype
|
||||
Value result;
|
||||
if (!op.getDtype().getType().isa<Torch::NoneType>()) {
|
||||
result = rewriter.create<AtenToDtypeOp>(
|
||||
loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal,
|
||||
/*copy=*/falseVal, /*memory_format=*/none);
|
||||
} else {
|
||||
Value f32Type = rewriter.create<ConstantIntOp>(
|
||||
loc, (int)torch_upstream::ScalarType::Float);
|
||||
result = rewriter.create<AtenToDtypeOp>(
|
||||
loc, op.getType(), addStart, f32Type, /*non_blocking=*/falseVal,
|
||||
/*copy=*/falseVal, /*memory_format=*/none);
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenVarMeanOp : public OpRewritePattern<AtenVarMeanOp> {
|
||||
public:
|
||||
|
@ -7216,6 +7288,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLinspaceOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
|
|
|
@ -424,6 +424,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenConvTranspose2dInputOp>();
|
||||
target.addIllegalOp<AtenArangeOp>();
|
||||
target.addIllegalOp<AtenArangeStartOp>();
|
||||
target.addIllegalOp<AtenLinspaceOp>();
|
||||
target.addIllegalOp<AtenArgmaxOp>();
|
||||
target.addIllegalOp<AtenArgminOp>();
|
||||
target.addIllegalOp<AtenSquareOp>();
|
||||
|
|
|
@ -841,6 +841,11 @@ STABLEHLO_PASS_SET = {
|
|||
"ZerosModuleFloat3D_basic",
|
||||
"ZerosModuleInt2D_basic",
|
||||
"ZerosModuleInt3D_basic",
|
||||
"LinspaceDtypeModule_basic",
|
||||
"LinspaceEmptyModule_basic",
|
||||
"LinspaceModule_basic",
|
||||
"LinspaceOneSizeModule_basic",
|
||||
"LinspaceTwoSizeModule_basic",
|
||||
}
|
||||
|
||||
STABLEHLO_CRASHING_SET = {
|
||||
|
@ -1260,6 +1265,9 @@ TOSA_PASS_SET = {
|
|||
"_LogSoftmaxModuleStable_basic",
|
||||
"_LogSoftmaxModule_basic",
|
||||
"_SoftmaxModule_basic",
|
||||
"LinspaceModule_basic",
|
||||
"LinspaceOneSizeModule_basic",
|
||||
"LinspaceTwoSizeModule_basic",
|
||||
}
|
||||
|
||||
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
||||
|
|
|
@ -1124,6 +1124,9 @@ def aten〇arange〇start〡shape(start: float, end: float, dtype: Optional[int]
|
|||
def aten〇arange〡shape(end: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||
return upstream_shape_functions.arange_end(end, dtype, layout, device, pin_memory)
|
||||
|
||||
def aten〇linspace〡shape(start: float, end: float, steps: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||
return [steps]
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(2, 3), TensorOfShape(2, 3)), # Basic case.
|
||||
Invocation(TensorOfShape(2, 3), TensorOfShape(3)), # Rank broadcasting.
|
||||
|
@ -4248,6 +4251,16 @@ def aten〇randn〡dtype(size: List[int], dtype: Optional[int] = None, layout: O
|
|||
assert not is_integer_dtype(dtype)
|
||||
return dtype
|
||||
|
||||
@check_dtype_function([Invocation(start=1, end=10, steps=9),
|
||||
Invocation(start=1, end=10, steps=9, dtype=torch.int32),
|
||||
Invocation(start=1, end=10, steps=9, dtype=torch.double),
|
||||
Invocation(start=1, end=10, steps=9, dtype=torch.complex64),
|
||||
Invocation(start=1, end=10, steps=9, dtype=torch.complex128)])
|
||||
def aten〇linspace〡dtype(start: Union[int, float, complex], end: Union[int, float, complex], steps: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
if dtype is None:
|
||||
return torch.float32
|
||||
return dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(
|
||||
num_of_tensors=1,
|
||||
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}))
|
||||
|
|
|
@ -62,6 +62,7 @@ class ArangeZeroElementOutputModule(torch.nn.Module):
|
|||
def ArangeZeroElementOutputModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ArangeStartIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -130,6 +131,7 @@ class ArangeNegativeStartFloatModule(torch.nn.Module):
|
|||
def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ArangeStartStepIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -198,6 +200,7 @@ class ArangeStartNegativeStepFloatModule(torch.nn.Module):
|
|||
def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ArangeDtypeFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -232,6 +235,7 @@ class ArangeDtypeIntModule(torch.nn.Module):
|
|||
def ArangeDtypeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ArangeFalsePinMemoryModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -298,3 +302,81 @@ class ArangeStartOutDtypeModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ArangeStartOutDtypeModule())
|
||||
def ArangeStartOutDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.zeros(12).to(torch.int64))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class LinspaceModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.linspace(-10.1, 10.1, 10)
|
||||
|
||||
@register_test_case(module_factory=lambda: LinspaceModule())
|
||||
def LinspaceModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
class LinspaceDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.linspace(-10.1, 10.1, 10, dtype=torch.int64)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: LinspaceDtypeModule())
|
||||
def LinspaceDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
class LinspaceEmptyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.linspace(-10.1, 10.1, 0)
|
||||
|
||||
@register_test_case(module_factory=lambda: LinspaceEmptyModule())
|
||||
def LinspaceEmptyModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
class LinspaceOneSizeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.linspace(-10.1, 10.1, 1)
|
||||
|
||||
@register_test_case(module_factory=lambda: LinspaceOneSizeModule())
|
||||
def LinspaceOneSizeModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
class LinspaceTwoSizeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.linspace(-10.1, 10.1, 2)
|
||||
|
||||
@register_test_case(module_factory=lambda: LinspaceTwoSizeModule())
|
||||
def LinspaceTwoSizeModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
|
Loading…
Reference in New Issue