mirror of https://github.com/llvm/torch-mlir
[MHLO] support non-constant torch scalar in BasicOps (#1134)
See RFC https://github.com/llvm/torch-mlir/issues/999 Co-authored-by: Bairen Yi yibairen.byron@bytedance.com Co-authored-by: Jiawei Wu xremold@gmail.com Co-authored-by: Tianyou Guo tianyou.gty@alibaba-inc.com Co-authored-by: Xu Yan yancey.yx@alibaba-inc.com Co-authored-by: Ziheng Jiang ziheng.jiang@bytedance.compull/1142/head
parent
82af44da2f
commit
0b23af27d3
|
@ -159,23 +159,15 @@ public:
|
|||
}
|
||||
|
||||
if (!rhsType) {
|
||||
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
|
||||
outElemTy, {})))
|
||||
return op.emitError("currently only scalar constants are supported for "
|
||||
"conversion in MHLO operation");
|
||||
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy);
|
||||
}
|
||||
|
||||
lhs = mhlo::promoteType(rewriter, lhs, outType);
|
||||
rhs = mhlo::promoteType(rewriter, rhs, outType);
|
||||
|
||||
if (!skipMultiplyAlpha(op.alpha())) {
|
||||
Value alpha;
|
||||
if (failed(mhlo::torchAlphaToMhloTensor(rewriter, op.getOperation(),
|
||||
op.alpha(), alpha, outElemTy, {},
|
||||
/*checkForUnity=*/false))) {
|
||||
return op.emitError("currently only scalar constants are supported for "
|
||||
"alpha in conversion to MHLO operation");
|
||||
}
|
||||
Value alpha =
|
||||
mhlo::scalarToMhloTensor(rewriter, op, adaptor.alpha(), outElemTy);
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
|
||||
bcastDimensions);
|
||||
|
@ -216,13 +208,13 @@ public:
|
|||
return op.emitError(
|
||||
"only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
if (!rhsType) {
|
||||
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
|
||||
outElemTy, {})))
|
||||
return op.emitError("currently only scalar constants are supported for "
|
||||
"conversion in MHLO operation");
|
||||
}
|
||||
|
||||
Value lhsTensor = lhs;
|
||||
if (std::is_same<AtenOpT, AtenSquareOp>()) {
|
||||
rhs = lhs;
|
||||
} else if (!rhsType) {
|
||||
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy);
|
||||
}
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
lhs = mhlo::promoteType(rewriter, lhs, outType);
|
||||
rhs = mhlo::promoteType(rewriter, rhs, outType);
|
||||
|
@ -263,11 +255,7 @@ public:
|
|||
}
|
||||
|
||||
if (!rhsTy) {
|
||||
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
|
||||
lhsElemTy, {}))) {
|
||||
return op.emitError("currently only scalar constants are supported for "
|
||||
"conversion in MHLO operation");
|
||||
}
|
||||
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), lhsElemTy);
|
||||
}
|
||||
|
||||
// TODO: what is the PyTorch default type promotion?
|
||||
|
@ -569,12 +557,8 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
|||
.cast<RankedTensorType>();
|
||||
auto outputShape = outputType.getShape();
|
||||
auto outputElemType = outputType.getElementType();
|
||||
Value mhloTensor;
|
||||
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.a(), mhloTensor,
|
||||
outputElemType, outputShape,
|
||||
false))) {
|
||||
return op->emitError("failed lowering PrimNumToTensorScalarOp to MHLO");
|
||||
}
|
||||
Value mhloTensor =
|
||||
mhlo::scalarToMhloTensor(rewriter, op, adaptor.a(), outputElemType);
|
||||
rewriter.replaceOp(op, mhloTensor);
|
||||
return success();
|
||||
}
|
||||
|
@ -1020,4 +1004,4 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
|
||||
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
}
|
||||
}
|
||||
|
|
|
@ -174,93 +174,15 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
|||
return const_op.getResult();
|
||||
}
|
||||
|
||||
// TODO: Support for variable scalar.
|
||||
LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value torchScalarValue,
|
||||
Value &mhloTensor, Type dtype,
|
||||
llvm::ArrayRef<int64_t> dshape,
|
||||
bool doBroadcast) {
|
||||
// Retrieve a const float or int value but create the out Tensor with dtype.
|
||||
double doubleValue;
|
||||
auto isFloat =
|
||||
matchPattern(torchScalarValue, m_TorchConstantFloat(&doubleValue));
|
||||
|
||||
int64_t intValue;
|
||||
auto isInt = matchPattern(torchScalarValue, m_TorchConstantInt(&intValue));
|
||||
|
||||
if (!isFloat && !isInt)
|
||||
return op->emitError("Unable to extract the scalar constant");
|
||||
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
if (doBroadcast) {
|
||||
mhloTensor = getSplatConstTensor<float>(
|
||||
rewriter, op, (isFloat ? doubleValue : intValue), dtype, dshape);
|
||||
} else {
|
||||
mhloTensor = mhlo::getConstTensor<float>(
|
||||
rewriter, op, (isFloat ? doubleValue : intValue), dshape)
|
||||
.getValue();
|
||||
}
|
||||
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
||||
auto w = intType.getWidth();
|
||||
if (w != 32 && w != 64)
|
||||
return op->emitError("Unsupported integer type") << intType;
|
||||
|
||||
if (w == 32) {
|
||||
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||
return op->emitError("Supplied value of scalar constant exceeds limits "
|
||||
"of destination type");
|
||||
}
|
||||
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
|
||||
: static_cast<int32_t>(intValue);
|
||||
if (doBroadcast) {
|
||||
mhloTensor =
|
||||
getSplatConstTensor<int32_t>(rewriter, op, d, dtype, dshape);
|
||||
} else {
|
||||
mhloTensor =
|
||||
mhlo::getConstTensor<int32_t>(rewriter, op, {d}, dshape).getValue();
|
||||
}
|
||||
} else if (w == 64) {
|
||||
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||
return op->emitError("Supplied value of scalar constant exceeds limits "
|
||||
"of destination type");
|
||||
}
|
||||
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
|
||||
if (doBroadcast) {
|
||||
mhloTensor =
|
||||
getSplatConstTensor<int64_t>(rewriter, op, d, dtype, dshape);
|
||||
} else {
|
||||
mhloTensor =
|
||||
mhlo::getConstTensor<int64_t>(rewriter, op, {d}, dshape).getValue();
|
||||
}
|
||||
}
|
||||
} else
|
||||
return op->emitError("Usupported element type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value alphaScalar,
|
||||
Value &alphaTensor, Type dtype,
|
||||
llvm::ArrayRef<int64_t> dshape,
|
||||
bool checkForUnity) {
|
||||
if (succeeded(torchScalarToMhloTensor(rewriter, op, alphaScalar, alphaTensor,
|
||||
dtype, dshape)))
|
||||
return success();
|
||||
|
||||
// `alpha` has not been specified.
|
||||
int64_t alphaValue;
|
||||
if (!matchPattern(alphaScalar, m_TorchConstantInt(&alphaValue)))
|
||||
return op->emitError("Currently only scalar constants are supported for "
|
||||
"alpha in MHLO operation");
|
||||
// When no alpha has been specified, this must be 1.
|
||||
if (checkForUnity && alphaValue != 1)
|
||||
return op->emitError("Unsupported integer value for alpha");
|
||||
|
||||
alphaTensor =
|
||||
mlir::mhlo::getMhloConstTensorSingleF32(rewriter, op, alphaValue);
|
||||
|
||||
return success();
|
||||
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value scalarValue, Type dtype) {
|
||||
auto tensor = rewriter.create<tensor::FromElementsOp>(
|
||||
op->getLoc(), ArrayRef<Value>{scalarValue});
|
||||
auto dtype_tensor =
|
||||
rewriter.create<mhlo::ConvertOp>(op->getLoc(), tensor, dtype);
|
||||
return rewriter.create<mhlo::ReshapeOp>(
|
||||
op->getLoc(), RankedTensorType::get(mlir::ArrayRef<int64_t>{}, dtype),
|
||||
dtype_tensor);
|
||||
}
|
||||
|
||||
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
|
||||
|
@ -439,4 +361,4 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
|||
.getResult();
|
||||
}
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
} // namespace mlir
|
||||
|
|
|
@ -47,17 +47,8 @@ template <typename T>
|
|||
Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
T val, Type dtype, llvm::ArrayRef<int64_t> dshape);
|
||||
|
||||
LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value torchScalarValue,
|
||||
Value &mhloTensor, Type dtype,
|
||||
llvm::ArrayRef<int64_t> dshape,
|
||||
bool doBroadcast = true);
|
||||
|
||||
LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value alphaScalar,
|
||||
Value &alphaTensor, Type dtype,
|
||||
llvm::ArrayRef<int64_t> dshape,
|
||||
bool checkForUnity);
|
||||
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value scalarValue, Type dtype);
|
||||
|
||||
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
|
||||
|
||||
|
|
|
@ -41,11 +41,15 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[],si64> {
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<i64>
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<i64> -> !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[VAL_1]] : !torch.vtensor<[],si64>
|
||||
// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic(
|
||||
// CHECK-SAME: ) -> !torch.vtensor<[],si64> {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]]
|
||||
// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64>
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T3:.*]] = "mhlo.reshape"(%[[T2]]) : (tensor<1xi64>) -> tensor<i64>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<i64> -> !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[],si64>
|
||||
func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[], si64>
|
||||
|
@ -251,4 +255,4 @@ func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) ->
|
|||
%2 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %2, %1, %0, %float1.000000e-05 : !torch.vtensor<[3,7,4,5],f32>, !torch.list<int>, !torch.vtensor<[4,5],f32>, !torch.vtensor<[4,5],f32>, !torch.float -> !torch.vtensor<[3,7,4,5],f32>, !torch.vtensor<[3,7,1,1],f32>, !torch.vtensor<[3,7,1,1],f32>
|
||||
return %result0 : !torch.vtensor<[3,7,4,5],f32>
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.gelu(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[STR:.*]] = torch.constant.str "none"
|
||||
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
|
@ -22,13 +21,14 @@ func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
|
|||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.tanh$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = mhlo.tanh %[[VAL_1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.tanh$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.tanh %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
|
@ -36,12 +36,12 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.log$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = mhlo.log %[[VAL_1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.log$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.log %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.log %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
|
@ -49,43 +49,44 @@ func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.exp$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = mhlo.exponential %[[VAL_1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.exp$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.exponential %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.exp %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.neg$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = mhlo.negate %[[VAL_1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.neg$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.negate %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.addscalar$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %int9 = torch.constant.int 9
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_3:.*]] = chlo.broadcast_add %[[VAL_1]], %[[VAL_2]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.addscalar$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[INT9:.*]] = torch.constant.int 9
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = chlo.broadcast_add %[[T0]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int9 = torch.constant.int 9
|
||||
%int1 = torch.constant.int 1
|
||||
|
@ -95,17 +96,23 @@ func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.addscalar$alpha(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %int9 = torch.constant.int 9
|
||||
// CHECK: %int2 = torch.constant.int 2
|
||||
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_4:.*]] = chlo.broadcast_multiply %[[VAL_2]], %[[VAL_3]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[VAL_5:.*]] = chlo.broadcast_add %[[VAL_1]], %[[VAL_4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.addscalar$alpha(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[INT9:.*]] = torch.constant.int 9
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
|
||||
// CHECK: %[[T3:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T4:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T5:.*]] = "mhlo.reshape"(%[[T4]]) : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
|
||||
// CHECK: %[[T7:.*]] = mhlo.convert(%[[T6]]) : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T8:.*]] = "mhlo.reshape"(%[[T7]]) : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T9:.*]] = chlo.broadcast_multiply %[[T5]], %[[T8]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[T10:.*]] = chlo.broadcast_add %[[T0]], %[[T9]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T11]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.addscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int9 = torch.constant.int 9
|
||||
%int2 = torch.constant.int 2
|
||||
|
@ -115,15 +122,14 @@ func.func @torch.aten.addscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.addtensor$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %[[VAL_4:.*]] = chlo.broadcast_add %[[VAL_2]], %[[VAL_3]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.addtensor$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[T2:.*]] = chlo.broadcast_add %[[T0]], %[[T1]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
|
@ -132,17 +138,19 @@ func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.addtensor$alpha(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %int2 = torch.constant.int 2
|
||||
// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_5:.*]] = chlo.broadcast_multiply %[[VAL_3]], %[[VAL_4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = chlo.broadcast_add %[[VAL_2]], %[[VAL_5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.addtensor$alpha(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
|
||||
// CHECK: %[[T3:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
|
||||
// CHECK: %[[T4:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T5:.*]] = "mhlo.reshape"(%[[T4]]) : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T6:.*]] = chlo.broadcast_multiply %[[T1]], %[[T5]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T7:.*]] = chlo.broadcast_add %[[T0]], %[[T6]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T8]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.aten.add.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
|
@ -151,16 +159,15 @@ func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.addtensor$promote(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %[[VAL_4:.*]] = mhlo.convert(%[[VAL_2]]) : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_5:.*]] = chlo.broadcast_add %[[VAL_4]], %[[VAL_3]] : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],si64>
|
||||
// CHECK-LABEL: func.func @torch.aten.addtensor$promote(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],si32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert(%[[T0]]) : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T3:.*]] = chlo.broadcast_add %[[T2]], %[[T1]] : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64>
|
||||
func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si64>, !torch.int -> !torch.vtensor<[?,?],si64>
|
||||
|
@ -169,15 +176,18 @@ func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.subscalar$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %int9 = torch.constant.int 9
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_3:.*]] = chlo.broadcast_subtract %[[VAL_1]], %[[VAL_2]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.subscalar$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[INT9:.*]] = torch.constant.int 9
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = chlo.broadcast_subtract %[[T0]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int9 = torch.constant.int 9
|
||||
%int1 = torch.constant.int 1
|
||||
|
@ -187,17 +197,23 @@ func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.subscalar$alpha(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %int9 = torch.constant.int 9
|
||||
// CHECK: %int2 = torch.constant.int 2
|
||||
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_4:.*]] = chlo.broadcast_multiply %[[VAL_2]], %[[VAL_3]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[VAL_5:.*]] = chlo.broadcast_subtract %[[VAL_1]], %[[VAL_4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.subscalar$alpha(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[INT9:.*]] = torch.constant.int 9
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
|
||||
// CHECK: %[[T3:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T4:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T5:.*]] = "mhlo.reshape"(%[[T4]]) : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
|
||||
// CHECK: %[[T7:.*]] = mhlo.convert(%[[T6]]) : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T8:.*]] = "mhlo.reshape"(%[[T7]]) : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T9:.*]] = chlo.broadcast_multiply %[[T5]], %[[T8]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[T10:.*]] = chlo.broadcast_subtract %[[T0]], %[[T9]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T11]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.subscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int9 = torch.constant.int 9
|
||||
%int2 = torch.constant.int 2
|
||||
|
@ -207,15 +223,14 @@ func.func @torch.aten.subscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.subtensor$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %[[VAL_4:.*]] = chlo.broadcast_subtract %[[VAL_2]], %[[VAL_3]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.subtensor$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[T2:.*]] = chlo.broadcast_subtract %[[T0]], %[[T1]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
|
@ -224,17 +239,19 @@ func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.subtensor$alpha(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %int2 = torch.constant.int 2
|
||||
// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_5:.*]] = chlo.broadcast_multiply %[[VAL_3]], %[[VAL_4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = chlo.broadcast_subtract %[[VAL_2]], %[[VAL_5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.subtensor$alpha(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
|
||||
// CHECK: %[[T3:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
|
||||
// CHECK: %[[T4:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T5:.*]] = "mhlo.reshape"(%[[T4]]) : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T6:.*]] = chlo.broadcast_multiply %[[T1]], %[[T5]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T7:.*]] = chlo.broadcast_subtract %[[T0]], %[[T6]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T8]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.aten.sub.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
|
@ -243,16 +260,15 @@ func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.subtensor$promote(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %[[VAL_4:.*]] = mhlo.convert(%[[VAL_2]]) : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_5:.*]] = chlo.broadcast_subtract %[[VAL_4]], %[[VAL_3]] : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],si64>
|
||||
// CHECK-LABEL: func.func @torch.aten.subtensor$promote(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],si32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert(%[[T0]]) : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T3:.*]] = chlo.broadcast_subtract %[[T2]], %[[T1]] : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64>
|
||||
func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si64>, !torch.int -> !torch.vtensor<[?,?],si64>
|
||||
|
@ -261,14 +277,17 @@ func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.mulscalar$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %int9 = torch.constant.int 9
|
||||
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_3:.*]] = chlo.broadcast_multiply %[[VAL_1]], %[[VAL_2]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.mulscalar$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[INT9:.*]] = torch.constant.int 9
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
||||
// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T0]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.mulscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int9 = torch.constant.int 9
|
||||
%0 = torch.aten.mul.Scalar %arg0, %int9 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
|
@ -277,14 +296,13 @@ func.func @torch.aten.mulscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.multensor$basic(
|
||||
// CHECK-SAME: %[[VLA_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VLA_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VLA_2:.*]] = torch_c.to_builtin_tensor %[[VLA_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VLA_3:.*]] = torch_c.to_builtin_tensor %[[VLA_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VLA_4:.*]] = chlo.broadcast_multiply %[[VLA_2]], %[[VLA_3]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VLA_5:.*]] = torch_c.from_builtin_tensor %[[VLA_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VLA_5]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.multensor$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = chlo.broadcast_multiply %[[T0]], %[[T1]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
|
@ -292,14 +310,17 @@ func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.divscalar$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %int9 = torch.constant.int 9
|
||||
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_3:.*]] = chlo.broadcast_divide %[[VAL_1]], %[[VAL_2]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.divscalar$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[INT9:.*]] = torch.constant.int 9
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
||||
// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = chlo.broadcast_divide %[[T0]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.divscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int9 = torch.constant.int 9
|
||||
%0 = torch.aten.div.Scalar %arg0, %int9 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
|
@ -308,14 +329,13 @@ func.func @torch.aten.divscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.divtensor$basic(
|
||||
// CHECK-SAME: %[[VLA_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VLA_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VLA_2:.*]] = torch_c.to_builtin_tensor %[[VLA_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VLA_3:.*]] = torch_c.to_builtin_tensor %[[VLA_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VLA_4:.*]] = chlo.broadcast_divide %[[VLA_2]], %[[VLA_3]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VLA_5:.*]] = torch_c.from_builtin_tensor %[[VLA_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VLA_5]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.divtensor$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
|
@ -323,14 +343,17 @@ func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.gt.scalar(
|
||||
// CHECK-SAME: %arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %int3 = torch.constant.int 3
|
||||
// CHECK: %1 = mhlo.constant dense<3.000000e+00> : tensor<f32>
|
||||
// CHECK: %2 = chlo.broadcast_compare %0, %1 {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||
// CHECK: %3 = torch_c.from_builtin_tensor %2 : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %3 : !torch.vtensor<[?,?],i1>
|
||||
// CHECK-LABEL: func.func @torch.aten.gt.scalar(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT3]]
|
||||
// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = chlo.broadcast_compare %[[T0]], %[[T4]] {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,?],i1>
|
||||
func.func @torch.aten.gt.scalar(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
%int3 = torch.constant.int 3
|
||||
%0 = torch.aten.gt.Scalar %arg0, %int3 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],i1>
|
||||
|
@ -339,14 +362,13 @@ func.func @torch.aten.gt.scalar(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.gt.tensor(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<64xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
|
||||
// CHECK-LABEL: func.func @torch.aten.gt.tensor(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
|
||||
// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<64xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1>
|
||||
func.func @torch.aten.gt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
|
@ -354,14 +376,13 @@ func.func @torch.aten.gt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.lt.tensor(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<?x?xf32>, tensor<64xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
|
||||
// CHECK-LABEL: func.func @torch.aten.lt.tensor(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
|
||||
// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<?x?xf32>, tensor<64xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1>
|
||||
func.func @torch.aten.lt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.lt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
|
@ -369,14 +390,13 @@ func.func @torch.aten.lt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.eq.tensor(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<?x?xf32>, tensor<64xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
|
||||
// CHECK-LABEL: func.func @torch.aten.eq.tensor(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
|
||||
// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<?x?xf32>, tensor<64xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1>
|
||||
func.func @torch.aten.eq.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
|
@ -384,14 +404,13 @@ func.func @torch.aten.eq.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.ne.tensor(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction NE>} : (tensor<?x?xf32>, tensor<64xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
|
||||
// CHECK-LABEL: func.func @torch.aten.ne.tensor(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
|
||||
// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction NE>} : (tensor<?x?xf32>, tensor<64xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1>
|
||||
func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.ne.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
|
@ -399,15 +418,15 @@ func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.permute$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_5:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<64x4xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32>
|
||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[64,4],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.permute$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T2:.*]] = "mhlo.transpose"(%[[T0]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<64x4xf32>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[64,4],f32>
|
||||
func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
|
@ -418,14 +437,106 @@ func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.relu(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 0.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = mhlo.maximum %[[VAL_1]], %[[VAL_2]] : tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-LABEL: func.func @torch.aten.relu(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 0.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = mhlo.maximum %[[T0]], %[[T1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.relu %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.addscalar$variable(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.float) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_f64 %[[ARG1]]
|
||||
// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xf64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64>
|
||||
// CHECK: %[[T6:.*]] = mhlo.convert(%[[T5]]) : (tensor<1xf64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T7:.*]] = "mhlo.reshape"(%[[T6]]) : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T8:.*]] = chlo.broadcast_multiply %[[T4]], %[[T7]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[T9:.*]] = chlo.broadcast_add %[[T0]], %[[T8]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.addscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.float) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.add.Scalar %arg0, %arg1, %arg1: !torch.vtensor<[?,?],f32>, !torch.float, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.addtensor$variable(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG2:.*]]: !torch.float) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.to_f64 %[[ARG2]]
|
||||
// CHECK: %[[T3:.*]] = tensor.from_elements %[[T2]] : tensor<1xf64>
|
||||
// CHECK: %[[T4:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xf64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T5:.*]] = "mhlo.reshape"(%[[T4]]) : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T6:.*]] = chlo.broadcast_multiply %[[T1]], %[[T5]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T7:.*]] = chlo.broadcast_add %[[T0]], %[[T6]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T8]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.addtensor$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>, %arg2: !torch.float) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.add.Tensor %arg0, %arg1, %arg2: !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.mulscalar$variable(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
|
||||
// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T0]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.mulscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.mul.Scalar %arg0, %arg1: !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.divscalar$variable(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
|
||||
// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = chlo.broadcast_divide %[[T0]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.divscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.div.Scalar %arg0, %arg1: !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.gt.scalar$variable(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
|
||||
// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = chlo.broadcast_compare %[[T0]], %[[T4]] {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,?],i1>
|
||||
func.func @torch.aten.gt.scalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.gt.Scalar %arg0, %arg1: !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue