mirror of https://github.com/llvm/torch-mlir
[torch] Folders for `torch.aten.*.tensor` operators [add, sub, mul] (#2878)
Simple folder for limited size aten tensor operations. This is primarily useful for shape computation folding as they unfortunately can use `aten` operators. Add, sub, mul are common examples of these folders.pull/2924/merge
parent
cea51897a5
commit
e80054a3cc
|
@ -3790,6 +3790,7 @@ def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [
|
|||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
|
@ -3839,6 +3840,7 @@ def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [
|
|||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
|
@ -3889,6 +3891,7 @@ def Torch_AtenSubTensorOp : Torch_Op<"aten.sub.Tensor", [
|
|||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -1106,6 +1106,177 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NAry folder helpers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static bool checkSameDTypes(llvm::ArrayRef<Attribute> attrs) {
|
||||
bool allFp = true;
|
||||
bool allInt = true;
|
||||
|
||||
for (auto attr : attrs) {
|
||||
if (!attr)
|
||||
return false;
|
||||
|
||||
Type attrty;
|
||||
if (auto dense = dyn_cast_or_null<ElementsAttr>(attr))
|
||||
attrty = dense.getType();
|
||||
if (auto fp = dyn_cast_or_null<mlir::FloatAttr>(attr))
|
||||
attrty = fp.getType();
|
||||
if (auto integer = dyn_cast_or_null<mlir::IntegerAttr>(attr))
|
||||
attrty = integer.getType();
|
||||
if (auto shaped = dyn_cast_or_null<ShapedType>(attrty))
|
||||
attrty = shaped.getElementType();
|
||||
allFp &= isa<mlir::FloatType>(attrty);
|
||||
allInt &= isa<mlir::IntegerType>(attrty);
|
||||
}
|
||||
|
||||
return allFp || allInt;
|
||||
}
|
||||
|
||||
static bool checkAllSplats(llvm::ArrayRef<Attribute> attrs) {
|
||||
for (auto attr : attrs) {
|
||||
if (auto dense = dyn_cast_or_null<ElementsAttr>(attr)) {
|
||||
if (!dense.isSplat())
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
llvm::SmallVector<double> getFoldValueAtIndexFp(llvm::ArrayRef<Attribute> attrs,
|
||||
int64_t idx = 0) {
|
||||
llvm::SmallVector<double> splattrs;
|
||||
|
||||
for (auto attr : attrs) {
|
||||
if (auto dense = dyn_cast<ElementsAttr>(attr)) {
|
||||
if (dense.isSplat()) {
|
||||
splattrs.push_back(dense.getSplatValue<APFloat>().convertToDouble());
|
||||
} else {
|
||||
splattrs.push_back(dense.getValues<APFloat>()[idx].convertToDouble());
|
||||
}
|
||||
} else if (auto intattr = dyn_cast<FloatAttr>(attr)) {
|
||||
splattrs.push_back(intattr.getValueAsDouble());
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return splattrs;
|
||||
}
|
||||
|
||||
llvm::SmallVector<APInt> getFoldValueAtIndexInt(llvm::ArrayRef<Attribute> attrs,
|
||||
int64_t bitwidth,
|
||||
int64_t idx = 0) {
|
||||
llvm::SmallVector<APInt> splattrs;
|
||||
|
||||
for (auto attr : attrs) {
|
||||
bool isunsigned = false;
|
||||
if (auto dense = dyn_cast<ElementsAttr>(attr)) {
|
||||
isunsigned = dyn_cast<IntegerType>(dense.getElementType()).isUnsigned();
|
||||
if (dense.isSplat()) {
|
||||
splattrs.push_back(dense.getSplatValue<APInt>());
|
||||
} else {
|
||||
splattrs.push_back(dense.getValues<APInt>()[idx]);
|
||||
}
|
||||
} else if (auto intattr = dyn_cast<IntegerAttr>(attr)) {
|
||||
isunsigned = cast<IntegerType>(intattr.getType()).isUnsigned();
|
||||
splattrs.push_back(intattr.getValue());
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
|
||||
auto &apint = splattrs.back();
|
||||
if (apint.getBitWidth() < bitwidth) {
|
||||
if (isunsigned) {
|
||||
apint = apint.zextOrTrunc(bitwidth);
|
||||
} else {
|
||||
apint = apint.sextOrTrunc(bitwidth);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return splattrs;
|
||||
}
|
||||
|
||||
using NAryFoldFpOperator = std::function<double(ArrayRef<double>)>;
|
||||
using NAryFoldIntOperator = std::function<APInt(ArrayRef<APInt>)>;
|
||||
|
||||
static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
|
||||
NAryFoldFpOperator fpFolder,
|
||||
NAryFoldIntOperator intFolder) {
|
||||
constexpr int64_t maxFold = 16;
|
||||
if (!checkSameDTypes(operands))
|
||||
return nullptr;
|
||||
|
||||
auto resultTy = dyn_cast<ValueTensorType>(ty);
|
||||
if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes())
|
||||
return nullptr;
|
||||
|
||||
auto dty = resultTy.getDtype();
|
||||
auto resultBTy = resultTy.toBuiltinTensor().clone(dty);
|
||||
|
||||
auto fpTy = dyn_cast<mlir::FloatType>(dty);
|
||||
auto intTy = dyn_cast<mlir::IntegerType>(dty);
|
||||
if (!fpTy && !intTy)
|
||||
return nullptr;
|
||||
|
||||
bool allSplats = checkAllSplats(operands);
|
||||
bool withinMaxFold =
|
||||
resultBTy.hasStaticShape() && resultBTy.getNumElements() <= maxFold;
|
||||
|
||||
if (!allSplats && !withinMaxFold)
|
||||
return nullptr;
|
||||
|
||||
// We do not support broadcasting in the non-splat case so validate same
|
||||
// shaped inputs / outputs:
|
||||
if (!allSplats) {
|
||||
auto resultShape = resultBTy.getShape();
|
||||
for (int i = 0, s = operands.size(); i < s; ++i) {
|
||||
if (auto dense = dyn_cast<DenseElementsAttr>(operands[i])) {
|
||||
if (dense.isSplat())
|
||||
continue;
|
||||
auto operandShape = cast<ShapedType>(dense.getType()).getShape();
|
||||
if (operandShape.size() != resultShape.size())
|
||||
return nullptr;
|
||||
for (int i = 0, s = operandShape.size(); i < s; ++i)
|
||||
if (operandShape[i] != resultShape[i])
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t numValues = allSplats ? 1 : resultBTy.getNumElements();
|
||||
|
||||
if (fpTy) {
|
||||
llvm::SmallVector<APFloat> folded;
|
||||
for (int i = 0, s = numValues; i < s; ++i) {
|
||||
auto inputs = getFoldValueAtIndexFp(operands, i);
|
||||
double fold = fpFolder(inputs);
|
||||
|
||||
APFloat val(fold);
|
||||
bool unused;
|
||||
val.convert(fpTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
|
||||
&unused);
|
||||
folded.push_back(val);
|
||||
}
|
||||
return DenseElementsAttr::get(resultBTy, folded);
|
||||
}
|
||||
|
||||
if (intTy) {
|
||||
llvm::SmallVector<APInt> folded;
|
||||
for (int i = 0, s = numValues; i < s; ++i) {
|
||||
auto inputs =
|
||||
getFoldValueAtIndexInt(operands, dty.getIntOrFloatBitWidth(), i);
|
||||
folded.push_back(intFolder(inputs));
|
||||
}
|
||||
return DenseElementsAttr::get(resultBTy, folded);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenAddTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1116,6 +1287,20 @@ void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
OpFoldResult AtenAddTensorOp::fold(FoldAdaptor adaptor) {
|
||||
auto fpFold = [](llvm::ArrayRef<double> inputs) {
|
||||
assert(inputs.size() == 3);
|
||||
return inputs[0] + (inputs[1] * inputs[2]);
|
||||
};
|
||||
|
||||
auto intFold = [](llvm::ArrayRef<APInt> inputs) {
|
||||
assert(inputs.size() == 3);
|
||||
return inputs[0] + (inputs[1] * inputs[2]);
|
||||
};
|
||||
|
||||
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenAddScalarOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1136,6 +1321,20 @@ void AtenSubTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
OpFoldResult AtenSubTensorOp::fold(FoldAdaptor adaptor) {
|
||||
auto fpFold = [](llvm::ArrayRef<double> inputs) {
|
||||
assert(inputs.size() == 3);
|
||||
return inputs[0] - (inputs[1] * inputs[2]);
|
||||
};
|
||||
|
||||
auto intFold = [](llvm::ArrayRef<APInt> inputs) {
|
||||
assert(inputs.size() == 3);
|
||||
return inputs[0] - (inputs[1] * inputs[2]);
|
||||
};
|
||||
|
||||
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSubScalarOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1166,6 +1365,20 @@ void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
OpFoldResult AtenMulTensorOp::fold(FoldAdaptor adaptor) {
|
||||
auto fpFold = [](llvm::ArrayRef<double> inputs) {
|
||||
assert(inputs.size() == 2);
|
||||
return inputs[0] * inputs[1];
|
||||
};
|
||||
|
||||
auto intFold = [](llvm::ArrayRef<APInt> inputs) {
|
||||
assert(inputs.size() == 2);
|
||||
return inputs[0] * inputs[1];
|
||||
};
|
||||
|
||||
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenEqTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -340,9 +340,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
# Elementwise tensor compute ops that don't have the standard mutating
|
||||
# variants.
|
||||
emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True)
|
||||
emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
|
||||
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
|
||||
emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
|
||||
emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||
emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||
emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
|
||||
emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
|
||||
emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
|
||||
|
|
|
@ -1916,9 +1916,8 @@ func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor
|
|||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
|
||||
// CHECK: %[[INT6]] = torch.constant.int 6
|
||||
// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[PR0]] : !torch.vtensor<[],si64>
|
||||
// CHECK: %[[INT6:.+]] = torch.vtensor.literal(dense<6> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[INT6]] : !torch.vtensor<[],si64>
|
||||
func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
|
||||
%0 = torch.vtensor.literal(dense<2> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
%1 = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @fold_aten_add_splat_int
|
||||
func.func @fold_aten_add_splat_int() -> !torch.vtensor<[4],si64> {
|
||||
// CHECK: torch.vtensor.literal(dense<29> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64>
|
||||
return %0 : !torch.vtensor<[4],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_add_splat_int_mismatch
|
||||
func.func @fold_aten_add_splat_int_mismatch() -> !torch.vtensor<[4],si64> {
|
||||
// CHECK: torch.vtensor.literal(dense<29> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi32>) : !torch.vtensor<[4],si32>
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si32>, !torch.int -> !torch.vtensor<[4],si64>
|
||||
return %0 : !torch.vtensor<[4],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_add_splat_float
|
||||
func.func @fold_aten_add_splat_float() -> !torch.vtensor<[4],f32> {
|
||||
// CHECK: torch.vtensor.literal(dense<2.900000e+01> : tensor<4xf32>)
|
||||
%int2 = torch.constant.float 2.0
|
||||
%cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32>
|
||||
return %0 : !torch.vtensor<[4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_add_splat_float_mismatch
|
||||
func.func @fold_aten_add_splat_float_mismatch() -> !torch.vtensor<[4],f32> {
|
||||
// CHECK: torch.vtensor.literal(dense<2.900000e+01> : tensor<4xf32>)
|
||||
%int2 = torch.constant.float 2.0
|
||||
%cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf64>) : !torch.vtensor<[4],f64>
|
||||
%cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],f64>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32>
|
||||
return %0 : !torch.vtensor<[4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_add_arr0_int
|
||||
func.func @fold_aten_add_arr0_int() -> !torch.vtensor<[4],si64> {
|
||||
// CHECK: torch.vtensor.literal(dense<[28, 29, 30, 31]> : tensor<4xsi64>)
|
||||
%cst_7 = torch.vtensor.literal(dense<[6,7,8,9]> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64>
|
||||
return %0 : !torch.vtensor<[4],si64>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_add_arr1_int
|
||||
func.func @fold_aten_add_arr1_int() -> !torch.vtensor<[4],si64> {
|
||||
// CHECK: torch.vtensor.literal(dense<[27, 29, 31, 33]> : tensor<4xsi64>)
|
||||
%int2 = torch.constant.int 2
|
||||
%cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%cst_11 = torch.vtensor.literal(dense<[10,11,12,13]> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64>
|
||||
return %0 : !torch.vtensor<[4],si64>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_add_arr0_float
|
||||
func.func @fold_aten_add_arr0_float() -> !torch.vtensor<[4],f32> {
|
||||
// CHECK: torch.vtensor.literal(dense<[2.800000e+01, 2.900000e+01, 3.000000e+01, 3.100000e+01]> : tensor<4xf32>)
|
||||
%int2 = torch.constant.float 2.0
|
||||
%cst_7 = torch.vtensor.literal(dense<[6.0, 7.0, 8.0, 9.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32>
|
||||
return %0 : !torch.vtensor<[4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_add_arr1_float
|
||||
func.func @fold_aten_add_arr1_float() -> !torch.vtensor<[4],f32> {
|
||||
// CHECK: torch.vtensor.literal(dense<[2.700000e+01, 2.900000e+01, 3.100000e+01, 3.300000e+01]> : tensor<4xf32>)
|
||||
%fp_2 = torch.constant.float 2.0
|
||||
%cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%cst_11 = torch.vtensor.literal(dense<[10.0,11.0,12.0,13.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%0 = torch.aten.add.Tensor %cst_7, %cst_11, %fp_2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32>
|
||||
return %0 : !torch.vtensor<[4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_sub_splat_int
|
||||
func.func @fold_aten_sub_splat_int() -> !torch.vtensor<[4],si64> {
|
||||
// CHECK: torch.vtensor.literal(dense<-15> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%int_2 = torch.constant.int 2
|
||||
%cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%0 = torch.aten.sub.Tensor %cst_7, %cst_11, %int_2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64>
|
||||
return %0 : !torch.vtensor<[4],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_sub_splat_float
|
||||
func.func @fold_aten_sub_splat_float() -> !torch.vtensor<[4],f32> {
|
||||
// CHECK: torch.vtensor.literal(dense<-1.500000e+01> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%fp_2 = torch.constant.float 2.0
|
||||
%cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%0 = torch.aten.sub.Tensor %cst_7, %cst_11, %fp_2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32>
|
||||
return %0 : !torch.vtensor<[4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_mul_splat_int
|
||||
func.func @fold_aten_mul_splat_int() -> !torch.vtensor<[4],si64> {
|
||||
// CHECK: torch.vtensor.literal(dense<77> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%0 = torch.aten.mul.Tensor %cst_7, %cst_11: !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64>
|
||||
return %0 : !torch.vtensor<[4],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_mul_splat_float
|
||||
func.func @fold_aten_mul_splat_float() -> !torch.vtensor<[4],f32> {
|
||||
// CHECK: torch.vtensor.literal(dense<7.700000e+01> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%0 = torch.aten.mul.Tensor %cst_7, %cst_11 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],f32>
|
||||
return %0 : !torch.vtensor<[4],f32>
|
||||
}
|
Loading…
Reference in New Issue