mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten::max_pool2d_with_indices_backward op
This commit adds lowering of `aten::max_pool2d_with_indices_backward` op. This commit also fixes formatting issues in basic.py. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/759/head snapshot-20220414.389
parent
91d3e7ba15
commit
1bccb4fc8a
|
@ -77,6 +77,10 @@ SmallVector<Value> getTypeConvertedValues(OpBuilder &b, Location loc,
|
|||
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
||||
Type dtype);
|
||||
|
||||
// Return the number of elements of a tensor if the shape is static; otherwise,
|
||||
// return -1.
|
||||
int64_t getNumberOfElements(RankedTensorType inputType);
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -2938,6 +2938,36 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_indices_backward", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$grad_output,
|
||||
AnyTorchTensorType:$self,
|
||||
ListOfTorchIntType:$kernel_size,
|
||||
ListOfTorchIntType:$stride,
|
||||
ListOfTorchIntType:$padding,
|
||||
ListOfTorchIntType:$dilation,
|
||||
Torch_BoolType:$ceil_mode,
|
||||
AnyTorchTensorType:$indices
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMaxPool2dWithIndicesBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 8, 1);
|
||||
}
|
||||
void AtenMaxPool2dWithIndicesBackwardOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 8, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -305,6 +305,222 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// The original implementation of the op is as follows:
|
||||
//
|
||||
// Indices and GradOutput Layout: [N, C, H, W] or [C, H, W]
|
||||
// Input Layout: [N, C, Hin, Win] or [C, Hin, Win]
|
||||
//
|
||||
// for i in range(N):
|
||||
// for j in range(C):
|
||||
// for k in range(H):
|
||||
// for l in range(W):
|
||||
// index = indices[i, j, k, l]
|
||||
// result[i, j, index/Win, index%Win] += gradOutput[i, j, k, l]
|
||||
//
|
||||
// OR
|
||||
//
|
||||
// for i in range(C):
|
||||
// for j in range(H):
|
||||
// for k in range(W):
|
||||
// index = indices[i, j, k]
|
||||
// result[i, index/Win, index%Win] += gradOutput[i, j, k]
|
||||
//
|
||||
class ConvertAtenMaxPool2dWithIndicesBackwardOp
|
||||
: public OpConversionPattern<AtenMaxPool2dWithIndicesBackwardOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenMaxPool2dWithIndicesBackwardOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
Value gradOutput = adaptor.grad_output();
|
||||
Value input = adaptor.self();
|
||||
RankedTensorType gradOutputType =
|
||||
gradOutput.getType().cast<RankedTensorType>();
|
||||
Type gradOutputElemType = gradOutputType.getElementType();
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
Type inputElemType = inputType.getElementType();
|
||||
int64_t tensorOperandRank = inputType.getRank();
|
||||
|
||||
// `TMTensor::ScatterOp` expects indices of element type i32.
|
||||
Value indices = convertTensorToDtype(
|
||||
rewriter, loc, op.indices(),
|
||||
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
|
||||
indices = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
|
||||
RankedTensorType indicesType = indices.getType().cast<RankedTensorType>();
|
||||
Type indicesElemType = indicesType.getElementType();
|
||||
|
||||
// The element type of the `input` and `grad_output` should be same.
|
||||
if (inputElemType != gradOutputElemType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"Input element type should be same as the grad_output element type.");
|
||||
|
||||
// Since the scatter op requires indices to be a 2-d tensor, we create a new
|
||||
// 5-d/4-d tensor (depending on the original indices layout) comprising the
|
||||
// index values. We will collapse this tensor into a 2-d tensor. The
|
||||
// algorithm for the creation of updated indices tensor is as follows:
|
||||
//
|
||||
// for i in range(N):
|
||||
// for j in range(C):
|
||||
// for k in range(H):
|
||||
// for l in range(W):
|
||||
// for m in range(4):
|
||||
// if m == 0:
|
||||
// updatedIndices[N][C][H][W][0] = i
|
||||
// if m == 1:
|
||||
// updatedIndices[N][C][H][W][1] = j
|
||||
// if m == 2:
|
||||
// updatedIndices[N][C][H][W][2] =
|
||||
// originalIndices[i, j, k, l] / Win
|
||||
// if m == 3:
|
||||
// updatedIndices[N][C][H][W][3] =
|
||||
// originalIndices[i, j, k, l] % Win
|
||||
//
|
||||
// OR
|
||||
//
|
||||
// for j in range(C):
|
||||
// for k in range(H):
|
||||
// for l in range(W):
|
||||
// for m in range(3):
|
||||
// if m == 0:
|
||||
// updatedIndices[C][H][W][0] = i
|
||||
// if m == 1:
|
||||
// updatedIndices[C][H][W][1] = originalIndices[i, j, k, l] / Win
|
||||
// if m == 2:
|
||||
// updatedIndices[C][H][W][2] = originalIndices[i, j, k, l] % Win
|
||||
|
||||
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
|
||||
|
||||
SmallVector<AffineExpr> originalIndicesDimExprs, updatedIndicesDimExprs;
|
||||
for (int64_t i = 0; i < tensorOperandRank; i++) {
|
||||
originalIndicesDimExprs.push_back(rewriter.getAffineDimExpr(i));
|
||||
updatedIndicesDimExprs.push_back(rewriter.getAffineDimExpr(i));
|
||||
}
|
||||
updatedIndicesDimExprs.push_back(
|
||||
rewriter.getAffineDimExpr(tensorOperandRank));
|
||||
|
||||
SmallVector<AffineMap> indexingMaps = AffineMap::inferFromExprList(
|
||||
{originalIndicesDimExprs, updatedIndicesDimExprs});
|
||||
SmallVector<StringRef> iteratorTypes(tensorOperandRank + 1, "parallel");
|
||||
|
||||
SmallVector<OpFoldResult> updatedIndicesShape =
|
||||
getAsOpFoldResult(getTensorSizes(rewriter, loc, indices));
|
||||
updatedIndicesShape.push_back(rewriter.getIndexAttr(tensorOperandRank));
|
||||
|
||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, updatedIndicesShape, indicesElemType);
|
||||
|
||||
Value wIn = inputShape[tensorOperandRank - 1];
|
||||
SmallVector<Value> cstValues;
|
||||
for (int64_t i = 0; i < tensorOperandRank; i++)
|
||||
cstValues.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
|
||||
|
||||
Value updatedIndices =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initTensor.getType(), indices, initTensor, indexingMaps,
|
||||
iteratorTypes,
|
||||
[tensorOperandRank, wIn, cstValues,
|
||||
indicesElemType](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value index = castIntToIndex(b, loc, args[0]);
|
||||
Value updatedIndex = cstValues[0];
|
||||
Value lastDim =
|
||||
b.create<linalg::IndexOp>(loc, tensorOperandRank);
|
||||
|
||||
for (int64_t i = tensorOperandRank - 1; i >= 0; i--) {
|
||||
Value result;
|
||||
if (i == tensorOperandRank - 1)
|
||||
result = b.create<arith::RemSIOp>(loc, index, wIn);
|
||||
if (i == tensorOperandRank - 2)
|
||||
result = b.create<arith::FloorDivSIOp>(loc, index, wIn);
|
||||
if (i == tensorOperandRank - 3 ||
|
||||
i == tensorOperandRank - 4)
|
||||
result = b.create<linalg::IndexOp>(loc, i);
|
||||
|
||||
Value pred = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, lastDim, cstValues[i]);
|
||||
Value addAmount = b.create<arith::SelectOp>(
|
||||
loc, pred, result, cstValues[0]);
|
||||
updatedIndex =
|
||||
b.create<arith::AddIOp>(loc, updatedIndex, addAmount);
|
||||
}
|
||||
|
||||
updatedIndex = b.create<arith::IndexCastOp>(
|
||||
loc, indicesElemType, updatedIndex);
|
||||
b.create<linalg::YieldOp>(loc, updatedIndex);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
// Creating a new tensor initialized with zeros and size same as the input
|
||||
// tensor.
|
||||
Value outputTensor =
|
||||
createZeroInitTensor(rewriter, loc, inputShape, inputElemType);
|
||||
|
||||
// Collapsing `gradOutput` into a 1-d tensor.
|
||||
SmallVector<ReassociationIndices> reassociationCollapse(1);
|
||||
for (auto i = 0; i < gradOutputType.getRank(); i++)
|
||||
reassociationCollapse[0].push_back(i);
|
||||
RankedTensorType gradOutputFlattenedType;
|
||||
int64_t numelGradOutput = getNumberOfElements(gradOutputType);
|
||||
gradOutputFlattenedType =
|
||||
RankedTensorType::get({numelGradOutput}, gradOutputElemType);
|
||||
Value gradOutputFlattened = rewriter.create<tensor::CollapseShapeOp>(
|
||||
loc, gradOutputFlattenedType, gradOutput, reassociationCollapse);
|
||||
|
||||
// Collapsing updated indices into a 2-d tensor.
|
||||
SmallVector<ReassociationIndices> reassociationCollapseIndices(2);
|
||||
for (auto i = 0; i < tensorOperandRank; i++)
|
||||
reassociationCollapseIndices[0].push_back(i);
|
||||
reassociationCollapseIndices[1].push_back(tensorOperandRank);
|
||||
int64_t numelIndices = getNumberOfElements(indicesType);
|
||||
Value indicesCollapsed = rewriter.create<tensor::CollapseShapeOp>(
|
||||
loc,
|
||||
RankedTensorType::get({numelIndices, tensorOperandRank},
|
||||
indicesElemType),
|
||||
updatedIndices, reassociationCollapseIndices);
|
||||
|
||||
bool invalidInputTypeFound = false;
|
||||
Value scatterOp = createTMTensorScatterOp(
|
||||
rewriter, loc, /*updates=*/gradOutputFlattened,
|
||||
/*indices=*/indicesCollapsed, /*original=*/outputTensor,
|
||||
/*uniqueIndices=*/false,
|
||||
[&](OpBuilder &b, Location loc, Value valuesElement,
|
||||
Value inputElement) {
|
||||
Value yieldValue = valuesElement;
|
||||
if (inputElement.getType().isa<mlir::IntegerType>()) {
|
||||
yieldValue =
|
||||
b.create<arith::AddIOp>(loc, inputElement, valuesElement);
|
||||
} else if (inputElement.getType().isa<mlir::FloatType>()) {
|
||||
yieldValue =
|
||||
b.create<arith::AddFOp>(loc, inputElement, valuesElement);
|
||||
} else {
|
||||
invalidInputTypeFound = true;
|
||||
return;
|
||||
}
|
||||
b.create<TMTensor::YieldOp>(loc, yieldValue);
|
||||
});
|
||||
|
||||
if (invalidInputTypeFound) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"unimplemented: input tensor must be of integer type or float type");
|
||||
}
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, scatterOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -339,6 +555,9 @@ public:
|
|||
target.addIllegalOp<ValsemVariantAtenIndexPutImplOp>();
|
||||
patterns.add<ConvertValsemVariantAtenIndexPutImplOp>(typeConverter,
|
||||
context);
|
||||
target.addIllegalOp<AtenMaxPool2dWithIndicesBackwardOp>();
|
||||
patterns.add<ConvertAtenMaxPool2dWithIndicesBackwardOp>(typeConverter,
|
||||
context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -272,6 +272,18 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
|||
llvm_unreachable("convertScalarToDtype should handle all the types");
|
||||
}
|
||||
|
||||
// Return the number of elements of a tensor if the shape is static; otherwise,
|
||||
// return -1.
|
||||
int64_t getNumberOfElements(RankedTensorType inputType) {
|
||||
if (!inputType.hasStaticShape())
|
||||
return -1;
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
int64_t numel = 1;
|
||||
for (int64_t i = 0; i < inputType.getRank(); i++)
|
||||
numel *= inputShape[i];
|
||||
return numel;
|
||||
}
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -521,7 +521,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
}
|
||||
|
||||
// Take dtype from second operand.
|
||||
if (isa<AtenNllLossBackwardOp>(op)) {
|
||||
if (isa<AtenNllLossBackwardOp, AtenMaxPool2dWithIndicesBackwardOp>(op)) {
|
||||
auto self = operands[1]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
|
|
|
@ -1572,6 +1572,9 @@ module {
|
|||
}
|
||||
return %none : !torch.none
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> {
|
||||
return %arg1 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.adaptive_avg_pool2d"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.adaptive_avg_pool2d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
|
|
|
@ -563,6 +563,9 @@ def aten〇resize_(self: List[int], size: List[int], memory_format: Optional[int
|
|||
def aten〇max_pool2d(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> List[int]:
|
||||
return upstream_shape_helpers.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
|
||||
def aten〇max_pool2d_with_indices_backward(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇adaptive_avg_pool2d(self: List[int], output_size: List[int]) -> List[int]:
|
||||
return upstream_shape_helpers.adaptive_avg_pool2d(self, output_size)
|
||||
|
||||
|
|
|
@ -325,6 +325,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
|
||||
)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue