[MLIR][TORCH] Add E2E support for aten.upsample_nearest2d_backward.vec op

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
pull/1558/head
Vivek Khandelwal 2022-11-01 18:38:04 +05:30
parent db5a496eb4
commit fedf8c0640
10 changed files with 284 additions and 3 deletions

View File

@ -621,4 +621,6 @@ LTC_XFAIL_SET = {
"Fill_TensorFloat32WithFloat32_basic",
"Fill_TensorFloat32WithFloat64_basic",
"Fill_TensorFloat32WithInt64_basic",
"UpSampleNearest2dBackwardVec_basic",
"UpSampleNearest2dBackwardOutputSizeNone_basic",
}

View File

@ -49,6 +49,10 @@ SmallVector<Value>
castIntVectorToIndexVector(OpBuilder &b, Location loc,
SmallVectorImpl<Value> &intValues);
SmallVector<Value>
castIndexVectorToInt64Vector(OpBuilder &b, Location loc,
SmallVectorImpl<Value> &indexValues);
Value getDimOp(OpBuilder &b, Location loc, Value v, int dim);
SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,

View File

@ -4893,6 +4893,32 @@ def Torch_AtenMseLossOp : Torch_Op<"aten.mse_loss", [
}];
}
def Torch_AtenUpsampleNearest2dBackwardVecOp : Torch_Op<"aten.upsample_nearest2d_backward.vec", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::upsample_nearest2d_backward.vec : (Tensor, int[]?, int[], float[]?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchOptionalListOfTorchIntType:$output_size,
AnyTorchListOfTorchIntType:$input_size,
AnyTorchOptionalListOfTorchFloatType:$scale_factors
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUpsampleNearest2dBackwardVecOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenUpsampleNearest2dBackwardVecOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -896,6 +896,190 @@ public:
};
} // namespace
static Value getGradOutputValue(OpBuilder &builder, Location loc,
Value gradOutput, Type gradOutputElemType,
Value numBatch, Value numChannel,
Value inputIndexH, Value inputIndexW,
Value kernelIndexH, Value kernelIndexW,
SmallVector<Value> &gradOutputSizeIndexValues,
SmallVector<Value, 2> &scaleFactorsIntValues) {
Value constantOne = builder.create<arith::ConstantIndexOp>(loc, 1);
Value outputIndexH = builder.create<arith::MulIOp>(
loc, inputIndexH, castIntToIndex(builder, loc, scaleFactorsIntValues[0]));
outputIndexH = builder.create<arith::AddIOp>(loc, outputIndexH, kernelIndexH);
Value outputIndexW = builder.create<arith::MulIOp>(
loc, inputIndexW, castIntToIndex(builder, loc, scaleFactorsIntValues[1]));
outputIndexW = builder.create<arith::AddIOp>(loc, outputIndexW, kernelIndexW);
// Handling corner cases.
Value gradOutputHMinusOne = builder.create<arith::SubIOp>(
loc, gradOutputSizeIndexValues[2], constantOne);
Value predH = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, outputIndexH, gradOutputHMinusOne);
outputIndexH = builder.create<arith::SelectOp>(loc, predH, outputIndexH,
gradOutputHMinusOne);
Value gradOutputWMinusOne = builder.create<arith::SubIOp>(
loc, gradOutputSizeIndexValues[3], constantOne);
Value predW = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, outputIndexW, gradOutputWMinusOne);
outputIndexW = builder.create<arith::SelectOp>(loc, predW, outputIndexW,
gradOutputWMinusOne);
Value gradOutputValue = builder.create<tensor::ExtractOp>(
loc, gradOutput,
ValueRange{numBatch, numChannel, outputIndexH, outputIndexW});
Value constantZero =
builder.create<arith::ConstantOp>(loc, builder.getF32FloatAttr(0.0));
Value pred = builder.create<arith::AndIOp>(loc, predH, predW);
Value result = builder.create<arith::SelectOp>(
loc, pred, gradOutputValue,
convertScalarToDtype(builder, loc, constantZero, gradOutputElemType));
return result;
}
// The implementation of the `aten.upsample_nearest2d_backward.vec` op's
// lowering is as follows:
// gradOutput: Tensor of size [n, c, oh, ow]
// outTensor: Tensor of size [n, c, ih, iw], initialized with zero
// kh = ceil(oh/ih), kw = ceil(ow/iw)
//
// for i in range(n):
// for j in range(c):
// for p in range(ih):
// for q in range(iw):
// for x in range(kh):
// for y in range(kw):
// outTensor[i, j, p, q] += gradOutput[i, j, (p*kh)+x, (q*kw)+y]
namespace {
class ConvertAtenUpsampleNearest2dBackwardVecOp
: public OpConversionPattern<AtenUpsampleNearest2dBackwardVecOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenUpsampleNearest2dBackwardVecOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value gradOutput = adaptor.grad_output();
Type resultType = getTypeConverter()->convertType(op.getResult().getType());
auto gradOutputType = gradOutput.getType().cast<RankedTensorType>();
auto gradOutputRank = gradOutputType.getRank();
Type elementType = gradOutputType.getElementType();
SmallVector<Value> gradOutputSizeIndexValues =
getTensorSizes(rewriter, loc, gradOutput);
SmallVector<Value> gradOutputSizeIntValues =
castIndexVectorToInt64Vector(rewriter, loc, gradOutputSizeIndexValues);
SmallVector<Value, 2> scaleFactorsFloatValues;
SmallVector<Value, 4> inputSizeTorchInt;
if (!getListConstructElements(op.input_size(), inputSizeTorchInt))
return rewriter.notifyMatchFailure(
op, "unimplemented: the input_size is not constructed from "
"ListConstruct");
SmallVector<Value, 4> inputSizeIntValues;
inputSizeIntValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), inputSizeTorchInt);
// The dimension at which the scaling starts.
unsigned hDimOffset = 2;
if (!op.scale_factors().getType().isa<Torch::NoneType>()) {
SmallVector<Value, 2> scaleFactorsTorchFloat;
if (!getListConstructElements(op.scale_factors(), scaleFactorsTorchFloat))
return rewriter.notifyMatchFailure(
op, "unimplemented: the scale_factors is not constructed from "
"ListConstruct");
scaleFactorsFloatValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), scaleFactorsTorchFloat);
} else {
for (unsigned i = hDimOffset; i < gradOutputRank; i++) {
auto scaleFactorVal = rewriter.create<arith::DivFOp>(
loc,
convertScalarToDtype(rewriter, loc, gradOutputSizeIntValues[i],
mlir::Float32Type::get(op->getContext())),
convertScalarToDtype(rewriter, loc, inputSizeIntValues[i],
mlir::Float32Type::get(op->getContext())));
scaleFactorsFloatValues.push_back(scaleFactorVal);
}
}
SmallVector<Value, 2> scaleFactorsIntValues;
for (auto v : scaleFactorsFloatValues)
scaleFactorsIntValues.push_back(convertScalarToDtype(
rewriter, loc, rewriter.create<math::CeilOp>(loc, v),
mlir::IntegerType::get(op->getContext(), 64)));
Value outTensor = createZeroInitTensor(
rewriter, loc,
castIntVectorToIndexVector(rewriter, loc, inputSizeIntValues),
elementType);
Value kernelTensor = rewriter.create<tensor::EmptyOp>(
loc,
getAsOpFoldResult(
castIntVectorToIndexVector(rewriter, loc, scaleFactorsIntValues)),
elementType);
unsigned kernelRank = scaleFactorsIntValues.size();
SmallVector<AffineExpr> affineExprs;
for (unsigned i = 0; i < gradOutputRank; i++)
affineExprs.push_back(rewriter.getAffineDimExpr(i));
AffineMap outputMap =
AffineMap::get(gradOutputRank + kernelRank,
/*symbolCount=*/0, affineExprs, op->getContext());
affineExprs.clear();
for (unsigned i = gradOutputRank; i < gradOutputRank + kernelRank; i++)
affineExprs.push_back(rewriter.getAffineDimExpr(i));
AffineMap kernelMap =
AffineMap::get(gradOutputRank + kernelRank,
/*symbolCount=*/0, affineExprs, op->getContext());
SmallVector<AffineMap> indexingMaps{kernelMap, outputMap};
SmallVector<StringRef> iteratorTypes(gradOutputRank,
getParallelIteratorTypeName());
iteratorTypes.push_back(getReductionIteratorTypeName());
iteratorTypes.push_back(getReductionIteratorTypeName());
Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, outTensor.getType(), ValueRange{kernelTensor},
ValueRange{outTensor},
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value n = rewriter.create<linalg::IndexOp>(loc, 0);
Value c = rewriter.create<linalg::IndexOp>(loc, 1);
Value ih = rewriter.create<linalg::IndexOp>(loc, 2);
Value iw = rewriter.create<linalg::IndexOp>(loc, 3);
Value kh = rewriter.create<linalg::IndexOp>(loc, 4);
Value kw = rewriter.create<linalg::IndexOp>(loc, 5);
Value accValue = getGradOutputValue(
rewriter, loc, gradOutput, elementType, n, c, ih, iw, kh,
kw, gradOutputSizeIndexValues, scaleFactorsIntValues);
Value outputVal = args[1];
outputVal =
rewriter.create<arith::AddFOp>(loc, outputVal, accValue);
b.create<linalg::YieldOp>(loc, outputVal);
})
->getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
return success();
}
};
} // namespace
void mlir::torch::torch_to_linalg::
populateIndirectDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
@ -913,4 +1097,6 @@ void mlir::torch::torch_to_linalg::
patterns.add<ConvertAtenEmbeddingBagPaddingIdxOp>(typeConverter, context);
target.addIllegalOp<AtenUpsampleNearest2dVecOp>();
patterns.add<ConvertAtenUpsampleNearest2dVecOp>(typeConverter, context);
target.addIllegalOp<AtenUpsampleNearest2dBackwardVecOp>();
patterns.add<ConvertAtenUpsampleNearest2dBackwardVecOp>(typeConverter, context);
}

View File

@ -156,6 +156,15 @@ castIntVectorToIndexVector(OpBuilder &b, Location loc,
return indexValues;
}
SmallVector<Value>
castIndexVectorToInt64Vector(OpBuilder &b, Location loc,
SmallVectorImpl<Value> &indexValues) {
SmallVector<Value> intValues;
for (Value v : indexValues)
intValues.push_back(castIndexToInt64(b, loc, v));
return intValues;
}
Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) {
return b.createOrFold<tensor::DimOp>(loc, v, dim);
}

View File

@ -700,8 +700,8 @@ void TypeAnalysis::visitOperation(Operation *op,
AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp,
AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp,
AtenFillTensorOp>(op)) {
AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp,
AtenUpsampleNearest2dBackwardVecOp>(op)) {
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
}

View File

@ -6076,6 +6076,9 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward.vec\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<float>>) -> !torch.list<int> {\n"
" return %arg2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.avg_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"

View File

@ -690,6 +690,9 @@ def atenmax_pool2d_with_indices(self: List[int], kernel_size: List[int], stri
def atenmax_pool2d_with_indices_backward(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]:
return self
def atenupsample_nearest2d_backwardvec(grad_output: List[int], output_size: Optional[List[int]], input_size: List[int], scale_factors: Optional[List[float]]) -> List[int]:
return input_size
# TODO: This should be upstreamed.
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool, divisor_override: Optional[int]):

View File

@ -407,6 +407,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)")
emit("aten::upsample_nearest2d_backward.vec : (Tensor, int[]?, int[], float[]?) -> (Tensor)")
# Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")

View File

@ -3021,4 +3021,51 @@ class SingleTensorTupleReturn(torch.nn.Module):
@register_test_case(module_factory=lambda: SingleTensorTupleReturn())
def SingleTensorTupleReturn_basic(module, tu: TestUtils):
module.forward(torch.randn(2, 4))
module.forward(torch.randn(2, 4))
# ==============================================================================
class UpSampleNearest2dBackwardVec(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, input):
return torch.ops.aten.upsample_nearest2d_backward(input,
output_size=[4, 8],
input_size=[1, 1, 2, 3],
scale_factors=None)
@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardVec())
def UpSampleNearest2dBackwardVec_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 4, 8))
class UpSampleNearest2dBackwardOutputSizeNone(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float64, True),
])
def forward(self, input):
return torch.ops.aten.upsample_nearest2d_backward(input,
output_size=None,
input_size=[1, 1, 2, 3],
scale_factors=[3.0, 4.0])
@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardOutputSizeNone())
def UpSampleNearest2dBackwardOutputSizeNone_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 6, 12).to(torch.float64))