[MLIR][TORCH] Add E2E support for aten::max_pool2d_with_indices_backward op

This commit adds lowering of `aten::max_pool2d_with_indices_backward` op.

This commit also fixes formatting issues in basic.py.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/759/head snapshot-20220414.389
Vivek Khandelwal 2022-04-14 17:46:39 +05:30
parent 91d3e7ba15
commit 1bccb4fc8a
9 changed files with 680 additions and 30 deletions

View File

@ -77,6 +77,10 @@ SmallVector<Value> getTypeConvertedValues(OpBuilder &b, Location loc,
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
Type dtype);
// Return the number of elements of a tensor if the shape is static; otherwise,
// return -1.
int64_t getNumberOfElements(RankedTensorType inputType);
} // namespace Torch
} // namespace torch
} // namespace mlir

View File

@ -2938,6 +2938,36 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
}];
}
def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_indices_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self,
ListOfTorchIntType:$kernel_size,
ListOfTorchIntType:$stride,
ListOfTorchIntType:$padding,
ListOfTorchIntType:$dilation,
Torch_BoolType:$ceil_mode,
AnyTorchTensorType:$indices
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxPool2dWithIndicesBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 1);
}
void AtenMaxPool2dWithIndicesBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 1);
}
}];
}
def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -305,6 +305,222 @@ public:
};
} // namespace
namespace {
// The original implementation of the op is as follows:
//
// Indices and GradOutput Layout: [N, C, H, W] or [C, H, W]
// Input Layout: [N, C, Hin, Win] or [C, Hin, Win]
//
// for i in range(N):
// for j in range(C):
// for k in range(H):
// for l in range(W):
// index = indices[i, j, k, l]
// result[i, j, index/Win, index%Win] += gradOutput[i, j, k, l]
//
// OR
//
// for i in range(C):
// for j in range(H):
// for k in range(W):
// index = indices[i, j, k]
// result[i, index/Win, index%Win] += gradOutput[i, j, k]
//
class ConvertAtenMaxPool2dWithIndicesBackwardOp
: public OpConversionPattern<AtenMaxPool2dWithIndicesBackwardOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenMaxPool2dWithIndicesBackwardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
MLIRContext *context = op->getContext();
Value gradOutput = adaptor.grad_output();
Value input = adaptor.self();
RankedTensorType gradOutputType =
gradOutput.getType().cast<RankedTensorType>();
Type gradOutputElemType = gradOutputType.getElementType();
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
Type inputElemType = inputType.getElementType();
int64_t tensorOperandRank = inputType.getRank();
// `TMTensor::ScatterOp` expects indices of element type i32.
Value indices = convertTensorToDtype(
rewriter, loc, op.indices(),
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
indices = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
RankedTensorType indicesType = indices.getType().cast<RankedTensorType>();
Type indicesElemType = indicesType.getElementType();
// The element type of the `input` and `grad_output` should be same.
if (inputElemType != gradOutputElemType)
return rewriter.notifyMatchFailure(
op,
"Input element type should be same as the grad_output element type.");
// Since the scatter op requires indices to be a 2-d tensor, we create a new
// 5-d/4-d tensor (depending on the original indices layout) comprising the
// index values. We will collapse this tensor into a 2-d tensor. The
// algorithm for the creation of updated indices tensor is as follows:
//
// for i in range(N):
// for j in range(C):
// for k in range(H):
// for l in range(W):
// for m in range(4):
// if m == 0:
// updatedIndices[N][C][H][W][0] = i
// if m == 1:
// updatedIndices[N][C][H][W][1] = j
// if m == 2:
// updatedIndices[N][C][H][W][2] =
// originalIndices[i, j, k, l] / Win
// if m == 3:
// updatedIndices[N][C][H][W][3] =
// originalIndices[i, j, k, l] % Win
//
// OR
//
// for j in range(C):
// for k in range(H):
// for l in range(W):
// for m in range(3):
// if m == 0:
// updatedIndices[C][H][W][0] = i
// if m == 1:
// updatedIndices[C][H][W][1] = originalIndices[i, j, k, l] / Win
// if m == 2:
// updatedIndices[C][H][W][2] = originalIndices[i, j, k, l] % Win
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
SmallVector<AffineExpr> originalIndicesDimExprs, updatedIndicesDimExprs;
for (int64_t i = 0; i < tensorOperandRank; i++) {
originalIndicesDimExprs.push_back(rewriter.getAffineDimExpr(i));
updatedIndicesDimExprs.push_back(rewriter.getAffineDimExpr(i));
}
updatedIndicesDimExprs.push_back(
rewriter.getAffineDimExpr(tensorOperandRank));
SmallVector<AffineMap> indexingMaps = AffineMap::inferFromExprList(
{originalIndicesDimExprs, updatedIndicesDimExprs});
SmallVector<StringRef> iteratorTypes(tensorOperandRank + 1, "parallel");
SmallVector<OpFoldResult> updatedIndicesShape =
getAsOpFoldResult(getTensorSizes(rewriter, loc, indices));
updatedIndicesShape.push_back(rewriter.getIndexAttr(tensorOperandRank));
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, updatedIndicesShape, indicesElemType);
Value wIn = inputShape[tensorOperandRank - 1];
SmallVector<Value> cstValues;
for (int64_t i = 0; i < tensorOperandRank; i++)
cstValues.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
Value updatedIndices =
rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), indices, initTensor, indexingMaps,
iteratorTypes,
[tensorOperandRank, wIn, cstValues,
indicesElemType](OpBuilder &b, Location loc, ValueRange args) {
Value index = castIntToIndex(b, loc, args[0]);
Value updatedIndex = cstValues[0];
Value lastDim =
b.create<linalg::IndexOp>(loc, tensorOperandRank);
for (int64_t i = tensorOperandRank - 1; i >= 0; i--) {
Value result;
if (i == tensorOperandRank - 1)
result = b.create<arith::RemSIOp>(loc, index, wIn);
if (i == tensorOperandRank - 2)
result = b.create<arith::FloorDivSIOp>(loc, index, wIn);
if (i == tensorOperandRank - 3 ||
i == tensorOperandRank - 4)
result = b.create<linalg::IndexOp>(loc, i);
Value pred = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, lastDim, cstValues[i]);
Value addAmount = b.create<arith::SelectOp>(
loc, pred, result, cstValues[0]);
updatedIndex =
b.create<arith::AddIOp>(loc, updatedIndex, addAmount);
}
updatedIndex = b.create<arith::IndexCastOp>(
loc, indicesElemType, updatedIndex);
b.create<linalg::YieldOp>(loc, updatedIndex);
})
.getResult(0);
// Creating a new tensor initialized with zeros and size same as the input
// tensor.
Value outputTensor =
createZeroInitTensor(rewriter, loc, inputShape, inputElemType);
// Collapsing `gradOutput` into a 1-d tensor.
SmallVector<ReassociationIndices> reassociationCollapse(1);
for (auto i = 0; i < gradOutputType.getRank(); i++)
reassociationCollapse[0].push_back(i);
RankedTensorType gradOutputFlattenedType;
int64_t numelGradOutput = getNumberOfElements(gradOutputType);
gradOutputFlattenedType =
RankedTensorType::get({numelGradOutput}, gradOutputElemType);
Value gradOutputFlattened = rewriter.create<tensor::CollapseShapeOp>(
loc, gradOutputFlattenedType, gradOutput, reassociationCollapse);
// Collapsing updated indices into a 2-d tensor.
SmallVector<ReassociationIndices> reassociationCollapseIndices(2);
for (auto i = 0; i < tensorOperandRank; i++)
reassociationCollapseIndices[0].push_back(i);
reassociationCollapseIndices[1].push_back(tensorOperandRank);
int64_t numelIndices = getNumberOfElements(indicesType);
Value indicesCollapsed = rewriter.create<tensor::CollapseShapeOp>(
loc,
RankedTensorType::get({numelIndices, tensorOperandRank},
indicesElemType),
updatedIndices, reassociationCollapseIndices);
bool invalidInputTypeFound = false;
Value scatterOp = createTMTensorScatterOp(
rewriter, loc, /*updates=*/gradOutputFlattened,
/*indices=*/indicesCollapsed, /*original=*/outputTensor,
/*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value valuesElement,
Value inputElement) {
Value yieldValue = valuesElement;
if (inputElement.getType().isa<mlir::IntegerType>()) {
yieldValue =
b.create<arith::AddIOp>(loc, inputElement, valuesElement);
} else if (inputElement.getType().isa<mlir::FloatType>()) {
yieldValue =
b.create<arith::AddFOp>(loc, inputElement, valuesElement);
} else {
invalidInputTypeFound = true;
return;
}
b.create<TMTensor::YieldOp>(loc, yieldValue);
});
if (invalidInputTypeFound) {
return rewriter.notifyMatchFailure(
op,
"unimplemented: input tensor must be of integer type or float type");
}
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, scatterOp);
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
@ -339,6 +555,9 @@ public:
target.addIllegalOp<ValsemVariantAtenIndexPutImplOp>();
patterns.add<ConvertValsemVariantAtenIndexPutImplOp>(typeConverter,
context);
target.addIllegalOp<AtenMaxPool2dWithIndicesBackwardOp>();
patterns.add<ConvertAtenMaxPool2dWithIndicesBackwardOp>(typeConverter,
context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))

View File

@ -272,6 +272,18 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
llvm_unreachable("convertScalarToDtype should handle all the types");
}
// Return the number of elements of a tensor if the shape is static; otherwise,
// return -1.
int64_t getNumberOfElements(RankedTensorType inputType) {
if (!inputType.hasStaticShape())
return -1;
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t numel = 1;
for (int64_t i = 0; i < inputType.getRank(); i++)
numel *= inputShape[i];
return numel;
}
} // namespace Torch
} // namespace torch
} // namespace mlir

View File

@ -521,7 +521,7 @@ ChangeResult TypeAnalyzer::visitOperation(
}
// Take dtype from second operand.
if (isa<AtenNllLossBackwardOp>(op)) {
if (isa<AtenNllLossBackwardOp, AtenMaxPool2dWithIndicesBackwardOp>(op)) {
auto self = operands[1]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());

View File

@ -1572,6 +1572,9 @@ module {
}
return %none : !torch.none
}
func @"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> {
return %arg1 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.adaptive_avg_pool2d"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.adaptive_avg_pool2d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>

View File

@ -563,6 +563,9 @@ def atenresize_(self: List[int], size: List[int], memory_format: Optional[int
def atenmax_pool2d(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> List[int]:
return upstream_shape_helpers.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode)
def atenmax_pool2d_with_indices_backward(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]:
return self
def atenadaptive_avg_pool2d(self: List[int], output_size: List[int]) -> List[int]:
return upstream_shape_helpers.adaptive_avg_pool2d(self, output_size)

View File

@ -325,6 +325,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit(
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
)
emit(
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
)
emit(
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
)

File diff suppressed because it is too large Load Diff