mirror of https://github.com/llvm/torch-mlir
OnnxToTorch support for onnx.InstanceNormalization op (#2710)
https://github.com/nod-ai/SHARK-Turbine/issues/327pull/2924/head
parent
78e10ff09b
commit
d29157b33f
|
@ -6203,6 +6203,37 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenInstanceNormOp : Torch_Op<"aten.instance_norm", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$input,
|
||||
AnyTorchOptionalTensorType:$weight,
|
||||
AnyTorchOptionalTensorType:$bias,
|
||||
AnyTorchOptionalTensorType:$running_mean,
|
||||
AnyTorchOptionalTensorType:$running_var,
|
||||
Torch_BoolType:$use_input_stats,
|
||||
Torch_FloatType:$momentum,
|
||||
Torch_FloatType:$eps,
|
||||
Torch_BoolType:$cudnn_enabled
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 9, 1);
|
||||
}
|
||||
void AtenInstanceNormOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 9, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -392,6 +392,37 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
binder.op, resultType, lhs, rhs);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"InstanceNormalization", 6,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
llvm::SmallVector<Value> operands;
|
||||
float eps;
|
||||
|
||||
if (binder.tensorOperands(operands, 3) ||
|
||||
binder.tensorResultType(resultType) || operands.size() != 3 ||
|
||||
binder.f32FloatAttr(eps, "epsilon", 1e-05f)) {
|
||||
return failure();
|
||||
}
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
Value boolTrue =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
||||
Value boolFalse =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||
auto epsValue = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getF64FloatAttr(eps));
|
||||
|
||||
auto momentum = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getF64FloatAttr(0.0f));
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenInstanceNormOp>(
|
||||
binder.op, resultType, /* input */ operands[0],
|
||||
/* weight */ operands[1],
|
||||
/* bias */ operands[2], /* running mean */ none,
|
||||
/* running var */ none,
|
||||
/* use input stats */ boolTrue, momentum, epsValue,
|
||||
/* cudnn enabled */ boolFalse);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
|
|
@ -8784,6 +8784,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" return %3 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.instance_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -9643,6 +9647,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %3 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
|
||||
" return %3 : !torch.tuple<int, int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.instance_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<tuple<int, int>>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.optional<tuple<int, int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -3962,6 +3962,151 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenInstanceNormOp
|
||||
: public OpRewritePattern<AtenInstanceNormOp> {
|
||||
using OpRewritePattern<AtenInstanceNormOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenInstanceNormOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto context = op.getContext();
|
||||
|
||||
auto inputTy = op.getInput().getType().cast<BaseTensorType>();
|
||||
int64_t inputRank = inputTy.getSizes().size();
|
||||
auto reduceDimInts =
|
||||
llvm::SmallVector<int64_t>({inputRank - 2, inputRank - 1});
|
||||
|
||||
SmallVector<int64_t> reducedShape(inputTy.getSizes());
|
||||
reducedShape[inputRank - 1] = 1;
|
||||
reducedShape[inputRank - 2] = 1;
|
||||
|
||||
Type dtype = inputTy.getOptionalDtype();
|
||||
Type reducedTy = ValueTensorType::get(op.getContext(),
|
||||
llvm::ArrayRef(reducedShape), dtype);
|
||||
|
||||
auto sizeListType = ListType::get(IntType::get(context));
|
||||
SmallVector<Value> reduceDimVals;
|
||||
reduceDimVals.reserve(reduceDimInts.size());
|
||||
std::transform(reduceDimInts.begin(), reduceDimInts.end(),
|
||||
std::back_inserter(reduceDimVals), [&](int64_t d) {
|
||||
return rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(d));
|
||||
});
|
||||
Value reduceDimList =
|
||||
rewriter.create<PrimListConstructOp>(loc, sizeListType, reduceDimVals);
|
||||
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
|
||||
Value one = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// mean(x)
|
||||
Value inputMean = rewriter.create<AtenMeanDimOp>(
|
||||
loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none);
|
||||
|
||||
// x - mean(x)
|
||||
Value inputMeanExpanded =
|
||||
rewriter.create<AtenExpandAsOp>(loc, inputTy, inputMean, op.getInput());
|
||||
Value inputSubMean = rewriter.create<AtenSubTensorOp>(
|
||||
loc, inputTy, op.getInput(), inputMeanExpanded, one);
|
||||
// (x - mean(x))^2
|
||||
Value inputSubMeanSquare = rewriter.create<AtenMulTensorOp>(
|
||||
loc, inputTy, inputSubMean, inputSubMean);
|
||||
|
||||
Value variancesum = rewriter.create<AtenSumDimIntListOp>(
|
||||
loc, reducedTy, inputSubMeanSquare, reduceDimList, cstTrue,
|
||||
/*dtype=*/none);
|
||||
|
||||
Value hw = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(inputTy.getSizes()[inputRank - 1] *
|
||||
inputTy.getSizes()[inputRank - 2]));
|
||||
Value inputVar =
|
||||
rewriter.create<AtenDivScalarOp>(loc, reducedTy, variancesum, hw);
|
||||
|
||||
// rsqrt(var(x) + eps)
|
||||
Value inputVarPlusEps = rewriter.create<AtenAddScalarOp>(
|
||||
loc, reducedTy, inputVar, op.getEps(), one);
|
||||
Value inputRsqrtVar =
|
||||
rewriter.create<AtenRsqrtOp>(loc, reducedTy, inputVarPlusEps);
|
||||
|
||||
// (x - mean(x)) * rsqrt(var(x) + eps)
|
||||
Value inputRsqrtVarExpanded = rewriter.create<AtenExpandAsOp>(
|
||||
loc, inputTy, inputRsqrtVar, op.getInput());
|
||||
Value inputNormalized = rewriter.create<AtenMulTensorOp>(
|
||||
loc, inputTy, inputSubMean, inputRsqrtVarExpanded);
|
||||
Value out = rewriter.create<TensorStaticInfoCastOp>(
|
||||
loc, op.getResult().getType(), inputNormalized);
|
||||
|
||||
Value weight = op.getWeight();
|
||||
auto weightTy = weight.getType().cast<BaseTensorType>();
|
||||
dtype = weightTy.getOptionalDtype();
|
||||
|
||||
SmallVector<int64_t> weightShape(weightTy.getSizes());
|
||||
SmallVector<int64_t> newWeightShape;
|
||||
newWeightShape.push_back(1);
|
||||
newWeightShape.append(weightShape);
|
||||
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Type newWeightTy = ValueTensorType::get(
|
||||
op.getContext(), llvm::ArrayRef(newWeightShape), dtype);
|
||||
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, zero);
|
||||
|
||||
Value two = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(2));
|
||||
newWeightShape.push_back(1);
|
||||
newWeightTy = ValueTensorType::get(op.getContext(),
|
||||
llvm::ArrayRef(newWeightShape), dtype);
|
||||
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, two);
|
||||
|
||||
Value three = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(3));
|
||||
newWeightShape.push_back(1);
|
||||
newWeightTy = ValueTensorType::get(op.getContext(),
|
||||
llvm::ArrayRef(newWeightShape), dtype);
|
||||
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, three);
|
||||
|
||||
Value weightExpanded =
|
||||
rewriter.create<AtenExpandAsOp>(loc, inputTy, weight, op.getInput());
|
||||
|
||||
Value bias = op.getBias();
|
||||
auto biasTy = bias.getType().cast<BaseTensorType>();
|
||||
dtype = biasTy.getOptionalDtype();
|
||||
|
||||
SmallVector<int64_t> biasShape(biasTy.getSizes());
|
||||
SmallVector<int64_t> newBiasShape;
|
||||
newBiasShape.push_back(1);
|
||||
newBiasShape.append(biasShape);
|
||||
|
||||
Type newBiasTy = ValueTensorType::get(op.getContext(),
|
||||
llvm::ArrayRef(newBiasShape), dtype);
|
||||
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, zero);
|
||||
|
||||
newBiasShape.push_back(1);
|
||||
newBiasTy = ValueTensorType::get(op.getContext(),
|
||||
llvm::ArrayRef(newBiasShape), dtype);
|
||||
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, two);
|
||||
|
||||
newBiasShape.push_back(1);
|
||||
newBiasTy = ValueTensorType::get(op.getContext(),
|
||||
llvm::ArrayRef(newBiasShape), dtype);
|
||||
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, three);
|
||||
|
||||
Value biasExpanded =
|
||||
rewriter.create<AtenExpandAsOp>(loc, inputTy, bias, op.getInput());
|
||||
|
||||
out = rewriter.create<AtenMulTensorOp>(loc, out.getType(), out,
|
||||
weightExpanded);
|
||||
out = rewriter.create<AtenAddTensorOp>(loc, out.getType(), out,
|
||||
biasExpanded, one);
|
||||
|
||||
rewriter.replaceOp(op, out);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenNativeLayerNormOp
|
||||
: public OpRewritePattern<AtenNativeLayerNormOp> {
|
||||
|
@ -6733,6 +6878,7 @@ public:
|
|||
DecomposeAtenAddCLikeOp<AtenAddcmulOp, AtenMulTensorOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenInstanceNormOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenGroupNormOp>(patterns);
|
||||
|
|
|
@ -409,6 +409,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
});
|
||||
target.addIllegalOp<AtenAddcmulOp>();
|
||||
target.addIllegalOp<AtenAddcdivOp>();
|
||||
target.addIllegalOp<AtenInstanceNormOp>();
|
||||
target.addIllegalOp<AtenLayerNormOp>();
|
||||
target.addIllegalOp<AtenNativeLayerNormOp>();
|
||||
target.addIllegalOp<AtenGroupNormOp>();
|
||||
|
|
|
@ -233,6 +233,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
|
||||
# START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
"AtenInstanceNormModule_basic",
|
||||
"BatchNorm1DModule_basic",
|
||||
"BatchNorm1DWith2DInputModule_basic",
|
||||
"BatchNorm2DModule_basic",
|
||||
|
@ -898,6 +899,7 @@ TOSA_PASS_SET = {
|
|||
"AtenEyeModuleFalsePinMemory_basic",
|
||||
"AtenEyeModuleFloat2D_basic",
|
||||
"AtenRoundIntModule_basic",
|
||||
"AtenInstanceNormModule_basic",
|
||||
"AtenToDeviceModule_basic",
|
||||
"BaddbmmBroadcast1DInputModule_basic",
|
||||
"BaddbmmBroadcast2DInputModule_basic",
|
||||
|
@ -1306,6 +1308,8 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
|||
"Conv2dNoPaddingModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideModule_basic",
|
||||
"Conv2dWithPaddingModule_basic",
|
||||
|
||||
"AtenInstanceNormModule_basic",
|
||||
}
|
||||
|
||||
LTC_CRASHING_SET = {
|
||||
|
|
|
@ -1415,6 +1415,9 @@ def aten〇group_norm〡shape(input: List[int], num_groups: int, weight: Optiona
|
|||
def aten〇native_group_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[List[int], List[int], List[int]]:
|
||||
return upstream_shape_functions.unary(input), [N, group], [N, group]
|
||||
|
||||
def aten〇instance_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
|
||||
return upstream_shape_functions.unary(input)
|
||||
|
||||
def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
|
||||
return upstream_shape_functions.slice(self, dim, start, end, step)
|
||||
|
||||
|
@ -2048,6 +2051,11 @@ def aten〇native_group_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_r
|
|||
assert not is_integer_dtype(input_dtype)
|
||||
return input_dtype, input_dtype, input_dtype
|
||||
|
||||
# device is not supported hence unable to check the dtype function
|
||||
def aten〇instance_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> int:
|
||||
input_rank, input_dtype = input_rank_dtype
|
||||
return input_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -437,6 +437,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
|
|
|
@ -489,3 +489,21 @@ class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
|
|||
def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 2, 3))
|
||||
|
||||
class AtenInstanceNormModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 2, 1, 3], torch.float32, True),
|
||||
([2], torch.float32, True),
|
||||
([2], torch.float32, True)
|
||||
])
|
||||
def forward(self, x, w, b):
|
||||
return torch.ops.aten.instance_norm(x, w, b, None,
|
||||
None, True, 0.0, 1e-05, False)
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenInstanceNormModule())
|
||||
def AtenInstanceNormModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 2, 1, 3), tu.rand(2), tu.rand(2))
|
||||
|
|
|
@ -603,6 +603,15 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_instancenorm
|
||||
func.func @test_instancenorm(%arg0: !torch.vtensor<[1,2,1,3],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.instance_norm %arg0, %arg1, %arg2, %none, %none, %true, %float0.000000e00, %float9.999990e-06, %false : !torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.none, !torch.none, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[1,2,1,3],f32>
|
||||
%0 = torch.operator "onnx.InstanceNormalization"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32>
|
||||
return %0 : !torch.vtensor<[1,2,1,3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_not_2d
|
||||
func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>
|
||||
|
|
Loading…
Reference in New Issue