Register fake_quantize_cachemask ops and add their decompose patterns (#3556)

Test:

`cmake --build build --target check-torch-mlir-all`
pull/3559/head
Ze Zhang 2024-07-23 11:33:12 -07:00 committed by GitHub
parent 21ad890009
commit d1e172f418
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 242 additions and 0 deletions

View File

@ -4650,6 +4650,35 @@ def Torch_AtenFakeQuantizePerTensorAffineTensorQparamsOp : Torch_Op<"aten.fake_q
}];
}
def Torch_Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp : Torch_Op<"aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams : (Tensor, Tensor, Tensor, Tensor, int, int) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$scale,
AnyTorchTensorType:$zero_point,
AnyTorchTensorType:$fake_quant_enabled,
Torch_IntType:$quant_min,
Torch_IntType:$quant_max
);
let results = (outs
AnyTorchOptionalTensorType:$output,
AnyTorchOptionalTensorType:$mask
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 2);
}
void Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 2);
}
}];
}
def Torch_AtenFakeQuantizePerChannelAffineOp : Torch_Op<"aten.fake_quantize_per_channel_affine", [
AllowsTypeRefinement,
HasValueSemantics,
@ -4678,6 +4707,35 @@ def Torch_AtenFakeQuantizePerChannelAffineOp : Torch_Op<"aten.fake_quantize_per_
}];
}
def Torch_AtenFakeQuantizePerChannelAffineCachemaskOp : Torch_Op<"aten.fake_quantize_per_channel_affine_cachemask", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::fake_quantize_per_channel_affine_cachemask : (Tensor, Tensor, Tensor, int, int, int) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$scale,
AnyTorchTensorType:$zero_point,
Torch_IntType:$axis,
Torch_IntType:$quant_min,
Torch_IntType:$quant_max
);
let results = (outs
AnyTorchOptionalTensorType:$output,
AnyTorchOptionalTensorType:$mask
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFakeQuantizePerChannelAffineCachemaskOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 2);
}
void AtenFakeQuantizePerChannelAffineCachemaskOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 2);
}
}];
}
def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -6367,10 +6367,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.int, %arg5: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %2 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.fake_quantize_per_channel_affine\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.fake_quantize_per_channel_affine_cachemask\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %2 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -10735,6 +10747,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>, %arg4: !torch.int, %arg5: !torch.int) -> !torch.tuple<int, int> {\n"
" %int11 = torch.constant.int 11\n"
" %int15 = torch.constant.int 15\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int1 = torch.constant.int 1\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.ne.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" %4 = torch.prim.TupleConstruct %3, %int11 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %4 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.fake_quantize_per_channel_affine\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int) -> !torch.int {\n"
" %int15 = torch.constant.int 15\n"
" %none = torch.constant.none\n"
@ -10756,6 +10793,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.fake_quantize_per_channel_affine_cachemask\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int) -> !torch.tuple<int, int> {\n"
" %int11 = torch.constant.int 11\n"
" %int15 = torch.constant.int 15\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int1 = torch.constant.int 1\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.ne.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" %4 = torch.prim.TupleConstruct %3, %int11 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %4 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.cosh\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"

View File

@ -9029,6 +9029,61 @@ public:
};
} // namespace
namespace {
// Decompose aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams
// into aten.fake_quantize_per_tensor_affine.tensor_qparams
// when the second result is unused.
class DecomposeAten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp
: public OpRewritePattern<
Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp> {
public:
using OpRewritePattern<
Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp>::
OpRewritePattern;
LogicalResult
matchAndRewrite(Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp op,
PatternRewriter &rewriter) const override {
if (!op->getResult(1).use_empty())
return failure();
auto newOp =
rewriter.create<AtenFakeQuantizePerTensorAffineTensorQparamsOp>(
op.getLoc(), op->getResult(0).getType(), op.getSelf(),
op.getScale(), op.getZeroPoint(), op.getQuantMin(),
op.getQuantMax());
rewriter.replaceAllUsesWith(op->getResult(0), newOp);
rewriter.eraseOp(op);
return success();
}
};
} // namespace
namespace {
// Decompose aten.fake_quantize_per_channel_affine_cachemask
// into aten.fake_quantize_per_channel_affine
// when the second result is unused.
class DecomposeAtenFakeQuantizePerChannelAffineCachemaskOp
: public OpRewritePattern<AtenFakeQuantizePerChannelAffineCachemaskOp> {
public:
using OpRewritePattern<
AtenFakeQuantizePerChannelAffineCachemaskOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenFakeQuantizePerChannelAffineCachemaskOp op,
PatternRewriter &rewriter) const override {
if (!op->getResult(1).use_empty())
return failure();
auto newOp = rewriter.create<AtenFakeQuantizePerChannelAffineOp>(
op.getLoc(), op->getResult(0).getType(), op.getSelf(), op.getScale(),
op.getZeroPoint(), op.getAxis(), op.getQuantMin(), op.getQuantMax());
rewriter.replaceAllUsesWith(op->getResult(0), newOp);
rewriter.eraseOp(op);
return success();
}
};
} // namespace
namespace {
// Decompose aten.fmax/fmin to aten.maximum/minimum + aten.where(nanMask)
template <typename AtenFOpT, typename AtenOpT>
@ -9306,6 +9361,11 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgSlogdetOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp>(
patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenFakeQuantizePerChannelAffineCachemaskOp>(patterns);
// More specific conv ops
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTbcOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConv1dOp>(patterns);

View File

@ -138,9 +138,15 @@ def atenfake_quantize_per_tensor_affine_cachemask〡shape(self: List[int], sc
def atenfake_quantize_per_tensor_affinetensor_qparams〡shape(self: List[int], scale: List[int], zero_point: List[int], quant_min: int, quant_max: int) -> List[int]:
return upstream_shape_functions.unary(self)
def aten_fake_quantize_per_tensor_affine_cachemask_tensor_qparams〡shape(self: List[int], scale: List[int], zero_point: List[int], fake_quant_enabled: List[int], quant_min: int, quant_max: int) -> Tuple[List[int], List[int]]:
return (upstream_shape_functions.unary(self), upstream_shape_functions.unary(self))
def atenfake_quantize_per_channel_affine〡shape(self: List[int], scale: List[int], zero_point: List[int], axis: int, quant_min: int, quant_max: int) -> List[int]:
return upstream_shape_functions.unary(self)
def atenfake_quantize_per_channel_affine_cachemask〡shape(self: List[int], scale: List[int], zero_point: List[int], axis: int, quant_min: int, quant_max: int) -> Tuple[List[int], List[int]]:
return (upstream_shape_functions.unary(self), upstream_shape_functions.unary(self))
def atensin〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
@ -2372,6 +2378,14 @@ def atenfake_quantize_per_tensor_affinetensor_qparams〡dtype(self_rank_dt
assert self_dtype != torch.bfloat16
return self_dtype
# note: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams doesn't support "meta" device, use "cpu" instead.
@check_dtype_function(Invocation(TensorOfShape(3, 3, dtype=dtype, device="cpu"), TensorOfShape(1, dtype=torch.float32, device="cpu"), TensorOfShape(1, dtype=torch.int32, device="cpu"), TensorOfShape(1, dtype=torch.int32, device="cpu"), 0, 255) for dtype in [torch.float64, torch.float32, torch.float16])
def aten_fake_quantize_per_tensor_affine_cachemask_tensor_qparams〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], fake_quant_enabled_rank_dtype: Tuple[int, int], quant_min: int, quant_max: int) -> Tuple[int, int]:
self_rank, self_dtype = self_rank_dtype
assert is_float_dtype(self_dtype)
assert self_dtype != torch.bfloat16
return (self_rank_dtype[1], torch.bool)
# note: fake_quantize_per_channel_affine doesn't support "meta" device, use "cpu" instead.
@check_dtype_function(Invocation(TensorOfShape(3, 3, dtype=dtype, device="cpu"), TensorOfShape(3, dtype=torch.float32, device="cpu"), TensorOfShape(3, dtype=torch.int32, device="cpu"), 0, 0, 255) for dtype in [torch.float64, torch.float32, torch.float16])
def atenfake_quantize_per_channel_affine〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], axis: int, quant_min: int, quant_max: int) -> int:
@ -2380,6 +2394,14 @@ def atenfake_quantize_per_channel_affine〡dtype(self_rank_dtype: Tuple[int,
assert self_dtype != torch.bfloat16
return self_dtype
# note: fake_quantize_per_channel_affine_cachemask doesn't support "meta" device, use "cpu" instead.
@check_dtype_function(Invocation(TensorOfShape(3, 3, dtype=dtype, device="cpu"), TensorOfShape(3, dtype=torch.float32, device="cpu"), TensorOfShape(3, dtype=torch.int32, device="cpu"), 0, 0, 255) for dtype in [torch.float64, torch.float32, torch.float16])
def atenfake_quantize_per_channel_affine_cachemask〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], axis: int, quant_min: int, quant_max: int) -> Tuple[int, int]:
self_rank, self_dtype = self_rank_dtype
assert is_float_dtype(self_dtype)
assert self_dtype != torch.bfloat16
return (self_rank_dtype[1], torch.bool)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atencosh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -464,9 +464,15 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit(
"aten::fake_quantize_per_tensor_affine.tensor_qparams : (Tensor, Tensor, Tensor, int, int) -> (Tensor)"
)
emit(
"aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams : (Tensor, Tensor, Tensor, Tensor, int, int) -> (Tensor, Tensor)"
)
emit(
"aten::fake_quantize_per_channel_affine : (Tensor, Tensor, Tensor, int, int, int) -> (Tensor)"
)
emit(
"aten::fake_quantize_per_channel_affine_cachemask : (Tensor, Tensor, Tensor, int, int, int) -> (Tensor, Tensor)"
)
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
emit("aten::fmax : (Tensor, Tensor) -> (Tensor)")

View File

@ -97,3 +97,37 @@ func.func @torch.aten.one_hot$fold(%arg0: !torch.vtensor<[3],si64>, %arg1: !torc
%0 = torch.aten.one_hot %arg0, %arg1 : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3,?],si64>
return %0 : !torch.vtensor<[3,?],si64>
}
// -----
// CHECK-LABEL: func.func @torch.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[1],f32>,
// CHECK-SAME: %[[ARG_2:.*]]: !torch.vtensor<[1],si32>, %[[ARG_3:.*]]: !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK: %[[CONST1:.*]] = torch.constant.int 127
// CHECK: %[[CONST2:.*]] = torch.constant.int -128
// CHECK: %[[RESULT:.*]] = torch.aten.fake_quantize_per_tensor_affine.tensor_qparams %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[CONST2]], %[[CONST1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],si32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],si32>, %arg3: !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?,?,?],f32> {
%int127 = torch.constant.int 127
%int-128 = torch.constant.int -128
%0:2 = torch.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams %arg0, %arg1, %arg2, %arg3, %int-128, %int127 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],i1>
return %0#0 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.fake_quantize_per_channel_affine_cachemask(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG_2:.*]]: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK: %[[CONST0:.*]] = torch.constant.int 0
// CHECK: %[[CONST1:.*]] = torch.constant.int 127
// CHECK: %[[CONST2:.*]] = torch.constant.int -128
// CHECK: %[[RESULT:.*]] = torch.aten.fake_quantize_per_channel_affine %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[CONST0]], %[[CONST2]], %[[CONST1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.fake_quantize_per_channel_affine_cachemask(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>, %arg2: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int0 = torch.constant.int 0
%int127 = torch.constant.int 127
%int-128 = torch.constant.int -128
%0:2 = torch.aten.fake_quantize_per_channel_affine_cachemask %arg0, %arg1, %arg2, %int0, %int-128, %int127 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],i1>
return %0#0 : !torch.vtensor<[?,?,?,?],f32>
}