mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.upsample_nearest2d_backward.vec op
Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>pull/1558/head
parent
db5a496eb4
commit
fedf8c0640
|
@ -621,4 +621,6 @@ LTC_XFAIL_SET = {
|
||||||
"Fill_TensorFloat32WithFloat32_basic",
|
"Fill_TensorFloat32WithFloat32_basic",
|
||||||
"Fill_TensorFloat32WithFloat64_basic",
|
"Fill_TensorFloat32WithFloat64_basic",
|
||||||
"Fill_TensorFloat32WithInt64_basic",
|
"Fill_TensorFloat32WithInt64_basic",
|
||||||
|
"UpSampleNearest2dBackwardVec_basic",
|
||||||
|
"UpSampleNearest2dBackwardOutputSizeNone_basic",
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,6 +49,10 @@ SmallVector<Value>
|
||||||
castIntVectorToIndexVector(OpBuilder &b, Location loc,
|
castIntVectorToIndexVector(OpBuilder &b, Location loc,
|
||||||
SmallVectorImpl<Value> &intValues);
|
SmallVectorImpl<Value> &intValues);
|
||||||
|
|
||||||
|
SmallVector<Value>
|
||||||
|
castIndexVectorToInt64Vector(OpBuilder &b, Location loc,
|
||||||
|
SmallVectorImpl<Value> &indexValues);
|
||||||
|
|
||||||
Value getDimOp(OpBuilder &b, Location loc, Value v, int dim);
|
Value getDimOp(OpBuilder &b, Location loc, Value v, int dim);
|
||||||
|
|
||||||
SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
|
SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
|
||||||
|
|
|
@ -4893,6 +4893,32 @@ def Torch_AtenMseLossOp : Torch_Op<"aten.mse_loss", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenUpsampleNearest2dBackwardVecOp : Torch_Op<"aten.upsample_nearest2d_backward.vec", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::upsample_nearest2d_backward.vec : (Tensor, int[]?, int[], float[]?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$grad_output,
|
||||||
|
AnyTorchOptionalListOfTorchIntType:$output_size,
|
||||||
|
AnyTorchListOfTorchIntType:$input_size,
|
||||||
|
AnyTorchOptionalListOfTorchFloatType:$scale_factors
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenUpsampleNearest2dBackwardVecOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||||
|
}
|
||||||
|
void AtenUpsampleNearest2dBackwardVecOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 4, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
|
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -896,6 +896,190 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
static Value getGradOutputValue(OpBuilder &builder, Location loc,
|
||||||
|
Value gradOutput, Type gradOutputElemType,
|
||||||
|
Value numBatch, Value numChannel,
|
||||||
|
Value inputIndexH, Value inputIndexW,
|
||||||
|
Value kernelIndexH, Value kernelIndexW,
|
||||||
|
SmallVector<Value> &gradOutputSizeIndexValues,
|
||||||
|
SmallVector<Value, 2> &scaleFactorsIntValues) {
|
||||||
|
Value constantOne = builder.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
|
|
||||||
|
Value outputIndexH = builder.create<arith::MulIOp>(
|
||||||
|
loc, inputIndexH, castIntToIndex(builder, loc, scaleFactorsIntValues[0]));
|
||||||
|
outputIndexH = builder.create<arith::AddIOp>(loc, outputIndexH, kernelIndexH);
|
||||||
|
|
||||||
|
Value outputIndexW = builder.create<arith::MulIOp>(
|
||||||
|
loc, inputIndexW, castIntToIndex(builder, loc, scaleFactorsIntValues[1]));
|
||||||
|
outputIndexW = builder.create<arith::AddIOp>(loc, outputIndexW, kernelIndexW);
|
||||||
|
|
||||||
|
// Handling corner cases.
|
||||||
|
Value gradOutputHMinusOne = builder.create<arith::SubIOp>(
|
||||||
|
loc, gradOutputSizeIndexValues[2], constantOne);
|
||||||
|
Value predH = builder.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::sle, outputIndexH, gradOutputHMinusOne);
|
||||||
|
outputIndexH = builder.create<arith::SelectOp>(loc, predH, outputIndexH,
|
||||||
|
gradOutputHMinusOne);
|
||||||
|
|
||||||
|
Value gradOutputWMinusOne = builder.create<arith::SubIOp>(
|
||||||
|
loc, gradOutputSizeIndexValues[3], constantOne);
|
||||||
|
Value predW = builder.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::sle, outputIndexW, gradOutputWMinusOne);
|
||||||
|
outputIndexW = builder.create<arith::SelectOp>(loc, predW, outputIndexW,
|
||||||
|
gradOutputWMinusOne);
|
||||||
|
|
||||||
|
Value gradOutputValue = builder.create<tensor::ExtractOp>(
|
||||||
|
loc, gradOutput,
|
||||||
|
ValueRange{numBatch, numChannel, outputIndexH, outputIndexW});
|
||||||
|
Value constantZero =
|
||||||
|
builder.create<arith::ConstantOp>(loc, builder.getF32FloatAttr(0.0));
|
||||||
|
Value pred = builder.create<arith::AndIOp>(loc, predH, predW);
|
||||||
|
Value result = builder.create<arith::SelectOp>(
|
||||||
|
loc, pred, gradOutputValue,
|
||||||
|
convertScalarToDtype(builder, loc, constantZero, gradOutputElemType));
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The implementation of the `aten.upsample_nearest2d_backward.vec` op's
|
||||||
|
// lowering is as follows:
|
||||||
|
// gradOutput: Tensor of size [n, c, oh, ow]
|
||||||
|
// outTensor: Tensor of size [n, c, ih, iw], initialized with zero
|
||||||
|
// kh = ceil(oh/ih), kw = ceil(ow/iw)
|
||||||
|
//
|
||||||
|
// for i in range(n):
|
||||||
|
// for j in range(c):
|
||||||
|
// for p in range(ih):
|
||||||
|
// for q in range(iw):
|
||||||
|
// for x in range(kh):
|
||||||
|
// for y in range(kw):
|
||||||
|
// outTensor[i, j, p, q] += gradOutput[i, j, (p*kh)+x, (q*kw)+y]
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenUpsampleNearest2dBackwardVecOp
|
||||||
|
: public OpConversionPattern<AtenUpsampleNearest2dBackwardVecOp> {
|
||||||
|
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenUpsampleNearest2dBackwardVecOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
Value gradOutput = adaptor.grad_output();
|
||||||
|
|
||||||
|
Type resultType = getTypeConverter()->convertType(op.getResult().getType());
|
||||||
|
auto gradOutputType = gradOutput.getType().cast<RankedTensorType>();
|
||||||
|
auto gradOutputRank = gradOutputType.getRank();
|
||||||
|
Type elementType = gradOutputType.getElementType();
|
||||||
|
|
||||||
|
SmallVector<Value> gradOutputSizeIndexValues =
|
||||||
|
getTensorSizes(rewriter, loc, gradOutput);
|
||||||
|
SmallVector<Value> gradOutputSizeIntValues =
|
||||||
|
castIndexVectorToInt64Vector(rewriter, loc, gradOutputSizeIndexValues);
|
||||||
|
SmallVector<Value, 2> scaleFactorsFloatValues;
|
||||||
|
|
||||||
|
SmallVector<Value, 4> inputSizeTorchInt;
|
||||||
|
if (!getListConstructElements(op.input_size(), inputSizeTorchInt))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: the input_size is not constructed from "
|
||||||
|
"ListConstruct");
|
||||||
|
SmallVector<Value, 4> inputSizeIntValues;
|
||||||
|
inputSizeIntValues = getTypeConvertedValues(
|
||||||
|
rewriter, loc, getTypeConverter(), inputSizeTorchInt);
|
||||||
|
|
||||||
|
// The dimension at which the scaling starts.
|
||||||
|
unsigned hDimOffset = 2;
|
||||||
|
|
||||||
|
if (!op.scale_factors().getType().isa<Torch::NoneType>()) {
|
||||||
|
SmallVector<Value, 2> scaleFactorsTorchFloat;
|
||||||
|
if (!getListConstructElements(op.scale_factors(), scaleFactorsTorchFloat))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: the scale_factors is not constructed from "
|
||||||
|
"ListConstruct");
|
||||||
|
scaleFactorsFloatValues = getTypeConvertedValues(
|
||||||
|
rewriter, loc, getTypeConverter(), scaleFactorsTorchFloat);
|
||||||
|
} else {
|
||||||
|
for (unsigned i = hDimOffset; i < gradOutputRank; i++) {
|
||||||
|
auto scaleFactorVal = rewriter.create<arith::DivFOp>(
|
||||||
|
loc,
|
||||||
|
convertScalarToDtype(rewriter, loc, gradOutputSizeIntValues[i],
|
||||||
|
mlir::Float32Type::get(op->getContext())),
|
||||||
|
convertScalarToDtype(rewriter, loc, inputSizeIntValues[i],
|
||||||
|
mlir::Float32Type::get(op->getContext())));
|
||||||
|
scaleFactorsFloatValues.push_back(scaleFactorVal);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 2> scaleFactorsIntValues;
|
||||||
|
for (auto v : scaleFactorsFloatValues)
|
||||||
|
scaleFactorsIntValues.push_back(convertScalarToDtype(
|
||||||
|
rewriter, loc, rewriter.create<math::CeilOp>(loc, v),
|
||||||
|
mlir::IntegerType::get(op->getContext(), 64)));
|
||||||
|
|
||||||
|
Value outTensor = createZeroInitTensor(
|
||||||
|
rewriter, loc,
|
||||||
|
castIntVectorToIndexVector(rewriter, loc, inputSizeIntValues),
|
||||||
|
elementType);
|
||||||
|
|
||||||
|
Value kernelTensor = rewriter.create<tensor::EmptyOp>(
|
||||||
|
loc,
|
||||||
|
getAsOpFoldResult(
|
||||||
|
castIntVectorToIndexVector(rewriter, loc, scaleFactorsIntValues)),
|
||||||
|
elementType);
|
||||||
|
unsigned kernelRank = scaleFactorsIntValues.size();
|
||||||
|
|
||||||
|
SmallVector<AffineExpr> affineExprs;
|
||||||
|
for (unsigned i = 0; i < gradOutputRank; i++)
|
||||||
|
affineExprs.push_back(rewriter.getAffineDimExpr(i));
|
||||||
|
|
||||||
|
AffineMap outputMap =
|
||||||
|
AffineMap::get(gradOutputRank + kernelRank,
|
||||||
|
/*symbolCount=*/0, affineExprs, op->getContext());
|
||||||
|
|
||||||
|
affineExprs.clear();
|
||||||
|
for (unsigned i = gradOutputRank; i < gradOutputRank + kernelRank; i++)
|
||||||
|
affineExprs.push_back(rewriter.getAffineDimExpr(i));
|
||||||
|
|
||||||
|
AffineMap kernelMap =
|
||||||
|
AffineMap::get(gradOutputRank + kernelRank,
|
||||||
|
/*symbolCount=*/0, affineExprs, op->getContext());
|
||||||
|
|
||||||
|
SmallVector<AffineMap> indexingMaps{kernelMap, outputMap};
|
||||||
|
SmallVector<StringRef> iteratorTypes(gradOutputRank,
|
||||||
|
getParallelIteratorTypeName());
|
||||||
|
iteratorTypes.push_back(getReductionIteratorTypeName());
|
||||||
|
iteratorTypes.push_back(getReductionIteratorTypeName());
|
||||||
|
|
||||||
|
Value finalRes =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, outTensor.getType(), ValueRange{kernelTensor},
|
||||||
|
ValueRange{outTensor},
|
||||||
|
/*indexingMaps=*/indexingMaps,
|
||||||
|
/*iteratorTypes=*/iteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value n = rewriter.create<linalg::IndexOp>(loc, 0);
|
||||||
|
Value c = rewriter.create<linalg::IndexOp>(loc, 1);
|
||||||
|
Value ih = rewriter.create<linalg::IndexOp>(loc, 2);
|
||||||
|
Value iw = rewriter.create<linalg::IndexOp>(loc, 3);
|
||||||
|
Value kh = rewriter.create<linalg::IndexOp>(loc, 4);
|
||||||
|
Value kw = rewriter.create<linalg::IndexOp>(loc, 5);
|
||||||
|
Value accValue = getGradOutputValue(
|
||||||
|
rewriter, loc, gradOutput, elementType, n, c, ih, iw, kh,
|
||||||
|
kw, gradOutputSizeIndexValues, scaleFactorsIntValues);
|
||||||
|
Value outputVal = args[1];
|
||||||
|
outputVal =
|
||||||
|
rewriter.create<arith::AddFOp>(loc, outputVal, accValue);
|
||||||
|
b.create<linalg::YieldOp>(loc, outputVal);
|
||||||
|
})
|
||||||
|
->getResult(0);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void mlir::torch::torch_to_linalg::
|
void mlir::torch::torch_to_linalg::
|
||||||
populateIndirectDataMovementPatternsAndLegality(
|
populateIndirectDataMovementPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
|
@ -913,4 +1097,6 @@ void mlir::torch::torch_to_linalg::
|
||||||
patterns.add<ConvertAtenEmbeddingBagPaddingIdxOp>(typeConverter, context);
|
patterns.add<ConvertAtenEmbeddingBagPaddingIdxOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenUpsampleNearest2dVecOp>();
|
target.addIllegalOp<AtenUpsampleNearest2dVecOp>();
|
||||||
patterns.add<ConvertAtenUpsampleNearest2dVecOp>(typeConverter, context);
|
patterns.add<ConvertAtenUpsampleNearest2dVecOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenUpsampleNearest2dBackwardVecOp>();
|
||||||
|
patterns.add<ConvertAtenUpsampleNearest2dBackwardVecOp>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
|
@ -156,6 +156,15 @@ castIntVectorToIndexVector(OpBuilder &b, Location loc,
|
||||||
return indexValues;
|
return indexValues;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<Value>
|
||||||
|
castIndexVectorToInt64Vector(OpBuilder &b, Location loc,
|
||||||
|
SmallVectorImpl<Value> &indexValues) {
|
||||||
|
SmallVector<Value> intValues;
|
||||||
|
for (Value v : indexValues)
|
||||||
|
intValues.push_back(castIndexToInt64(b, loc, v));
|
||||||
|
return intValues;
|
||||||
|
}
|
||||||
|
|
||||||
Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) {
|
Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) {
|
||||||
return b.createOrFold<tensor::DimOp>(loc, v, dim);
|
return b.createOrFold<tensor::DimOp>(loc, v, dim);
|
||||||
}
|
}
|
||||||
|
|
|
@ -700,8 +700,8 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp,
|
AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp,
|
||||||
AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
|
AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
|
||||||
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
|
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
|
||||||
AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp,
|
AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp,
|
||||||
AtenFillTensorOp>(op)) {
|
AtenUpsampleNearest2dBackwardVecOp>(op)) {
|
||||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6076,6 +6076,9 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" return %arg1 : !torch.list<int>\n"
|
" return %arg1 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward.vec\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<float>>) -> !torch.list<int> {\n"
|
||||||
|
" return %arg2 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.avg_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.avg_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
|
|
@ -690,6 +690,9 @@ def aten〇max_pool2d_with_indices(self: List[int], kernel_size: List[int], stri
|
||||||
def aten〇max_pool2d_with_indices_backward(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]:
|
def aten〇max_pool2d_with_indices_backward(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def aten〇upsample_nearest2d_backward〇vec(grad_output: List[int], output_size: Optional[List[int]], input_size: List[int], scale_factors: Optional[List[float]]) -> List[int]:
|
||||||
|
return input_size
|
||||||
|
|
||||||
# TODO: This should be upstreamed.
|
# TODO: This should be upstreamed.
|
||||||
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
|
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
|
||||||
def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool, divisor_override: Optional[int]):
|
def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool, divisor_override: Optional[int]):
|
||||||
|
|
|
@ -407,6 +407,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
|
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
|
||||||
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
|
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
|
||||||
emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)")
|
emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)")
|
||||||
|
emit("aten::upsample_nearest2d_backward.vec : (Tensor, int[]?, int[], float[]?) -> (Tensor)")
|
||||||
|
|
||||||
# Misc tensor ops.
|
# Misc tensor ops.
|
||||||
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
||||||
|
|
|
@ -3021,4 +3021,51 @@ class SingleTensorTupleReturn(torch.nn.Module):
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: SingleTensorTupleReturn())
|
@register_test_case(module_factory=lambda: SingleTensorTupleReturn())
|
||||||
def SingleTensorTupleReturn_basic(module, tu: TestUtils):
|
def SingleTensorTupleReturn_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randn(2, 4))
|
module.forward(torch.randn(2, 4))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class UpSampleNearest2dBackwardVec(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.ops.aten.upsample_nearest2d_backward(input,
|
||||||
|
output_size=[4, 8],
|
||||||
|
input_size=[1, 1, 2, 3],
|
||||||
|
scale_factors=None)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardVec())
|
||||||
|
def UpSampleNearest2dBackwardVec_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 1, 4, 8))
|
||||||
|
|
||||||
|
|
||||||
|
class UpSampleNearest2dBackwardOutputSizeNone(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.ops.aten.upsample_nearest2d_backward(input,
|
||||||
|
output_size=None,
|
||||||
|
input_size=[1, 1, 2, 3],
|
||||||
|
scale_factors=[3.0, 4.0])
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardOutputSizeNone())
|
||||||
|
def UpSampleNearest2dBackwardOutputSizeNone_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 1, 6, 12).to(torch.float64))
|
||||||
|
|
Loading…
Reference in New Issue