mirror of https://github.com/llvm/torch-mlir
[MHLO] Init MHLO basic op conversion (#1092)
* [MHLO] Init MHLO basic Op Conversion Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com> Co-authored-by: Jiawei Wu <xremold@gmail.com> Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com> Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com> Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com> * [NFC] Remove 'from @llvm-project' annotation Co-authored-by: wujiawei.jw <wujiawei.jw@bytedance.com>pull/1116/head
parent
e23fbc89f2
commit
052d2f84dc
File diff suppressed because it is too large
Load Diff
|
@ -1,5 +1,6 @@
|
||||||
add_mlir_conversion_library(TorchMLIRTorchToMhlo
|
add_mlir_conversion_library(TorchMLIRTorchToMhlo
|
||||||
TorchToMhlo.cpp
|
TorchToMhlo.cpp
|
||||||
|
MhloLegalizeUtils.cpp
|
||||||
BasicOp.cpp
|
BasicOp.cpp
|
||||||
GatherOp.cpp
|
GatherOp.cpp
|
||||||
ViewLikeOps.cpp
|
ViewLikeOps.cpp
|
||||||
|
|
|
@ -0,0 +1,318 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||||
|
#include "./MhloLegalizeUtils.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace mhlo {
|
||||||
|
|
||||||
|
// Create a 32-bit float constant operator from a float
|
||||||
|
Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||||
|
float val) {
|
||||||
|
auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
|
||||||
|
auto const_attr = DenseElementsAttr::get(const_type, val);
|
||||||
|
|
||||||
|
auto const_op =
|
||||||
|
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||||
|
return const_op.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a 64-bit float constant operator from a double
|
||||||
|
Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
|
||||||
|
double val) {
|
||||||
|
auto const_type = RankedTensorType::get({}, rewriter.getF64Type());
|
||||||
|
auto const_attr = DenseElementsAttr::get(const_type, val);
|
||||||
|
|
||||||
|
auto const_op =
|
||||||
|
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||||
|
return const_op.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Templated function to create a constant op for given type and shape.
|
||||||
|
// T: storage C type.
|
||||||
|
// Default template creates a constant tensor in T.
|
||||||
|
template <typename T>
|
||||||
|
llvm::Optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
|
ArrayRef<T> vec, ArrayRef<int64_t> shape) {
|
||||||
|
uint64_t num_total_elements = 1;
|
||||||
|
for (int64_t a : shape) {
|
||||||
|
num_total_elements *= a;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vec.size() != num_total_elements) {
|
||||||
|
op->emitOpError("getConstTensor(): number of elements mismatch.");
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const_type =
|
||||||
|
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
|
||||||
|
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||||
|
|
||||||
|
auto const_op =
|
||||||
|
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||||
|
return const_op.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Template specialization for APInt
|
||||||
|
template <>
|
||||||
|
llvm::Optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
|
||||||
|
Operation *op, ArrayRef<APInt> vec,
|
||||||
|
ArrayRef<int64_t> shape) {
|
||||||
|
uint64_t num_total_elements = 1;
|
||||||
|
for (int64_t a : shape) {
|
||||||
|
num_total_elements *= a;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vec.size() != num_total_elements) {
|
||||||
|
op->emitOpError("getConstTensor(): number of elements mismatch.");
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
auto const_type = RankedTensorType::get(
|
||||||
|
shape, rewriter.getIntegerType(vec[0].getBitWidth()));
|
||||||
|
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||||
|
|
||||||
|
auto const_op =
|
||||||
|
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||||
|
return const_op.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Template specialization for float
|
||||||
|
template <>
|
||||||
|
llvm::Optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
||||||
|
Operation *op, ArrayRef<float> vec,
|
||||||
|
ArrayRef<int64_t> shape) {
|
||||||
|
uint64_t num_total_elements = 1;
|
||||||
|
for (int64_t a : shape) {
|
||||||
|
num_total_elements *= a;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vec.size() != num_total_elements) {
|
||||||
|
op->emitOpError("getConstTensor(): number of elements mismatch.");
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
|
||||||
|
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||||
|
|
||||||
|
auto const_op =
|
||||||
|
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||||
|
return const_op.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
llvm::Optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
|
||||||
|
Operation *op, ArrayRef<double> vec,
|
||||||
|
ArrayRef<int64_t> shape) {
|
||||||
|
uint64_t num_total_elements = 1;
|
||||||
|
for (int64_t a : shape) {
|
||||||
|
num_total_elements *= a;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vec.size() != num_total_elements) {
|
||||||
|
op->emitOpError("getConstTensor(): number of elements mismatch.");
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const_type = RankedTensorType::get(shape, rewriter.getF64Type());
|
||||||
|
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||||
|
|
||||||
|
auto const_op =
|
||||||
|
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||||
|
return const_op.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Template instantiation
|
||||||
|
template llvm::Optional<Value> getConstTensor<int32_t>(PatternRewriter &,
|
||||||
|
Operation *,
|
||||||
|
ArrayRef<int32_t> vec,
|
||||||
|
ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
template llvm::Optional<Value> getConstTensor<int64_t>(PatternRewriter &,
|
||||||
|
Operation *,
|
||||||
|
ArrayRef<int64_t> vec,
|
||||||
|
ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
|
||||||
|
const int64_t &intValue) {
|
||||||
|
if (isFloat) {
|
||||||
|
// Do a round-trip check here instead of numeric limits due to
|
||||||
|
// compiler warnings around double <-> int conversion.
|
||||||
|
return (doubleValue == static_cast<double>(static_cast<T>(doubleValue)));
|
||||||
|
} else {
|
||||||
|
assert(isInt);
|
||||||
|
return (intValue >= std::numeric_limits<T>::min()) &&
|
||||||
|
(intValue <= std::numeric_limits<T>::max());
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Value getSplatConstTensor(ConversionPatternRewriter &rewriter,
|
||||||
|
Operation *op,
|
||||||
|
T val,
|
||||||
|
Type dtype,
|
||||||
|
llvm::ArrayRef<int64_t> dshape) {
|
||||||
|
auto const_type = RankedTensorType::get(
|
||||||
|
dshape, dtype);
|
||||||
|
auto const_attr = SplatElementsAttr::get(const_type, val);
|
||||||
|
auto const_op =
|
||||||
|
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||||
|
return const_op.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
|
||||||
|
Operation *op, Value torchScalarValue,
|
||||||
|
Value &mhloTensor, Type dtype,
|
||||||
|
llvm::ArrayRef<int64_t> dshape,
|
||||||
|
bool doBroadcast) {
|
||||||
|
// Retrieve a const float or int value but create the out Tensor with dtype.
|
||||||
|
double doubleValue;
|
||||||
|
auto isFloat =
|
||||||
|
matchPattern(torchScalarValue, m_TorchConstantFloat(&doubleValue));
|
||||||
|
|
||||||
|
int64_t intValue;
|
||||||
|
auto isInt = matchPattern(torchScalarValue, m_TorchConstantInt(&intValue));
|
||||||
|
|
||||||
|
if (!isFloat && !isInt)
|
||||||
|
return op->emitError("Unable to extract the scalar constant");
|
||||||
|
|
||||||
|
if (dtype.isa<mlir::FloatType>()) {
|
||||||
|
if (doBroadcast) {
|
||||||
|
mhloTensor = getSplatConstTensor<float>(rewriter, op,
|
||||||
|
(isFloat ? doubleValue : intValue),
|
||||||
|
dtype, dshape);
|
||||||
|
} else {
|
||||||
|
mhloTensor = mhlo::getConstTensor<float>(
|
||||||
|
rewriter, op, (isFloat ? doubleValue : intValue), dshape)
|
||||||
|
.getValue();
|
||||||
|
}
|
||||||
|
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
||||||
|
auto w = intType.getWidth();
|
||||||
|
if (w != 32 && w != 64)
|
||||||
|
return op->emitError("Unsupported integer type") << intType;
|
||||||
|
|
||||||
|
if (w == 32) {
|
||||||
|
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||||
|
return op->emitError("Supplied value of scalar constant exceeds limits "
|
||||||
|
"of destination type");
|
||||||
|
}
|
||||||
|
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
|
||||||
|
: static_cast<int32_t>(intValue);
|
||||||
|
if (doBroadcast) {
|
||||||
|
mhloTensor = getSplatConstTensor<int32_t>(rewriter, op, d, dtype, dshape);
|
||||||
|
} else {
|
||||||
|
mhloTensor =
|
||||||
|
mhlo::getConstTensor<int32_t>(rewriter, op, {d}, dshape).getValue();
|
||||||
|
}
|
||||||
|
} else if (w == 64) {
|
||||||
|
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||||
|
return op->emitError("Supplied value of scalar constant exceeds limits "
|
||||||
|
"of destination type");
|
||||||
|
}
|
||||||
|
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
|
||||||
|
if (doBroadcast) {
|
||||||
|
mhloTensor = getSplatConstTensor<int64_t>(rewriter, op, d, dtype, dshape);
|
||||||
|
} else {
|
||||||
|
mhloTensor =
|
||||||
|
mhlo::getConstTensor<int64_t>(rewriter, op, {d}, dshape).getValue();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else
|
||||||
|
return op->emitError("Usupported element type");
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
|
||||||
|
Operation *op, Value alphaScalar,
|
||||||
|
Value &alphaTensor, Type dtype,
|
||||||
|
llvm::ArrayRef<int64_t> dshape,
|
||||||
|
bool checkForUnity) {
|
||||||
|
if (succeeded(torchScalarToMhloTensor(rewriter, op, alphaScalar, alphaTensor,
|
||||||
|
dtype, dshape)))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
// `alpha` has not been specified.
|
||||||
|
int64_t alphaValue;
|
||||||
|
if (!matchPattern(alphaScalar, m_TorchConstantInt(&alphaValue)))
|
||||||
|
return op->emitError("Currently only scalar constants are supported for "
|
||||||
|
"alpha in MHLO operation");
|
||||||
|
// When no alpha has been specified, this must be 1.
|
||||||
|
if (checkForUnity && alphaValue != 1)
|
||||||
|
return op->emitError("Unsupported integer value for alpha");
|
||||||
|
|
||||||
|
alphaTensor =
|
||||||
|
mlir::mhlo::getMhloConstTensorSingleF32(rewriter, op, alphaValue);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter,
|
||||||
|
Value input, TensorType outType) {
|
||||||
|
// Two tensors are “broadcastable” if the following rules hold:
|
||||||
|
// - Each tensor has at least one dimension.
|
||||||
|
// - When iterating over the dimension sizes, starting at the trailing dimension,
|
||||||
|
// the dimension sizes must either be equal, one of them is 1, or one of them
|
||||||
|
// does not exist.
|
||||||
|
Operation* op = input.getDefiningOp();
|
||||||
|
TensorType in_type = input.getType().dyn_cast<TensorType>();
|
||||||
|
|
||||||
|
if (in_type.getElementType() != outType.getElementType()) {
|
||||||
|
TensorType promoted_type = in_type.cloneWith(in_type.getShape(), outType.getElementType());
|
||||||
|
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), promoted_type, input);
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<int64_t> inShape = in_type.getShape();
|
||||||
|
ArrayRef<int64_t> outShape = outType.getShape();
|
||||||
|
|
||||||
|
bool do_bcast = (inShape.size() != outShape.size());
|
||||||
|
SmallVector<int64_t> bcastDims;
|
||||||
|
for (size_t i = 0; i < inShape.size(); ++i) {
|
||||||
|
// iterating over the dimension sizes, starting at the trailing dimension
|
||||||
|
size_t outPos = outShape.size() - 1 - i;
|
||||||
|
size_t inPos = inShape.size() - 1 - i;
|
||||||
|
int64_t outDim = outShape[outPos];
|
||||||
|
int64_t inDim = inShape[inPos];
|
||||||
|
if (inDim == outDim) {
|
||||||
|
bcastDims.push_back(outPos);
|
||||||
|
} else if (inDim != outDim && inDim == 1) {
|
||||||
|
bcastDims.push_back(outPos);
|
||||||
|
do_bcast = true;
|
||||||
|
} else {
|
||||||
|
op->emitError("The size of tensor a (") << inDim << ")"
|
||||||
|
<< "must match the size of tensor b (" << outDim << ")"
|
||||||
|
<< "at non-singleton dimension " << inPos;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::reverse(bcastDims.begin(), bcastDims.end());
|
||||||
|
if (!do_bcast) {
|
||||||
|
return input;
|
||||||
|
}
|
||||||
|
DenseIntElementsAttr bcast_attr = DenseIntElementsAttr::get(
|
||||||
|
RankedTensorType::get({static_cast<long int>(bcastDims.size())}, rewriter.getI64Type()),
|
||||||
|
bcastDims);
|
||||||
|
auto bcast_op =
|
||||||
|
rewriter.create<mhlo::BroadcastInDimOp>(op->getLoc(), outType, input, bcast_attr);
|
||||||
|
return bcast_op.getResult();
|
||||||
|
}
|
||||||
|
} // namespace mhlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,62 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
|
||||||
|
#define TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace mhlo {
|
||||||
|
|
||||||
|
using mlir::ConversionPatternRewriter;
|
||||||
|
|
||||||
|
// Create a 32-bit float constant operator from a float
|
||||||
|
Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||||
|
float val);
|
||||||
|
|
||||||
|
// Create a 64-bit float constant operator from a double
|
||||||
|
Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
|
||||||
|
double val);
|
||||||
|
|
||||||
|
// Templated function to create a constant op for given type and shape.
|
||||||
|
// T: storage C type.
|
||||||
|
// Default template creates a constant tensor in T.
|
||||||
|
// To create INT48 MHLO constant, need to pass in llvm::APInt instead.
|
||||||
|
template <typename T>
|
||||||
|
llvm::Optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
|
ArrayRef<T> vec, ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
||||||
|
T val, Type dtype, llvm::ArrayRef<int64_t> dshape);
|
||||||
|
|
||||||
|
LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
|
||||||
|
Operation *op, Value torchScalarValue,
|
||||||
|
Value &mhloTensor, Type dtype,
|
||||||
|
llvm::ArrayRef<int64_t> dshape,
|
||||||
|
bool doBroadcast = true);
|
||||||
|
|
||||||
|
LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
|
||||||
|
Operation *op, Value alphaScalar,
|
||||||
|
Value &alphaTensor, Type dtype,
|
||||||
|
llvm::ArrayRef<int64_t> dshape,
|
||||||
|
bool checkForUnity);
|
||||||
|
|
||||||
|
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
||||||
|
TensorType outType);
|
||||||
|
} // namespace mhlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
|
|
@ -10,3 +10,695 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte
|
||||||
%0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
%0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
return %0 : !torch.vtensor<[?,?],f32>
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.addscalar$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 9
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<9.000000e+00> : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = mhlo.add %[[VAL_1]], %[[VAL_4]] : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],f32>
|
||||||
|
func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
%int9 = torch.constant.int 9
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.add.Scalar %arg0, %int9, %int1 : !torch.vtensor<[4,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,64],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.addtensor$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = mhlo.add %[[VAL_2]], %[[VAL_3]] : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32>
|
||||||
|
func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.addtensor$promote(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],si32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],si64>) -> !torch.vtensor<[4,64],si64> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],si32> -> tensor<4x64xi32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],si64> -> tensor<4x64xi64>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_5:.*]] = mhlo.convert(%[[VAL_2]]) : (tensor<4x64xi32>) -> tensor<4x64xi64>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = mhlo.add %[[VAL_5]], %[[VAL_3]] : tensor<4x64xi64>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x64xi64> -> !torch.vtensor<[4,64],si64>
|
||||||
|
// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,64],si64>
|
||||||
|
func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[4,64],si32>, %arg1: !torch.vtensor<[4,64],si64>) -> !torch.vtensor<[4,64],si64> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[4,64],si32>, !torch.vtensor<[4,64],si64>, !torch.int -> !torch.vtensor<[4,64],si64>
|
||||||
|
return %0 : !torch.vtensor<[4,64],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.addtensor$bcast(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[64],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = mhlo.add %[[VAL_5]], %[[VAL_3]] : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32>
|
||||||
|
// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,64],f32>
|
||||||
|
func.func @torch.aten.addtensor$bcast(%arg0: !torch.vtensor<[64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.addtensor$alpha(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<2.000000e+00> : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = mhlo.multiply %[[VAL_3]], %[[VAL_5]] : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = mhlo.add %[[VAL_2]], %[[VAL_6]] : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %4 : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32>
|
||||||
|
// CHECK: return %[[VAL_8]] : !torch.vtensor<[4,64],f32>
|
||||||
|
func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%0 = torch.aten.add.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.mulscalar$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 9
|
||||||
|
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<9.000000e+00> : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = mhlo.multiply %[[VAL_1]], %[[VAL_3]] : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32>
|
||||||
|
func.func @torch.aten.mulscalar$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
%int9 = torch.constant.int 9
|
||||||
|
%0 = torch.aten.mul.Scalar %arg0, %int9 : !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.multensor$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = mhlo.multiply %[[VAL_2]], %[[VAL_3]] : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32>
|
||||||
|
func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
%0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32> -> !torch.vtensor<[4,64],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.multensor$bcast(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[8,4,64],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[8,1,64],f32>) -> !torch.vtensor<[8,4,64],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[8,4,64],f32> -> tensor<8x4x64xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[8,1,64],f32> -> tensor<8x1x64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x64xf32>) -> tensor<8x4x64xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = mhlo.multiply %[[VAL_2]], %[[VAL_4]] : tensor<8x4x64xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<8x4x64xf32> -> !torch.vtensor<[8,4,64],f32>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[8,4,64],f32>
|
||||||
|
func.func @torch.aten.multensor$bcast(%arg0: !torch.vtensor<[8,4,64],f32>, %arg1: !torch.vtensor<[8,1,64],f32>) -> !torch.vtensor<[8,4,64],f32> {
|
||||||
|
%0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[8,4,64],f32>, !torch.vtensor<[8,1,64],f32> -> !torch.vtensor<[8,4,64],f32>
|
||||||
|
return %0 : !torch.vtensor<[8,4,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.subscalar$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 9
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<9.000000e+00> : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = mhlo.subtract %[[VAL_1]], %[[VAL_4]] : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],f32>
|
||||||
|
func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
%int9 = torch.constant.int 9
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.sub.Scalar %arg0, %int9, %int1 : !torch.vtensor<[4,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,64],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.subtensor$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_3]] : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32>
|
||||||
|
func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.subtensor$promote(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],si32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],si64>) -> !torch.vtensor<[4,64],si64> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],si32> -> tensor<4x64xi32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],si64> -> tensor<4x64xi64>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_5:.*]] = mhlo.convert(%[[VAL_2]]) : (tensor<4x64xi32>) -> tensor<4x64xi64>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = mhlo.subtract %[[VAL_5]], %[[VAL_3]] : tensor<4x64xi64>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x64xi64> -> !torch.vtensor<[4,64],si64>
|
||||||
|
// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,64],si64>
|
||||||
|
func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[4,64],si32>, %arg1: !torch.vtensor<[4,64],si64>) -> !torch.vtensor<[4,64],si64> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[4,64],si32>, !torch.vtensor<[4,64],si64>, !torch.int -> !torch.vtensor<[4,64],si64>
|
||||||
|
return %0 : !torch.vtensor<[4,64],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.subtensor$bcast(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[64],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = mhlo.subtract %[[VAL_5]], %[[VAL_3]] : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32>
|
||||||
|
// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,64],f32>
|
||||||
|
func.func @torch.aten.subtensor$bcast(%arg0: !torch.vtensor<[64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.subtensor$alpha(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<2.000000e+00> : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = mhlo.multiply %[[VAL_3]], %[[VAL_5]] : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_6]] : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %4 : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32>
|
||||||
|
// CHECK: return %[[VAL_8]] : !torch.vtensor<[4,64],f32>
|
||||||
|
func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%0 = torch.aten.sub.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.divscalar$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 9
|
||||||
|
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<9.000000e+00> : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = mhlo.divide %[[VAL_1]], %[[VAL_3]] : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32>
|
||||||
|
func.func @torch.aten.divscalar$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
%int9 = torch.constant.int 9
|
||||||
|
%0 = torch.aten.div.Scalar %arg0, %int9 : !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.divtensor$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_3]] : tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32>
|
||||||
|
func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32> -> !torch.vtensor<[4,64],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.divtensor$bcast(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[8,4,64],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[8,1,64],f32>) -> !torch.vtensor<[8,4,64],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[8,4,64],f32> -> tensor<8x4x64xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[8,1,64],f32> -> tensor<8x1x64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x64xf32>) -> tensor<8x4x64xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_4]] : tensor<8x4x64xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<8x4x64xf32> -> !torch.vtensor<[8,4,64],f32>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[8,4,64],f32>
|
||||||
|
func.func @torch.aten.divtensor$bcast(%arg0: !torch.vtensor<[8,4,64],f32>, %arg1: !torch.vtensor<[8,1,64],f32>) -> !torch.vtensor<[8,4,64],f32> {
|
||||||
|
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[8,4,64],f32>, !torch.vtensor<[8,1,64],f32> -> !torch.vtensor<[8,4,64],f32>
|
||||||
|
return %0 : !torch.vtensor<[8,4,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.log$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = mhlo.log %[[VAL_1]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
|
||||||
|
func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%0 = torch.aten.log %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.exp$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = mhlo.exponential %[[VAL_1]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
|
||||||
|
func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%0 = torch.aten.exp %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.clone$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "mhlo.copy"(%[[VAL_1]]) : (tensor<4x64xf32>) -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32>
|
||||||
|
// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,64],f32>
|
||||||
|
func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.aten.clone %arg0, %none : !torch.vtensor<[4,64],f32>, !torch.none -> !torch.vtensor<[4,64],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
|
||||||
|
// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK: return %[[VAL_1]] : !torch.vtensor<[],f32>
|
||||||
|
func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
|
||||||
|
%0 = torch.vtensor.literal(dense<0.0> : tensor<f32>) : !torch.vtensor<[],f32>
|
||||||
|
return %0 : !torch.vtensor<[],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
|
||||||
|
// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<2xi64>
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<2xi64> -> !torch.vtensor<[2],si64>
|
||||||
|
// CHECK: return %[[VAL_1]] : !torch.vtensor<[2],si64>
|
||||||
|
func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
|
||||||
|
%0 = torch.vtensor.literal(dense<1> : tensor<2xsi64>) : !torch.vtensor<[2],si64>
|
||||||
|
return %0 : !torch.vtensor<[2],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.gt.scalar(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],i1> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<3.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "mhlo.compare"(%[[VAL_1]], %[[VAL_4]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<4x64xf32>, tensor<4x64xf32>) -> tensor<4x64xi1>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],i1>
|
||||||
|
func.func @torch.aten.gt.scalar(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],i1> {
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%0 = torch.aten.gt.Scalar %arg0, %int3 : !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],i1>
|
||||||
|
return %0 : !torch.vtensor<[4,64],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.gt.tensor(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_4]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<4x64xf32>, tensor<4x64xf32>) -> tensor<4x64xi1>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],i1>
|
||||||
|
func.func @torch.aten.gt.tensor(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> {
|
||||||
|
%0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[4,64],i1>
|
||||||
|
return %0 : !torch.vtensor<[4,64],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.gt.tensor$convert(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],si32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],si32> -> tensor<4x64xi32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = mhlo.convert(%[[VAL_3]]) : (tensor<64xf32>) -> tensor<64xi32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_4]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xi32>) -> tensor<4x64xi32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_5]]) {compare_type = #mhlo<comparison_type SIGNED>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<4x64xi32>, tensor<4x64xi32>) -> tensor<4x64xi1>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1>
|
||||||
|
// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,64],i1>
|
||||||
|
|
||||||
|
func.func @torch.aten.gt.tensor$convert(%arg0: !torch.vtensor<[4,64],si32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> {
|
||||||
|
%0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],si32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[4,64],i1>
|
||||||
|
return %0 : !torch.vtensor<[4,64],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.lt.tensor(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_4]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<4x64xf32>, tensor<4x64xf32>) -> tensor<4x64xi1>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],i1>
|
||||||
|
func.func @torch.aten.lt.tensor(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> {
|
||||||
|
%0 = torch.aten.lt.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[4,64],i1>
|
||||||
|
return %0 : !torch.vtensor<[4,64],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.eq.tensor(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_4]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<4x64xf32>, tensor<4x64xf32>) -> tensor<4x64xi1>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],i1>
|
||||||
|
func.func @torch.aten.eq.tensor(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> {
|
||||||
|
%0 = torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[4,64],i1>
|
||||||
|
return %0 : !torch.vtensor<[4,64],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.ne.tensor(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_4]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction NE>} : (tensor<4x64xf32>, tensor<4x64xf32>) -> tensor<4x64xi1>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],i1>
|
||||||
|
func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> {
|
||||||
|
%0 = torch.aten.ne.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[4,64],i1>
|
||||||
|
return %0 : !torch.vtensor<[4,64],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.batch_norm(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> {
|
||||||
|
// CEHCK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,5,5],f32> -> tensor<2x3x5x5xf32>
|
||||||
|
// CEHCK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32>
|
||||||
|
// CEHCK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
|
||||||
|
// CEHCK: %true = torch.constant.bool true
|
||||||
|
// CEHCK: %[[VAL_4:.*]] = mhlo.constant dense<0> : tensor<i64>
|
||||||
|
// CEHCK: %float1.000000e-01 = torch.constant.float 1.000000e-01
|
||||||
|
// CEHCK: %float1.000000e-05 = torch.constant.float 1.000000e-05
|
||||||
|
// CEHCK: %int1 = torch.constant.int 1
|
||||||
|
// CEHCK: %[[VAL_5:.*]] = mhlo.constant dense<1> : tensor<i64>
|
||||||
|
// CEHCK: %[[VAL_6:.*]] = mhlo.add %[[VAL_4]], %[[VAL_5]] : tensor<i64>
|
||||||
|
// CEHCK: %[[VAL_7:.*]], %batch_mean, %batch_var = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>)
|
||||||
|
// CEHCK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<2x3x5x5xf32> -> !torch.vtensor<[2,3,5,5],f32>
|
||||||
|
// CEHCK: return %[[VAL_8]] : !torch.vtensor<[2,3,5,5],f32>
|
||||||
|
|
||||||
|
func.func @torch.aten.batch_norm(%arg0: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> {
|
||||||
|
%0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32>
|
||||||
|
%1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32>
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%2 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
%float1.000000e-01 = torch.constant.float 1.000000e-01
|
||||||
|
%float1.000000e-05 = torch.constant.float 1.000000e-05
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%3 = torch.aten.add.Scalar %2, %int1, %int1 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
%4 = torch.aten.batch_norm %arg0, %1, %0, %0, %1, %true, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[2,3,5,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,5,5],f32>
|
||||||
|
return %4 : !torch.vtensor<[2,3,5,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.batch_norm$none_bias_weight(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> {
|
||||||
|
// CEHCK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,5,5],f32> -> tensor<2x3x5x5xf32>
|
||||||
|
// CEHCK: %none = torch.constant.none
|
||||||
|
// CEHCK: %1 = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
|
||||||
|
// CEHCK: %2 = mhlo.constant dense<0.000000e+00> : tensor<3xf32>
|
||||||
|
// CEHCK: %true = torch.constant.bool true
|
||||||
|
// CEHCK: %[[VAL_2:.*]] = mhlo.constant dense<0> : tensor<i64>
|
||||||
|
// CEHCK: %float1.000000e-01 = torch.constant.float 1.000000e-01
|
||||||
|
// CEHCK: %float1.000000e-05 = torch.constant.float 1.000000e-05
|
||||||
|
// CEHCK: %int1 = torch.constant.int 1
|
||||||
|
// CEHCK: %[[VAL_3:.*]] = mhlo.constant dense<1> : tensor<i64>
|
||||||
|
// CEHCK: %[[VAL_4:.*]] = mhlo.add %[[VAL_2]], %[[VAL_3]] : tensor<i64>
|
||||||
|
// CEHCK: %[[VAL_5:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
|
||||||
|
// CEHCK: %[[VAL_6:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32>
|
||||||
|
// CEHCK: %[[VAL_7:.*]], %batch_mean, %batch_var = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_5]], %[[VAL_6]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>)
|
||||||
|
// CEHCK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<2x3x5x5xf32> -> !torch.vtensor<[2,3,5,5],f32>
|
||||||
|
// CEHCK: return %[[VAL_8]] : !torch.vtensor<[2,3,5,5],f32>
|
||||||
|
func.func @torch.aten.batch_norm$none_bias_weight(%arg0: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32>
|
||||||
|
%1 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32>
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%2 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
%float1.000000e-01 = torch.constant.float 1.000000e-01
|
||||||
|
%float1.000000e-05 = torch.constant.float 1.000000e-05
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%3 = torch.aten.add.Scalar %2, %int1, %int1 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
%4 = torch.aten.batch_norm %arg0, %none, %none, %1, %0, %true, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[2,3,5,5],f32>, !torch.none, !torch.none, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,5,5],f32>
|
||||||
|
return %4 : !torch.vtensor<[2,3,5,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.batch_norm$inference(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> {
|
||||||
|
// CEHCK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,5,5],f32> -> tensor<2x3x5x5xf32>
|
||||||
|
// CEHCK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32>
|
||||||
|
// CEHCK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
|
||||||
|
// CEHCK: %true = torch.constant.bool true
|
||||||
|
// CHECK: %false = torch.constant.bool false
|
||||||
|
// CEHCK: %[[VAL_4:.*]] = mhlo.constant dense<0> : tensor<i64>
|
||||||
|
// CEHCK: %float1.000000e-01 = torch.constant.float 1.000000e-01
|
||||||
|
// CEHCK: %float1.000000e-05 = torch.constant.float 1.000000e-05
|
||||||
|
// CEHCK: %int1 = torch.constant.int 1
|
||||||
|
// CEHCK: %[[VAL_5:.*]] = mhlo.constant dense<1> : tensor<i64>
|
||||||
|
// CEHCK: %[[VAL_6:.*]] = mhlo.add %[[VAL_4]], %[[VAL_5]] : tensor<i64>
|
||||||
|
// CEHCK: %[[VAL_7:.*]], %batch_mean, %batch_var = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]], %[[VAL_2]], %[[VAL_3]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>)
|
||||||
|
// CEHCK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<2x3x5x5xf32> -> !torch.vtensor<[2,3,5,5],f32>
|
||||||
|
// CEHCK: return %[[VAL_8]] : !torch.vtensor<[2,3,5,5],f32>
|
||||||
|
func.func @torch.aten.batch_norm$inference(%arg0: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> {
|
||||||
|
%0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32>
|
||||||
|
%1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32>
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%2 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
%float1.000000e-01 = torch.constant.float 1.000000e-01
|
||||||
|
%float1.000000e-05 = torch.constant.float 1.000000e-05
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%3 = torch.aten.add.Scalar %2, %int1, %int1 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
%4 = torch.aten.batch_norm %arg0, %1, %0, %0, %1, %false, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[2,3,5,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,5,5],f32>
|
||||||
|
return %4 : !torch.vtensor<[2,3,5,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.relu(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,5],f32>) -> !torch.vtensor<[2,5],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,5],f32> -> tensor<2x5xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<2x5xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = mhlo.maximum %[[VAL_1]], %[[VAL_2]] : tensor<2x5xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<2x5xf32> -> !torch.vtensor<[2,5],f32>
|
||||||
|
// CHECK: return %[[VAL_4]] : !torch.vtensor<[2,5],f32>
|
||||||
|
func.func @torch.aten.relu(%arg0: !torch.vtensor<[2,5],f32>) -> !torch.vtensor<[2,5],f32> {
|
||||||
|
%0 = torch.aten.relu %arg0 : !torch.vtensor<[2,5],f32> -> !torch.vtensor<[2,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,5],f32>
|
||||||
|
}
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.relu$int8(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,5],si8>) -> !torch.vtensor<[2,5],si8> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,5],si8> -> tensor<2x5xi8>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0> : tensor<2x5xi8>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = mhlo.maximum %[[VAL_1]], %[[VAL_2]] : tensor<2x5xi8>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<2x5xi8> -> !torch.vtensor<[2,5],si8>
|
||||||
|
// CHECK: return %[[VAL_4]] : !torch.vtensor<[2,5],si8>
|
||||||
|
func.func @torch.aten.relu$int8(%arg0: !torch.vtensor<[2,5],si8>) -> !torch.vtensor<[2,5],si8> {
|
||||||
|
%0 = torch.aten.relu %arg0 : !torch.vtensor<[2,5],si8> -> !torch.vtensor<[2,5],si8>
|
||||||
|
return %0 : !torch.vtensor<[2,5],si8>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.reciprocal(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5,5],f32>) -> !torch.vtensor<[5,5,5],f32> {
|
||||||
|
// CEHCK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0:.*]] : !torch.vtensor<[5,5,5],f32> -> tensor<5x5x5xf32>
|
||||||
|
// CEHCK: %[[VAL_2:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||||
|
// CEHCK: %[[VAL_3:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x5x5xf32>
|
||||||
|
// CEHCK: %[[VAL_4:.*]] = mhlo.divide %[[VAL_3]], %[[VAL_1]] : tensor<5x5x5xf32>
|
||||||
|
// CEHCK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<5x5x5xf32> -> !torch.vtensor<[5,5,5],f32>
|
||||||
|
// CEHCK: return %[[VAL_5]] : !torch.vtensor<[5,5,5],f32>
|
||||||
|
func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[5,5,5],f32>) -> !torch.vtensor<[5,5,5],f32> {
|
||||||
|
%0 = torch.aten.reciprocal %arg0 : !torch.vtensor<[5,5,5],f32> -> !torch.vtensor<[5,5,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[5,5,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.native_layer_norm(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,7,4,5],f32> -> tensor<3x7x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<4x5xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<4x5xf32>
|
||||||
|
// CHECK: %int4 = torch.constant.int 4
|
||||||
|
// CHECK: %int5 = torch.constant.int 5
|
||||||
|
// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05
|
||||||
|
// CHECK: %true = torch.constant.bool true
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<[1, 21, 20]> : tensor<3xi64>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = "mhlo.dynamic_reshape"(%[[VAL_1]], %[[VAL_5]]) : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<21xf32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor<21xf32>
|
||||||
|
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "mhlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>)
|
||||||
|
// CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = "mhlo.dynamic_reshape"(%[[VAL_9]], %[[VAL_12]]) : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = "mhlo.dynamic_reshape"(%[[VAL_10]], %[[VAL_14]]) : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
|
||||||
|
// CHECK: %[[VAL_17:.*]] = "mhlo.dynamic_reshape"(%[[VAL_11]], %[[VAL_16]]) : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
|
||||||
|
// CHECK: %[[VAL_18:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_19:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_20:.*]] = mhlo.multiply %[[VAL_13]], %[[VAL_18]] : tensor<3x7x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_21:.*]] = mhlo.add %[[VAL_20]], %[[VAL_19]] : tensor<3x7x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21:.*]] : tensor<3x7x4x5xf32> -> !torch.vtensor<[3,7,4,5],f32>
|
||||||
|
// CHECK: return %[[VAL_22]] : !torch.vtensor<[3,7,4,5],f32>
|
||||||
|
func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> {
|
||||||
|
%0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<4x5xf32>) : !torch.vtensor<[4,5],f32>
|
||||||
|
%1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<4x5xf32>) : !torch.vtensor<[4,5],f32>
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
%int5 = torch.constant.int 5
|
||||||
|
%float1.000000e-05 = torch.constant.float 1.000000e-05
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%2 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %2, %1, %0, %float1.000000e-05 : !torch.vtensor<[3,7,4,5],f32>, !torch.list<int>, !torch.vtensor<[4,5],f32>, !torch.vtensor<[4,5],f32>, !torch.float -> !torch.vtensor<[3,7,4,5],f32>, !torch.vtensor<[3,7,1,1],f32>, !torch.vtensor<[3,7,1,1],f32>
|
||||||
|
return %result0 : !torch.vtensor<[3,7,4,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.contiguous(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %int0 = torch.constant.int 0
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.from_builtin_tensor %[[VAL_1]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32>
|
||||||
|
// CHECK: return %[[VAL_2]] : !torch.vtensor<[4,64],f32>
|
||||||
|
func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%0 = torch.aten.contiguous %arg0, %int0 : !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CEHCK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[],si64> {
|
||||||
|
// CEHCK: %int1 = torch.constant.int 1
|
||||||
|
// CEHCK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<i64>
|
||||||
|
// CEHCK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<i64> -> !torch.vtensor<[],si64>
|
||||||
|
// CEHCK: return %[[VAL_1]] : !torch.vtensor<[],si64>
|
||||||
|
func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[], si64>
|
||||||
|
return %0 : !torch.vtensor<[], si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.broadcast_to$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[8,4,64],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 64
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 8
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_1]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<8x4x64xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<8x4x64xf32> -> !torch.vtensor<[8,4,64],f32>
|
||||||
|
// CHECK: return %[[VAL_7]] : !torch.vtensor<[8,4,64],f32>
|
||||||
|
func.func @torch.aten.broadcast_to$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[8,4,64],f32> {
|
||||||
|
%int64 = torch.constant.int 64
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
%int8 = torch.constant.int 8
|
||||||
|
%0 = torch.prim.ListConstruct %int8, %int4, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.aten.broadcast_to %arg0, %0 : !torch.vtensor<[4,64],f32>, !torch.list<int> -> !torch.vtensor<[8,4,64],f32>
|
||||||
|
return %1 : !torch.vtensor<[8,4,64],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.permute$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<64x4xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[64,4],f32>
|
||||||
|
func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.prim.ListConstruct %int1, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[4,64],f32>, !torch.list<int> -> !torch.vtensor<[64,4],f32>
|
||||||
|
return %1 : !torch.vtensor<[64,4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.transpose$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],f32> -> tensor<4x3xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x3xf32>) -> tensor<3x4xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32>
|
||||||
|
func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.transpose.int %arg0, %int0, %int1 : !torch.vtensor<[4,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4],f32>
|
||||||
|
}
|
Loading…
Reference in New Issue