mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten::max_pool2d_with_indices_backward op
This commit adds lowering of `aten::max_pool2d_with_indices_backward` op. This commit also fixes formatting issues in basic.py. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/759/head snapshot-20220414.389
parent
91d3e7ba15
commit
1bccb4fc8a
|
@ -77,6 +77,10 @@ SmallVector<Value> getTypeConvertedValues(OpBuilder &b, Location loc,
|
||||||
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
||||||
Type dtype);
|
Type dtype);
|
||||||
|
|
||||||
|
// Return the number of elements of a tensor if the shape is static; otherwise,
|
||||||
|
// return -1.
|
||||||
|
int64_t getNumberOfElements(RankedTensorType inputType);
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -2938,6 +2938,36 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_indices_backward", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$grad_output,
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
ListOfTorchIntType:$kernel_size,
|
||||||
|
ListOfTorchIntType:$stride,
|
||||||
|
ListOfTorchIntType:$padding,
|
||||||
|
ListOfTorchIntType:$dilation,
|
||||||
|
Torch_BoolType:$ceil_mode,
|
||||||
|
AnyTorchTensorType:$indices
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenMaxPool2dWithIndicesBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 8, 1);
|
||||||
|
}
|
||||||
|
void AtenMaxPool2dWithIndicesBackwardOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 8, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [
|
def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -305,6 +305,222 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// The original implementation of the op is as follows:
|
||||||
|
//
|
||||||
|
// Indices and GradOutput Layout: [N, C, H, W] or [C, H, W]
|
||||||
|
// Input Layout: [N, C, Hin, Win] or [C, Hin, Win]
|
||||||
|
//
|
||||||
|
// for i in range(N):
|
||||||
|
// for j in range(C):
|
||||||
|
// for k in range(H):
|
||||||
|
// for l in range(W):
|
||||||
|
// index = indices[i, j, k, l]
|
||||||
|
// result[i, j, index/Win, index%Win] += gradOutput[i, j, k, l]
|
||||||
|
//
|
||||||
|
// OR
|
||||||
|
//
|
||||||
|
// for i in range(C):
|
||||||
|
// for j in range(H):
|
||||||
|
// for k in range(W):
|
||||||
|
// index = indices[i, j, k]
|
||||||
|
// result[i, index/Win, index%Win] += gradOutput[i, j, k]
|
||||||
|
//
|
||||||
|
class ConvertAtenMaxPool2dWithIndicesBackwardOp
|
||||||
|
: public OpConversionPattern<AtenMaxPool2dWithIndicesBackwardOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenMaxPool2dWithIndicesBackwardOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
MLIRContext *context = op->getContext();
|
||||||
|
Value gradOutput = adaptor.grad_output();
|
||||||
|
Value input = adaptor.self();
|
||||||
|
RankedTensorType gradOutputType =
|
||||||
|
gradOutput.getType().cast<RankedTensorType>();
|
||||||
|
Type gradOutputElemType = gradOutputType.getElementType();
|
||||||
|
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||||
|
Type inputElemType = inputType.getElementType();
|
||||||
|
int64_t tensorOperandRank = inputType.getRank();
|
||||||
|
|
||||||
|
// `TMTensor::ScatterOp` expects indices of element type i32.
|
||||||
|
Value indices = convertTensorToDtype(
|
||||||
|
rewriter, loc, op.indices(),
|
||||||
|
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
|
||||||
|
indices = typeConverter->materializeTargetConversion(
|
||||||
|
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
|
||||||
|
RankedTensorType indicesType = indices.getType().cast<RankedTensorType>();
|
||||||
|
Type indicesElemType = indicesType.getElementType();
|
||||||
|
|
||||||
|
// The element type of the `input` and `grad_output` should be same.
|
||||||
|
if (inputElemType != gradOutputElemType)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op,
|
||||||
|
"Input element type should be same as the grad_output element type.");
|
||||||
|
|
||||||
|
// Since the scatter op requires indices to be a 2-d tensor, we create a new
|
||||||
|
// 5-d/4-d tensor (depending on the original indices layout) comprising the
|
||||||
|
// index values. We will collapse this tensor into a 2-d tensor. The
|
||||||
|
// algorithm for the creation of updated indices tensor is as follows:
|
||||||
|
//
|
||||||
|
// for i in range(N):
|
||||||
|
// for j in range(C):
|
||||||
|
// for k in range(H):
|
||||||
|
// for l in range(W):
|
||||||
|
// for m in range(4):
|
||||||
|
// if m == 0:
|
||||||
|
// updatedIndices[N][C][H][W][0] = i
|
||||||
|
// if m == 1:
|
||||||
|
// updatedIndices[N][C][H][W][1] = j
|
||||||
|
// if m == 2:
|
||||||
|
// updatedIndices[N][C][H][W][2] =
|
||||||
|
// originalIndices[i, j, k, l] / Win
|
||||||
|
// if m == 3:
|
||||||
|
// updatedIndices[N][C][H][W][3] =
|
||||||
|
// originalIndices[i, j, k, l] % Win
|
||||||
|
//
|
||||||
|
// OR
|
||||||
|
//
|
||||||
|
// for j in range(C):
|
||||||
|
// for k in range(H):
|
||||||
|
// for l in range(W):
|
||||||
|
// for m in range(3):
|
||||||
|
// if m == 0:
|
||||||
|
// updatedIndices[C][H][W][0] = i
|
||||||
|
// if m == 1:
|
||||||
|
// updatedIndices[C][H][W][1] = originalIndices[i, j, k, l] / Win
|
||||||
|
// if m == 2:
|
||||||
|
// updatedIndices[C][H][W][2] = originalIndices[i, j, k, l] % Win
|
||||||
|
|
||||||
|
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
|
||||||
|
|
||||||
|
SmallVector<AffineExpr> originalIndicesDimExprs, updatedIndicesDimExprs;
|
||||||
|
for (int64_t i = 0; i < tensorOperandRank; i++) {
|
||||||
|
originalIndicesDimExprs.push_back(rewriter.getAffineDimExpr(i));
|
||||||
|
updatedIndicesDimExprs.push_back(rewriter.getAffineDimExpr(i));
|
||||||
|
}
|
||||||
|
updatedIndicesDimExprs.push_back(
|
||||||
|
rewriter.getAffineDimExpr(tensorOperandRank));
|
||||||
|
|
||||||
|
SmallVector<AffineMap> indexingMaps = AffineMap::inferFromExprList(
|
||||||
|
{originalIndicesDimExprs, updatedIndicesDimExprs});
|
||||||
|
SmallVector<StringRef> iteratorTypes(tensorOperandRank + 1, "parallel");
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> updatedIndicesShape =
|
||||||
|
getAsOpFoldResult(getTensorSizes(rewriter, loc, indices));
|
||||||
|
updatedIndicesShape.push_back(rewriter.getIndexAttr(tensorOperandRank));
|
||||||
|
|
||||||
|
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||||
|
loc, updatedIndicesShape, indicesElemType);
|
||||||
|
|
||||||
|
Value wIn = inputShape[tensorOperandRank - 1];
|
||||||
|
SmallVector<Value> cstValues;
|
||||||
|
for (int64_t i = 0; i < tensorOperandRank; i++)
|
||||||
|
cstValues.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
|
||||||
|
|
||||||
|
Value updatedIndices =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, initTensor.getType(), indices, initTensor, indexingMaps,
|
||||||
|
iteratorTypes,
|
||||||
|
[tensorOperandRank, wIn, cstValues,
|
||||||
|
indicesElemType](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value index = castIntToIndex(b, loc, args[0]);
|
||||||
|
Value updatedIndex = cstValues[0];
|
||||||
|
Value lastDim =
|
||||||
|
b.create<linalg::IndexOp>(loc, tensorOperandRank);
|
||||||
|
|
||||||
|
for (int64_t i = tensorOperandRank - 1; i >= 0; i--) {
|
||||||
|
Value result;
|
||||||
|
if (i == tensorOperandRank - 1)
|
||||||
|
result = b.create<arith::RemSIOp>(loc, index, wIn);
|
||||||
|
if (i == tensorOperandRank - 2)
|
||||||
|
result = b.create<arith::FloorDivSIOp>(loc, index, wIn);
|
||||||
|
if (i == tensorOperandRank - 3 ||
|
||||||
|
i == tensorOperandRank - 4)
|
||||||
|
result = b.create<linalg::IndexOp>(loc, i);
|
||||||
|
|
||||||
|
Value pred = b.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::eq, lastDim, cstValues[i]);
|
||||||
|
Value addAmount = b.create<arith::SelectOp>(
|
||||||
|
loc, pred, result, cstValues[0]);
|
||||||
|
updatedIndex =
|
||||||
|
b.create<arith::AddIOp>(loc, updatedIndex, addAmount);
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedIndex = b.create<arith::IndexCastOp>(
|
||||||
|
loc, indicesElemType, updatedIndex);
|
||||||
|
b.create<linalg::YieldOp>(loc, updatedIndex);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
|
||||||
|
// Creating a new tensor initialized with zeros and size same as the input
|
||||||
|
// tensor.
|
||||||
|
Value outputTensor =
|
||||||
|
createZeroInitTensor(rewriter, loc, inputShape, inputElemType);
|
||||||
|
|
||||||
|
// Collapsing `gradOutput` into a 1-d tensor.
|
||||||
|
SmallVector<ReassociationIndices> reassociationCollapse(1);
|
||||||
|
for (auto i = 0; i < gradOutputType.getRank(); i++)
|
||||||
|
reassociationCollapse[0].push_back(i);
|
||||||
|
RankedTensorType gradOutputFlattenedType;
|
||||||
|
int64_t numelGradOutput = getNumberOfElements(gradOutputType);
|
||||||
|
gradOutputFlattenedType =
|
||||||
|
RankedTensorType::get({numelGradOutput}, gradOutputElemType);
|
||||||
|
Value gradOutputFlattened = rewriter.create<tensor::CollapseShapeOp>(
|
||||||
|
loc, gradOutputFlattenedType, gradOutput, reassociationCollapse);
|
||||||
|
|
||||||
|
// Collapsing updated indices into a 2-d tensor.
|
||||||
|
SmallVector<ReassociationIndices> reassociationCollapseIndices(2);
|
||||||
|
for (auto i = 0; i < tensorOperandRank; i++)
|
||||||
|
reassociationCollapseIndices[0].push_back(i);
|
||||||
|
reassociationCollapseIndices[1].push_back(tensorOperandRank);
|
||||||
|
int64_t numelIndices = getNumberOfElements(indicesType);
|
||||||
|
Value indicesCollapsed = rewriter.create<tensor::CollapseShapeOp>(
|
||||||
|
loc,
|
||||||
|
RankedTensorType::get({numelIndices, tensorOperandRank},
|
||||||
|
indicesElemType),
|
||||||
|
updatedIndices, reassociationCollapseIndices);
|
||||||
|
|
||||||
|
bool invalidInputTypeFound = false;
|
||||||
|
Value scatterOp = createTMTensorScatterOp(
|
||||||
|
rewriter, loc, /*updates=*/gradOutputFlattened,
|
||||||
|
/*indices=*/indicesCollapsed, /*original=*/outputTensor,
|
||||||
|
/*uniqueIndices=*/false,
|
||||||
|
[&](OpBuilder &b, Location loc, Value valuesElement,
|
||||||
|
Value inputElement) {
|
||||||
|
Value yieldValue = valuesElement;
|
||||||
|
if (inputElement.getType().isa<mlir::IntegerType>()) {
|
||||||
|
yieldValue =
|
||||||
|
b.create<arith::AddIOp>(loc, inputElement, valuesElement);
|
||||||
|
} else if (inputElement.getType().isa<mlir::FloatType>()) {
|
||||||
|
yieldValue =
|
||||||
|
b.create<arith::AddFOp>(loc, inputElement, valuesElement);
|
||||||
|
} else {
|
||||||
|
invalidInputTypeFound = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
b.create<TMTensor::YieldOp>(loc, yieldValue);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (invalidInputTypeFound) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op,
|
||||||
|
"unimplemented: input tensor must be of integer type or float type");
|
||||||
|
}
|
||||||
|
|
||||||
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, scatterOp);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// The pass
|
// The pass
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -339,6 +555,9 @@ public:
|
||||||
target.addIllegalOp<ValsemVariantAtenIndexPutImplOp>();
|
target.addIllegalOp<ValsemVariantAtenIndexPutImplOp>();
|
||||||
patterns.add<ConvertValsemVariantAtenIndexPutImplOp>(typeConverter,
|
patterns.add<ConvertValsemVariantAtenIndexPutImplOp>(typeConverter,
|
||||||
context);
|
context);
|
||||||
|
target.addIllegalOp<AtenMaxPool2dWithIndicesBackwardOp>();
|
||||||
|
patterns.add<ConvertAtenMaxPool2dWithIndicesBackwardOp>(typeConverter,
|
||||||
|
context);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
|
|
|
@ -272,6 +272,18 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
||||||
llvm_unreachable("convertScalarToDtype should handle all the types");
|
llvm_unreachable("convertScalarToDtype should handle all the types");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the number of elements of a tensor if the shape is static; otherwise,
|
||||||
|
// return -1.
|
||||||
|
int64_t getNumberOfElements(RankedTensorType inputType) {
|
||||||
|
if (!inputType.hasStaticShape())
|
||||||
|
return -1;
|
||||||
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||||
|
int64_t numel = 1;
|
||||||
|
for (int64_t i = 0; i < inputType.getRank(); i++)
|
||||||
|
numel *= inputShape[i];
|
||||||
|
return numel;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -521,7 +521,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Take dtype from second operand.
|
// Take dtype from second operand.
|
||||||
if (isa<AtenNllLossBackwardOp>(op)) {
|
if (isa<AtenNllLossBackwardOp, AtenMaxPool2dWithIndicesBackwardOp>(op)) {
|
||||||
auto self = operands[1]->getValue();
|
auto self = operands[1]->getValue();
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||||
|
|
|
@ -1572,6 +1572,9 @@ module {
|
||||||
}
|
}
|
||||||
return %none : !torch.none
|
return %none : !torch.none
|
||||||
}
|
}
|
||||||
|
func @"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> {
|
||||||
|
return %arg1 : !torch.list<int>
|
||||||
|
}
|
||||||
func @"__torch_mlir_shape_fn.aten.adaptive_avg_pool2d"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
func @"__torch_mlir_shape_fn.aten.adaptive_avg_pool2d"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.adaptive_avg_pool2d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.adaptive_avg_pool2d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
|
|
|
@ -563,6 +563,9 @@ def aten〇resize_(self: List[int], size: List[int], memory_format: Optional[int
|
||||||
def aten〇max_pool2d(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> List[int]:
|
def aten〇max_pool2d(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> List[int]:
|
||||||
return upstream_shape_helpers.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode)
|
return upstream_shape_helpers.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode)
|
||||||
|
|
||||||
|
def aten〇max_pool2d_with_indices_backward(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]:
|
||||||
|
return self
|
||||||
|
|
||||||
def aten〇adaptive_avg_pool2d(self: List[int], output_size: List[int]) -> List[int]:
|
def aten〇adaptive_avg_pool2d(self: List[int], output_size: List[int]) -> List[int]:
|
||||||
return upstream_shape_helpers.adaptive_avg_pool2d(self, output_size)
|
return upstream_shape_helpers.adaptive_avg_pool2d(self, output_size)
|
||||||
|
|
||||||
|
|
|
@ -325,6 +325,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit(
|
emit(
|
||||||
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
|
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
|
||||||
)
|
)
|
||||||
|
emit(
|
||||||
|
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
||||||
|
)
|
||||||
emit(
|
emit(
|
||||||
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
|
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
|
||||||
)
|
)
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue