mirror of https://github.com/llvm/torch-mlir
Add aten.isclose support and its torch-to-tosa lowering (#2512)
Add aten.isclose op Add its torch-to-tosa lowering Update the TorchToTosa/basic.mlir tests To test e2e tosa lowering: `python -m e2e_testing.main -v -c=tosa` --------- Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>pull/2511/head
parent
e649e06b7b
commit
f2c53b8ca5
|
@ -18,6 +18,8 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
|||
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
"UnflattenStaticModule_basic",
|
||||
"IscloseStaticModule_basic",
|
||||
"IscloseStaticModuleTrue_basic",
|
||||
}
|
||||
|
||||
TORCHDYNAMO_XFAIL_SET = {
|
||||
|
@ -928,6 +930,8 @@ STABLEHLO_CRASHING_SET = {
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"IscloseStaticModule_basic",
|
||||
"IscloseStaticModuleTrue_basic",
|
||||
"TileBigDimsSizeModule_basic",
|
||||
"TileSmallDimsSizeModule_basic",
|
||||
"IndexPutImpl2DNoneIndexStaticModule_basic",
|
||||
|
|
|
@ -4162,6 +4162,33 @@ def Torch_AtenViewAsRealOp : Torch_Op<"aten.view_as_real", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenIscloseOp : Torch_Op<"aten.isclose", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other,
|
||||
Torch_FloatType:$rtol,
|
||||
Torch_FloatType:$atol,
|
||||
Torch_BoolType:$equal_nan
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenIscloseOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||
}
|
||||
void AtenIscloseOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 5, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -3920,6 +3920,59 @@ LogicalResult ConvertAtenOp<AtenLeTensorOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
|
||||
AtenIscloseOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// check args
|
||||
double rtol, atol;
|
||||
bool equalNan;
|
||||
if (!matchPattern(op.getRtol(), m_TorchConstantFloat(&rtol)))
|
||||
return rewriter.notifyMatchFailure(op, "rtol must be a scalar constant");
|
||||
if (!matchPattern(op.getAtol(), m_TorchConstantFloat(&atol)))
|
||||
return rewriter.notifyMatchFailure(op, "atol must be a scalar constant");
|
||||
if (!matchPattern(op.getEqualNan(), m_TorchConstantBool(&equalNan)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: equal_nan is expected to be false");
|
||||
|
||||
// check tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
auto otherType = adaptor.getOther().getType().dyn_cast<TensorType>();
|
||||
if (!selfType || !otherType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
if (!selfType.hasStaticShape() || !otherType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types with static shape are supported");
|
||||
if (!selfType.getElementType().isa<mlir::FloatType>() ||
|
||||
!otherType.getElementType().isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only FP element type is supported");
|
||||
}
|
||||
|
||||
auto rhsSubOp = rewriter.create<tosa::SubOp>(
|
||||
op->getLoc(), selfType, adaptor.getSelf(), adaptor.getOther());
|
||||
auto rhsAbsOp =
|
||||
rewriter.create<tosa::AbsOp>(op->getLoc(), selfType, rhsSubOp);
|
||||
|
||||
auto lhsAbsOp =
|
||||
rewriter.create<tosa::AbsOp>(op->getLoc(), otherType, adaptor.getOther());
|
||||
auto rtolConstOp =
|
||||
tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(rtol));
|
||||
auto mulOp = rewriter.create<tosa::MulOp>(op->getLoc(), otherType,
|
||||
rtolConstOp, lhsAbsOp, /*shift=*/0);
|
||||
auto atolConstOp =
|
||||
tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(atol));
|
||||
auto addOp =
|
||||
rewriter.create<tosa::AddOp>(op->getLoc(), otherType, atolConstOp, mulOp);
|
||||
|
||||
auto outType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tosa::GreaterEqualOp>(op, outType, addOp,
|
||||
rhsAbsOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
||||
AtenClampOp op, OpAdaptor adaptor,
|
||||
|
@ -5134,6 +5187,7 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenRemainderScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenCatOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSqrtOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIscloseOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -7480,6 +7480,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.isclose\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.unsqueeze\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unsqueeze(%arg0, %arg1) : (!torch.list<int>, !torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -9093,6 +9097,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.isclose\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool) -> !torch.int {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -844,6 +844,9 @@ def aten〇lt〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
|
|||
def aten〇le〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.broadcast(self, other)
|
||||
|
||||
def aten〇isclose〡shape(self: List[int], other: List[int], rtol: float = 1.0000000000000001e-05, atol: float = 1e-08, equal_nan: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.broadcast(self, other)
|
||||
|
||||
def aten〇unsqueeze〡shape(self: List[int], dim: int) -> List[int]:
|
||||
return upstream_shape_functions.unsqueeze(self, dim)
|
||||
|
||||
|
@ -2171,6 +2174,10 @@ def aten〇logical_and〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp
|
|||
def aten〇logical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
return torch.bool
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2))
|
||||
def aten〇isclose〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], rtol: float = 1.0000000000000001e-05, atol: float = 1e-08, equal_nan: bool = False) -> int:
|
||||
return torch.bool
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4, 32, 16), (3, 4, 32, 16), (3, 4, 32, 16)]))
|
||||
def aten〇scaled_dot_product_attention〡dtype(query_rank_dtype: Tuple[int, int], key_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int], attn_mask_rank_dtype: Optional[Tuple[int, int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None) -> int:
|
||||
_, query_dtype = query_rank_dtype
|
||||
|
|
|
@ -342,6 +342,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::imag : (Tensor) -> (Tensor)")
|
||||
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
|
||||
emit("aten::view_as_real : (Tensor) -> (Tensor)")
|
||||
emit("aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)")
|
||||
|
||||
# Ops with dynamic number of outputs
|
||||
emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])")
|
||||
|
|
|
@ -4580,3 +4580,48 @@ class Add_Module(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: Add_Module())
|
||||
def Add_Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class IscloseStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([5, 5], torch.float32, True),
|
||||
([5, 5], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.isclose(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IscloseStaticModule())
|
||||
def IscloseStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 5), tu.rand(5, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class IscloseStaticModuleTrue(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer('tensor', torch.ones(1))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([5, 5], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.isclose(x, self.tensor)
|
||||
|
||||
@register_test_case(module_factory=lambda: IscloseStaticModuleTrue())
|
||||
def IscloseStaticModuleTrue_basic(module, tu: TestUtils):
|
||||
module.forward(torch.ones(5, 5))
|
||||
|
|
|
@ -1155,3 +1155,32 @@ func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !to
|
|||
%0 = torch.aten.remainder.Scalar %arg0, %int2 : !torch.vtensor<[2, 4],f32>, !torch.int -> !torch.vtensor<[2, 4],f32>
|
||||
return %0 : !torch.vtensor<[2, 4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @forward(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[5,5],f32>,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> {
|
||||
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
|
||||
// CHECK: %[[ATOL:.*]] = torch.constant.float 1.000000e-08
|
||||
// CHECK: %[[RTOL:.*]] = torch.constant.float 1.000000e-05
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[VAL_2:.*]] = tosa.sub %[[VAL_0]], %[[VAL_1]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = tosa.abs %[[VAL_2]] : (tensor<5x5xf32>) -> tensor<5x5xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.abs %[[VAL_1]] : (tensor<5x5xf32>) -> tensor<5x5xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<f32>}> : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_4]] {shift = 0 : i32} : (tensor<f32>, tensor<5x5xf32>) -> tensor<5x5xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<9.99999993E-9> : tensor<f32>}> : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_6]] : (tensor<f32>, tensor<5x5xf32>) -> tensor<5x5xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = tosa.greater_equal %[[VAL_8]], %[[VAL_3]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1>
|
||||
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1>
|
||||
// CHECK: return %[[VAL_10]] : !torch.vtensor<[5,5],i1>
|
||||
// CHECK: }
|
||||
func.func @forward(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> {
|
||||
%float1.000000e-08 = torch.constant.float 1.000000e-08
|
||||
%float1.000000e-05 = torch.constant.float 1.000000e-05
|
||||
%false = torch.constant.bool false
|
||||
%0 = torch.aten.isclose %arg0, %arg1, %float1.000000e-05, %float1.000000e-08, %false : !torch.vtensor<[5,5],f32>, !torch.vtensor<[5,5],f32>, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[5,5],i1>
|
||||
return %0 : !torch.vtensor<[5,5],i1>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue