mirror of https://github.com/llvm/torch-mlir
[Torch] support binary_cross_entropy_with_logits decomposition (#3741)
parent
f03d32afa1
commit
7f63cb225d
|
@ -9127,6 +9127,33 @@ def Torch_AtenBinaryCrossEntropyBackwardOp : Torch_Op<"aten.binary_cross_entropy
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenBinaryCrossEntropyWithLogitsOp : Torch_Op<"aten.binary_cross_entropy_with_logits", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::binary_cross_entropy_with_logits : (Tensor, Tensor, Tensor?, Tensor?, int) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$target,
|
||||||
|
AnyTorchOptionalTensorType:$weight,
|
||||||
|
AnyTorchOptionalTensorType:$pos_weight,
|
||||||
|
Torch_IntType:$reduction
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenBinaryCrossEntropyWithLogitsOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||||
|
}
|
||||||
|
void AtenBinaryCrossEntropyWithLogitsOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 5, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenLogSigmoidForwardOp : Torch_Op<"aten.log_sigmoid_forward", [
|
def Torch_AtenLogSigmoidForwardOp : Torch_Op<"aten.log_sigmoid_forward", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -10215,6 +10215,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int, !torch.int, !torch.float) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int, !torch.int, !torch.float) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.int) -> !torch.list<int> {\n"
|
||||||
|
" %int0 = torch.constant.int 0\n"
|
||||||
|
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||||
|
" %1 = torch.aten.eq.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
|
||||||
|
" %3 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" torch.prim.If.yield %3 : !torch.list<int>\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" return %2 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||||
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||||
|
@ -14494,6 +14506,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %4 : !torch.int\n"
|
" return %4 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" return %0#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
|
|
@ -8510,6 +8510,77 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
|
||||||
|
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
|
||||||
|
using OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenBinaryCrossEntropyWithLogitsOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
auto self = op.getSelf();
|
||||||
|
auto target = op.getTarget();
|
||||||
|
auto posWeight = op.getPosWeight();
|
||||||
|
auto weight = op.getWeight();
|
||||||
|
auto reduction = op.getReduction();
|
||||||
|
|
||||||
|
Value loss;
|
||||||
|
auto one =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
auto _one =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
|
||||||
|
|
||||||
|
auto _target =
|
||||||
|
rewriter.create<AtenMulScalarOp>(loc, target.getType(), target, _one);
|
||||||
|
auto _target_1 = rewriter.create<AtenAddScalarOp>(loc, _target.getType(),
|
||||||
|
_target, one, one);
|
||||||
|
Value mm =
|
||||||
|
rewriter.create<AtenMulTensorOp>(loc, self.getType(), _target_1, self);
|
||||||
|
Value logSigm =
|
||||||
|
rewriter.create<AtenLogSigmoidOp>(loc, self.getType(), self);
|
||||||
|
|
||||||
|
if (!isa<Torch::NoneType>(posWeight.getType())) {
|
||||||
|
auto logWeight = rewriter.create<AtenAddScalarOp>(
|
||||||
|
loc, posWeight.getType(),
|
||||||
|
rewriter.create<AtenSubScalarOp>(loc, posWeight.getType(), posWeight,
|
||||||
|
one, one),
|
||||||
|
one, one);
|
||||||
|
loss = rewriter.create<AtenSubTensorOp>(
|
||||||
|
loc, mm.getType(), mm,
|
||||||
|
rewriter.create<AtenMulTensorOp>(loc, logWeight.getType(), logWeight,
|
||||||
|
logSigm),
|
||||||
|
one);
|
||||||
|
} else {
|
||||||
|
loss =
|
||||||
|
rewriter.create<AtenSubTensorOp>(loc, mm.getType(), mm, logSigm, one);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isa<Torch::NoneType>(weight.getType())) {
|
||||||
|
loss =
|
||||||
|
rewriter.create<AtenMulTensorOp>(loc, loss.getType(), loss, weight);
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply loss reduction.
|
||||||
|
int64_t reductionInt;
|
||||||
|
if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "no reduction type is appointed!");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
Value res;
|
||||||
|
if (reductionInt == 1) {
|
||||||
|
res = rewriter.create<AtenMeanOp>(loc, op.getType(), loss, none);
|
||||||
|
} else if (reductionInt == 2) {
|
||||||
|
res = rewriter.create<AtenSumOp>(loc, op.getType(), loss, none);
|
||||||
|
} else {
|
||||||
|
res = loss;
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, res);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
|
class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
|
||||||
using OpRewritePattern<AtenOneHotOp>::OpRewritePattern;
|
using OpRewritePattern<AtenOneHotOp>::OpRewritePattern;
|
||||||
|
@ -9643,6 +9714,8 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
|
||||||
|
patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHannWindowPeriodicOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenHannWindowPeriodicOp>(patterns);
|
||||||
|
|
|
@ -1961,6 +1961,14 @@ def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int =
|
||||||
def aten〇cross_entropy_loss〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, reduction: int = 1, ignore_index: int = -100, label_smoothing: float = 0.) -> List[int]:
|
def aten〇cross_entropy_loss〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, reduction: int = 1, ignore_index: int = -100, label_smoothing: float = 0.) -> List[int]:
|
||||||
return upstream_shape_functions.cross_entropy_loss(self, target, weight, reduction, ignore_index, label_smoothing)
|
return upstream_shape_functions.cross_entropy_loss(self, target, weight, reduction, ignore_index, label_smoothing)
|
||||||
|
|
||||||
|
def aten〇binary_cross_entropy_with_logits〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, pos_weight: Optional[List[int]] = None, reduction: int = 1) -> List[int]:
|
||||||
|
scalar_shape: List[int] = []
|
||||||
|
if reduction == 0:
|
||||||
|
result_shape = upstream_shape_functions._copy(self)
|
||||||
|
else:
|
||||||
|
result_shape = scalar_shape
|
||||||
|
return result_shape
|
||||||
|
|
||||||
@check_shape_function([
|
@check_shape_function([
|
||||||
Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
|
Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
|
||||||
])
|
])
|
||||||
|
@ -4909,6 +4917,10 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U
|
||||||
return dtype
|
return dtype
|
||||||
return aten〇std〡dtype(self_rank_dtype)
|
return aten〇std〡dtype(self_rank_dtype)
|
||||||
|
|
||||||
|
def aten〇binary_cross_entropy_with_logits〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]] = None, pos_weight_rank_dtype: Optional[Tuple[int, int]] = None, reduction: int = 1) -> int:
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
@check_dtype_function(
|
@check_dtype_function(
|
||||||
_check_tensors_with_the_same_dtype(
|
_check_tensors_with_the_same_dtype(
|
||||||
tensor_shapes=[(3,3)],
|
tensor_shapes=[(3,3)],
|
||||||
|
|
|
@ -739,6 +739,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit(
|
emit(
|
||||||
"aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)"
|
"aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)"
|
||||||
)
|
)
|
||||||
|
emit(
|
||||||
|
"aten::binary_cross_entropy_with_logits : (Tensor, Tensor, Tensor?, Tensor?, int) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)")
|
emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)")
|
||||||
emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)")
|
emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)")
|
||||||
|
|
|
@ -2294,6 +2294,29 @@ def CrossEntropyLossNoReductionModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(8, 2), tu.randint(8, high=2))
|
module.forward(tu.rand(8, 2), tu.randint(8, high=2))
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryCrossEntropyWithLogitsStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([8, 2], torch.float32, True),
|
||||||
|
([8, 2], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, input, target):
|
||||||
|
return torch.ops.aten.binary_cross_entropy_with_logits(
|
||||||
|
input, target, reduction=0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: BinaryCrossEntropyWithLogitsStaticModule())
|
||||||
|
def BinaryCrossEntropyWithLogitsStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(8, 2), tu.rand(8, 2))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue