add e2e support for torch.log10 (#2479)

pull/2485/head snapshot-20230929.976
saienduri 2023-09-28 10:17:03 -07:00 committed by GitHub
parent 8abfa5b196
commit 4e1dd3bf10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 111 additions and 2 deletions

View File

@ -2527,6 +2527,51 @@ def Torch_AtenLog2_Op : Torch_Op<"aten.log2_", [
}]; }];
} }
def Torch_AtenLog10Op : Torch_Op<"aten.log10", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::log10 : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLog10Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLog10Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenLog10_Op : Torch_Op<"aten.log10_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::log10_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLog10_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLog10_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [ def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -235,6 +235,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return createCalculationForMathOpWithDtypeConversion<math::Log2Op>( return createCalculationForMathOpWithDtypeConversion<math::Log2Op>(
b, converter, payloadArgs[0], op); b, converter, payloadArgs[0], op);
} }
if (isa<AtenLog10Op>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::Log10Op>(
b, converter, payloadArgs[0], op);
}
if (isa<AtenLog1pOp>(op)) { if (isa<AtenLog1pOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::Log1pOp>( return createCalculationForMathOpWithDtypeConversion<math::Log1pOp>(
b, converter, payloadArgs[0], op); b, converter, payloadArgs[0], op);
@ -1177,7 +1181,7 @@ public:
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp,
AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp,
AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp,
AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp,
@ -1712,7 +1716,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp,
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp,

View File

@ -6322,6 +6322,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.log10\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.log1p\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.log1p\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -8291,6 +8295,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
" return %1 : !torch.int\n" " return %1 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.log10\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"

View File

@ -122,6 +122,9 @@ def atendetach〡shape(self: List[int]) -> List[int]:
def atenlog2〡shape(self: List[int]) -> List[int]: def atenlog2〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self) return upstream_shape_functions.unary(self)
def atenlog10〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenlog1p〡shape(self: List[int]) -> List[int]: def atenlog1p〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self) return upstream_shape_functions.unary(self)
@ -1438,6 +1441,11 @@ def atenlog2〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype
return _get_dtype_of_floating_point_op(self_dtype) return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenlog10〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenlog1p〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def atenlog1p〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype

View File

@ -294,6 +294,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::clamp_max : (Tensor, Scalar) -> (Tensor)", "aten::clamp_max : (Tensor, Scalar) -> (Tensor)",
"aten::clamp_max.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::clamp_max.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::log2 : (Tensor) -> (Tensor)", "aten::log2 : (Tensor) -> (Tensor)",
"aten::log10 : (Tensor) -> (Tensor)",
"aten::sqrt : (Tensor) -> (Tensor)", "aten::sqrt : (Tensor) -> (Tensor)",
"aten::log1p : (Tensor) -> (Tensor)", "aten::log1p : (Tensor) -> (Tensor)",
"aten::rsqrt : (Tensor) -> (Tensor)", "aten::rsqrt : (Tensor) -> (Tensor)",

View File

@ -1683,6 +1683,48 @@ def ElementwiseLog2IntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
# ==============================================================================
class ElementwiseLog10Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.log10(a)
@register_test_case(module_factory=lambda: ElementwiseLog10Module())
def ElementwiseLog10Module_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseLog10IntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, a):
return torch.log10(a)
@register_test_case(module_factory=lambda: ElementwiseLog10IntModule())
def ElementwiseLog10IntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
# ============================================================================== # ==============================================================================