mirror of https://github.com/llvm/torch-mlir
parent
fa0b24a73c
commit
9ec0683e92
|
@ -18,6 +18,8 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
|||
"TableBatchEmbeddingModule_basic",
|
||||
"MobilenetV2Module_basic",
|
||||
"MobilenetV3Module_basic",
|
||||
"ConvolutionModule3D_basic",
|
||||
"ConvolutionModule1D_basic",
|
||||
}
|
||||
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
|
||||
|
@ -156,4 +158,5 @@ TOSA_PASS_SET = {
|
|||
"GeluBackwardModule_basic",
|
||||
"ElementwiseNeIntScalarModule_basic",
|
||||
"ElementwiseNeFloatTensorModule_basic",
|
||||
"ConvolutionModule2DStatic_basic",
|
||||
}
|
||||
|
|
|
@ -2728,6 +2728,68 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$input,
|
||||
AnyTorchTensorType:$weight,
|
||||
AnyTorchOptionalTensorType:$bias,
|
||||
ListOfTorchIntType:$stride,
|
||||
ListOfTorchIntType:$padding,
|
||||
ListOfTorchIntType:$dilation,
|
||||
Torch_BoolType:$transposed,
|
||||
ListOfTorchIntType:$output_padding,
|
||||
Torch_IntType:$groups
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenConvolutionOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 9, 1);
|
||||
}
|
||||
void AtenConvolutionOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 9, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenConvolutionOverrideableOp : Torch_Op<"aten.convolution_overrideable", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$input,
|
||||
AnyTorchTensorType:$weight,
|
||||
AnyTorchOptionalTensorType:$bias,
|
||||
ListOfTorchIntType:$stride,
|
||||
ListOfTorchIntType:$padding,
|
||||
ListOfTorchIntType:$dilation,
|
||||
Torch_BoolType:$transposed,
|
||||
ListOfTorchIntType:$output_padding,
|
||||
Torch_IntType:$groups
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenConvolutionOverrideableOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 9, 1);
|
||||
}
|
||||
void AtenConvolutionOverrideableOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 9, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenNativeBatchNormOp : Torch_Op<"aten.native_batch_norm", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -449,58 +449,65 @@ public:
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenConv2dOp : public OpConversionPattern<AtenConv2dOp> {
|
||||
class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenConv2dOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
Value input = adaptor.input(); /* in form of N*C*H*W */
|
||||
Value weight = adaptor.weight(); /* in form of F*C*H*W */
|
||||
Value groups = adaptor.groups();
|
||||
|
||||
Type elementType =
|
||||
input.getType().cast<RankedTensorType>().getElementType();
|
||||
if (!elementType.isa<mlir::FloatType>())
|
||||
return op.emitError("unimplemented: non-floating point type");
|
||||
size_t inRank = input.getType().cast<RankedTensorType>().getRank();
|
||||
if (inRank != 4)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only 2D convolution currently supported");
|
||||
|
||||
Type intType = IntegerType::get(context, 64);
|
||||
auto castIndexToInt = [&](Value v) {
|
||||
return rewriter.create<arith::IndexCastOp>(loc, intType, v);
|
||||
};
|
||||
|
||||
Value N = getDimOp(rewriter, loc, input, 0);
|
||||
Value Hin = getDimOp(rewriter, loc, input, 2);
|
||||
Value Win = getDimOp(rewriter, loc, input, 3);
|
||||
Value F = getDimOp(rewriter, loc, weight, 0);
|
||||
Value weightH = getDimOp(rewriter, loc, weight, 2);
|
||||
Value weightW = getDimOp(rewriter, loc, weight, 3);
|
||||
|
||||
// Pattern match against the op's original operands, because otherwise we
|
||||
// will get the lowered version of the operands which is harder to pattern
|
||||
// match.
|
||||
SmallVector<int64_t> paddingInts;
|
||||
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support constant padding values");
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 2> strideInts;
|
||||
SmallVector<int64_t> strideInts;
|
||||
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int strides");
|
||||
SmallVector<int64_t, 2> dilationInts;
|
||||
SmallVector<int64_t> dilationInts;
|
||||
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int dilations");
|
||||
Value c1 =
|
||||
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(intType, 1));
|
||||
Value groupEqual1 = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, groups, c1);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, groupEqual1, rewriter.getStringAttr("expect groups to be 1"));
|
||||
|
||||
Value N = getDimOp(rewriter, loc, input, 0);
|
||||
SmallVector<Value> inDims;
|
||||
for (size_t i = 2; i < inRank; i++)
|
||||
inDims.push_back(getDimOp(rewriter, loc, input, i));
|
||||
Value F = getDimOp(rewriter, loc, weight, 0);
|
||||
SmallVector<Value> weightDims;
|
||||
for (size_t i = 2; i < inRank; i++)
|
||||
weightDims.push_back(getDimOp(rewriter, loc, weight, i));
|
||||
|
||||
// Guard unused values (transposed, groups)
|
||||
int64_t group_size;
|
||||
if (!matchPattern(op.groups(), m_TorchConstantInt(&group_size)) ||
|
||||
group_size != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only group size of 1 supported");
|
||||
bool transposed = true;
|
||||
if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed)) ||
|
||||
transposed)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only non-transposed convolution supported");
|
||||
|
||||
// Pad the input tensor according to padding.
|
||||
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
|
||||
|
@ -516,15 +523,14 @@ public:
|
|||
SmallVector<Value> strideIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, strideInts);
|
||||
|
||||
Value Hout = torch_to_linalg::getOutputDimForConvOps(
|
||||
rewriter, loc, Hin, paddingIntValues[0], dilationIntValues[0],
|
||||
castIndexToInt(weightH), strideIntValues[0]);
|
||||
Value Wout = torch_to_linalg::getOutputDimForConvOps(
|
||||
rewriter, loc, Win, paddingIntValues[1], dilationIntValues[1],
|
||||
castIndexToInt(weightW), strideIntValues[1]);
|
||||
SmallVector<Value> outDims{N, F};
|
||||
for (size_t i = 0; i < inRank - 2; i++)
|
||||
outDims.push_back(torch_to_linalg::getOutputDimForConvOps(
|
||||
rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i],
|
||||
castIndexToInt(weightDims[i]), strideIntValues[i]));
|
||||
|
||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{N, F, Hout, Wout}, elementType);
|
||||
Value initTensor =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, outDims, elementType);
|
||||
|
||||
Value bias = adaptor.bias();
|
||||
Value biasInitTensor;
|
||||
|
@ -559,14 +565,17 @@ public:
|
|||
|
||||
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
|
||||
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
|
||||
Value conv2d =
|
||||
|
||||
// TODO: add 1D and 3D case
|
||||
Value conv =
|
||||
rewriter
|
||||
.create<linalg::Conv2DNchwFchwOp>(
|
||||
loc, biasInitTensor.getType(), ValueRange{paddedInput, weight},
|
||||
biasInitTensor, stridesAttr, dilationAttr)
|
||||
.getResult(0);
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv2d);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -584,6 +593,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
|
|||
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenLinearOp>();
|
||||
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenConv2dOp>();
|
||||
patterns.add<ConvertAtenConv2dOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenConvolutionOp>();
|
||||
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -1677,8 +1677,8 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
|
|||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenConv2dOp>::matchAndRewrite(
|
||||
AtenConv2dOp op, OpAdaptor adaptor,
|
||||
LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
||||
AtenConvolutionOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
auto input = adaptor.input();
|
||||
|
@ -1692,13 +1692,19 @@ LogicalResult ConvertAtenOp<AtenConv2dOp>::matchAndRewrite(
|
|||
|
||||
if (!inputTy || !weightTy || !outputTy)
|
||||
return op.emitError(
|
||||
"Input, weight and output to Conv2d must be ranked tensors");
|
||||
"Input, weight and output to Convolution must be ranked tensors");
|
||||
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
auto weightElemTy = weightTy.getElementType();
|
||||
auto inputShape = inputTy.getShape();
|
||||
auto weightShape = weightTy.getShape();
|
||||
|
||||
if (inputTy.getRank() != 4)
|
||||
return op.emitError("Unimplemented: only 2D convolutions supported");
|
||||
|
||||
if (!weightTy.hasStaticShape())
|
||||
return op.emitError("Unimplemented: TOSA only supports static weight");
|
||||
|
||||
// Bias is optional. TOSA mandates a zero tensor here, so construct one if
|
||||
// required.
|
||||
auto bias = adaptor.bias();
|
||||
|
@ -3140,7 +3146,7 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
|
||||
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenRsubScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenConv2dOp);
|
||||
INSERT_ATENOP_PATTERN(AtenConvolutionOp);
|
||||
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReshapeOp);
|
||||
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
|
||||
|
|
|
@ -737,6 +737,47 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.convolution_overrideable to aten.convolution
|
||||
namespace {
|
||||
class DecomposeAtenConvolutionOverrideableOp
|
||||
: public OpRewritePattern<AtenConvolutionOverrideableOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenConvolutionOverrideableOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
||||
op, op->getResultTypes(), op.input(), op.weight(), op.bias(),
|
||||
op.stride(), op.padding(), op.dilation(), op.transposed(),
|
||||
op.output_padding(), op.groups());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.conv2d to aten.convolution
|
||||
namespace {
|
||||
class DecomposeAtenConv2dOp : public OpRewritePattern<AtenConv2dOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenConv2dOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Value emptyList = rewriter.create<PrimListConstructOp>(
|
||||
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
||||
SmallVector<Value>());
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
||||
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
||||
op, op->getResultTypes(), op.input(), op.weight(), op.bias(),
|
||||
op.stride(), op.padding(), op.dilation(), cstFalse, emptyList,
|
||||
op.groups());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
|
||||
namespace {
|
||||
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
|
||||
|
@ -1674,6 +1715,10 @@ class DecomposeComplexOpsPass
|
|||
patterns.add<DecomposeAtenLayerNormOp>(context);
|
||||
target.addIllegalOp<AtenNativeBatchNormOp>();
|
||||
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
|
||||
target.addIllegalOp<AtenConvolutionOverrideableOp>();
|
||||
patterns.add<DecomposeAtenConvolutionOverrideableOp>(context);
|
||||
target.addIllegalOp<AtenConv2dOp>();
|
||||
patterns.add<DecomposeAtenConv2dOp>(context);
|
||||
patterns.add<DecomposeAtenArangeOp>(context);
|
||||
target.addIllegalOp<AtenArangeOp>();
|
||||
patterns.add<DecomposeAtenArangeStartOp>(context);
|
||||
|
|
|
@ -549,7 +549,8 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
}
|
||||
|
||||
// Promote the two dtypes assuming non-zero rank.
|
||||
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp>(op)) {
|
||||
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
|
||||
AtenConvolutionOverrideableOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
||||
|
|
|
@ -2295,6 +2295,10 @@ module {
|
|||
} : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool
|
||||
return %1 : !torch.bool
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.convolution"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.conv_output_size(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.batch_norm"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list<int> {
|
||||
return %arg0 : !torch.list<int>
|
||||
}
|
||||
|
|
|
@ -762,6 +762,9 @@ def aten〇topk(self: List[int], k: int, dim: int = -1, largest: bool = True, so
|
|||
def aten〇conv2d(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> List[int]:
|
||||
return upstream_shape_helpers.conv2d(input, weight, bias, stride, padding, dilation, groups)
|
||||
|
||||
def aten〇convolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> List[int]:
|
||||
return upstream_shape_helpers.conv_output_size(input, weight, bias, stride, padding, dilation, groups)
|
||||
|
||||
def aten〇batch_norm(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
|
||||
# Torch's symbolic shape analysis is a bit looser about optional
|
||||
# arguments than we are, so their batch_norm helper function works
|
||||
|
|
|
@ -308,6 +308,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
|
||||
)
|
||||
emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
|
||||
emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
|
||||
emit(
|
||||
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
|
|
|
@ -133,3 +133,130 @@ class Conv2dWithPaddingDilationStrideStaticModule(torch.nn.Module):
|
|||
def Conv2dWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
|
||||
t = tu.rand(5, 2, 10, 20)
|
||||
module.forward(t)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ConvolutionModule1D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.convolution(inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[1],
|
||||
padding=[0],
|
||||
dilation=[1],
|
||||
transposed=False,
|
||||
output_padding=[0],
|
||||
groups=1)
|
||||
|
||||
@register_test_case(module_factory=lambda: ConvolutionModule1D())
|
||||
def ConvolutionModule1D_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10), torch.randn(3, 3, 2))
|
||||
|
||||
class ConvolutionModule2D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.convolution(inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[1, 1],
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1)
|
||||
|
||||
@register_test_case(module_factory=lambda: ConvolutionModule2D())
|
||||
def ConvolutionModule2D_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
class ConvolutionModule3D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.convolution(inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[1, 1, 1],
|
||||
padding=[0, 0, 0],
|
||||
dilation=[1, 1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0, 0],
|
||||
groups=1)
|
||||
|
||||
@register_test_case(module_factory=lambda: ConvolutionModule3D())
|
||||
def ConvolutionModule3D_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10, 10), torch.randn(3, 3, 2, 2, 2))
|
||||
|
||||
class ConvolutionModule2DStatic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 3, 10, 10], torch.float32, True),
|
||||
([3, 3, 2, 2], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.convolution(inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[1, 1],
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1)
|
||||
|
||||
@register_test_case(module_factory=lambda: ConvolutionModule2DStatic())
|
||||
def ConvolutionModule2DStatic_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
class ConvolutionModule2DStrided(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.convolution(inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[3, 3],
|
||||
padding=[2, 2],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1)
|
||||
|
||||
@register_test_case(module_factory=lambda: ConvolutionModule2DStrided())
|
||||
def ConvolutionModule2DStrided_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
|
Loading…
Reference in New Issue