mirror of https://github.com/llvm/torch-mlir
Add support for reflection_pad1d (#2706)
Adds a lowering to Linalg for reflection_pad1d. Based on ideas/code from draft PR https://github.com/llvm/torch-mlir/pull/2693. --------- Co-authored-by: Kumar Deepak <kumar@xilinx.com>pull/2722/head
parent
6660a26594
commit
9adad9bc40
|
@ -7869,6 +7869,30 @@ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenReflectionPad1dOp : Torch_Op<"aten.reflection_pad1d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$padding
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenReflectionPad1dOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenReflectionPad1dOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenPadOp : Torch_Op<"aten.pad", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -107,6 +107,143 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
|||
return success();
|
||||
}
|
||||
|
||||
// Example:
|
||||
// input = tensor([[[0., 1., 2., 3.],
|
||||
// [4., 5., 6., 7.]]])
|
||||
// torch.ops.aten.reflection_pad1d(input, (3,1)) ; padding_left = 3, padding_right = 1
|
||||
// tensor([[[3., 2., 1., 0., 1., 2., 3., 2.],
|
||||
// [7., 6., 5., 4., 5., 6., 7., 6.]]])
|
||||
// Checks: 1) Each of padding_left and padding_right must be non-negative less than size of last dimension
|
||||
// Implementation: a) Construct a result tensor of shape of input tensor except for the last dimension.
|
||||
// The last dimension of the result tensor should be last dimension of input tensor +
|
||||
// left padding size + right padding size. INitialize result tensor to all zeros
|
||||
// b) Setup affine map to take slice from input tensor of size left padding starting from
|
||||
// second column onwards as first column is reflection boundary
|
||||
// c) Reflect the affine map to have resultant slice reflected
|
||||
// d) Take the slice and write from begining in result tensor
|
||||
// e) write the original tensor next into result tensor
|
||||
// f) Setup affine map to take slice from input tensor of right padding size ending
|
||||
// at second last column as last column is reflection boundary for right padding
|
||||
// g) Reflect the affine map to have resultant slice reflected
|
||||
// h) Take the slice and write from left padding size + orignal tensor last dim size
|
||||
// into result tensor
|
||||
// Uses the ideas/code used for AtenReflectionPad2dOp
|
||||
namespace {
|
||||
class ConvertAtenReflectionPad1dOp
|
||||
: public OpConversionPattern<AtenReflectionPad1dOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenReflectionPad1dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> padInts;
|
||||
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant int padding range is supported");
|
||||
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
Location loc = op.getLoc();
|
||||
|
||||
// Lambda Unitility Functions
|
||||
// Create an Integer expression of x + y
|
||||
auto createIAdd = [&](Value x, Value y) {
|
||||
return rewriter.create<arith::AddIOp>(loc, x, y);
|
||||
};
|
||||
|
||||
// Create an integer expression of x - y
|
||||
auto createISub = [&](Value x, Value y) {
|
||||
return rewriter.create<arith::SubIOp>(loc, x, y);
|
||||
};
|
||||
|
||||
enum PadLocation {PAD_LEFT = 0, PAD_RIGHT = 1, PAD_CENTER=2};
|
||||
|
||||
Value input = adaptor.getSelf();
|
||||
Type indexType = rewriter.getIndexType();
|
||||
Value zero = getConstant(rewriter, loc, 0, indexType);
|
||||
Value one = getConstant(rewriter, loc, 1, indexType);
|
||||
auto inputType = llvm::cast<RankedTensorType>(input.getType());
|
||||
auto outputType = llvm::cast<RankedTensorType>(getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
unsigned numDims = inputType.getRank();
|
||||
assert(numDims >= 2 && "Not enough input dimensions");
|
||||
int64_t lastDim = numDims - 1;
|
||||
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
|
||||
Value lastDimSize = inputShape[lastDim]; // input [1,2,4], then lastDim = 2, inputShape[2] will give 4
|
||||
|
||||
Value tileWidth[3], extractOffset[3], insertOffset[3];
|
||||
|
||||
tileWidth[PAD_LEFT] = getConstant(rewriter, loc, padInts[PAD_LEFT], indexType);
|
||||
tileWidth[PAD_RIGHT] = getConstant(rewriter, loc, padInts[PAD_RIGHT], indexType);
|
||||
tileWidth[PAD_CENTER] = lastDimSize;
|
||||
|
||||
extractOffset[PAD_LEFT] = one;
|
||||
// for (1,2,4) input, padding (3,1) lastDimSize=4, 4 - 1 - 1 = 2 [3,5, 6,7], so start offset to 6, which is right
|
||||
// lasDimSize - (tileWidth[PAD_RIGHT] + one)
|
||||
extractOffset[PAD_RIGHT] = createISub(lastDimSize, createIAdd(tileWidth[PAD_RIGHT], one));
|
||||
extractOffset[PAD_CENTER] = zero;
|
||||
|
||||
insertOffset[PAD_LEFT] = zero;
|
||||
insertOffset[PAD_RIGHT] = createIAdd(lastDimSize, tileWidth[PAD_LEFT]);
|
||||
insertOffset[PAD_CENTER] = tileWidth[PAD_LEFT];
|
||||
|
||||
|
||||
SmallVector<Value> resultShape{inputShape};
|
||||
// Result's last dimension will have shape lastDimSize + left padding size + right padding size
|
||||
resultShape[lastDim] = createIAdd(resultShape[lastDim], createIAdd(tileWidth[PAD_LEFT], tileWidth[PAD_RIGHT]));
|
||||
Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, inputType.getElementType());
|
||||
|
||||
// Helper to reflect/reverse the i-th dimension of an affine map without symbols. This only works if applied on a tensor
|
||||
// for which the corresponding dimension has a statically known size
|
||||
auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, int64_t size) {
|
||||
AffineExpr d = map.getResult(i);
|
||||
return map.replace(d, size - d - 1, numDims, 0); // left reflect for (3,1) on input shape (1,2,4). size = 3, lastDim=2, numDims=3
|
||||
};
|
||||
|
||||
SmallVector<utils::IteratorType> iteratorTypes{numDims, utils::IteratorType::parallel};
|
||||
auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context);
|
||||
SmallVector<Value> allOneStrides(numDims, one);
|
||||
|
||||
auto addTileToResult = [&](PadLocation padPosition) {
|
||||
// Create the tile by extracting a slice from the input tensor.
|
||||
SmallVector<Value> extractShape{inputShape};
|
||||
extractShape[lastDim] = tileWidth[padPosition];
|
||||
SmallVector<Value> extractOffsets(numDims, zero);
|
||||
extractOffsets[lastDim] = extractOffset[padPosition];
|
||||
Value tile = rewriter.create<tensor::ExtractSliceOp>(
|
||||
loc, input, extractOffsets, extractShape, allOneStrides);
|
||||
|
||||
|
||||
auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context);
|
||||
// Setup the affine map function to resverse the tile along the horizontal for left and right slices
|
||||
if(padPosition < PAD_CENTER) {
|
||||
inputMap = reflectDim(inputMap, numDims, lastDim, padInts[padPosition]);
|
||||
// Take reflected slice as per inputMap
|
||||
tile = rewriter.create<linalg::GenericOp>(loc, llvm::cast<RankedTensorType>(tile.getType()), tile,
|
||||
tile, ArrayRef({inputMap, idMap}), iteratorTypes,
|
||||
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(nestedLoc, args[0]);
|
||||
}).getResult(0);
|
||||
}
|
||||
// Insert the tile in the resultTensor
|
||||
SmallVector<Value> insertOffsets(numDims, zero);
|
||||
insertOffsets[lastDim] = insertOffset[padPosition];
|
||||
resultTensor = rewriter.create<tensor::InsertSliceOp>(loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides);
|
||||
};
|
||||
|
||||
if(padInts[PAD_LEFT] > 0)
|
||||
addTileToResult(PAD_LEFT);
|
||||
if(padInts[PAD_RIGHT] > 0)
|
||||
addTileToResult(PAD_RIGHT);
|
||||
addTileToResult(PAD_CENTER);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputType, resultTensor);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConvertAtenFlattenUsingIntsOp
|
||||
: public OpConversionPattern<AtenFlattenUsingIntsOp> {
|
||||
|
@ -1413,6 +1550,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
|||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
target.addIllegalOp<AtenReflectionPad1dOp>();
|
||||
patterns.add<ConvertAtenReflectionPad1dOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenFlattenUsingIntsOp>();
|
||||
patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenViewOp>();
|
||||
|
|
|
@ -8331,6 +8331,41 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.reflection_pad1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %int-1 = torch.constant.int -1\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %4 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %5 = torch.aten.lt.int %3, %2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %6 = torch.prim.If %5 -> (!torch.bool) {\n"
|
||||
" %8 = torch.aten.lt.int %4, %2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %8 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %6 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %7 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %7 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list<int>, !torch.list<optional<list<int>>>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -8952,6 +8987,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||
" %2 = torch.aten.eq.int %1, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -1271,6 +1271,21 @@ def aten〇constant_pad_nd〡shape(self: List[int], pad: List[int], value: float
|
|||
def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]:
|
||||
return pad_shape_fn(self, pad)
|
||||
|
||||
#Padding size must be smaller than the size of the last dimension
|
||||
@check_shape_function([ErrorInvocation(TensorOfShape(1, 2, 4), padding=[4,1]),
|
||||
Invocation(TensorOfShape(1, 2, 4), padding=[3,3]),
|
||||
ErrorInvocation(TensorOfShape(1, 2, 4), padding=[1,4]),
|
||||
ErrorInvocation(TensorOfShape(1, 4), padding=[4,1]),
|
||||
Invocation(TensorOfShape(1, 4), padding=[3,3]),
|
||||
ErrorInvocation(TensorOfShape(1, 4), padding=[1,4])])
|
||||
def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List[int]:
|
||||
assert len(self) >= 2
|
||||
hdim = self[-1]
|
||||
padding_left = padding[0]
|
||||
padding_right = padding[1]
|
||||
assert padding_left < hdim and padding_right < hdim
|
||||
return pad_shape_fn(self, padding)
|
||||
|
||||
# TODO: upstream this
|
||||
def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
|
||||
assert len(indices) <= len(self), "More indices than dimensions to index"
|
||||
|
@ -1804,6 +1819,18 @@ def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[i
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
|
||||
@check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1),
|
||||
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]),
|
||||
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2]),
|
||||
Invocation(TensorOfShape(2, 3, 4), padding=[2,1]),
|
||||
Invocation(TensorOfShape(5, 5, 4), padding=[1,2]),
|
||||
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[3,2,1])])
|
||||
def aten〇reflection_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert len(padding) == 2, 'padding size expected to be 2'
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇contiguous〡dtype(self_rank_dtype: Tuple[int, int], memory_format: int = 0) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -541,6 +541,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
|
||||
# Misc tensor ops.
|
||||
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
||||
emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)")
|
||||
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
|
||||
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
|
||||
|
|
|
@ -552,8 +552,80 @@ def ConstantPadNdPartialStaticModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
# ==============================================================================
|
||||
class ReflectionPad1dModule3dInput(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 2, 4], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.reflection_pad1d(x, (3,1))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReflectionPad1dModule3dInput())
|
||||
def ReflectionPad1dModule3dInput_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1,2,4))
|
||||
|
||||
|
||||
class ReflectionPad1dModule2dInput(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 4], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.reflection_pad1d(x, (3,2))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReflectionPad1dModule2dInput())
|
||||
def ReflectionPad1dModule2dInput_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2,4))
|
||||
|
||||
class ReflectionPad1dModule3dInputLeft(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 4, 5], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.reflection_pad1d(x, (2,0))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReflectionPad1dModule3dInputLeft())
|
||||
def ReflectionPad1dModule3dInput_Left(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1,4,5))
|
||||
|
||||
class ReflectionPad1dModule2dInputRight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 6], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.reflection_pad1d(x, (0,3))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReflectionPad1dModule2dInputRight())
|
||||
def ReflectionPad1dModule2dInput_Right(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3,6))
|
||||
|
||||
# ==============================================================================
|
||||
class TransposeIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue