mirror of https://github.com/llvm/torch-mlir
[Torch] fold aten.log (#3223)
parent
122eb69a98
commit
634a796933
|
@ -256,51 +256,6 @@ def Torch_AtenLeakyRelu_Op : Torch_Op<"aten.leaky_relu_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLogOp : Torch_Op<"aten.log", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::log : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenLogOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenLogOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLog_Op : Torch_Op<"aten.log_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::log_ : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
Torch_NonValueTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalNonValueTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenLog_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenLog_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSeluOp : Torch_Op<"aten.selu", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -4085,6 +4040,52 @@ def Torch_AtenNe_ScalarOp : Torch_Op<"aten.ne_.Scalar", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLogOp : Torch_Op<"aten.log", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::log : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenLogOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenLogOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenLog_Op : Torch_Op<"aten.log_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::log_ : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
Torch_NonValueTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalNonValueTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenLog_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenLog_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenFloorOp : Torch_Op<"aten.floor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -1241,16 +1241,20 @@ llvm::SmallVector<APInt> getFoldValueAtIndexInt(llvm::ArrayRef<Attribute> attrs,
|
|||
llvm::SmallVector<APInt> splattrs;
|
||||
|
||||
for (auto attr : attrs) {
|
||||
bool isunsigned = false;
|
||||
// Note that i1 is neither signed nor unsigned.
|
||||
// But we should trait i1 as unsigned, otherwise that
|
||||
// APInt(1,1).getSExtValue() return allOnes 64-bit integer.
|
||||
// So here only distinguish signed integer.
|
||||
bool isSigned = false;
|
||||
if (auto dense = dyn_cast<ElementsAttr>(attr)) {
|
||||
isunsigned = dyn_cast<IntegerType>(dense.getElementType()).isUnsigned();
|
||||
isSigned = dyn_cast<IntegerType>(dense.getElementType()).isSigned();
|
||||
if (dense.isSplat()) {
|
||||
splattrs.push_back(dense.getSplatValue<APInt>());
|
||||
} else {
|
||||
splattrs.push_back(dense.getValues<APInt>()[idx]);
|
||||
}
|
||||
} else if (auto intattr = dyn_cast<IntegerAttr>(attr)) {
|
||||
isunsigned = cast<IntegerType>(intattr.getType()).isUnsigned();
|
||||
isSigned = cast<IntegerType>(intattr.getType()).isSigned();
|
||||
splattrs.push_back(intattr.getValue());
|
||||
} else {
|
||||
return {};
|
||||
|
@ -1258,10 +1262,10 @@ llvm::SmallVector<APInt> getFoldValueAtIndexInt(llvm::ArrayRef<Attribute> attrs,
|
|||
|
||||
auto &apint = splattrs.back();
|
||||
if (apint.getBitWidth() < bitwidth) {
|
||||
if (isunsigned) {
|
||||
apint = apint.zextOrTrunc(bitwidth);
|
||||
} else {
|
||||
if (isSigned) {
|
||||
apint = apint.sextOrTrunc(bitwidth);
|
||||
} else {
|
||||
apint = apint.zextOrTrunc(bitwidth);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1795,6 +1799,81 @@ OpFoldResult AtenNeScalarOp::fold(FoldAdaptor adaptor) {
|
|||
return comparisonScaleFolder(self, other, resultTy, fpFold, intFold);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenLogOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
using UnaryPromoteFpOperator = std::function<double(double)>;
|
||||
using UnaryPromoteIntOperator = std::function<double(APInt, bool)>;
|
||||
|
||||
static OpFoldResult unaryPromoteFolder(DenseElementsAttr operand,
|
||||
ValueTensorType resultTy,
|
||||
UnaryPromoteFpOperator fpFolder,
|
||||
UnaryPromoteIntOperator intFolder) {
|
||||
constexpr int64_t kMaxFold = 16;
|
||||
if (!resultTy.hasDtype() || !resultTy.hasSizes())
|
||||
return nullptr;
|
||||
if (!isa<mlir::FloatType>(resultTy.getDtype()))
|
||||
return nullptr;
|
||||
|
||||
auto fpTy = dyn_cast<mlir::FloatType>(operand.getType().getElementType());
|
||||
auto intTy = dyn_cast<mlir::IntegerType>(operand.getType().getElementType());
|
||||
if (!fpTy && !intTy)
|
||||
return nullptr;
|
||||
|
||||
auto resultBTy = resultTy.toBuiltinTensor().clone(resultTy.getDtype());
|
||||
bool splat = operand.isSplat();
|
||||
bool withinMaxFold =
|
||||
resultBTy.hasStaticShape() && resultBTy.getNumElements() <= kMaxFold;
|
||||
if (!splat && !withinMaxFold)
|
||||
return nullptr;
|
||||
|
||||
const int64_t numValues = splat ? 1 : resultBTy.getNumElements();
|
||||
|
||||
llvm::SmallVector<Attribute> operands = {operand};
|
||||
llvm::SmallVector<APFloat> folded;
|
||||
for (int i = 0, s = numValues; i < s; ++i) {
|
||||
double fold = 0.0;
|
||||
if (fpTy) {
|
||||
auto inputs = getFoldValueAtIndexFp(operands, i);
|
||||
fold = fpFolder(inputs[0]);
|
||||
}
|
||||
if (intTy) {
|
||||
auto inputs =
|
||||
getFoldValueAtIndexInt(operands, intTy.getIntOrFloatBitWidth(), i);
|
||||
fold = intFolder(inputs[0], intTy.isSigned());
|
||||
}
|
||||
|
||||
APFloat val(fold);
|
||||
bool unused;
|
||||
val.convert(
|
||||
cast<mlir::FloatType>(resultBTy.getElementType()).getFloatSemantics(),
|
||||
APFloat::rmNearestTiesToEven, &unused);
|
||||
folded.push_back(val);
|
||||
}
|
||||
return DenseElementsAttr::get(resultBTy, folded);
|
||||
}
|
||||
|
||||
OpFoldResult AtenLogOp::fold(FoldAdaptor adaptor) {
|
||||
auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
||||
auto resultType = dyn_cast<ValueTensorType>(getType());
|
||||
if (!self || !resultType)
|
||||
return nullptr;
|
||||
|
||||
// Note that i1 is neither signed nor unsigned.
|
||||
// But we should trait i1 as unsigned, otherwise that
|
||||
// APInt(1,1).getSExtValue() return allOnes 64-bit integer.
|
||||
auto intFold = [](APInt a, bool isSigned) -> double {
|
||||
if (isSigned)
|
||||
return std::log(a.getSExtValue());
|
||||
else
|
||||
return std::log(a.getZExtValue());
|
||||
};
|
||||
auto fpFold = [](double a) -> double { return std::log(a); };
|
||||
|
||||
return unaryPromoteFolder(self, resultType, fpFold, intFold);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenFloorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -268,7 +268,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::relu : (Tensor) -> (Tensor)",
|
||||
"aten::relu6 : (Tensor) -> (Tensor)",
|
||||
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::log : (Tensor) -> (Tensor)",
|
||||
"aten::selu : (Tensor) -> (Tensor)",
|
||||
"aten::sigmoid : (Tensor) -> (Tensor)",
|
||||
"aten::sinh : (Tensor) -> (Tensor)",
|
||||
|
@ -356,6 +355,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit_with_mutating_variants("aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::log : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)
|
||||
|
|
|
@ -2888,3 +2888,36 @@ func.func @aten_tensor_tensor_ne() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4
|
|||
%fpBool = torch.aten.ne.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1>
|
||||
return %intBool, %uintBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @aten_log$fold_splat_i1
|
||||
func.func @aten_log$fold_splat_i1() -> !torch.vtensor<[4], f32> {
|
||||
// CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
// CHECK: return %[[RET]] : !torch.vtensor<[4],f32>
|
||||
%cst = torch.vtensor.literal(dense<true> : tensor<4xi1>) : !torch.vtensor<[4], i1>
|
||||
%result = torch.aten.log %cst : !torch.vtensor<[4], i1> -> !torch.vtensor<[4], f32>
|
||||
return %result : !torch.vtensor<[4], f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @aten_log$fold_splat_si32
|
||||
func.func @aten_log$fold_splat_si32() -> !torch.vtensor<[4], f32> {
|
||||
// CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<1.09861231> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
// CHECK: return %[[RET]] : !torch.vtensor<[4],f32>
|
||||
%cst = torch.vtensor.literal(dense<3> : tensor<4xsi32>) : !torch.vtensor<[4], si32>
|
||||
%result = torch.aten.log %cst : !torch.vtensor<[4], si32> -> !torch.vtensor<[4], f32>
|
||||
return %result : !torch.vtensor<[4], f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @aten_log$fold_splat_f32
|
||||
func.func @aten_log$fold_splat_f32() -> !torch.vtensor<[4], f32> {
|
||||
// CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<1.09861231> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
// CHECK: return %[[RET]] : !torch.vtensor<[4],f32>
|
||||
%cst = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4], f32>
|
||||
%result = torch.aten.log %cst : !torch.vtensor<[4], f32> -> !torch.vtensor<[4], f32>
|
||||
return %result : !torch.vtensor<[4], f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue