mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.mse_loss op
This commit adds decomposition for the `aten.mse_loss` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>rewrite-getitem
parent
2f097d3976
commit
ca87033d2f
|
@ -479,7 +479,8 @@ TOSA_PASS_SET = {
|
||||||
"ToDtypeBoolLayoutNoneStaticModule_basic",
|
"ToDtypeBoolLayoutNoneStaticModule_basic",
|
||||||
"ToCopyBoolDTypeStaticModule_basic",
|
"ToCopyBoolDTypeStaticModule_basic",
|
||||||
"HardTanhIntModule_basic",
|
"HardTanhIntModule_basic",
|
||||||
"AtenRoundIntModule_basic"
|
"AtenRoundIntModule_basic",
|
||||||
|
"MseLossNoReductionModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
LTC_XFAIL_SET = {
|
LTC_XFAIL_SET = {
|
||||||
|
|
|
@ -4693,6 +4693,31 @@ def Torch_AtenFrobeniusNormDimOp : Torch_Op<"aten.frobenius_norm.dim", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenMseLossOp : Torch_Op<"aten.mse_loss", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$target,
|
||||||
|
Torch_IntType:$reduction
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenMseLossOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||||
|
}
|
||||||
|
void AtenMseLossOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 3, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
|
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -2845,6 +2845,57 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenMseLossOp : public OpRewritePattern<AtenMseLossOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenMseLossOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
// The `reduction` arg would have only three valid values.
|
||||||
|
// 0 means no reduction.
|
||||||
|
// 1 means mean reduction.
|
||||||
|
// 2 means sum reduction.
|
||||||
|
int64_t reductionType;
|
||||||
|
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reductionType)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Expected a constant integer value for reduction");
|
||||||
|
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
BaseTensorType resultType = op.getType().cast<BaseTensorType>();
|
||||||
|
BaseTensorType inputType = op.self().getType().cast<BaseTensorType>();
|
||||||
|
if (!inputType.hasSizes())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Expected the input tensor to have sizes");
|
||||||
|
BaseTensorType subType =
|
||||||
|
inputType
|
||||||
|
.getWithSizesAndDtype(llvm::makeArrayRef(inputType.getSizes()),
|
||||||
|
resultType.getDtype())
|
||||||
|
.cast<BaseTensorType>();
|
||||||
|
|
||||||
|
Value sub = createTensorSub(rewriter, loc, subType, op.self(), op.target());
|
||||||
|
Value result = rewriter.create<AtenSquareOp>(loc, subType, sub);
|
||||||
|
if (reductionType == torch_upstream::Reduction::None) {
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||||
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||||
|
if (reductionType == torch_upstream::Reduction::Mean)
|
||||||
|
result = rewriter.create<AtenMeanDimOp>(loc, resultType, result,
|
||||||
|
/*dim=*/cstNone,
|
||||||
|
/*keepdim=*/cstFalse,
|
||||||
|
/*dtype=*/cstNone);
|
||||||
|
else
|
||||||
|
result = rewriter.create<AtenSumDimIntListOp>(
|
||||||
|
loc, resultType, result, /*dim=*/cstNone, /*keepdim=*/cstFalse,
|
||||||
|
/*dtype=*/cstNone);
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeComplexOpsPass
|
class DecomposeComplexOpsPass
|
||||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||||
|
@ -3040,6 +3091,8 @@ public:
|
||||||
target.addIllegalOp<AtenLiftFreshCopyOp>();
|
target.addIllegalOp<AtenLiftFreshCopyOp>();
|
||||||
patterns.add<DecomposeAtenIndexTensorHackedTwinOp>(context);
|
patterns.add<DecomposeAtenIndexTensorHackedTwinOp>(context);
|
||||||
target.addIllegalOp<AtenIndexTensorHackedTwinOp>();
|
target.addIllegalOp<AtenIndexTensorHackedTwinOp>();
|
||||||
|
patterns.add<DecomposeAtenMseLossOp>(context);
|
||||||
|
target.addIllegalOp<AtenMseLossOp>();
|
||||||
|
|
||||||
for (std::string opName : legalOps) {
|
for (std::string opName : legalOps) {
|
||||||
target.addLegalOp(OperationName(opName, context));
|
target.addLegalOp(OperationName(opName, context));
|
||||||
|
|
|
@ -756,7 +756,8 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
// Promote the two dtypes assuming non-zero rank.
|
// Promote the two dtypes assuming non-zero rank.
|
||||||
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
|
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
|
||||||
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenMvOp,
|
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenMvOp,
|
||||||
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp>(op)) {
|
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp,
|
||||||
|
AtenMseLossOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
||||||
|
|
|
@ -6743,6 +6743,18 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.mse_loss\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.list<int> {\n"
|
||||||
|
" %int0 = torch.constant.int 0\n"
|
||||||
|
" %0 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %1 = torch.prim.If %0 -> (!torch.list<int>) {\n"
|
||||||
|
" %2 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" torch.prim.If.yield %2 : !torch.list<int>\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||||
|
" torch.prim.If.yield %2 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" return %1 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
|
||||||
" %true = torch.constant.bool true\n"
|
" %true = torch.constant.bool true\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
|
|
|
@ -1047,6 +1047,11 @@ def aten〇nll_loss_forward(self: List[int], target: List[int], weight: Optional
|
||||||
def aten〇nll_loss_backward(grad_output: List[int], self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int, total_weight: List[int]) -> List[int]:
|
def aten〇nll_loss_backward(grad_output: List[int], self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int, total_weight: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
def aten〇mse_loss(self: List[int], target: List[int], reduction: int = 1) -> List[int]:
|
||||||
|
if reduction == 0:
|
||||||
|
return upstream_shape_functions.unary(self)
|
||||||
|
return []
|
||||||
|
|
||||||
@check_shape_function([
|
@check_shape_function([
|
||||||
Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
|
Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
|
||||||
])
|
])
|
||||||
|
|
|
@ -406,6 +406,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
|
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
|
||||||
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
|
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
|
||||||
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
|
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
|
||||||
|
emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)")
|
||||||
|
|
||||||
# Misc tensor ops.
|
# Misc tensor ops.
|
||||||
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
||||||
|
|
|
@ -601,3 +601,61 @@ class ReduceFrobeniusNormKeepDimModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: ReduceFrobeniusNormKeepDimModule())
|
@register_test_case(module_factory=lambda: ReduceFrobeniusNormKeepDimModule())
|
||||||
def ReduceFrobeniusNormKeepDimModule_basic(module, tu: TestUtils):
|
def ReduceFrobeniusNormKeepDimModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.rand(3, 4, 5))
|
module.forward(torch.rand(3, 4, 5))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class MseLossNoReductionModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1 , -1], torch.float32, True),
|
||||||
|
([-1 , -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.mse_loss(x, y, reduction=0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MseLossNoReductionModule())
|
||||||
|
def MseLossNoReductionModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 4), tu.rand(2, 4))
|
||||||
|
|
||||||
|
|
||||||
|
class MseLossMeanReductionModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1 , -1], torch.float32, True),
|
||||||
|
([-1 , -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.mse_loss(x, y, reduction=1)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MseLossMeanReductionModule())
|
||||||
|
def MseLossMeanReductionModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 4), tu.rand(2, 4))
|
||||||
|
|
||||||
|
|
||||||
|
class MseLossSumReductionWithDifferentElemTypeModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1 , -1], torch.float32, True),
|
||||||
|
([-1 , -1], torch.float64, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.mse_loss(x, y, reduction=2)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MseLossSumReductionWithDifferentElemTypeModule())
|
||||||
|
def MseLossSumReductionWithDifferentElemTypeModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 4), tu.rand(2, 4).to(torch.float64))
|
||||||
|
|
|
@ -991,3 +991,59 @@ func.func @torch.aten.roll(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int,
|
||||||
%2 = torch.aten.roll %arg0, %0, %1 : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
|
%2 = torch.aten.roll %arg0, %0, %1 : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
|
||||||
return %2 : !torch.vtensor<[?,?],f32>
|
return %2 : !torch.vtensor<[?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.mse_loss$no_reduction(
|
||||||
|
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?],f32>,
|
||||||
|
// CHECK-SAME: %[[TARGET:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[REDUCTION:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[SELF]], %[[TARGET]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[SUB_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB]], %[[SUB]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[SUB_SQUARE]] : !torch.vtensor<[?,?],f32>
|
||||||
|
func.func @torch.aten.mse_loss$no_reduction(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%0 = torch.aten.mse_loss %arg0, %arg1, %int0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.mse_loss$mean_reduction(
|
||||||
|
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?],f32>,
|
||||||
|
// CHECK-SAME: %[[TARGET:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[REDUCTION:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[SELF]], %[[TARGET]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[SUB_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB]], %[[SUB]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[SUB_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_SQUARE]], %[[NONE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[NUMEL:.*]] = torch.aten.numel %[[SUB_SQUARE]] : !torch.vtensor<[?,?],f32> -> !torch.int
|
||||||
|
// CHECK: %[[SUB_SQUARE_MEAN:.*]] = torch.aten.div.Scalar %[[SUB_SQUARE_SUM]], %[[NUMEL]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[SUB_SQUARE_MEAN]] : !torch.vtensor<[?,?],f32>
|
||||||
|
func.func @torch.aten.mse_loss$mean_reduction(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.mse_loss %arg0, %arg1, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.mse_loss$sum_reduction(
|
||||||
|
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?],f32>,
|
||||||
|
// CHECK-SAME: %[[TARGET:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[REDUCTION:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[SELF]], %[[TARGET]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[SUB_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB]], %[[SUB]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[SUB_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_SQUARE]], %[[NONE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[SUB_SQUARE_SUM]] : !torch.vtensor<[?,?],f32>
|
||||||
|
func.func @torch.aten.mse_loss$sum_reduction(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%0 = torch.aten.mse_loss %arg0, %arg1, %int2 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue