mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Fix aten.upsample_nearest2d_backward op
Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>pull/1587/head
parent
d571d050fd
commit
a558034c1a
|
@ -4893,28 +4893,29 @@ def Torch_AtenMseLossOp : Torch_Op<"aten.mse_loss", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUpsampleNearest2dBackwardVecOp : Torch_Op<"aten.upsample_nearest2d_backward.vec", [
|
||||
def Torch_AtenUpsampleNearest2dBackwardOp : Torch_Op<"aten.upsample_nearest2d_backward", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::upsample_nearest2d_backward.vec : (Tensor, int[]?, int[], float[]?) -> (Tensor)`";
|
||||
let summary = "Generated op for `aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$grad_output,
|
||||
AnyTorchOptionalListOfTorchIntType:$output_size,
|
||||
AnyTorchListOfTorchIntType:$output_size,
|
||||
AnyTorchListOfTorchIntType:$input_size,
|
||||
AnyTorchOptionalListOfTorchFloatType:$scale_factors
|
||||
AnyTorchOptionalFloatType:$scales_h,
|
||||
AnyTorchOptionalFloatType:$scales_w
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenUpsampleNearest2dBackwardVecOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
ParseResult AtenUpsampleNearest2dBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||
}
|
||||
void AtenUpsampleNearest2dBackwardVecOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
void AtenUpsampleNearest2dBackwardOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 5, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
|
|
@ -955,13 +955,13 @@ static Value getGradOutputValue(OpBuilder &builder, Location loc,
|
|||
// for y in range(kw):
|
||||
// outTensor[i, j, p, q] += gradOutput[i, j, (p*kh)+x, (q*kw)+y]
|
||||
namespace {
|
||||
class ConvertAtenUpsampleNearest2dBackwardVecOp
|
||||
: public OpConversionPattern<AtenUpsampleNearest2dBackwardVecOp> {
|
||||
class ConvertAtenUpsampleNearest2dBackwardOp
|
||||
: public OpConversionPattern<AtenUpsampleNearest2dBackwardOp> {
|
||||
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenUpsampleNearest2dBackwardVecOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(AtenUpsampleNearest2dBackwardOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op->getLoc();
|
||||
|
@ -976,7 +976,6 @@ public:
|
|||
getTensorSizes(rewriter, loc, gradOutput);
|
||||
SmallVector<Value> gradOutputSizeIntValues =
|
||||
castIndexVectorToInt64Vector(rewriter, loc, gradOutputSizeIndexValues);
|
||||
SmallVector<Value, 2> scaleFactorsFloatValues;
|
||||
|
||||
SmallVector<Value, 4> inputSizeTorchInt;
|
||||
if (!getListConstructElements(op.input_size(), inputSizeTorchInt))
|
||||
|
@ -990,24 +989,32 @@ public:
|
|||
// The dimension at which the scaling starts.
|
||||
unsigned hDimOffset = 2;
|
||||
|
||||
if (!op.scale_factors().getType().isa<Torch::NoneType>()) {
|
||||
SmallVector<Value, 2> scaleFactorsTorchFloat;
|
||||
if (!getListConstructElements(op.scale_factors(), scaleFactorsTorchFloat))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: the scale_factors is not constructed from "
|
||||
"ListConstruct");
|
||||
scaleFactorsFloatValues = getTypeConvertedValues(
|
||||
rewriter, loc, getTypeConverter(), scaleFactorsTorchFloat);
|
||||
SmallVector<Value, 2> scaleFactorsFloatValues;
|
||||
if (!op.scales_h().getType().isa<Torch::NoneType>()) {
|
||||
scaleFactorsFloatValues.push_back(adaptor.scales_h());
|
||||
} else {
|
||||
for (unsigned i = hDimOffset; i < gradOutputRank; i++) {
|
||||
auto scaleFactorVal = rewriter.create<arith::DivFOp>(
|
||||
loc,
|
||||
convertScalarToDtype(rewriter, loc, gradOutputSizeIntValues[i],
|
||||
mlir::Float32Type::get(op->getContext())),
|
||||
convertScalarToDtype(rewriter, loc, inputSizeIntValues[i],
|
||||
mlir::Float32Type::get(op->getContext())));
|
||||
scaleFactorsFloatValues.push_back(scaleFactorVal);
|
||||
}
|
||||
auto scaleFactorVal = rewriter.create<arith::DivFOp>(
|
||||
loc,
|
||||
convertScalarToDtype(rewriter, loc,
|
||||
gradOutputSizeIntValues[hDimOffset],
|
||||
mlir::Float32Type::get(op->getContext())),
|
||||
convertScalarToDtype(rewriter, loc, inputSizeIntValues[hDimOffset],
|
||||
mlir::Float32Type::get(op->getContext())));
|
||||
scaleFactorsFloatValues.push_back(scaleFactorVal);
|
||||
}
|
||||
|
||||
if (!op.scales_w().getType().isa<Torch::NoneType>()) {
|
||||
scaleFactorsFloatValues.push_back(adaptor.scales_w());
|
||||
} else {
|
||||
auto scaleFactorVal = rewriter.create<arith::DivFOp>(
|
||||
loc,
|
||||
convertScalarToDtype(rewriter, loc,
|
||||
gradOutputSizeIntValues[hDimOffset + 1],
|
||||
mlir::Float32Type::get(op->getContext())),
|
||||
convertScalarToDtype(rewriter, loc,
|
||||
inputSizeIntValues[hDimOffset + 1],
|
||||
mlir::Float32Type::get(op->getContext())));
|
||||
scaleFactorsFloatValues.push_back(scaleFactorVal);
|
||||
}
|
||||
|
||||
SmallVector<Value, 2> scaleFactorsIntValues;
|
||||
|
@ -1097,6 +1104,6 @@ void mlir::torch::torch_to_linalg::
|
|||
patterns.add<ConvertAtenEmbeddingBagPaddingIdxOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenUpsampleNearest2dVecOp>();
|
||||
patterns.add<ConvertAtenUpsampleNearest2dVecOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenUpsampleNearest2dBackwardVecOp>();
|
||||
patterns.add<ConvertAtenUpsampleNearest2dBackwardVecOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenUpsampleNearest2dBackwardOp>();
|
||||
patterns.add<ConvertAtenUpsampleNearest2dBackwardOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -701,7 +701,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
|
||||
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
|
||||
AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp,
|
||||
AtenUpsampleNearest2dBackwardVecOp>(op)) {
|
||||
AtenUpsampleNearest2dBackwardOp>(op)) {
|
||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||
}
|
||||
|
||||
|
|
|
@ -6076,7 +6076,7 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" return %arg1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward.vec\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<float>>) -> !torch.list<int> {\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<float>, %arg4: !torch.optional<float>) -> !torch.list<int> {\n"
|
||||
" return %arg2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
|
|
|
@ -690,7 +690,7 @@ def aten〇max_pool2d_with_indices(self: List[int], kernel_size: List[int], stri
|
|||
def aten〇max_pool2d_with_indices_backward(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇upsample_nearest2d_backward〇vec(grad_output: List[int], output_size: Optional[List[int]], input_size: List[int], scale_factors: Optional[List[float]]) -> List[int]:
|
||||
def aten〇upsample_nearest2d_backward(grad_output: List[int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]:
|
||||
return input_size
|
||||
|
||||
# TODO: This should be upstreamed.
|
||||
|
|
|
@ -407,7 +407,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
|
||||
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
|
||||
emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)")
|
||||
emit("aten::upsample_nearest2d_backward.vec : (Tensor, int[]?, int[], float[]?) -> (Tensor)")
|
||||
emit("aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)")
|
||||
|
||||
# Misc tensor ops.
|
||||
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
||||
|
|
|
@ -3010,7 +3010,30 @@ def AtenToDeviceModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class UpSampleNearest2dBackwardVec(torch.nn.Module):
|
||||
class UpSampleNearest2dBackward(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.upsample_nearest2d_backward(input,
|
||||
output_size=[6, 12],
|
||||
input_size=[1, 1, 2, 3],
|
||||
scales_h=3.0,
|
||||
scales_w=4.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UpSampleNearest2dBackward())
|
||||
def UpSampleNearest2dBackward_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 6, 12).to(torch.float64))
|
||||
|
||||
|
||||
class UpSampleNearest2dBackwardScalesNone(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -3024,31 +3047,9 @@ class UpSampleNearest2dBackwardVec(torch.nn.Module):
|
|||
return torch.ops.aten.upsample_nearest2d_backward(input,
|
||||
output_size=[4, 8],
|
||||
input_size=[1, 1, 2, 3],
|
||||
scale_factors=None)
|
||||
scales_h=None,
|
||||
scales_w=None)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardVec())
|
||||
def UpSampleNearest2dBackwardVec_basic(module, tu: TestUtils):
|
||||
@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardScalesNone())
|
||||
def UpSampleNearest2dBackwardScalesNone_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 4, 8))
|
||||
|
||||
|
||||
class UpSampleNearest2dBackwardOutputSizeNone(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.upsample_nearest2d_backward(input,
|
||||
output_size=None,
|
||||
input_size=[1, 1, 2, 3],
|
||||
scale_factors=[3.0, 4.0])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardOutputSizeNone())
|
||||
def UpSampleNearest2dBackwardOutputSizeNone_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 6, 12).to(torch.float64))
|
||||
|
|
Loading…
Reference in New Issue