mirror of https://github.com/llvm/torch-mlir
Add decomposition for aten.roll (#1170)
* Add decomposition for aten.roll * add e2e unittest * refine type of torch.roll * fix aten::cat output typepull/1261/head
parent
1106b9aeae
commit
3d0e18bbe7
|
@ -127,6 +127,7 @@ MHLO_PASS_SET = {
|
||||||
"ReshapeAliasCollapseModule_basic",
|
"ReshapeAliasCollapseModule_basic",
|
||||||
"ReshapeAliasExpandModule_basic",
|
"ReshapeAliasExpandModule_basic",
|
||||||
"ReshapeExpandModule_basic",
|
"ReshapeExpandModule_basic",
|
||||||
|
"RollModule_basic",
|
||||||
"TestMultipleTensorReturn_basic",
|
"TestMultipleTensorReturn_basic",
|
||||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||||
"BaddbmmStaticModule_basic",
|
"BaddbmmStaticModule_basic",
|
||||||
|
@ -447,6 +448,7 @@ LTC_XFAIL_SET = {
|
||||||
"QuantizedMLP_basic",
|
"QuantizedMLP_basic",
|
||||||
"RandLikeDtypeModule_basic",
|
"RandLikeDtypeModule_basic",
|
||||||
"RandLikeModule_basic",
|
"RandLikeModule_basic",
|
||||||
|
"RollModule_basic",
|
||||||
"ScalarImplicitFloatModule_basic",
|
"ScalarImplicitFloatModule_basic",
|
||||||
"ScalarImplicitIntModule_basic",
|
"ScalarImplicitIntModule_basic",
|
||||||
"SliceEndSleStartModule_basic",
|
"SliceEndSleStartModule_basic",
|
||||||
|
|
|
@ -3478,6 +3478,31 @@ def Torch_Aten_ConvolutionDeprecatedOp : Torch_Op<"aten._convolution.deprecated"
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenRollOp : Torch_Op<"aten.roll", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::roll : (Tensor, int[], int[]) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchListOfTorchIntType:$shifts,
|
||||||
|
AnyTorchListOfTorchIntType:$dims
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenRollOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||||
|
}
|
||||||
|
void AtenRollOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 3, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenFlipOp : Torch_Op<"aten.flip", [
|
def Torch_AtenFlipOp : Torch_Op<"aten.flip", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -979,7 +979,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
||||||
|
|
||||||
size_t posDim = toPositiveDim(dim, outType.getRank());
|
size_t posDim = toPositiveDim(dim, outType.getRank());
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ConcatenateOp>(
|
rewriter.replaceOpWithNewOp<mhlo::ConcatenateOp>(
|
||||||
op, ValueRange(builtinTensors), posDim);
|
op, outType, ValueRange(builtinTensors), posDim);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -709,6 +709,77 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Decompose aten.roll into aten.slice and aten.cat ops.
|
||||||
|
// https://pytorch.org/docs/stable/generated/torch.roll.html
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenRollOp : public OpRewritePattern<AtenRollOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenRollOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
SmallVector<Value> shifts;
|
||||||
|
if (!getListConstructElements(op.shifts(), shifts))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: shifts not list of Scalar");
|
||||||
|
SmallVector<Value> dims;
|
||||||
|
if (!getListConstructElements(op.dims(), dims))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: dims not list of Scalar");
|
||||||
|
|
||||||
|
if (shifts.size() != dims.size())
|
||||||
|
return op.emitError("list sizes of shifts and dims are not the same");
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
Value constNone = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
auto self = op.self();
|
||||||
|
auto selfTy = self.getType().cast<BaseTensorType>();
|
||||||
|
// roll(input, shift, dim) = cat({
|
||||||
|
// slice(input, dim, -shift, none),
|
||||||
|
// slice(input, dim, 0, -shift)}, dim)
|
||||||
|
auto imitateRoll = [&](Value input, Value shift, Value dim,
|
||||||
|
int64_t cstDim) {
|
||||||
|
Value negShift = rewriter.create<AtenNegIntOp>(loc, shift);
|
||||||
|
ArrayRef<int64_t> inputShape = selfTy.getSizes();
|
||||||
|
SmallVector<int64_t> sizes;
|
||||||
|
sizes.append(inputShape.begin(), inputShape.end());
|
||||||
|
sizes[cstDim] = ShapedType::kDynamicSize;
|
||||||
|
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
|
||||||
|
selfTy.getDtype());
|
||||||
|
Value slice0 = rewriter.create<AtenSliceTensorOp>(
|
||||||
|
loc, sliceTy, input, dim, negShift, constNone, constOne);
|
||||||
|
Value slice1 = rewriter.create<AtenSliceTensorOp>(
|
||||||
|
loc, sliceTy, input, dim, constZero, negShift, constOne);
|
||||||
|
|
||||||
|
Type listType = Torch::ListType::get(sliceTy);
|
||||||
|
Value slices = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc, listType, llvm::ArrayRef<Value>{slice0, slice1});
|
||||||
|
return rewriter.create<AtenCatOp>(loc, self.getType(), slices, dim);
|
||||||
|
};
|
||||||
|
int rank = getTensorRank(self);
|
||||||
|
if (rank < 0)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
|
||||||
|
Value output = self;
|
||||||
|
auto nShifts = shifts.size();
|
||||||
|
for (size_t k = 0; k < nShifts; ++k) {
|
||||||
|
auto dim = dims[k];
|
||||||
|
int64_t cstDim = -1;
|
||||||
|
if (!matchPattern(dim, m_TorchConstantInt(&cstDim)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: dim must be constant");
|
||||||
|
|
||||||
|
cstDim = toPositiveDim(cstDim, rank);
|
||||||
|
output = imitateRoll(output, shifts[k], dim, cstDim);
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, output);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Decompose aten.repeat into aten.expand and aten.view ops.
|
// Decompose aten.repeat into aten.expand and aten.view ops.
|
||||||
//
|
//
|
||||||
// Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html
|
// Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html
|
||||||
|
@ -2555,6 +2626,8 @@ public:
|
||||||
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(
|
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(
|
||||||
context);
|
context);
|
||||||
target.addIllegalOp<AtenZerosLikeOp>();
|
target.addIllegalOp<AtenZerosLikeOp>();
|
||||||
|
patterns.add<DecomposeAtenRollOp>(context);
|
||||||
|
target.addIllegalOp<AtenRollOp>();
|
||||||
patterns.add<DecomposeAtenRepeatOp>(context);
|
patterns.add<DecomposeAtenRepeatOp>(context);
|
||||||
target.addIllegalOp<AtenRepeatOp>();
|
target.addIllegalOp<AtenRepeatOp>();
|
||||||
patterns.add<DecomposeAtenExpandOp>(context);
|
patterns.add<DecomposeAtenExpandOp>(context);
|
||||||
|
|
|
@ -658,7 +658,8 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp,
|
AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp,
|
||||||
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroOp,
|
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroOp,
|
||||||
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
||||||
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp>(
|
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp,
|
||||||
|
AtenRollOp>(
|
||||||
op)) {
|
op)) {
|
||||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||||
}
|
}
|
||||||
|
|
|
@ -4213,6 +4213,10 @@ module {
|
||||||
}
|
}
|
||||||
return %7 : !torch.list<int>
|
return %7 : !torch.list<int>
|
||||||
}
|
}
|
||||||
|
func.func @__torch_mlir_shape_fn.aten.roll(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {
|
||||||
|
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||||
|
return %0 : !torch.list<int>
|
||||||
|
}
|
||||||
func.func @__torch__.torch.jit._shape_functions.expand(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
func.func @__torch__.torch.jit._shape_functions.expand(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||||
%int-1 = torch.constant.int -1
|
%int-1 = torch.constant.int -1
|
||||||
%true = torch.constant.bool true
|
%true = torch.constant.bool true
|
||||||
|
|
|
@ -635,6 +635,9 @@ def aten〇repeat(self: List[int], repeats: List[int]) -> List[int]:
|
||||||
out.append(self[i] * repeats[i + leading_rank])
|
out.append(self[i] * repeats[i + leading_rank])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def aten〇roll(self: List[int], shifts: List[int], dims: List[int] = ()) -> List[int]:
|
||||||
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
def aten〇expand(self: List[int], size: List[int], implicit: bool = False) -> List[int]:
|
def aten〇expand(self: List[int], size: List[int], implicit: bool = False) -> List[int]:
|
||||||
return upstream_shape_functions.expand(self, size)
|
return upstream_shape_functions.expand(self, size)
|
||||||
|
|
||||||
|
|
|
@ -338,6 +338,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
|
emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
|
||||||
emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)")
|
emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)")
|
||||||
emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)")
|
emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)")
|
||||||
|
emit("aten::roll : (Tensor, int[], int[]) -> (Tensor)"),
|
||||||
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
|
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
|
||||||
emit(
|
emit(
|
||||||
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
|
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
|
||||||
|
|
|
@ -1047,6 +1047,27 @@ def BroadcastToModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class RollModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, -1, 2], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return x.roll([2, -1], [0, 2])
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: RollModule())
|
||||||
|
def RollModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 1, 2))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class RepeatModule(torch.nn.Module):
|
class RepeatModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -1065,7 +1086,6 @@ class RepeatModule(torch.nn.Module):
|
||||||
def RepeatModule_basic(module, tu: TestUtils):
|
def RepeatModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 1, 2))
|
module.forward(tu.rand(3, 1, 2))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1336,6 +1336,7 @@ func.func @torch.aten.std.dim(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vten
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.flatten.using_ints(
|
// CHECK-LABEL: func.func @torch.aten.flatten.using_ints(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> {
|
||||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
@ -1350,3 +1351,34 @@ func.func @torch.aten.flatten.using_ints(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
|
||||||
%1 = torch.aten.flatten.using_ints %arg0, %int0, %int3: !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
|
%1 = torch.aten.flatten.using_ints %arg0, %int0, %int3: !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
|
||||||
return %1 : !torch.vtensor<[?],f32>
|
return %1 : !torch.vtensor<[?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.roll(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int, %[[ARG2:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[INT:.*]]-2 = torch.constant.int -2
|
||||||
|
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT]]-2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[T2:.*]] = torch.aten.neg.int %[[ARG1]] : !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[T3:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[INT1]], %[[T2]], %[[NONE]], %[[INT1]]_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[T4:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[INT1]], %[[INT0]], %[[T2]], %[[INT1]]_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[T5:.*]] = torch.prim.ListConstruct %[[T3]], %[[T4]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor<[?,?],f32>>
|
||||||
|
// CHECK: %[[T6:.*]] = torch.aten.cat %[[T5]], %[[INT1]] : !torch.list<vtensor<[?,?],f32>>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[T7:.*]] = torch.aten.neg.int %[[ARG2]] : !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[T8:.*]] = torch.aten.slice.Tensor %[[T6]], %[[INT]]-2, %[[T7]], %[[NONE]], %[[INT]]1_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[T9:.*]] = torch.aten.slice.Tensor %[[T6]], %[[INT]]-2, %[[INT]]0, %[[T7]], %[[INT]]1_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[T10:.*]] = torch.prim.ListConstruct %[[T8]], %[[T9]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor<[?,?],f32>>
|
||||||
|
// CHECK: %[[T11:.*]] = torch.aten.cat %[[T10]], %[[INT]]-2 : !torch.list<vtensor<[?,?],f32>>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[T11]] : !torch.vtensor<[?,?],f32>
|
||||||
|
func.func @torch.aten.roll(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%0 = torch.prim.ListConstruct %arg1, %arg2: (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int-2 = torch.constant.int -2
|
||||||
|
%1 = torch.prim.ListConstruct %int1, %int-2: (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%2 = torch.aten.roll %arg0, %0, %1 : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %2 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue