[MLIR][TORCH] Fix aten.upsample_nearest2d_backward op

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
pull/1587/head
Vivek Khandelwal 2022-11-11 13:26:48 +05:30
parent d571d050fd
commit a558034c1a
7 changed files with 71 additions and 62 deletions

View File

@ -4893,28 +4893,29 @@ def Torch_AtenMseLossOp : Torch_Op<"aten.mse_loss", [
}]; }];
} }
def Torch_AtenUpsampleNearest2dBackwardVecOp : Torch_Op<"aten.upsample_nearest2d_backward.vec", [ def Torch_AtenUpsampleNearest2dBackwardOp : Torch_Op<"aten.upsample_nearest2d_backward", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
ReadOnly ReadOnly
]> { ]> {
let summary = "Generated op for `aten::upsample_nearest2d_backward.vec : (Tensor, int[]?, int[], float[]?) -> (Tensor)`"; let summary = "Generated op for `aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)`";
let arguments = (ins let arguments = (ins
AnyTorchTensorType:$grad_output, AnyTorchTensorType:$grad_output,
AnyTorchOptionalListOfTorchIntType:$output_size, AnyTorchListOfTorchIntType:$output_size,
AnyTorchListOfTorchIntType:$input_size, AnyTorchListOfTorchIntType:$input_size,
AnyTorchOptionalListOfTorchFloatType:$scale_factors AnyTorchOptionalFloatType:$scales_h,
AnyTorchOptionalFloatType:$scales_w
); );
let results = (outs let results = (outs
AnyTorchTensorType:$result AnyTorchTensorType:$result
); );
let hasCustomAssemblyFormat = 1; let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{ let extraClassDefinition = [{
ParseResult AtenUpsampleNearest2dBackwardVecOp::parse(OpAsmParser &parser, OperationState &result) { ParseResult AtenUpsampleNearest2dBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1); return parseDefaultTorchOp(parser, result, 5, 1);
} }
void AtenUpsampleNearest2dBackwardVecOp::print(OpAsmPrinter &printer) { void AtenUpsampleNearest2dBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1); printDefaultTorchOp(printer, *this, 5, 1);
} }
}]; }];
} }

View File

@ -955,13 +955,13 @@ static Value getGradOutputValue(OpBuilder &builder, Location loc,
// for y in range(kw): // for y in range(kw):
// outTensor[i, j, p, q] += gradOutput[i, j, (p*kh)+x, (q*kw)+y] // outTensor[i, j, p, q] += gradOutput[i, j, (p*kh)+x, (q*kw)+y]
namespace { namespace {
class ConvertAtenUpsampleNearest2dBackwardVecOp class ConvertAtenUpsampleNearest2dBackwardOp
: public OpConversionPattern<AtenUpsampleNearest2dBackwardVecOp> { : public OpConversionPattern<AtenUpsampleNearest2dBackwardOp> {
public: public:
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
LogicalResult LogicalResult
matchAndRewrite(AtenUpsampleNearest2dBackwardVecOp op, OpAdaptor adaptor, matchAndRewrite(AtenUpsampleNearest2dBackwardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc(); Location loc = op->getLoc();
@ -976,7 +976,6 @@ public:
getTensorSizes(rewriter, loc, gradOutput); getTensorSizes(rewriter, loc, gradOutput);
SmallVector<Value> gradOutputSizeIntValues = SmallVector<Value> gradOutputSizeIntValues =
castIndexVectorToInt64Vector(rewriter, loc, gradOutputSizeIndexValues); castIndexVectorToInt64Vector(rewriter, loc, gradOutputSizeIndexValues);
SmallVector<Value, 2> scaleFactorsFloatValues;
SmallVector<Value, 4> inputSizeTorchInt; SmallVector<Value, 4> inputSizeTorchInt;
if (!getListConstructElements(op.input_size(), inputSizeTorchInt)) if (!getListConstructElements(op.input_size(), inputSizeTorchInt))
@ -990,24 +989,32 @@ public:
// The dimension at which the scaling starts. // The dimension at which the scaling starts.
unsigned hDimOffset = 2; unsigned hDimOffset = 2;
if (!op.scale_factors().getType().isa<Torch::NoneType>()) { SmallVector<Value, 2> scaleFactorsFloatValues;
SmallVector<Value, 2> scaleFactorsTorchFloat; if (!op.scales_h().getType().isa<Torch::NoneType>()) {
if (!getListConstructElements(op.scale_factors(), scaleFactorsTorchFloat)) scaleFactorsFloatValues.push_back(adaptor.scales_h());
return rewriter.notifyMatchFailure(
op, "unimplemented: the scale_factors is not constructed from "
"ListConstruct");
scaleFactorsFloatValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), scaleFactorsTorchFloat);
} else { } else {
for (unsigned i = hDimOffset; i < gradOutputRank; i++) { auto scaleFactorVal = rewriter.create<arith::DivFOp>(
auto scaleFactorVal = rewriter.create<arith::DivFOp>( loc,
loc, convertScalarToDtype(rewriter, loc,
convertScalarToDtype(rewriter, loc, gradOutputSizeIntValues[i], gradOutputSizeIntValues[hDimOffset],
mlir::Float32Type::get(op->getContext())), mlir::Float32Type::get(op->getContext())),
convertScalarToDtype(rewriter, loc, inputSizeIntValues[i], convertScalarToDtype(rewriter, loc, inputSizeIntValues[hDimOffset],
mlir::Float32Type::get(op->getContext()))); mlir::Float32Type::get(op->getContext())));
scaleFactorsFloatValues.push_back(scaleFactorVal); scaleFactorsFloatValues.push_back(scaleFactorVal);
} }
if (!op.scales_w().getType().isa<Torch::NoneType>()) {
scaleFactorsFloatValues.push_back(adaptor.scales_w());
} else {
auto scaleFactorVal = rewriter.create<arith::DivFOp>(
loc,
convertScalarToDtype(rewriter, loc,
gradOutputSizeIntValues[hDimOffset + 1],
mlir::Float32Type::get(op->getContext())),
convertScalarToDtype(rewriter, loc,
inputSizeIntValues[hDimOffset + 1],
mlir::Float32Type::get(op->getContext())));
scaleFactorsFloatValues.push_back(scaleFactorVal);
} }
SmallVector<Value, 2> scaleFactorsIntValues; SmallVector<Value, 2> scaleFactorsIntValues;
@ -1097,6 +1104,6 @@ void mlir::torch::torch_to_linalg::
patterns.add<ConvertAtenEmbeddingBagPaddingIdxOp>(typeConverter, context); patterns.add<ConvertAtenEmbeddingBagPaddingIdxOp>(typeConverter, context);
target.addIllegalOp<AtenUpsampleNearest2dVecOp>(); target.addIllegalOp<AtenUpsampleNearest2dVecOp>();
patterns.add<ConvertAtenUpsampleNearest2dVecOp>(typeConverter, context); patterns.add<ConvertAtenUpsampleNearest2dVecOp>(typeConverter, context);
target.addIllegalOp<AtenUpsampleNearest2dBackwardVecOp>(); target.addIllegalOp<AtenUpsampleNearest2dBackwardOp>();
patterns.add<ConvertAtenUpsampleNearest2dBackwardVecOp>(typeConverter, context); patterns.add<ConvertAtenUpsampleNearest2dBackwardOp>(typeConverter, context);
} }

View File

@ -701,7 +701,7 @@ void TypeAnalysis::visitOperation(Operation *op,
AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp, AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp, AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp, AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp,
AtenUpsampleNearest2dBackwardVecOp>(op)) { AtenUpsampleNearest2dBackwardOp>(op)) {
return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
} }

View File

@ -6076,7 +6076,7 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n" " return %arg1 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward.vec\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<float>>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<float>, %arg4: !torch.optional<float>) -> !torch.list<int> {\n"
" return %arg2 : !torch.list<int>\n" " return %arg2 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"

View File

@ -690,7 +690,7 @@ def atenmax_pool2d_with_indices(self: List[int], kernel_size: List[int], stri
def atenmax_pool2d_with_indices_backward(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]: def atenmax_pool2d_with_indices_backward(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]:
return self return self
def atenupsample_nearest2d_backwardvec(grad_output: List[int], output_size: Optional[List[int]], input_size: List[int], scale_factors: Optional[List[float]]) -> List[int]: def atenupsample_nearest2d_backward(grad_output: List[int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]:
return input_size return input_size
# TODO: This should be upstreamed. # TODO: This should be upstreamed.

View File

@ -407,7 +407,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)") emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)") emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)") emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)")
emit("aten::upsample_nearest2d_backward.vec : (Tensor, int[]?, int[], float[]?) -> (Tensor)") emit("aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)")
# Misc tensor ops. # Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")

View File

@ -3010,7 +3010,30 @@ def AtenToDeviceModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class UpSampleNearest2dBackwardVec(torch.nn.Module): class UpSampleNearest2dBackward(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float64, True),
])
def forward(self, input):
return torch.ops.aten.upsample_nearest2d_backward(input,
output_size=[6, 12],
input_size=[1, 1, 2, 3],
scales_h=3.0,
scales_w=4.0)
@register_test_case(module_factory=lambda: UpSampleNearest2dBackward())
def UpSampleNearest2dBackward_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 6, 12).to(torch.float64))
class UpSampleNearest2dBackwardScalesNone(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -3024,31 +3047,9 @@ class UpSampleNearest2dBackwardVec(torch.nn.Module):
return torch.ops.aten.upsample_nearest2d_backward(input, return torch.ops.aten.upsample_nearest2d_backward(input,
output_size=[4, 8], output_size=[4, 8],
input_size=[1, 1, 2, 3], input_size=[1, 1, 2, 3],
scale_factors=None) scales_h=None,
scales_w=None)
@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardScalesNone())
@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardVec()) def UpSampleNearest2dBackwardScalesNone_basic(module, tu: TestUtils):
def UpSampleNearest2dBackwardVec_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 4, 8)) module.forward(tu.rand(1, 1, 4, 8))
class UpSampleNearest2dBackwardOutputSizeNone(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float64, True),
])
def forward(self, input):
return torch.ops.aten.upsample_nearest2d_backward(input,
output_size=None,
input_size=[1, 1, 2, 3],
scale_factors=[3.0, 4.0])
@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardOutputSizeNone())
def UpSampleNearest2dBackwardOutputSizeNone_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 6, 12).to(torch.float64))