mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] add support for adaptive_avgpool_1d (#2342)
* [MLIR][TORCH] Fix aten.cumsum lowering for int32 input (#2351) Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com> [Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op (#2340) [Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op and configure crashing e2e sets for stablehlo backend. update PyTorch version to 2.1.0.dev20230729 (#2354) - torch version: 2.1.0.dev20230729 - torch commit hash: b638df0afb83572724032c824c64e481bb4499a0 - torchvision version: 0.16.0.dev20230729 Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com> update PyTorch version to 2.1.0.dev20230730 (#2356) - torch version: 2.1.0.dev20230730 - torch commit hash: 0ff243ff350268cc98fe03fa6364375ee2824742 - torchvision version: 0.16.0.dev20230730 Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com> update PyTorch version to 2.1.0.dev20230731 (#2359) - torch version: 2.1.0.dev20230731 - torch commit hash: 6298ac688f8caafe30d71ff2ea2e20fbb32065c7 - torchvision version: 0.16.0.dev20230731 Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com> LTC->MLIR Debug Info support (#1922) * LTC->MLIR Debug Info support * SW-95317 Propagate Lazy->Jit->MLIR scope name. * Enhance location information based on op names Currently, the location information attached to the ops just considers the filename, line number and column number. Attaching operation name would help identify the type of computation by just looking at the profile of execution. * Update locations logic; updated debug-info.py test * Use {scope}/{op_name} format to track names by default --------- Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net> Co-authored-by: Mark Browning <mark@cerebras.net> Co-authored-by: Vimal Patel <vimal@polymagelabs.com> build: update llvm tag to 41895843 Summary of changes: - Update tags llvm: 41895843b5915bb78e9d02aa711fa10f7174db43 mhlo: 4726d31f7025da66de0dea709bd56c462edb83c2 Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com> update PyTorch version to 2.1.0.dev20230802 (#2366) - torch version: 2.1.0.dev20230802 - torch commit hash: c89b16917755c2abbef7b6420e340baf9ae8089e - torchvision version: 0.16.0.dev20230802 Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com> Change Python version from 3.10 to 3.11 in installation instructions (#2370) Add CITATION file (#2371) Add packaging as an install dependency (#2369) Needed by `torch_mlir._version`. Resolves #2368. [Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op (#2358) * [Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op update PyTorch version to 2.1.0.dev20230803 (#2372) - torch version: 2.1.0.dev20230803 - torch commit hash: f89c73be3a3e8274d025ac46a33a780853841c9e - torchvision version: 0.16.0.dev20230803 Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com> Prevent failed stable CI job from cancelling nightly jobs (#2373) The CI jobs that use stable PyTorch are currently not required to pass in order for a patch to get merged in `main`. This commit makes sure that if a CI job for stable PyTorch fails, it does not cancel the other required jobs. [Torch Dialect] emit aten.tile op and decompose it into aten.repeat (#2355) update update xfail sets update xfail_sets update fix xfail_sets update: update update: update parent 22e88d523b1970b2e904eb5421d49d987a3d255e author jianzhe.xiao <jianzhe.xiao@bytedance.com> 1691114110 +0800 committer jianzhe.xiao <jianzhe.xiao@bytedance.com> 1691114119 +0800 [Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op (#2340) [Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op and configure crashing e2e sets for stablehlo backend. update PyTorch version to 2.1.0.dev20230729 (#2354) - torch version: 2.1.0.dev20230729 - torch commit hash: b638df0afb83572724032c824c64e481bb4499a0 - torchvision version: 0.16.0.dev20230729 Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com> update PyTorch version to 2.1.0.dev20230730 (#2356) - torch version: 2.1.0.dev20230730 - torch commit hash: 0ff243ff350268cc98fe03fa6364375ee2824742 - torchvision version: 0.16.0.dev20230730 Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com> update PyTorch version to 2.1.0.dev20230731 (#2359) - torch version: 2.1.0.dev20230731 - torch commit hash: 6298ac688f8caafe30d71ff2ea2e20fbb32065c7 - torchvision version: 0.16.0.dev20230731 Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com> LTC->MLIR Debug Info support (#1922) * LTC->MLIR Debug Info support * SW-95317 Propagate Lazy->Jit->MLIR scope name. * Enhance location information based on op names Currently, the location information attached to the ops just considers the filename, line number and column number. Attaching operation name would help identify the type of computation by just looking at the profile of execution. * Update locations logic; updated debug-info.py test * Use {scope}/{op_name} format to track names by default --------- Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net> Co-authored-by: Mark Browning <mark@cerebras.net> Co-authored-by: Vimal Patel <vimal@polymagelabs.com> build: update llvm tag to 41895843 Summary of changes: - Update tags llvm: 41895843b5915bb78e9d02aa711fa10f7174db43 mhlo: 4726d31f7025da66de0dea709bd56c462edb83c2 Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com> update PyTorch version to 2.1.0.dev20230802 (#2366) - torch version: 2.1.0.dev20230802 - torch commit hash: c89b16917755c2abbef7b6420e340baf9ae8089e - torchvision version: 0.16.0.dev20230802 Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com> Change Python version from 3.10 to 3.11 in installation instructions (#2370) Add CITATION file (#2371) Add packaging as an install dependency (#2369) Needed by `torch_mlir._version`. Resolves #2368. [Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op (#2358) * [Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op update PyTorch version to 2.1.0.dev20230803 (#2372) - torch version: 2.1.0.dev20230803 - torch commit hash: f89c73be3a3e8274d025ac46a33a780853841c9e - torchvision version: 0.16.0.dev20230803 Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com> Prevent failed stable CI job from cancelling nightly jobs (#2373) The CI jobs that use stable PyTorch are currently not required to pass in order for a patch to get merged in `main`. This commit makes sure that if a CI job for stable PyTorch fails, it does not cancel the other required jobs. [Torch Dialect] emit aten.tile op and decompose it into aten.repeat (#2355) update update xfail sets update xfail_sets update fix xfail_sets update: update update: add support for adaptive_pool_id update xfail sets update xfail_sets update fix xfail_sets update: update: * update --------- Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/2379/head snapshot-20230805.921
parent
2fbb4c79f0
commit
38b049eb1a
|
@ -379,6 +379,7 @@ STABLEHLO_PASS_SET = {
|
|||
"ConstantBoolParameterModule_basic",
|
||||
"MaskedFillScalarIntValueStaticModule_basic",
|
||||
"MaskedFillScalarFloatValueStaticModule_basic",
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
"AddSizeIntModule_basic",
|
||||
"AddSizeIntNegDimModule_basic",
|
||||
|
@ -781,6 +782,7 @@ STABLEHLO_PASS_SET = {
|
|||
"ReshapeExpandModule_basic",
|
||||
"RollModule_basic",
|
||||
"TestMultipleTensorReturn_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||
"BaddbmmStaticModule_basic",
|
||||
"BaddbmmBroadcast1DInputModule_basic",
|
||||
|
@ -1197,6 +1199,8 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
|||
"SliceWholeTensorModule_basic",
|
||||
"TensorFloatModule_basic",
|
||||
"TensorIntModule_basic",
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||
}) - {
|
||||
### Test failing in make_fx_tosa but not in tosa
|
||||
|
||||
|
@ -1239,6 +1243,8 @@ LTC_XFAIL_SET = {
|
|||
"_ConvolutionDeprecated2DBenchmarkModule_basic",
|
||||
"_ConvolutionDeprecated2DCudnnModule_basic",
|
||||
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
"AddIntModule_basic",
|
||||
|
|
|
@ -5323,6 +5323,30 @@ def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$output_size
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenAdaptiveAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenAdaptiveAvgPool1dOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenTopkOp : Torch_Op<"aten.topk", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -6952,6 +6952,67 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %23 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.adaptive_avg_pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %3 = torch.aten.eq.int %2, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %11 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %12 = torch.aten.eq.int %11, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %12 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %4 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" torch.prim.Loop %5, %true, init() {\n"
|
||||
" ^bb0(%arg2: !torch.int):\n"
|
||||
" %11 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %12 = torch.aten.ne.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %12 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" %6 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" %7 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %8 = torch.aten.sub.int %7, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" torch.prim.Loop %8, %true, init() {\n"
|
||||
" ^bb0(%arg2: !torch.int):\n"
|
||||
" %11 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %12 = torch.aten.append.t %6, %11 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" %9 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %10 = torch.aten.append.t %6, %9 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" return %6 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.avg_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -8266,6 +8327,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -3326,6 +3326,85 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op.
|
||||
|
||||
// The logic of this decomposition is totally same with
|
||||
// the DecomposeAtenAdaptiveAvgPool2dOp, that means currently only following two
|
||||
// cases are supported:
|
||||
// 1. inputSize = outputSize
|
||||
// 2. outputSize = 1
|
||||
class DecomposeAtenAdaptiveAvgPool1dOp
|
||||
: public OpRewritePattern<AtenAdaptiveAvgPool1dOp> {
|
||||
using OpRewritePattern<AtenAdaptiveAvgPool1dOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenAdaptiveAvgPool1dOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op.getContext();
|
||||
|
||||
Value input = op.getSelf();
|
||||
std::optional<unsigned> maybeRank = getTensorRank(input);
|
||||
if (!maybeRank) {
|
||||
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
||||
}
|
||||
unsigned rank = *maybeRank;
|
||||
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(rank - 1));
|
||||
Value inputSize = rewriter.create<AtenSizeIntOp>(loc, input, sizeDim);
|
||||
|
||||
Value outputShape = op.getOutputSize();
|
||||
SmallVector<Value> outputShapeSizesTorchInt;
|
||||
getListConstructElements(outputShape, outputShapeSizesTorchInt);
|
||||
Value outputSize = outputShapeSizesTorchInt[0];
|
||||
|
||||
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
Value constantTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
||||
|
||||
int64_t outputSizeInt;
|
||||
if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "the output size of adaptive_pool_1d must be a constant int");
|
||||
}
|
||||
|
||||
SmallVector<Value, 1> kernelSize;
|
||||
if (outputSizeInt == 1) {
|
||||
BaseTensorType inputTensorType = input.getType().cast<BaseTensorType>();
|
||||
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
||||
kernelSize.push_back(
|
||||
inputShape[rank - 1] == kUnknownSize
|
||||
? inputSize
|
||||
: rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(inputShape[rank - 1])));
|
||||
} else {
|
||||
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
|
||||
rewriter.create<RuntimeAssertOp>(
|
||||
loc, cond,
|
||||
"unimplemented: only support cases where input and output size are "
|
||||
"equal for non-unit output size");
|
||||
kernelSize.push_back(constantOne);
|
||||
}
|
||||
|
||||
Value kernelSizeList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
|
||||
Value strideList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
||||
ValueRange{constantOne});
|
||||
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
||||
ValueRange{constantZero});
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenAvgPool1dOp>(
|
||||
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
|
||||
/*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.adaptiveAvgPool2d` op into `aten.avgPool2d` op.
|
||||
//
|
||||
|
@ -4800,6 +4879,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool1dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMaxOp>(patterns);
|
||||
|
|
|
@ -446,6 +446,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenPadOp>();
|
||||
target.addIllegalOp<AtenToDtypeLayoutOp>();
|
||||
target.addIllegalOp<AtenToDeviceOp>();
|
||||
target.addIllegalOp<AtenAdaptiveAvgPool1dOp>();
|
||||
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
|
||||
target.addIllegalOp<AtenClampMinOp>();
|
||||
target.addIllegalOp<AtenClampMaxOp>();
|
||||
|
|
|
@ -565,9 +565,28 @@ def avg_pool1d(input: List[int], kernel_size: List[int], stride: List[int], padd
|
|||
else:
|
||||
return [nbatch, nInputPlane, outputLength]
|
||||
|
||||
# TODO: This should be upstreamed.
|
||||
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
|
||||
def adaptive_avg_pool1d(self: List[int], out: List[int]):
|
||||
assert len(out) == 1
|
||||
assert len(self) == 2 or len(self) == 3
|
||||
|
||||
for i in range(len(self)):
|
||||
assert self[i] != 0
|
||||
|
||||
shape: List[int] = []
|
||||
for i in range(len(self) - 1):
|
||||
shape.append(self[i])
|
||||
shape.append(out[0])
|
||||
|
||||
return shape
|
||||
|
||||
def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]:
|
||||
return avg_pool1d(self, kernel_size, stride, padding, ceil_mode, count_include_pad)
|
||||
|
||||
def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]:
|
||||
return adaptive_avg_pool1d(self, output_size)
|
||||
|
||||
def aten〇avg_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> List[int]:
|
||||
return avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
|
||||
|
||||
|
@ -1407,6 +1426,11 @@ def aten〇abs〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
|||
return torch.float32
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], output_size=[2]))
|
||||
def aten〇adaptive_avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2]))
|
||||
def aten〇avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -418,6 +418,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
|
||||
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::permute : (Tensor, int[]) -> (Tensor)")
|
||||
|
|
|
@ -771,4 +771,88 @@ class AvgPool1dStaticModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: AvgPool1dStaticModule())
|
||||
def AvgPool1dStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(2, 4, 20, high=100))
|
||||
module.forward(tu.randint(2, 4, 20, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AdaptiveAvgPool1dNonUnitOutputSizeStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aap1d = torch.nn.AdaptiveAvgPool1d(7)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 512, 7], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.aap1d(x)
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: AdaptiveAvgPool1dNonUnitOutputSizeStaticModule())
|
||||
def AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 512, 7))
|
||||
|
||||
class AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aap1d = torch.nn.AdaptiveAvgPool1d(7)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.aap1d(x)
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule())
|
||||
def AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 512, 7))
|
||||
|
||||
class AdaptiveAvgPool1dUnitOutputSizeStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aap1d = torch.nn.AdaptiveAvgPool1d(1)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 512, 7], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.aap1d(x)
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: AdaptiveAvgPool1dUnitOutputSizeStaticModule())
|
||||
def AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 512, 7))
|
||||
|
||||
class AdaptiveAvgPool1dUnitOutputSizeDynamicModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aap1d = torch.nn.AdaptiveAvgPool1d(1)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.aap1d(x)
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: AdaptiveAvgPool1dUnitOutputSizeDynamicModule())
|
||||
def AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 512, 7))
|
Loading…
Reference in New Issue