mirror of https://github.com/llvm/torch-mlir
[NFC] Change to *cast instead of .*cast variants (#3405)
Member casts have been deprecated. Changing over a bunch of the member cast calls to the global templated variants to remove deprecation warnings.pull/3407/head
parent
4e05e2cd1e
commit
afca88a058
|
@ -54,13 +54,6 @@ cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLI
|
||||||
|
|
||||||
option(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER "Enables the ONNX C importer" OFF)
|
option(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER "Enables the ONNX C importer" OFF)
|
||||||
|
|
||||||
# TODO(#3299): migrate to from member x.cast<T>() to mlir::cast<T>(x).
|
|
||||||
if(MSVC)
|
|
||||||
add_compile_options(/wd4996)
|
|
||||||
else()
|
|
||||||
add_compile_options(-Wno-deprecated-declarations)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
macro(torch_mlir_enable_werror)
|
macro(torch_mlir_enable_werror)
|
||||||
if(TORCH_MLIR_ENABLE_WERROR_FLAG)
|
if(TORCH_MLIR_ENABLE_WERROR_FLAG)
|
||||||
if(NOT MSVC)
|
if(NOT MSVC)
|
||||||
|
|
|
@ -125,7 +125,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||||
llvm::copy_if(getInputOperands(),
|
llvm::copy_if(getInputOperands(),
|
||||||
std::back_inserter(result),
|
std::back_inserter(result),
|
||||||
[](OpOperand *opOperand) {
|
[](OpOperand *opOperand) {
|
||||||
return opOperand->get().getType().template isa<MemRefType>();
|
return isa<MemRefType>(opOperand->get().getType());
|
||||||
});
|
});
|
||||||
return result;
|
return result;
|
||||||
}]
|
}]
|
||||||
|
@ -144,7 +144,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||||
llvm::copy_if(getInputOperands(),
|
llvm::copy_if(getInputOperands(),
|
||||||
std::back_inserter(result),
|
std::back_inserter(result),
|
||||||
[](OpOperand *opOperand) {
|
[](OpOperand *opOperand) {
|
||||||
return opOperand->get().getType().template isa<RankedTensorType>();
|
return isa<RankedTensorType>(opOperand->get().getType());
|
||||||
});
|
});
|
||||||
return result;
|
return result;
|
||||||
}]
|
}]
|
||||||
|
@ -200,7 +200,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||||
llvm::copy_if(getOutputOperands(),
|
llvm::copy_if(getOutputOperands(),
|
||||||
std::back_inserter(result),
|
std::back_inserter(result),
|
||||||
[](OpOperand *opOperand) {
|
[](OpOperand *opOperand) {
|
||||||
return opOperand->get().getType().template isa<MemRefType>();
|
return isa<MemRefType>(opOperand->get().getType());
|
||||||
});
|
});
|
||||||
return result;
|
return result;
|
||||||
}]
|
}]
|
||||||
|
@ -219,7 +219,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||||
llvm::copy_if(getOutputOperands(),
|
llvm::copy_if(getOutputOperands(),
|
||||||
std::back_inserter(result),
|
std::back_inserter(result),
|
||||||
[](OpOperand *opOperand) {
|
[](OpOperand *opOperand) {
|
||||||
return opOperand->get().getType().template isa<RankedTensorType>();
|
return isa<RankedTensorType>(opOperand->get().getType());
|
||||||
});
|
});
|
||||||
return result;
|
return result;
|
||||||
}]
|
}]
|
||||||
|
@ -238,7 +238,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||||
llvm::transform(getOutputBufferOperands(),
|
llvm::transform(getOutputBufferOperands(),
|
||||||
std::back_inserter(result),
|
std::back_inserter(result),
|
||||||
[](OpOperand *opOperands) {
|
[](OpOperand *opOperands) {
|
||||||
return opOperands->get().getType().cast<MemRefType>();
|
return cast<MemRefType>(opOperands->get().getType());
|
||||||
});
|
});
|
||||||
return result;
|
return result;
|
||||||
}]
|
}]
|
||||||
|
@ -257,7 +257,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||||
llvm::transform(getOutputTensorOperands(),
|
llvm::transform(getOutputTensorOperands(),
|
||||||
std::back_inserter(result),
|
std::back_inserter(result),
|
||||||
[](OpOperand *opOperands) {
|
[](OpOperand *opOperands) {
|
||||||
return opOperands->get().getType().cast<RankedTensorType>();
|
return cast<RankedTensorType>(opOperands->get().getType());
|
||||||
});
|
});
|
||||||
return result;
|
return result;
|
||||||
}]
|
}]
|
||||||
|
@ -318,7 +318,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||||
/*args=*/(ins "OpOperand *":$opOperand),
|
/*args=*/(ins "OpOperand *":$opOperand),
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
if (!opOperand->get().getType().template isa<RankedTensorType>())
|
if (!isa<RankedTensorType>(opOperand->get().getType()))
|
||||||
return false;
|
return false;
|
||||||
if (opOperand->getOperandNumber() < $_op.getNumInputs())
|
if (opOperand->getOperandNumber() < $_op.getNumInputs())
|
||||||
return true;
|
return true;
|
||||||
|
@ -334,7 +334,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||||
/*args=*/(ins "OpOperand *":$opOperand),
|
/*args=*/(ins "OpOperand *":$opOperand),
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
if (!opOperand->get().getType().template isa<RankedTensorType>())
|
if (!isa<RankedTensorType>(opOperand->get().getType()))
|
||||||
return false;
|
return false;
|
||||||
if (opOperand->getOperandNumber() >= $_op.getNumInputs())
|
if (opOperand->getOperandNumber() >= $_op.getNumInputs())
|
||||||
return true;
|
return true;
|
||||||
|
@ -367,7 +367,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
assert(opOperand->getOwner() == this->getOperation());
|
assert(opOperand->getOwner() == this->getOperation());
|
||||||
if (auto shapedType =
|
if (auto shapedType =
|
||||||
opOperand->get().getType().template dyn_cast<ShapedType>())
|
dyn_cast<ShapedType>(opOperand->get().getType()))
|
||||||
return shapedType.getRank();
|
return shapedType.getRank();
|
||||||
return 0;
|
return 0;
|
||||||
}]
|
}]
|
||||||
|
@ -383,7 +383,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
assert(opOperand->getOwner() == this->getOperation());
|
assert(opOperand->getOwner() == this->getOperation());
|
||||||
if (auto shapedType =
|
if (auto shapedType =
|
||||||
opOperand->get().getType().template dyn_cast<ShapedType>())
|
dyn_cast<ShapedType>(opOperand->get().getType()))
|
||||||
return shapedType.getShape();
|
return shapedType.getShape();
|
||||||
return {};
|
return {};
|
||||||
}]
|
}]
|
||||||
|
@ -398,7 +398,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
assert(opOperand->getOwner() == this->getOperation());
|
assert(opOperand->getOwner() == this->getOperation());
|
||||||
return !opOperand->get().getType().template isa<ShapedType>();
|
return !isa<ShapedType>(opOperand->get().getType());
|
||||||
}]
|
}]
|
||||||
>,
|
>,
|
||||||
//===------------------------------------------------------------------===//
|
//===------------------------------------------------------------------===//
|
||||||
|
@ -416,10 +416,10 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||||
return this->getOperation()->getNumResults() == 0 &&
|
return this->getOperation()->getNumResults() == 0 &&
|
||||||
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
|
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
|
||||||
return isScalar(opOperand) ||
|
return isScalar(opOperand) ||
|
||||||
opOperand->get().getType().template isa<MemRefType>();
|
isa<MemRefType>(opOperand->get().getType());
|
||||||
}) &&
|
}) &&
|
||||||
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
|
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
|
||||||
return opOperand->get().getType().template isa<MemRefType>();
|
return isa<MemRefType>(opOperand->get().getType());
|
||||||
});
|
});
|
||||||
}]
|
}]
|
||||||
>,
|
>,
|
||||||
|
@ -435,10 +435,10 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||||
return
|
return
|
||||||
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
|
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
|
||||||
return isScalar(opOperand) ||
|
return isScalar(opOperand) ||
|
||||||
opOperand->get().getType().template isa<RankedTensorType>();
|
isa<RankedTensorType>(opOperand->get().getType());
|
||||||
}) &&
|
}) &&
|
||||||
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
|
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
|
||||||
return opOperand->get().getType().template isa<RankedTensorType>();
|
return isa<RankedTensorType>(opOperand->get().getType());
|
||||||
});
|
});
|
||||||
}]
|
}]
|
||||||
>,
|
>,
|
||||||
|
@ -478,8 +478,8 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void setOperandSegmentAt(unsigned idx, unsigned val) {
|
void setOperandSegmentAt(unsigned idx, unsigned val) {
|
||||||
auto attr = (*this)->getAttr("operand_segment_sizes")
|
auto attr = cast<DenseIntElementsAttr>((*this)->getAttr("operand_segment_sizes")
|
||||||
.cast<DenseIntElementsAttr>();
|
);
|
||||||
unsigned i = 0;
|
unsigned i = 0;
|
||||||
auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32),
|
auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32),
|
||||||
[&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
|
[&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
|
||||||
|
|
|
@ -88,7 +88,7 @@ def TMTensor_ScanOp : TMTensor_Op<"scan",
|
||||||
return getOutputOperand(0)->get();
|
return getOutputOperand(0)->get();
|
||||||
}
|
}
|
||||||
ShapedType getOperandType() {
|
ShapedType getOperandType() {
|
||||||
return input().getType().cast<ShapedType>();
|
return cast<ShapedType>(input().getType());
|
||||||
}
|
}
|
||||||
int64_t getOperandRank() {
|
int64_t getOperandRank() {
|
||||||
return getOperandType().getRank();
|
return getOperandType().getRank();
|
||||||
|
@ -151,10 +151,10 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
|
||||||
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
|
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
|
||||||
|
|
||||||
int64_t getIndexDepth() {
|
int64_t getIndexDepth() {
|
||||||
return getInputOperand(1)
|
return cast<ShapedType>(getInputOperand(1)
|
||||||
->get()
|
->get()
|
||||||
.getType()
|
.getType()
|
||||||
.cast<ShapedType>()
|
)
|
||||||
.getShape()
|
.getShape()
|
||||||
.back();
|
.back();
|
||||||
}
|
}
|
||||||
|
@ -164,7 +164,7 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
|
||||||
}
|
}
|
||||||
|
|
||||||
ShapedType getUpdateType() {
|
ShapedType getUpdateType() {
|
||||||
return updates().getType().cast<ShapedType>();
|
return cast<ShapedType>(updates().getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
Value indices() {
|
Value indices() {
|
||||||
|
@ -172,7 +172,7 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
|
||||||
}
|
}
|
||||||
|
|
||||||
ShapedType getIndicesType() {
|
ShapedType getIndicesType() {
|
||||||
return indices().getType().cast<ShapedType>();
|
return cast<ShapedType>(indices().getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
Value original() {
|
Value original() {
|
||||||
|
@ -180,11 +180,11 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
|
||||||
}
|
}
|
||||||
|
|
||||||
ShapedType getOriginalType() {
|
ShapedType getOriginalType() {
|
||||||
return original().getType().cast<ShapedType>();
|
return cast<ShapedType>(original().getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t getUpdateSliceRank() {
|
int64_t getUpdateSliceRank() {
|
||||||
return updates().getType().cast<ShapedType>().getRank() - 1;
|
return cast<ShapedType>(updates().getType()).getRank() - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isScalarUpdate() {
|
bool isScalarUpdate() {
|
||||||
|
@ -224,7 +224,7 @@ def TMTensor_SortOp : TMTensor_Op<"sort",
|
||||||
return getOutputs()[index];
|
return getOutputs()[index];
|
||||||
}
|
}
|
||||||
ShapedType getOperandType(int index) {
|
ShapedType getOperandType(int index) {
|
||||||
return operand(index).getType().cast<ShapedType>();
|
return cast<ShapedType>(operand(index).getType());
|
||||||
}
|
}
|
||||||
int64_t getOperandRank() {
|
int64_t getOperandRank() {
|
||||||
return getOperandType(0).getRank();
|
return getOperandType(0).getRank();
|
||||||
|
@ -291,16 +291,16 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
|
||||||
return getOutputOperand(0)->get();
|
return getOutputOperand(0)->get();
|
||||||
}
|
}
|
||||||
ShapedType getQueryType() {
|
ShapedType getQueryType() {
|
||||||
return getQuery().getType().cast<ShapedType>();
|
return cast<ShapedType>(getQuery().getType());
|
||||||
}
|
}
|
||||||
ShapedType getKeyType() {
|
ShapedType getKeyType() {
|
||||||
return getKey().getType().cast<ShapedType>();
|
return cast<ShapedType>(getKey().getType());
|
||||||
}
|
}
|
||||||
ShapedType getValueType() {
|
ShapedType getValueType() {
|
||||||
return getValue().getType().cast<ShapedType>();
|
return cast<ShapedType>(getValue().getType());
|
||||||
}
|
}
|
||||||
ShapedType getOutputType() {
|
ShapedType getOutputType() {
|
||||||
return getOutput().getType().cast<ShapedType>();
|
return cast<ShapedType>(getOutput().getType());
|
||||||
}
|
}
|
||||||
int64_t getQueryRank() {
|
int64_t getQueryRank() {
|
||||||
return getQueryType().getRank();
|
return getQueryType().getRank();
|
||||||
|
|
|
@ -61,12 +61,12 @@ struct onnx_list_of_constant_ints_op_binder {
|
||||||
|
|
||||||
bool match(Operation *op) {
|
bool match(Operation *op) {
|
||||||
auto constOp = dyn_cast<Torch::OperatorOp>(op);
|
auto constOp = dyn_cast<Torch::OperatorOp>(op);
|
||||||
if (!constOp || !constOp.getName().equals("onnx.Constant"))
|
if (!constOp || !(constOp.getName() == "onnx.Constant"))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
if (DenseResourceElementsAttr attr =
|
if (DenseResourceElementsAttr attr =
|
||||||
constOp->getAttr("torch.onnx.value")
|
dyn_cast_or_null<DenseResourceElementsAttr>(
|
||||||
.dyn_cast_or_null<DenseResourceElementsAttr>()) {
|
constOp->getAttr("torch.onnx.value"))) {
|
||||||
// Bytes are stored in little endian order. Big endian support will
|
// Bytes are stored in little endian order. Big endian support will
|
||||||
// require swizzling.
|
// require swizzling.
|
||||||
if (!Endian::little) {
|
if (!Endian::little) {
|
||||||
|
|
|
@ -190,7 +190,7 @@ struct torch_list_of_optional_constant_ints_op_binder {
|
||||||
int64_t num;
|
int64_t num;
|
||||||
if (matchPattern(value, m_TorchConstantInt(&num)))
|
if (matchPattern(value, m_TorchConstantInt(&num)))
|
||||||
bind_values.push_back(num);
|
bind_values.push_back(num);
|
||||||
else if (value.getType().isa<Torch::NoneType>())
|
else if (isa<Torch::NoneType>(value.getType()))
|
||||||
bind_values.push_back(std::nullopt);
|
bind_values.push_back(std::nullopt);
|
||||||
else
|
else
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -442,8 +442,8 @@ def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
Type getKeyType() { return getType().cast<DictType>().getKeyType(); }
|
Type getKeyType() { return cast<DictType>(getType()).getKeyType(); }
|
||||||
Type getValueType() { return getType().cast<DictType>().getValueType(); }
|
Type getValueType() { return cast<DictType>(getType()).getValueType(); }
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1003,7 +1003,7 @@ def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [
|
||||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||||
TypesMatchWith<"operand is corresponding !torch.vtensor",
|
TypesMatchWith<"operand is corresponding !torch.vtensor",
|
||||||
"result", "operand",
|
"result", "operand",
|
||||||
"$_self.cast<NonValueTensorType>().getWithValueSemantics()">,
|
"cast<NonValueTensorType>($_self).getWithValueSemantics()">,
|
||||||
]> {
|
]> {
|
||||||
let summary = "Create a !torch.tensor with the same contents as the operand";
|
let summary = "Create a !torch.tensor with the same contents as the operand";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -1036,7 +1036,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
|
||||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||||
TypesMatchWith<"operand is corresponding !torch.tensor",
|
TypesMatchWith<"operand is corresponding !torch.tensor",
|
||||||
"result", "operand",
|
"result", "operand",
|
||||||
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()">,
|
"cast<ValueTensorType>($_self).getWithoutValueSemantics()">,
|
||||||
]> {
|
]> {
|
||||||
let summary = "Create a !torch.vtensor with the same contents as the operand";
|
let summary = "Create a !torch.vtensor with the same contents as the operand";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -1064,7 +1064,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
|
||||||
def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [
|
def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [
|
||||||
TypesMatchWith<"overwritten tensor type is corresponding !torch.tensor of value tensor type",
|
TypesMatchWith<"overwritten tensor type is corresponding !torch.tensor of value tensor type",
|
||||||
"value", "overwritten",
|
"value", "overwritten",
|
||||||
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()">
|
"cast<ValueTensorType>($_self).getWithoutValueSemantics()">
|
||||||
]> {
|
]> {
|
||||||
let summary = "Ovewrite the contents of tensor with values from another.";
|
let summary = "Ovewrite the contents of tensor with values from another.";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -199,7 +199,7 @@ def Torch_ValueTensorType : AnyTorchTensorType<"ValueTensor", "vtensor"> {
|
||||||
}
|
}
|
||||||
|
|
||||||
def AnyTorchTensorType : Type<
|
def AnyTorchTensorType : Type<
|
||||||
CPred<"$_self.isa<::mlir::torch::Torch::BaseTensorType>()">,
|
CPred<"isa<::mlir::torch::Torch::BaseTensorType>($_self)">,
|
||||||
"Any Torch tensor type"
|
"Any Torch tensor type"
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
@ -410,11 +410,11 @@ def AnyTorchOptionalDeviceType:
|
||||||
def AnyTorchOptionalGeneratorType:
|
def AnyTorchOptionalGeneratorType:
|
||||||
OptionalOf<Torch_GeneratorType, "Optional torch Generator type">;
|
OptionalOf<Torch_GeneratorType, "Optional torch Generator type">;
|
||||||
|
|
||||||
def IsListTypePred : CPred<"$_self.isa<::mlir::torch::Torch::ListType>()">;
|
def IsListTypePred : CPred<"isa<::mlir::torch::Torch::ListType>($_self)">;
|
||||||
class ListOf<list<Type> allowedTypes, string descr> :
|
class ListOf<list<Type> allowedTypes, string descr> :
|
||||||
ContainerType<AnyTypeOf<allowedTypes>,
|
ContainerType<AnyTypeOf<allowedTypes>,
|
||||||
IsListTypePred,
|
IsListTypePred,
|
||||||
"$_self.cast<::mlir::torch::Torch::ListType>().getContainedType()",
|
"cast<::mlir::torch::Torch::ListType>($_self).getContainedType()",
|
||||||
descr, "::mlir::torch::Torch::ListType">;
|
descr, "::mlir::torch::Torch::ListType">;
|
||||||
|
|
||||||
def AnyTorchListOfTorchBoolType : ListOf<[Torch_BoolType], "Bool list type (bool[])">;
|
def AnyTorchListOfTorchBoolType : ListOf<[Torch_BoolType], "Bool list type (bool[])">;
|
||||||
|
|
|
@ -26,7 +26,7 @@ bool torchMlirTypeIsValidSubtype(MlirType subtype, MlirType type) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchNnModule(MlirType t) {
|
bool torchMlirTypeIsATorchNnModule(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::NnModuleType>();
|
return isa<Torch::NnModuleType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchNnModuleTypeGet(MlirContext context,
|
MlirType torchMlirTorchNnModuleTypeGet(MlirContext context,
|
||||||
|
@ -43,7 +43,7 @@ MlirTypeID torchMlirTorchNnModuleTypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchOptional(MlirType t) {
|
bool torchMlirTypeIsATorchOptional(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::OptionalType>();
|
return isa<Torch::OptionalType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) {
|
MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) {
|
||||||
|
@ -64,7 +64,7 @@ MlirTypeID torchMlirTorchOptionalTypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchTuple(MlirType t) {
|
bool torchMlirTypeIsATorchTuple(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::TupleType>();
|
return isa<Torch::TupleType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchTupleTypeGet(MlirContext context,
|
MlirType torchMlirTorchTupleTypeGet(MlirContext context,
|
||||||
|
@ -95,7 +95,7 @@ MlirTypeID torchMlirTorchTupleTypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchUnion(MlirType t) {
|
bool torchMlirTypeIsATorchUnion(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::UnionType>();
|
return isa<Torch::UnionType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchUnionTypeGet(MlirContext context,
|
MlirType torchMlirTorchUnionTypeGet(MlirContext context,
|
||||||
|
@ -126,7 +126,7 @@ MlirTypeID torchMlirTorchUnionTypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchList(MlirType t) {
|
bool torchMlirTypeIsATorchList(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::ListType>();
|
return isa<Torch::ListType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchListTypeGet(MlirType containedType) {
|
MlirType torchMlirTorchListTypeGet(MlirType containedType) {
|
||||||
|
@ -146,7 +146,7 @@ MlirTypeID torchMlirTorchListTypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchDevice(MlirType t) {
|
bool torchMlirTypeIsATorchDevice(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::DeviceType>();
|
return isa<Torch::DeviceType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchDeviceTypeGet(MlirContext context) {
|
MlirType torchMlirTorchDeviceTypeGet(MlirContext context) {
|
||||||
|
@ -162,7 +162,7 @@ MlirTypeID torchMlirTorchDeviceTypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchGenerator(MlirType t) {
|
bool torchMlirTypeIsATorchGenerator(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::GeneratorType>();
|
return isa<Torch::GeneratorType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchGeneratorTypeGet(MlirContext context) {
|
MlirType torchMlirTorchGeneratorTypeGet(MlirContext context) {
|
||||||
|
@ -178,7 +178,7 @@ MlirTypeID torchMlirTorchGeneratorTypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchBool(MlirType t) {
|
bool torchMlirTypeIsATorchBool(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::BoolType>();
|
return isa<Torch::BoolType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchBoolTypeGet(MlirContext context) {
|
MlirType torchMlirTorchBoolTypeGet(MlirContext context) {
|
||||||
|
@ -194,7 +194,7 @@ MlirTypeID torchMlirTorchBoolTypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchInt(MlirType t) {
|
bool torchMlirTypeIsATorchInt(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::IntType>();
|
return isa<Torch::IntType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchIntTypeGet(MlirContext context) {
|
MlirType torchMlirTorchIntTypeGet(MlirContext context) {
|
||||||
|
@ -210,7 +210,7 @@ MlirTypeID torchMlirTorchIntTypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchFloat(MlirType t) {
|
bool torchMlirTypeIsATorchFloat(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::FloatType>();
|
return isa<Torch::FloatType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchFloatTypeGet(MlirContext context) {
|
MlirType torchMlirTorchFloatTypeGet(MlirContext context) {
|
||||||
|
@ -226,7 +226,7 @@ MlirTypeID torchMlirTorchFloatTypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchLinearParams(MlirType t) {
|
bool torchMlirTypeIsATorchLinearParams(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::LinearParamsType>();
|
return isa<Torch::LinearParamsType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context) {
|
MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context) {
|
||||||
|
@ -242,7 +242,7 @@ MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchQInt8(MlirType t) {
|
bool torchMlirTypeIsATorchQInt8(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::QInt8Type>();
|
return isa<Torch::QInt8Type>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchQInt8TypeGet(MlirContext context) {
|
MlirType torchMlirTorchQInt8TypeGet(MlirContext context) {
|
||||||
|
@ -258,7 +258,7 @@ MlirTypeID torchMlirTorchQInt8TypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchQUInt8(MlirType t) {
|
bool torchMlirTypeIsATorchQUInt8(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::QUInt8Type>();
|
return isa<Torch::QUInt8Type>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchQUInt8TypeGet(MlirContext context) {
|
MlirType torchMlirTorchQUInt8TypeGet(MlirContext context) {
|
||||||
|
@ -274,7 +274,7 @@ MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchNonValueTensor(MlirType t) {
|
bool torchMlirTypeIsATorchNonValueTensor(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::NonValueTensorType>();
|
return isa<Torch::NonValueTensorType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
|
MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
|
||||||
|
@ -341,7 +341,7 @@ MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchValueTensor(MlirType t) {
|
bool torchMlirTypeIsATorchValueTensor(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::ValueTensorType>();
|
return isa<Torch::ValueTensorType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
|
MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
|
||||||
|
@ -408,7 +408,7 @@ MlirTypeID torchMlirTorchValueTensorTypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchNone(MlirType t) {
|
bool torchMlirTypeIsATorchNone(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::NoneType>();
|
return isa<Torch::NoneType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchNoneTypeGet(MlirContext context) {
|
MlirType torchMlirTorchNoneTypeGet(MlirContext context) {
|
||||||
|
@ -424,7 +424,7 @@ MlirTypeID torchMlirTorchNoneTypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchString(MlirType t) {
|
bool torchMlirTypeIsATorchString(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::StringType>();
|
return isa<Torch::StringType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchStringTypeGet(MlirContext context) {
|
MlirType torchMlirTorchStringTypeGet(MlirContext context) {
|
||||||
|
@ -440,7 +440,7 @@ MlirTypeID torchMlirTorchStringTypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchAny(MlirType t) {
|
bool torchMlirTypeIsATorchAny(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::AnyType>();
|
return isa<Torch::AnyType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchAnyTypeGet(MlirContext context) {
|
MlirType torchMlirTorchAnyTypeGet(MlirContext context) {
|
||||||
|
@ -456,7 +456,7 @@ MlirTypeID torchMlirTorchAnyTypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchNumber(MlirType t) {
|
bool torchMlirTypeIsATorchNumber(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::NumberType>();
|
return isa<Torch::NumberType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchNumberTypeGet(MlirContext context) {
|
MlirType torchMlirTorchNumberTypeGet(MlirContext context) {
|
||||||
|
@ -472,7 +472,7 @@ MlirTypeID torchMlirTorchNumberTypeGetTypeID() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool torchMlirTypeIsATorchDict(MlirType t) {
|
bool torchMlirTypeIsATorchDict(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::DictType>();
|
return isa<Torch::DictType>(unwrap(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) {
|
MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) {
|
||||||
|
|
|
@ -546,12 +546,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
Value shuffledPaddingList =
|
Value shuffledPaddingList =
|
||||||
createConstantIntList(binder, rewriter, padding);
|
createConstantIntList(binder, rewriter, padding);
|
||||||
Value zero;
|
Value zero;
|
||||||
if (resultTypeOut.getDtype().isa<FloatType>()) {
|
if (isa<FloatType>(resultTypeOut.getDtype())) {
|
||||||
zero = rewriter.create<Torch::ConstantFloatOp>(
|
zero = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||||
rewriter.getF64FloatAttr(
|
rewriter.getF64FloatAttr(
|
||||||
std::numeric_limits<double>::lowest()));
|
std::numeric_limits<double>::lowest()));
|
||||||
} else if (resultTypeOut.getDtype().isa<IntegerType>()) {
|
} else if (isa<IntegerType>(resultTypeOut.getDtype())) {
|
||||||
zero = rewriter.create<Torch::ConstantIntOp>(
|
zero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(
|
binder.getLoc(), rewriter.getI64IntegerAttr(
|
||||||
std::numeric_limits<int64_t>::lowest()));
|
std::numeric_limits<int64_t>::lowest()));
|
||||||
|
@ -1295,7 +1295,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
binder.tensorResultType(resultType))
|
binder.tensorResultType(resultType))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto inputTensorType = operand.getType().cast<Torch::ValueTensorType>();
|
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
|
||||||
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "Expected input type having sizes");
|
binder.op, "Expected input type having sizes");
|
||||||
|
@ -1509,10 +1509,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
|
|
||||||
if (!constantValue) {
|
if (!constantValue) {
|
||||||
auto dataTensorType = cast<Torch::ValueTensorType>(data.getType());
|
auto dataTensorType = cast<Torch::ValueTensorType>(data.getType());
|
||||||
if (dataTensorType.getDtype().isa<IntegerType>())
|
if (isa<IntegerType>(dataTensorType.getDtype()))
|
||||||
constantValue = rewriter.create<Torch::ConstantIntOp>(
|
constantValue = rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(0));
|
loc, rewriter.getI64IntegerAttr(0));
|
||||||
if (dataTensorType.getDtype().isa<FloatType>())
|
if (isa<FloatType>(dataTensorType.getDtype()))
|
||||||
constantValue = rewriter.create<Torch::ConstantFloatOp>(
|
constantValue = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
loc, rewriter.getF64FloatAttr(0.0f));
|
loc, rewriter.getF64FloatAttr(0.0f));
|
||||||
|
|
||||||
|
|
|
@ -1023,9 +1023,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||||
Value constFalse =
|
Value constFalse =
|
||||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||||
auto size = data.getType()
|
auto size =
|
||||||
.dyn_cast<Torch::ValueTensorType>()
|
dyn_cast<Torch::ValueTensorType>(data.getType()).getOptionalSizes();
|
||||||
.getOptionalSizes();
|
|
||||||
auto f64ResultType = rewriter.getType<Torch::ValueTensorType>(
|
auto f64ResultType = rewriter.getType<Torch::ValueTensorType>(
|
||||||
size, rewriter.getF64Type());
|
size, rewriter.getF64Type());
|
||||||
Value dataCast = rewriter.create<Torch::AtenToDtypeOp>(
|
Value dataCast = rewriter.create<Torch::AtenToDtypeOp>(
|
||||||
|
@ -2906,8 +2905,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
scalesValueList = noneVal;
|
scalesValueList = noneVal;
|
||||||
sizesValueList = getValueList(sizeOperand);
|
sizesValueList = getValueList(sizeOperand);
|
||||||
}
|
}
|
||||||
if (scalesValueList.getType().isa<Torch::NoneType>() &&
|
if (isa<Torch::NoneType>(scalesValueList.getType()) &&
|
||||||
sizesValueList.getType().isa<Torch::NoneType>()) {
|
isa<Torch::NoneType>(sizesValueList.getType())) {
|
||||||
return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode");
|
return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode");
|
||||||
}
|
}
|
||||||
rewriter
|
rewriter
|
||||||
|
|
|
@ -1868,9 +1868,8 @@ public:
|
||||||
const TypeConverter *typeConverter = getTypeConverter();
|
const TypeConverter *typeConverter = getTypeConverter();
|
||||||
|
|
||||||
auto input = adaptor.getSelf();
|
auto input = adaptor.getSelf();
|
||||||
RankedTensorType resultType =
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
typeConverter->convertType(op->getResult(0).getType())
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
|
|
||||||
SmallVector<Value> resultShape;
|
SmallVector<Value> resultShape;
|
||||||
SmallVector<Value> offsets;
|
SmallVector<Value> offsets;
|
||||||
|
@ -2107,9 +2106,8 @@ public:
|
||||||
|
|
||||||
auto input = adaptor.getSelf();
|
auto input = adaptor.getSelf();
|
||||||
|
|
||||||
RankedTensorType resultType =
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
typeConverter->convertType(op->getResult(0).getType())
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
|
|
||||||
SmallVector<Value> resultShape;
|
SmallVector<Value> resultShape;
|
||||||
SmallVector<Value> offsets;
|
SmallVector<Value> offsets;
|
||||||
|
@ -2343,9 +2341,8 @@ public:
|
||||||
op, "diagonal dimensions cannot be identical");
|
op, "diagonal dimensions cannot be identical");
|
||||||
|
|
||||||
Type elementType = inputType.getElementType();
|
Type elementType = inputType.getElementType();
|
||||||
RankedTensorType outputType = getTypeConverter()
|
RankedTensorType outputType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
Value dim1Size, dim2Size;
|
Value dim1Size, dim2Size;
|
||||||
|
@ -2581,9 +2578,8 @@ public:
|
||||||
})
|
})
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
|
|
||||||
RankedTensorType resultType = getTypeConverter()
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, resultTensor);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, resultTensor);
|
||||||
return success();
|
return success();
|
||||||
|
@ -2608,9 +2604,8 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
// Conversion is completed specified by information in the sparse tensor
|
// Conversion is completed specified by information in the sparse tensor
|
||||||
// type. Thus, we can rewrite all legalizedNames to the same construct.
|
// type. Thus, we can rewrite all legalizedNames to the same construct.
|
||||||
RankedTensorType resultType = getTypeConverter()
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>(
|
rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>(
|
||||||
op, resultType, adaptor.getOperands()[0]);
|
op, resultType, adaptor.getOperands()[0]);
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -845,7 +845,7 @@ public:
|
||||||
outputSizeIntValues = getTypeConvertedValues(
|
outputSizeIntValues = getTypeConvertedValues(
|
||||||
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
|
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
|
||||||
|
|
||||||
if (!op.getScalesH().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getScalesH().getType())) {
|
||||||
// Convert float values to int values.
|
// Convert float values to int values.
|
||||||
// int_value = (int64_t)ceil(float_value)
|
// int_value = (int64_t)ceil(float_value)
|
||||||
Value ceilVal = rewriter.create<math::CeilOp>(loc, adaptor.getScalesH());
|
Value ceilVal = rewriter.create<math::CeilOp>(loc, adaptor.getScalesH());
|
||||||
|
@ -858,7 +858,7 @@ public:
|
||||||
scaleFactorsInt.push_back(scaleFactorVal);
|
scaleFactorsInt.push_back(scaleFactorVal);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!op.getScalesW().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getScalesW().getType())) {
|
||||||
// Convert float values to int values.
|
// Convert float values to int values.
|
||||||
// int_value = (int64_t)ceil(float_value)
|
// int_value = (int64_t)ceil(float_value)
|
||||||
Value ceilVal = rewriter.create<math::CeilOp>(loc, adaptor.getScalesW());
|
Value ceilVal = rewriter.create<math::CeilOp>(loc, adaptor.getScalesW());
|
||||||
|
@ -1006,7 +1006,7 @@ public:
|
||||||
unsigned hDimOffset = 2;
|
unsigned hDimOffset = 2;
|
||||||
|
|
||||||
SmallVector<Value, 2> scaleFactorsFloatValues;
|
SmallVector<Value, 2> scaleFactorsFloatValues;
|
||||||
if (!op.getScalesH().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getScalesH().getType())) {
|
||||||
scaleFactorsFloatValues.push_back(adaptor.getScalesH());
|
scaleFactorsFloatValues.push_back(adaptor.getScalesH());
|
||||||
} else {
|
} else {
|
||||||
auto scaleFactorVal = rewriter.create<arith::DivFOp>(
|
auto scaleFactorVal = rewriter.create<arith::DivFOp>(
|
||||||
|
@ -1019,7 +1019,7 @@ public:
|
||||||
scaleFactorsFloatValues.push_back(scaleFactorVal);
|
scaleFactorsFloatValues.push_back(scaleFactorVal);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!op.getScalesW().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getScalesW().getType())) {
|
||||||
scaleFactorsFloatValues.push_back(adaptor.getScalesW());
|
scaleFactorsFloatValues.push_back(adaptor.getScalesW());
|
||||||
} else {
|
} else {
|
||||||
auto scaleFactorVal = rewriter.create<arith::DivFOp>(
|
auto scaleFactorVal = rewriter.create<arith::DivFOp>(
|
||||||
|
|
|
@ -41,7 +41,7 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
|
||||||
return;
|
return;
|
||||||
int64_t minSI = -(1 << (numBits - 1));
|
int64_t minSI = -(1 << (numBits - 1));
|
||||||
Value minSIValue = rewriter.create<arith::ConstantIntOp>(
|
Value minSIValue = rewriter.create<arith::ConstantIntOp>(
|
||||||
loc, minSI, zp.getType().cast<mlir::IntegerType>().getWidth());
|
loc, minSI, cast<mlir::IntegerType>(zp.getType()).getWidth());
|
||||||
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
|
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
|
||||||
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
|
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
|
||||||
arg = torch_to_linalg::createElementwiseLinalgGeneric(
|
arg = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||||
|
@ -1057,10 +1057,10 @@ public:
|
||||||
loc, getAsOpFoldResult(outDims), accumulatorDType);
|
loc, getAsOpFoldResult(outDims), accumulatorDType);
|
||||||
|
|
||||||
Value outputTensor;
|
Value outputTensor;
|
||||||
if (accumulatorDType != resultDTy && !bias.getType().isa<Torch::NoneType>())
|
if (accumulatorDType != resultDTy && !isa<Torch::NoneType>(bias.getType()))
|
||||||
bias = torch_to_linalg::convertTensorToElementType(rewriter, loc, bias,
|
bias = torch_to_linalg::convertTensorToElementType(rewriter, loc, bias,
|
||||||
accumulatorDType);
|
accumulatorDType);
|
||||||
if (bias.getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(bias.getType())) {
|
||||||
Value c0;
|
Value c0;
|
||||||
if (isa<mlir::FloatType>(accumulatorDType)) {
|
if (isa<mlir::FloatType>(accumulatorDType)) {
|
||||||
c0 = rewriter.create<arith::ConstantOp>(
|
c0 = rewriter.create<arith::ConstantOp>(
|
||||||
|
|
|
@ -409,10 +409,8 @@ public:
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
RankedTensorType selfType = cast<RankedTensorType>(self.getType());
|
RankedTensorType selfType = cast<RankedTensorType>(self.getType());
|
||||||
Type elementType = selfType.getElementType();
|
Type elementType = selfType.getElementType();
|
||||||
RankedTensorType indicesRankedTensorType =
|
RankedTensorType indicesRankedTensorType = cast<RankedTensorType>(
|
||||||
getTypeConverter()
|
getTypeConverter()->convertType(op->getResult(1).getType()));
|
||||||
->convertType(op->getResult(1).getType())
|
|
||||||
.cast<RankedTensorType>();
|
|
||||||
|
|
||||||
// TODO: Add support for 3D inputs.
|
// TODO: Add support for 3D inputs.
|
||||||
if (selfType.getRank() == 3)
|
if (selfType.getRank() == 3)
|
||||||
|
@ -717,10 +715,10 @@ public:
|
||||||
|
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
const TypeConverter *typeConverter = opConversionPattern.getTypeConverter();
|
const TypeConverter *typeConverter = opConversionPattern.getTypeConverter();
|
||||||
outputType = typeConverter->convertType(op.getResult0().getType())
|
outputType = cast<RankedTensorType>(
|
||||||
.template cast<RankedTensorType>();
|
typeConverter->convertType(op.getResult0().getType()));
|
||||||
auxTensorType = typeConverter->convertType(op.getResult1().getType())
|
auxTensorType = cast<RankedTensorType>(
|
||||||
.template cast<RankedTensorType>();
|
typeConverter->convertType(op.getResult1().getType()));
|
||||||
Type auxTensorElementType = auxTensorType.getElementType();
|
Type auxTensorElementType = auxTensorType.getElementType();
|
||||||
auto smallestFPValueAttr = rewriter.getFloatAttr(
|
auto smallestFPValueAttr = rewriter.getFloatAttr(
|
||||||
elementType,
|
elementType,
|
||||||
|
@ -799,8 +797,8 @@ public:
|
||||||
|
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
const TypeConverter *typeConverter = opConversionPattern.getTypeConverter();
|
const TypeConverter *typeConverter = opConversionPattern.getTypeConverter();
|
||||||
outputType = typeConverter->convertType(op.getResult().getType())
|
outputType = cast<RankedTensorType>(
|
||||||
.template cast<RankedTensorType>();
|
typeConverter->convertType(op.getResult().getType()));
|
||||||
buffVal = rewriter.create<arith::ConstantOp>(
|
buffVal = rewriter.create<arith::ConstantOp>(
|
||||||
loc, elementType, rewriter.getFloatAttr(elementType, 0));
|
loc, elementType, rewriter.getFloatAttr(elementType, 0));
|
||||||
auxTensor = rewriter.create<tensor::EmptyOp>(
|
auxTensor = rewriter.create<tensor::EmptyOp>(
|
||||||
|
|
|
@ -42,9 +42,8 @@ public:
|
||||||
|
|
||||||
if (train)
|
if (train)
|
||||||
return failure();
|
return failure();
|
||||||
auto resultType = getTypeConverter()
|
auto resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
|
||||||
adaptor.getInput());
|
adaptor.getInput());
|
||||||
return success();
|
return success();
|
||||||
|
@ -60,8 +59,8 @@ static Value toLinearIndex(OpBuilder &b, Location loc,
|
||||||
Value result =
|
Value result =
|
||||||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(b.getI64Type()));
|
b.create<arith::ConstantOp>(loc, b.getZeroAttr(b.getI64Type()));
|
||||||
for (auto [index, stride] : llvm::zip(indicesIntValues, shapeIntValues)) {
|
for (auto [index, stride] : llvm::zip(indicesIntValues, shapeIntValues)) {
|
||||||
assert(index.getType().isa<mlir::IntegerType>() &&
|
assert(isa<mlir::IntegerType>(index.getType()) &&
|
||||||
stride.getType().isa<mlir::IntegerType>() &&
|
isa<mlir::IntegerType>(stride.getType()) &&
|
||||||
"Input arrays to `toLinearIndex` must only contain values of type "
|
"Input arrays to `toLinearIndex` must only contain values of type "
|
||||||
"`mlir::IntegerType`");
|
"`mlir::IntegerType`");
|
||||||
Value mul = b.create<arith::MulIOp>(loc, result, stride);
|
Value mul = b.create<arith::MulIOp>(loc, result, stride);
|
||||||
|
@ -129,7 +128,7 @@ public:
|
||||||
if (!isa<mlir::FloatType>(elemTy))
|
if (!isa<mlir::FloatType>(elemTy))
|
||||||
return rewriter.notifyMatchFailure(op, "This op only support float type");
|
return rewriter.notifyMatchFailure(op, "This op only support float type");
|
||||||
|
|
||||||
if (!generator.getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(generator.getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "The generator has to be None because only global default "
|
op, "The generator has to be None because only global default "
|
||||||
"generator is supported");
|
"generator is supported");
|
||||||
|
@ -180,7 +179,7 @@ public:
|
||||||
b.create<arith::MulFOp>(loc, updateFloat, scale);
|
b.create<arith::MulFOp>(loc, updateFloat, scale);
|
||||||
Value res = b.create<arith::AddFOp>(loc, updateScaled, min);
|
Value res = b.create<arith::AddFOp>(loc, updateScaled, min);
|
||||||
Value truncRes = res;
|
Value truncRes = res;
|
||||||
if (elemTy.isa<Float16Type, Float32Type>())
|
if (isa<Float16Type, Float32Type>(elemTy))
|
||||||
truncRes = b.create<arith::TruncFOp>(loc, elemTy, res);
|
truncRes = b.create<arith::TruncFOp>(loc, elemTy, res);
|
||||||
b.create<linalg::YieldOp>(loc, truncRes);
|
b.create<linalg::YieldOp>(loc, truncRes);
|
||||||
})
|
})
|
||||||
|
|
|
@ -86,11 +86,8 @@ public:
|
||||||
bool isUnsigned = false;
|
bool isUnsigned = false;
|
||||||
if (!isa<mlir::FloatType>(inElementType)) {
|
if (!isa<mlir::FloatType>(inElementType)) {
|
||||||
if (isa<mlir::IntegerType>(inElementType)) {
|
if (isa<mlir::IntegerType>(inElementType)) {
|
||||||
auto integerTy = op.getSelf()
|
auto integerTy = dyn_cast<mlir::IntegerType>(
|
||||||
.getType()
|
cast<BaseTensorType>(op.getSelf().getType()).getDtype());
|
||||||
.template cast<BaseTensorType>()
|
|
||||||
.getDtype()
|
|
||||||
.template dyn_cast<mlir::IntegerType>();
|
|
||||||
isUnsigned = integerTy.isUnsigned();
|
isUnsigned = integerTy.isUnsigned();
|
||||||
} else {
|
} else {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -280,7 +277,7 @@ public:
|
||||||
|
|
||||||
static Value createAbsOpForNormOps(OpBuilder &b, Location loc, Value elem,
|
static Value createAbsOpForNormOps(OpBuilder &b, Location loc, Value elem,
|
||||||
Type resultElementType) {
|
Type resultElementType) {
|
||||||
if (elem.getType().isa<mlir::ComplexType>()) {
|
if (isa<mlir::ComplexType>(elem.getType())) {
|
||||||
return b.create<complex::AbsOp>(loc, elem);
|
return b.create<complex::AbsOp>(loc, elem);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -376,11 +373,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
|
||||||
if (isa<mlir::FloatType>(resultElementType))
|
if (isa<mlir::FloatType>(resultElementType))
|
||||||
return b.create<arith::MaximumFOp>(loc, self, result);
|
return b.create<arith::MaximumFOp>(loc, self, result);
|
||||||
else if (isa<mlir::IntegerType>(resultElementType)) {
|
else if (isa<mlir::IntegerType>(resultElementType)) {
|
||||||
IntegerType intType = max.getSelf()
|
IntegerType intType = dyn_cast<mlir::IntegerType>(
|
||||||
.getType()
|
cast<BaseTensorType>(max.getSelf().getType()).getDtype());
|
||||||
.cast<BaseTensorType>()
|
|
||||||
.getDtype()
|
|
||||||
.dyn_cast<mlir::IntegerType>();
|
|
||||||
if (intType.isUnsigned())
|
if (intType.isUnsigned())
|
||||||
return b.create<arith::MaxUIOp>(loc, self, result);
|
return b.create<arith::MaxUIOp>(loc, self, result);
|
||||||
if (intType.isSigned())
|
if (intType.isSigned())
|
||||||
|
@ -393,11 +387,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
|
||||||
if (isa<mlir::FloatType>(resultElementType))
|
if (isa<mlir::FloatType>(resultElementType))
|
||||||
return b.create<arith::MinimumFOp>(loc, self, result);
|
return b.create<arith::MinimumFOp>(loc, self, result);
|
||||||
else if (isa<mlir::IntegerType>(resultElementType)) {
|
else if (isa<mlir::IntegerType>(resultElementType)) {
|
||||||
IntegerType intType = min.getSelf()
|
IntegerType intType = dyn_cast<mlir::IntegerType>(
|
||||||
.getType()
|
cast<BaseTensorType>(min.getSelf().getType()).getDtype());
|
||||||
.cast<BaseTensorType>()
|
|
||||||
.getDtype()
|
|
||||||
.dyn_cast<mlir::IntegerType>();
|
|
||||||
if (intType.isUnsigned())
|
if (intType.isUnsigned())
|
||||||
return b.create<arith::MinUIOp>(loc, self, result);
|
return b.create<arith::MinUIOp>(loc, self, result);
|
||||||
if (intType.isSigned())
|
if (intType.isSigned())
|
||||||
|
@ -657,9 +648,8 @@ public:
|
||||||
return opInfo;
|
return opInfo;
|
||||||
|
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
auto resultType = getTypeConverter()
|
auto resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
Type elemType = resultType.getElementType();
|
Type elemType = resultType.getElementType();
|
||||||
LogicalResult elemTypeCheck =
|
LogicalResult elemTypeCheck =
|
||||||
validateReductionElementType(op, elemType, rewriter);
|
validateReductionElementType(op, elemType, rewriter);
|
||||||
|
|
|
@ -179,15 +179,13 @@ public:
|
||||||
|
|
||||||
for (auto i : {TOP, VCENTER, BOTTOM}) {
|
for (auto i : {TOP, VCENTER, BOTTOM}) {
|
||||||
for (auto j : {LEFT, HCENTER, RIGHT}) {
|
for (auto j : {LEFT, HCENTER, RIGHT}) {
|
||||||
auto constVtile{
|
auto constVtile{dyn_cast_or_null<mlir::IntegerAttr>(
|
||||||
mlir::dyn_cast<mlir::arith::ConstantOp>(vTile[i].getDefiningOp())
|
mlir::dyn_cast<mlir::arith::ConstantOp>(vTile[i].getDefiningOp())
|
||||||
.getValue()
|
.getValue())};
|
||||||
.dyn_cast_or_null<mlir::IntegerAttr>()};
|
|
||||||
|
|
||||||
auto constHtile{
|
auto constHtile{dyn_cast_or_null<mlir::IntegerAttr>(
|
||||||
mlir::dyn_cast<mlir::arith::ConstantOp>(hTile[j].getDefiningOp())
|
mlir::dyn_cast<mlir::arith::ConstantOp>(hTile[j].getDefiningOp())
|
||||||
.getValue()
|
.getValue())};
|
||||||
.dyn_cast_or_null<mlir::IntegerAttr>()};
|
|
||||||
auto vSize = constVtile.getInt();
|
auto vSize = constVtile.getInt();
|
||||||
auto hSize = constHtile.getInt();
|
auto hSize = constHtile.getInt();
|
||||||
|
|
||||||
|
@ -369,8 +367,8 @@ public:
|
||||||
for (auto size : resultSize)
|
for (auto size : resultSize)
|
||||||
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
|
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
|
||||||
|
|
||||||
auto resultType = typeConverter->convertType(op.getType())
|
auto resultType =
|
||||||
.template cast<RankedTensorType>();
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
Type resultElementType;
|
Type resultElementType;
|
||||||
if (isa<Torch::NoneType>(op.getDtype().getType())) {
|
if (isa<Torch::NoneType>(op.getDtype().getType())) {
|
||||||
resultElementType = resultType.getElementType();
|
resultElementType = resultType.getElementType();
|
||||||
|
@ -426,7 +424,7 @@ public:
|
||||||
op, "unimplemented: pin_memory must be either None or false");
|
op, "unimplemented: pin_memory must be either None or false");
|
||||||
|
|
||||||
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
||||||
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType())) {
|
||||||
int64_t memoryFormat;
|
int64_t memoryFormat;
|
||||||
if (!matchPattern(op.getMemoryFormat(),
|
if (!matchPattern(op.getMemoryFormat(),
|
||||||
m_TorchConstantInt(&memoryFormat)))
|
m_TorchConstantInt(&memoryFormat)))
|
||||||
|
@ -441,7 +439,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Add support for device arg other than cpu.
|
// TODO: Add support for device arg other than cpu.
|
||||||
if (!op.getDevice().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getDevice().getType())) {
|
||||||
std::string device;
|
std::string device;
|
||||||
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
|
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -453,7 +451,7 @@ public:
|
||||||
|
|
||||||
// TODO: Add support for non-strided layout.
|
// TODO: Add support for non-strided layout.
|
||||||
// torch.layout is by default strided i.e. 0.
|
// torch.layout is by default strided i.e. 0.
|
||||||
if (!op.getLayout().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getLayout().getType())) {
|
||||||
int64_t tensorLayout;
|
int64_t tensorLayout;
|
||||||
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -478,7 +476,7 @@ public:
|
||||||
auto resultType =
|
auto resultType =
|
||||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
Type resultElementType;
|
Type resultElementType;
|
||||||
if (op.getDtype().getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(op.getDtype().getType())) {
|
||||||
resultElementType = getDefaultDtypeForTorchScalar(
|
resultElementType = getDefaultDtypeForTorchScalar(
|
||||||
Torch::FloatType::get(op->getContext()));
|
Torch::FloatType::get(op->getContext()));
|
||||||
} else {
|
} else {
|
||||||
|
@ -527,7 +525,7 @@ public:
|
||||||
|
|
||||||
// The pin_memory should be either `False` or `none`.
|
// The pin_memory should be either `False` or `none`.
|
||||||
bool pinMemory;
|
bool pinMemory;
|
||||||
if (!op.getPinMemory().getType().isa<Torch::NoneType>() &&
|
if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
|
||||||
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
||||||
pinMemory)) {
|
pinMemory)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -536,9 +534,8 @@ public:
|
||||||
|
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
const TypeConverter *typeConverter = this->getTypeConverter();
|
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||||
RankedTensorType resultType =
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
typeConverter->convertType(op->getResult(0).getType())
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
Type dtype = resultType.getElementType();
|
Type dtype = resultType.getElementType();
|
||||||
Value start =
|
Value start =
|
||||||
convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype);
|
convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype);
|
||||||
|
|
|
@ -138,17 +138,16 @@ public:
|
||||||
requires_grad = tensorFloatOp.getRequiresGrad();
|
requires_grad = tensorFloatOp.getRequiresGrad();
|
||||||
}
|
}
|
||||||
// TODO: Dtype conversion.
|
// TODO: Dtype conversion.
|
||||||
if (!dtype.getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(dtype.getType()))
|
||||||
return rewriter.notifyMatchFailure(op, "Unimplemented non-None dtype");
|
return rewriter.notifyMatchFailure(op, "Unimplemented non-None dtype");
|
||||||
|
|
||||||
// TODO: Device information.
|
// TODO: Device information.
|
||||||
if (!device.getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(device.getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Unimplemented non-None device information");
|
op, "Unimplemented non-None device information");
|
||||||
|
|
||||||
RankedTensorType resultType = getTypeConverter()
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
Type outElementType = resultType.getElementType();
|
Type outElementType = resultType.getElementType();
|
||||||
Value elemValProm =
|
Value elemValProm =
|
||||||
convertScalarToDtype(rewriter, loc, elemVal, outElementType);
|
convertScalarToDtype(rewriter, loc, elemVal, outElementType);
|
||||||
|
@ -171,9 +170,8 @@ public:
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
RankedTensorType resultType = getTypeConverter()
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
Type outElementType = resultType.getElementType();
|
Type outElementType = resultType.getElementType();
|
||||||
Value elemVal = adaptor.getA();
|
Value elemVal = adaptor.getA();
|
||||||
Value elemValProm =
|
Value elemValProm =
|
||||||
|
|
|
@ -422,7 +422,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
|
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
|
||||||
int64_t memoryFormat;
|
int64_t memoryFormat;
|
||||||
if (!clone.getMemoryFormat().getType().isa<Torch::NoneType>() &&
|
if (!isa<Torch::NoneType>(clone.getMemoryFormat().getType()) &&
|
||||||
(!matchPattern(clone.getMemoryFormat(),
|
(!matchPattern(clone.getMemoryFormat(),
|
||||||
m_TorchConstantInt(&memoryFormat)) ||
|
m_TorchConstantInt(&memoryFormat)) ||
|
||||||
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
|
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
|
||||||
|
@ -434,24 +434,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return payloadArgs[0];
|
return payloadArgs[0];
|
||||||
}
|
}
|
||||||
if (auto bitwiseAndTensor = dyn_cast<AtenBitwiseAndTensorOp>(op)) {
|
if (auto bitwiseAndTensor = dyn_cast<AtenBitwiseAndTensorOp>(op)) {
|
||||||
if (bitwiseAndTensor.getType()
|
if (isa<mlir::FloatType>(
|
||||||
.cast<ValueTensorType>()
|
cast<ValueTensorType>(bitwiseAndTensor.getType()).getDtype())) {
|
||||||
.getDtype()
|
|
||||||
.isa<mlir::FloatType>()) {
|
|
||||||
bitwiseAndTensor.emitError(
|
bitwiseAndTensor.emitError(
|
||||||
"Bitwise_And does not support floating point dtype");
|
"Bitwise_And does not support floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
Type dtype = converter->convertType(bitwiseAndTensor.getType())
|
Type dtype = cast<RankedTensorType>(
|
||||||
.cast<RankedTensorType>()
|
converter->convertType(bitwiseAndTensor.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
return b.create<arith::AndIOp>(loc, lhs, rhs);
|
return b.create<arith::AndIOp>(loc, lhs, rhs);
|
||||||
}
|
}
|
||||||
if (auto bitwiseAndScalar = dyn_cast<AtenBitwiseAndScalarOp>(op)) {
|
if (auto bitwiseAndScalar = dyn_cast<AtenBitwiseAndScalarOp>(op)) {
|
||||||
Type dtype = converter->convertType(bitwiseAndScalar.getType())
|
Type dtype = cast<RankedTensorType>(
|
||||||
.cast<RankedTensorType>()
|
converter->convertType(bitwiseAndScalar.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!isa<mlir::IntegerType>(dtype)) {
|
if (!isa<mlir::IntegerType>(dtype)) {
|
||||||
bitwiseAndScalar.emitError(
|
bitwiseAndScalar.emitError(
|
||||||
|
@ -469,32 +467,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return b.create<arith::AndIOp>(loc, self, other);
|
return b.create<arith::AndIOp>(loc, self, other);
|
||||||
}
|
}
|
||||||
if (auto bitwiseOrTensor = dyn_cast<AtenBitwiseOrTensorOp>(op)) {
|
if (auto bitwiseOrTensor = dyn_cast<AtenBitwiseOrTensorOp>(op)) {
|
||||||
if (bitwiseOrTensor.getType()
|
if (isa<mlir::FloatType>(
|
||||||
.cast<ValueTensorType>()
|
cast<ValueTensorType>(bitwiseOrTensor.getType()).getDtype())) {
|
||||||
.getDtype()
|
|
||||||
.isa<mlir::FloatType>()) {
|
|
||||||
bitwiseOrTensor.emitError(
|
bitwiseOrTensor.emitError(
|
||||||
"Bitwise_Or does not support floating point dtype");
|
"Bitwise_Or does not support floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
Type dtype = converter->convertType(bitwiseOrTensor.getType())
|
Type dtype = cast<RankedTensorType>(
|
||||||
.cast<RankedTensorType>()
|
converter->convertType(bitwiseOrTensor.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
return b.create<arith::OrIOp>(loc, lhs, rhs);
|
return b.create<arith::OrIOp>(loc, lhs, rhs);
|
||||||
}
|
}
|
||||||
if (auto bitwiseXorTensor = dyn_cast<AtenBitwiseXorTensorOp>(op)) {
|
if (auto bitwiseXorTensor = dyn_cast<AtenBitwiseXorTensorOp>(op)) {
|
||||||
if (bitwiseXorTensor.getType()
|
if (isa<mlir::FloatType>(
|
||||||
.cast<ValueTensorType>()
|
cast<ValueTensorType>(bitwiseXorTensor.getType()).getDtype())) {
|
||||||
.getDtype()
|
|
||||||
.isa<mlir::FloatType>()) {
|
|
||||||
bitwiseXorTensor.emitError(
|
bitwiseXorTensor.emitError(
|
||||||
"Bitwise_Xor does not support floating point dtype");
|
"Bitwise_Xor does not support floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
Type dtype = converter->convertType(bitwiseXorTensor.getType())
|
Type dtype = cast<RankedTensorType>(
|
||||||
.cast<RankedTensorType>()
|
converter->convertType(bitwiseXorTensor.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
|
@ -502,8 +496,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
if (auto bitwiseRightShiftTensor =
|
if (auto bitwiseRightShiftTensor =
|
||||||
dyn_cast<AtenBitwiseRightShiftTensorOp>(op)) {
|
dyn_cast<AtenBitwiseRightShiftTensorOp>(op)) {
|
||||||
Type dtype = converter->convertType(bitwiseRightShiftTensor.getType())
|
Type dtype = cast<RankedTensorType>(
|
||||||
.cast<RankedTensorType>()
|
converter->convertType(bitwiseRightShiftTensor.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!isa<mlir::IntegerType>(dtype)) {
|
if (!isa<mlir::IntegerType>(dtype)) {
|
||||||
bitwiseRightShiftTensor.emitError(
|
bitwiseRightShiftTensor.emitError(
|
||||||
|
@ -516,8 +510,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
if (auto bitwiseLeftShiftTensor =
|
if (auto bitwiseLeftShiftTensor =
|
||||||
dyn_cast<AtenBitwiseLeftShiftTensorOp>(op)) {
|
dyn_cast<AtenBitwiseLeftShiftTensorOp>(op)) {
|
||||||
Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType())
|
Type dtype = cast<RankedTensorType>(
|
||||||
.cast<RankedTensorType>()
|
converter->convertType(bitwiseLeftShiftTensor.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!isa<mlir::IntegerType>(dtype)) {
|
if (!isa<mlir::IntegerType>(dtype)) {
|
||||||
bitwiseLeftShiftTensor.emitError(
|
bitwiseLeftShiftTensor.emitError(
|
||||||
|
@ -557,7 +551,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return createEqual(b, loc, floatDtype, self, zero);
|
return createEqual(b, loc, floatDtype, self, zero);
|
||||||
}
|
}
|
||||||
if (isa<AtenAbsOp>(op)) {
|
if (isa<AtenAbsOp>(op)) {
|
||||||
if (payloadArgs[0].getType().isa<IntegerType>())
|
if (isa<IntegerType>(payloadArgs[0].getType()))
|
||||||
return b.create<math::AbsIOp>(loc, payloadArgs[0]);
|
return b.create<math::AbsIOp>(loc, payloadArgs[0]);
|
||||||
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
|
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
|
||||||
}
|
}
|
||||||
|
@ -653,20 +647,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return b.create<arith::SelectOp>(loc, cmp, arg, zeroPoint);
|
return b.create<arith::SelectOp>(loc, cmp, arg, zeroPoint);
|
||||||
}
|
}
|
||||||
if (auto round = dyn_cast<AtenRoundOp>(op)) {
|
if (auto round = dyn_cast<AtenRoundOp>(op)) {
|
||||||
if (!round.getType()
|
if (!isa<mlir::FloatType>(
|
||||||
.cast<ValueTensorType>()
|
cast<ValueTensorType>(round.getType()).getDtype())) {
|
||||||
.getDtype()
|
|
||||||
.isa<mlir::FloatType>()) {
|
|
||||||
round.emitError("unimplemented: non-floating point dtype");
|
round.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return b.create<math::RoundEvenOp>(loc, payloadArgs[0]);
|
return b.create<math::RoundEvenOp>(loc, payloadArgs[0]);
|
||||||
}
|
}
|
||||||
if (auto prelu = dyn_cast<AtenPreluOp>(op)) {
|
if (auto prelu = dyn_cast<AtenPreluOp>(op)) {
|
||||||
if (!prelu.getType()
|
if (!isa<mlir::FloatType>(
|
||||||
.cast<ValueTensorType>()
|
cast<ValueTensorType>(prelu.getType()).getDtype())) {
|
||||||
.getDtype()
|
|
||||||
.isa<mlir::FloatType>()) {
|
|
||||||
prelu.emitError("unimplemented: non-floating point dtype");
|
prelu.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -685,10 +675,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return b.create<arith::AddFOp>(loc, positivePart, scaledNegativePart);
|
return b.create<arith::AddFOp>(loc, positivePart, scaledNegativePart);
|
||||||
}
|
}
|
||||||
if (auto gelu = dyn_cast<AtenGeluOp>(op)) {
|
if (auto gelu = dyn_cast<AtenGeluOp>(op)) {
|
||||||
if (!gelu.getType()
|
if (!isa<mlir::FloatType>(
|
||||||
.cast<ValueTensorType>()
|
cast<ValueTensorType>(gelu.getType()).getDtype())) {
|
||||||
.getDtype()
|
|
||||||
.isa<mlir::FloatType>()) {
|
|
||||||
gelu.emitError("unimplemented: non-floating point dtype");
|
gelu.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -732,10 +720,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (auto geluBackward = dyn_cast<AtenGeluBackwardOp>(op)) {
|
if (auto geluBackward = dyn_cast<AtenGeluBackwardOp>(op)) {
|
||||||
if (!geluBackward.getType()
|
if (!isa<mlir::FloatType>(
|
||||||
.cast<ValueTensorType>()
|
cast<ValueTensorType>(geluBackward.getType()).getDtype())) {
|
||||||
.getDtype()
|
|
||||||
.isa<mlir::FloatType>()) {
|
|
||||||
geluBackward.emitError("unimplemented: non-floating point dtype");
|
geluBackward.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -770,10 +756,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
if (auto hardtanhBackward = dyn_cast<AtenHardtanhBackwardOp>(op)) {
|
if (auto hardtanhBackward = dyn_cast<AtenHardtanhBackwardOp>(op)) {
|
||||||
AtenHardtanhBackwardOp::Adaptor adaptor(operands);
|
AtenHardtanhBackwardOp::Adaptor adaptor(operands);
|
||||||
if (!hardtanhBackward.getType()
|
if (!isa<mlir::FloatType>(
|
||||||
.cast<ValueTensorType>()
|
cast<ValueTensorType>(hardtanhBackward.getType()).getDtype())) {
|
||||||
.getDtype()
|
|
||||||
.isa<mlir::FloatType>()) {
|
|
||||||
hardtanhBackward.emitError("unimplemented: non-floating point dtype");
|
hardtanhBackward.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -967,10 +951,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) {
|
if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) {
|
||||||
if (!pow.getType()
|
if (!isa<mlir::FloatType>(
|
||||||
.cast<ValueTensorType>()
|
cast<ValueTensorType>(pow.getType()).getDtype())) {
|
||||||
.getDtype()
|
|
||||||
.isa<mlir::FloatType>()) {
|
|
||||||
pow.emitError("unimplemented: non-floating point dtype");
|
pow.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -1047,10 +1029,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto lerp = dyn_cast<AtenLerpTensorOp>(op)) {
|
if (auto lerp = dyn_cast<AtenLerpTensorOp>(op)) {
|
||||||
if (!lerp.getType()
|
if (!isa<mlir::FloatType>(
|
||||||
.cast<ValueTensorType>()
|
cast<ValueTensorType>(lerp.getType()).getDtype())) {
|
||||||
.getDtype()
|
|
||||||
.isa<mlir::FloatType>()) {
|
|
||||||
lerp.emitError("unimplemented: non-floating point dtype");
|
lerp.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -1064,8 +1044,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
if (auto minimum = dyn_cast<AtenMinimumOp>(op)) {
|
if (auto minimum = dyn_cast<AtenMinimumOp>(op)) {
|
||||||
Type dtype = cast<BaseTensorType>(minimum.getType()).getDtype();
|
Type dtype = cast<BaseTensorType>(minimum.getType()).getDtype();
|
||||||
Type elemTy = converter->convertType(minimum.getType())
|
Type elemTy =
|
||||||
.cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(minimum.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy);
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy);
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy);
|
||||||
|
@ -1074,8 +1054,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
if (auto maximum = dyn_cast<AtenMaximumOp>(op)) {
|
if (auto maximum = dyn_cast<AtenMaximumOp>(op)) {
|
||||||
Type dtype = cast<BaseTensorType>(maximum.getType()).getDtype();
|
Type dtype = cast<BaseTensorType>(maximum.getType()).getDtype();
|
||||||
Type elemTy = converter->convertType(maximum.getType())
|
Type elemTy =
|
||||||
.cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(maximum.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy);
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy);
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy);
|
||||||
|
@ -1086,8 +1066,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
AtenClampOp::Adaptor adaptor(operands);
|
AtenClampOp::Adaptor adaptor(operands);
|
||||||
auto min = adaptor.getMin();
|
auto min = adaptor.getMin();
|
||||||
auto max = adaptor.getMax();
|
auto max = adaptor.getMax();
|
||||||
if (min.getType().isa<Torch::OptionalType>() ||
|
if (isa<Torch::OptionalType>(min.getType()) ||
|
||||||
max.getType().isa<Torch::OptionalType>()) {
|
isa<Torch::OptionalType>(max.getType())) {
|
||||||
clamp.emitError("unimplemented: runtime optional type");
|
clamp.emitError("unimplemented: runtime optional type");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -1125,9 +1105,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
};
|
};
|
||||||
|
|
||||||
auto result = payloadArgs[0];
|
auto result = payloadArgs[0];
|
||||||
if (!min.getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(min.getType()))
|
||||||
result = cmpSelect(result, min, /*getMax=*/false);
|
result = cmpSelect(result, min, /*getMax=*/false);
|
||||||
if (!max.getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(max.getType()))
|
||||||
result = cmpSelect(result, max, /*getMax=*/true);
|
result = cmpSelect(result, max, /*getMax=*/true);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -1135,8 +1115,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
AtenClampTensorOp::Adaptor adaptor(operands);
|
AtenClampTensorOp::Adaptor adaptor(operands);
|
||||||
auto min = adaptor.getMin();
|
auto min = adaptor.getMin();
|
||||||
auto max = adaptor.getMax();
|
auto max = adaptor.getMax();
|
||||||
if (min.getType().isa<Torch::OptionalType>() ||
|
if (isa<Torch::OptionalType>(min.getType()) ||
|
||||||
max.getType().isa<Torch::OptionalType>()) {
|
isa<Torch::OptionalType>(max.getType())) {
|
||||||
clampTensor.emitError("unimplemented: runtime optional type");
|
clampTensor.emitError("unimplemented: runtime optional type");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -1145,7 +1125,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
.getElementType();
|
.getElementType();
|
||||||
bool isMinNone = true;
|
bool isMinNone = true;
|
||||||
auto result = payloadArgs[0];
|
auto result = payloadArgs[0];
|
||||||
if (!min.getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(min.getType())) {
|
||||||
isMinNone = false;
|
isMinNone = false;
|
||||||
auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
Value pred;
|
Value pred;
|
||||||
|
@ -1163,7 +1143,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
result = b.create<arith::SelectOp>(loc, pred, minPromoted, result);
|
result = b.create<arith::SelectOp>(loc, pred, minPromoted, result);
|
||||||
}
|
}
|
||||||
if (!max.getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(max.getType())) {
|
||||||
max = isMinNone ? payloadArgs[1] : payloadArgs[2];
|
max = isMinNone ? payloadArgs[1] : payloadArgs[2];
|
||||||
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
|
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
|
||||||
Value pred;
|
Value pred;
|
||||||
|
@ -1252,8 +1232,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return b.create<arith::DivFOp>(loc, self, other);
|
return b.create<arith::DivFOp>(loc, self, other);
|
||||||
}
|
}
|
||||||
if (auto remScalar = dyn_cast<AtenRemainderScalarOp>(op)) {
|
if (auto remScalar = dyn_cast<AtenRemainderScalarOp>(op)) {
|
||||||
Type newResultType = converter->convertType(remScalar.getType())
|
Type newResultType =
|
||||||
.cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(remScalar.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
|
|
||||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
|
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
|
||||||
|
@ -1272,8 +1252,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
if (auto remTensor = dyn_cast<AtenRemainderTensorOp>(op)) {
|
if (auto remTensor = dyn_cast<AtenRemainderTensorOp>(op)) {
|
||||||
Type newResultType = converter->convertType(remTensor.getType())
|
Type newResultType =
|
||||||
.cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(remTensor.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
|
|
||||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
|
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
|
||||||
|
@ -1292,8 +1272,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
if (auto fmod = dyn_cast<AtenFmodTensorOp>(op)) {
|
if (auto fmod = dyn_cast<AtenFmodTensorOp>(op)) {
|
||||||
Type newResultType = converter->convertType(fmod.getType())
|
Type newResultType =
|
||||||
.cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(fmod.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
|
|
||||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
|
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
|
||||||
|
@ -1420,8 +1400,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto bitwiseNot = dyn_cast<AtenBitwiseNotOp>(op)) {
|
if (auto bitwiseNot = dyn_cast<AtenBitwiseNotOp>(op)) {
|
||||||
Type elementType = converter->convertType(bitwiseNot.getType())
|
Type elementType =
|
||||||
.cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(bitwiseNot.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (isa<mlir::FloatType>(elementType)) {
|
if (isa<mlir::FloatType>(elementType)) {
|
||||||
bitwiseNot.emitError("Bitwise_Not does not support floating point dtype");
|
bitwiseNot.emitError("Bitwise_Not does not support floating point dtype");
|
||||||
|
@ -1607,10 +1587,9 @@ public:
|
||||||
|
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
auto tensorOperands = llvm::to_vector<6>(llvm::make_filter_range(
|
auto tensorOperands = llvm::to_vector<6>(llvm::make_filter_range(
|
||||||
operands, [](Value v) { return v.getType().isa<RankedTensorType>(); }));
|
operands, [](Value v) { return isa<RankedTensorType>(v.getType()); }));
|
||||||
auto resultType = getTypeConverter()
|
auto resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
bool hadErrorCreatingPayload = false;
|
bool hadErrorCreatingPayload = false;
|
||||||
Value generic = torch_to_linalg::createElementwiseLinalgGeneric(
|
Value generic = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||||
rewriter, loc, tensorOperands, resultType.getElementType(),
|
rewriter, loc, tensorOperands, resultType.getElementType(),
|
||||||
|
@ -1657,7 +1636,7 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
||||||
|
|
||||||
// TODO: Incorporate the weight argument.
|
// TODO: Incorporate the weight argument.
|
||||||
if (!weight.getType().isa<mlir::torch::Torch::NoneType>())
|
if (!isa<mlir::torch::Torch::NoneType>(weight.getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Unimplemented, the weight operand is not incorporated.");
|
op, "Unimplemented, the weight operand is not incorporated.");
|
||||||
|
|
||||||
|
@ -1672,9 +1651,8 @@ public:
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected input and target to be rank <= 2");
|
op, "expected input and target to be rank <= 2");
|
||||||
}
|
}
|
||||||
RankedTensorType resultType = getTypeConverter()
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
Type elementType = resultType.getElementType();
|
Type elementType = resultType.getElementType();
|
||||||
|
|
||||||
Value zeroVal = rewriter.create<arith::ConstantOp>(
|
Value zeroVal = rewriter.create<arith::ConstantOp>(
|
||||||
|
@ -1948,7 +1926,7 @@ public:
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
Value target = adaptor.getTarget();
|
Value target = adaptor.getTarget();
|
||||||
Value weight = adaptor.getWeight();
|
Value weight = adaptor.getWeight();
|
||||||
bool weightIsNone = op.getWeight().getType().isa<Torch::NoneType>();
|
bool weightIsNone = isa<Torch::NoneType>(op.getWeight().getType());
|
||||||
Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.getIgnoreIndex());
|
Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.getIgnoreIndex());
|
||||||
Value totalWeight = adaptor.getTotalWeight();
|
Value totalWeight = adaptor.getTotalWeight();
|
||||||
|
|
||||||
|
@ -2069,9 +2047,8 @@ public:
|
||||||
})
|
})
|
||||||
->getResult(0);
|
->getResult(0);
|
||||||
|
|
||||||
RankedTensorType resultType = getTypeConverter()
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, gradInput);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, gradInput);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -2214,9 +2191,8 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(TensorStaticInfoCastOp op, OpAdaptor adaptor,
|
matchAndRewrite(TensorStaticInfoCastOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
RankedTensorType resultType = getTypeConverter()
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
|
||||||
adaptor.getOperand());
|
adaptor.getOperand());
|
||||||
return success();
|
return success();
|
||||||
|
@ -2243,7 +2219,7 @@ public:
|
||||||
if (succeeded(checkNotNone(rewriter, op, eps)))
|
if (succeeded(checkNotNone(rewriter, op, eps)))
|
||||||
handleEps = true;
|
handleEps = true;
|
||||||
|
|
||||||
if (handleEps && !eps.getType().isa<mlir::FloatType>()) {
|
if (handleEps && !isa<mlir::FloatType>(eps.getType())) {
|
||||||
op.emitError("Logit does not support non-floating point type");
|
op.emitError("Logit does not support non-floating point type");
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
@ -2317,9 +2293,8 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenIntReprOp op, OpAdaptor adaptor,
|
matchAndRewrite(AtenIntReprOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
RankedTensorType resultType = getTypeConverter()
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
|
||||||
adaptor.getSelf());
|
adaptor.getSelf());
|
||||||
return success();
|
return success();
|
||||||
|
@ -2362,8 +2337,8 @@ public:
|
||||||
zeropoint = converter->materializeTargetConversion(
|
zeropoint = converter->materializeTargetConversion(
|
||||||
rewriter, loc, converter->convertType(zeropoint.getType()), zeropoint);
|
rewriter, loc, converter->convertType(zeropoint.getType()), zeropoint);
|
||||||
|
|
||||||
auto resultType = converter->convertType(op->getResult(0).getType())
|
auto resultType = cast<RankedTensorType>(
|
||||||
.cast<RankedTensorType>();
|
converter->convertType(op->getResult(0).getType()));
|
||||||
|
|
||||||
llvm::SmallVector<Value> dynSizes;
|
llvm::SmallVector<Value> dynSizes;
|
||||||
for (auto [index, dim] : llvm::enumerate(resultType.getShape())) {
|
for (auto [index, dim] : llvm::enumerate(resultType.getShape())) {
|
||||||
|
@ -2553,9 +2528,8 @@ public:
|
||||||
return res;
|
return res;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto resultType = getTypeConverter()
|
auto resultType = cast<RankedTensorType>(
|
||||||
->convertType(op.getResult().getType())
|
getTypeConverter()->convertType(op.getResult().getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
SmallVector<Value> resultSize{};
|
SmallVector<Value> resultSize{};
|
||||||
if (resultType.isDynamicDim(0))
|
if (resultType.isDynamicDim(0))
|
||||||
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
|
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
|
||||||
|
@ -2675,7 +2649,7 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
|
||||||
SmallVector<Value> scaleValues,
|
SmallVector<Value> scaleValues,
|
||||||
std::string coordStr) {
|
std::string coordStr) {
|
||||||
|
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
auto inputRank = inputType.getRank();
|
auto inputRank = inputType.getRank();
|
||||||
|
|
||||||
SmallVector<Value> indices;
|
SmallVector<Value> indices;
|
||||||
|
@ -2725,7 +2699,7 @@ static Value BilinearInterpolate(OpBuilder &b,
|
||||||
SmallVector<Value> scaleValues,
|
SmallVector<Value> scaleValues,
|
||||||
std::string coordStr) {
|
std::string coordStr) {
|
||||||
unsigned dimOffset = 2;
|
unsigned dimOffset = 2;
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
auto inputRank = inputType.getRank();
|
auto inputRank = inputType.getRank();
|
||||||
|
|
||||||
Value cstOneEps =
|
Value cstOneEps =
|
||||||
|
@ -2877,7 +2851,7 @@ public:
|
||||||
|
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
Value input = adaptor.getInput();
|
Value input = adaptor.getInput();
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
auto inputRank = inputType.getRank();
|
auto inputRank = inputType.getRank();
|
||||||
if (mode.substr(0, 8) == "bilinear" && inputRank != 4)
|
if (mode.substr(0, 8) == "bilinear" && inputRank != 4)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -2893,7 +2867,7 @@ public:
|
||||||
loc, rewriter.getIntegerType(64), inputSize));
|
loc, rewriter.getIntegerType(64), inputSize));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!op.getScaleFactor().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getScaleFactor().getType())) {
|
||||||
bool recompScale;
|
bool recompScale;
|
||||||
if (!matchPattern(op.getRecomputeScaleFactor(),
|
if (!matchPattern(op.getRecomputeScaleFactor(),
|
||||||
m_TorchConstantBool(&recompScale)))
|
m_TorchConstantBool(&recompScale)))
|
||||||
|
|
|
@ -52,7 +52,7 @@ Value torch_to_linalg::getPaddedTensor(
|
||||||
Value torch_to_linalg::getZeroPaddedTensor(
|
Value torch_to_linalg::getZeroPaddedTensor(
|
||||||
Operation *op, OpBuilder &b, Value &input,
|
Operation *op, OpBuilder &b, Value &input,
|
||||||
SmallVectorImpl<int64_t> &paddingInts) {
|
SmallVectorImpl<int64_t> &paddingInts) {
|
||||||
assert(input.getType().isa<RankedTensorType>() &&
|
assert(isa<RankedTensorType>(input.getType()) &&
|
||||||
"input must be RankedTensorType");
|
"input must be RankedTensorType");
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
Value c0 = b.create<arith::ConstantOp>(
|
Value c0 = b.create<arith::ConstantOp>(
|
||||||
|
@ -67,7 +67,7 @@ Value torch_to_linalg::getZeroPaddedTensor(
|
||||||
Value torch_to_linalg::getDynamicZeroPaddedTensor(
|
Value torch_to_linalg::getDynamicZeroPaddedTensor(
|
||||||
Operation *op, OpBuilder &b, Value &input, SmallVectorImpl<Value> &padding,
|
Operation *op, OpBuilder &b, Value &input, SmallVectorImpl<Value> &padding,
|
||||||
int unpaddedDims, Value pad) {
|
int unpaddedDims, Value pad) {
|
||||||
assert(input.getType().isa<RankedTensorType>() &&
|
assert(isa<RankedTensorType>(input.getType()) &&
|
||||||
"input must be RankedTensorType");
|
"input must be RankedTensorType");
|
||||||
unsigned int inRank = cast<RankedTensorType>(input.getType()).getRank();
|
unsigned int inRank = cast<RankedTensorType>(input.getType()).getRank();
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
|
|
|
@ -252,7 +252,7 @@ public:
|
||||||
// "block" arguments
|
// "block" arguments
|
||||||
for (const auto &barg : enumerate(op.getRegion().front().getArguments())) {
|
for (const auto &barg : enumerate(op.getRegion().front().getArguments())) {
|
||||||
Value to = block->getArgument(barg.index());
|
Value to = block->getArgument(barg.index());
|
||||||
if (to.getType().isa<mlir::IndexType>())
|
if (isa<mlir::IndexType>(to.getType()))
|
||||||
to =
|
to =
|
||||||
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI64Type(), to);
|
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI64Type(), to);
|
||||||
Type targetType = to.getType();
|
Type targetType = to.getType();
|
||||||
|
|
|
@ -146,9 +146,9 @@ public:
|
||||||
if (!selfType) {
|
if (!selfType) {
|
||||||
return op.emitError("only Tensor types supported in StableHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
}
|
}
|
||||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto outType = cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template cast<TensorType>();
|
op.getType()));
|
||||||
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
|
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
|
||||||
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
|
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
|
||||||
return success();
|
return success();
|
||||||
|
@ -203,9 +203,9 @@ public:
|
||||||
auto selfTy = cast<TensorType>(self.getType());
|
auto selfTy = cast<TensorType>(self.getType());
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return op.emitError("only Tensor types supported in StableHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto resultTy = cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template cast<TensorType>();
|
op.getType()));
|
||||||
|
|
||||||
if (isa<mlir::FloatType>(resultTy.getElementType())) {
|
if (isa<mlir::FloatType>(resultTy.getElementType())) {
|
||||||
Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy);
|
Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy);
|
||||||
|
@ -231,9 +231,9 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto outType = dyn_cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template dyn_cast<TensorType>();
|
op.getType()));
|
||||||
|
|
||||||
if (!outType)
|
if (!outType)
|
||||||
return op.emitError("only Tensor types supported in StableHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
@ -321,9 +321,9 @@ public:
|
||||||
if (!lhsTy || !rhsTy)
|
if (!lhsTy || !rhsTy)
|
||||||
return op.emitError("only Tensor types supported");
|
return op.emitError("only Tensor types supported");
|
||||||
|
|
||||||
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto outTy = cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template cast<TensorType>();
|
op.getType()));
|
||||||
|
|
||||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
|
||||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
|
||||||
|
@ -354,9 +354,9 @@ public:
|
||||||
if (!lhsType)
|
if (!lhsType)
|
||||||
return op.emitError("only Tensor types supported in StableHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
|
||||||
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
TensorType outType = cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template cast<TensorType>();
|
op.getType()));
|
||||||
|
|
||||||
Type outElemTy = outType.getElementType();
|
Type outElemTy = outType.getElementType();
|
||||||
if (!outElemTy.isIntOrFloat()) {
|
if (!outElemTy.isIntOrFloat()) {
|
||||||
|
@ -607,9 +607,9 @@ public:
|
||||||
if (!lhsTy)
|
if (!lhsTy)
|
||||||
return op.emitError("lhs must be a ranked tensor type");
|
return op.emitError("lhs must be a ranked tensor type");
|
||||||
|
|
||||||
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
TensorType outType = cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template cast<TensorType>();
|
op.getType()));
|
||||||
Type outElemTy = outType.getElementType();
|
Type outElemTy = outType.getElementType();
|
||||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
||||||
if (!rhsTy) {
|
if (!rhsTy) {
|
||||||
|
@ -917,9 +917,9 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
||||||
if (!lhsType)
|
if (!lhsType)
|
||||||
return op.emitError("only Tensor types supported in StableHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
|
||||||
auto outType = OpConversionPattern<AtenPowTensorScalarOp>::getTypeConverter()
|
auto outType = cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenPowTensorScalarOp>::getTypeConverter()
|
||||||
.template cast<TensorType>();
|
->convertType(op.getType()));
|
||||||
|
|
||||||
Type outElemTy = outType.getElementType();
|
Type outElemTy = outType.getElementType();
|
||||||
if (!outElemTy.isIntOrFloat()) {
|
if (!outElemTy.isIntOrFloat()) {
|
||||||
|
@ -1421,9 +1421,9 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
|
|
||||||
// Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp.
|
// Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp.
|
||||||
SmallVector<APFloat> zeroConstVec(
|
SmallVector<APFloat> zeroConstVec(
|
||||||
numFeatureDimSize, APFloat::getZero(inputTy.getElementType()
|
numFeatureDimSize,
|
||||||
.cast<mlir::FloatType>()
|
APFloat::getZero(
|
||||||
.getFloatSemantics()));
|
cast<mlir::FloatType>(inputTy.getElementType()).getFloatSemantics()));
|
||||||
SmallVector<APFloat> oneConstVec(
|
SmallVector<APFloat> oneConstVec(
|
||||||
numFeatureDimSize,
|
numFeatureDimSize,
|
||||||
APFloat(
|
APFloat(
|
||||||
|
@ -1633,9 +1633,8 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
|
|
||||||
// Get element type of resultType as dtype
|
// Get element type of resultType as dtype
|
||||||
auto outType = this->getTypeConverter()
|
auto outType = cast<RankedTensorType>(
|
||||||
->convertType(op.getType())
|
this->getTypeConverter()->convertType(op.getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
auto dtype = outType.getElementType();
|
auto dtype = outType.getElementType();
|
||||||
if (!isa<mlir::IntegerType>(dtype) && !isa<mlir::FloatType>(dtype)) {
|
if (!isa<mlir::IntegerType>(dtype) && !isa<mlir::FloatType>(dtype)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -1678,7 +1677,7 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
|
||||||
AtenConstantPadNdOp op, OpAdaptor adaptor,
|
AtenConstantPadNdOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||||
auto selfElemTy = selfTy.getElementType();
|
auto selfElemTy = selfTy.getElementType();
|
||||||
int64_t rank = selfTy.getRank();
|
int64_t rank = selfTy.getRank();
|
||||||
|
|
||||||
|
@ -2029,7 +2028,7 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
||||||
|
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
|
|
||||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||||
if (!selfTy.hasStaticShape()) {
|
if (!selfTy.hasStaticShape()) {
|
||||||
return op->emitError("dynamic shaped input is not supported");
|
return op->emitError("dynamic shaped input is not supported");
|
||||||
}
|
}
|
||||||
|
@ -2062,7 +2061,7 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
||||||
cmpTypeAttr);
|
cmpTypeAttr);
|
||||||
|
|
||||||
auto resTy =
|
auto resTy =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
|
|
||||||
auto bcastTy = resTy.clone(rewriter.getI1Type());
|
auto bcastTy = resTy.clone(rewriter.getI1Type());
|
||||||
auto bcastAttr = rewriter.getDenseI64ArrayAttr({selfRank - 2, selfRank - 1});
|
auto bcastAttr = rewriter.getDenseI64ArrayAttr({selfRank - 2, selfRank - 1});
|
||||||
|
@ -2071,15 +2070,15 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
||||||
|
|
||||||
auto resElemTy = resTy.getElementType();
|
auto resElemTy = resTy.getElementType();
|
||||||
Value zeroTensor;
|
Value zeroTensor;
|
||||||
if (resElemTy.isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(resElemTy)) {
|
||||||
auto constAttr = SplatElementsAttr::get(
|
auto constAttr = SplatElementsAttr::get(
|
||||||
resTy, llvm::APFloat::getZero(
|
resTy, llvm::APFloat::getZero(
|
||||||
resElemTy.cast<FloatType>().getFloatSemantics(), false));
|
cast<FloatType>(resElemTy).getFloatSemantics(), false));
|
||||||
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
|
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
|
||||||
} else if (resElemTy.isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(resElemTy)) {
|
||||||
auto constAttr = SplatElementsAttr::get(
|
auto constAttr = SplatElementsAttr::get(
|
||||||
resTy,
|
resTy,
|
||||||
llvm::APInt::getZero(resElemTy.cast<mlir::IntegerType>().getWidth()));
|
llvm::APInt::getZero(cast<mlir::IntegerType>(resElemTy).getWidth()));
|
||||||
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
|
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
|
||||||
} else {
|
} else {
|
||||||
return op.emitError("element type is not float or integer");
|
return op.emitError("element type is not float or integer");
|
||||||
|
|
|
@ -157,8 +157,8 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||||
Value builtinTypeStart = adaptor.getStart();
|
Value builtinTypeStart = adaptor.getStart();
|
||||||
Value builtinTypeEnd = adaptor.getEnd();
|
Value builtinTypeEnd = adaptor.getEnd();
|
||||||
|
|
||||||
if (torchTypeStart.getType().isa<OptionalType>() ||
|
if (isa<OptionalType>(torchTypeStart.getType()) ||
|
||||||
torchTypeEnd.getType().isa<OptionalType>())
|
isa<OptionalType>(torchTypeEnd.getType()))
|
||||||
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
|
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
|
||||||
|
|
||||||
int64_t step;
|
int64_t step;
|
||||||
|
@ -349,11 +349,11 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "offsets must be a vector with static shape equal to 1");
|
op, "offsets must be a vector with static shape equal to 1");
|
||||||
|
|
||||||
if (!op.getPaddingIdx().getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(op.getPaddingIdx().getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Unimplemented: padding_idx should be none");
|
op, "Unimplemented: padding_idx should be none");
|
||||||
|
|
||||||
if (!op.getPerSampleWeights().getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(op.getPerSampleWeights().getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Unimplemented: per_sample_weights should be none");
|
op, "Unimplemented: per_sample_weights should be none");
|
||||||
|
|
||||||
|
@ -453,25 +453,22 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
|
||||||
loc, getTypeConverter()->convertType(op.getType(0)),
|
loc, getTypeConverter()->convertType(op.getType(0)),
|
||||||
stablehloReduceOp.getResult(0), outShapeTensor);
|
stablehloReduceOp.getResult(0), outShapeTensor);
|
||||||
|
|
||||||
RankedTensorType resultType = getTypeConverter()
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(1).getType())
|
getTypeConverter()->convertType(op->getResult(1).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
Value resultB =
|
Value resultB =
|
||||||
createInitialValueForGatherScatterOp(op, resultType, rewriter);
|
createInitialValueForGatherScatterOp(op, resultType, rewriter);
|
||||||
if (!resultB)
|
if (!resultB)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
resultType = getTypeConverter()
|
resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(2).getType())
|
getTypeConverter()->convertType(op->getResult(2).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
Value resultC =
|
Value resultC =
|
||||||
createInitialValueForGatherScatterOp(op, resultType, rewriter);
|
createInitialValueForGatherScatterOp(op, resultType, rewriter);
|
||||||
if (!resultC)
|
if (!resultC)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
resultType = getTypeConverter()
|
resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(3).getType())
|
getTypeConverter()->convertType(op->getResult(3).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
Value resultD =
|
Value resultD =
|
||||||
createInitialValueForGatherScatterOp(op, resultType, rewriter);
|
createInitialValueForGatherScatterOp(op, resultType, rewriter);
|
||||||
if (!resultD)
|
if (!resultD)
|
||||||
|
@ -612,9 +609,8 @@ LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
|
||||||
|
|
||||||
auto input = adaptor.getSelf();
|
auto input = adaptor.getSelf();
|
||||||
|
|
||||||
RankedTensorType resultType =
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
typeConverter->convertType(op->getResult(0).getType())
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
|
|
||||||
SmallVector<Value> resultShape;
|
SmallVector<Value> resultShape;
|
||||||
SmallVector<Value> offsets;
|
SmallVector<Value> offsets;
|
||||||
|
|
|
@ -350,9 +350,9 @@ public:
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||||
op,
|
op,
|
||||||
ConvertAtenOp<AtenOpT>::getTypeConverter()
|
cast<RankedTensorType>(
|
||||||
->convertType(op.getType())
|
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template cast<RankedTensorType>(),
|
op.getType())),
|
||||||
output);
|
output);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -730,9 +730,8 @@ public:
|
||||||
// If transposed is set to true,
|
// If transposed is set to true,
|
||||||
// the weight shape changes to [IC, (OC//G), KH, KW]
|
// the weight shape changes to [IC, (OC//G), KH, KW]
|
||||||
auto weightTy = cast<RankedTensorType>(weight.getType());
|
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||||
auto outTy = getTypeConverter()
|
auto outTy =
|
||||||
->convertType(op.getType())
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
.template cast<RankedTensorType>();
|
|
||||||
if (!inputTy || !weightTy || !outTy) {
|
if (!inputTy || !weightTy || !outTy) {
|
||||||
return op.emitError("input, weight and output must be ranked tensors");
|
return op.emitError("input, weight and output must be ranked tensors");
|
||||||
}
|
}
|
||||||
|
|
|
@ -216,10 +216,10 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||||
auto *secondIdxArg = std::next(secondValArg);
|
auto *secondIdxArg = std::next(secondValArg);
|
||||||
|
|
||||||
stablehlo::ComparisonTypeAttr compareTypeAttr;
|
stablehlo::ComparisonTypeAttr compareTypeAttr;
|
||||||
if (inputTy.getElementType().isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(inputTy.getElementType())) {
|
||||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||||
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
|
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
|
||||||
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(inputTy.getElementType())) {
|
||||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||||
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
||||||
}
|
}
|
||||||
|
@ -395,9 +395,8 @@ public:
|
||||||
RankedTensorType inputTy = cast<RankedTensorType>(input.getType());
|
RankedTensorType inputTy = cast<RankedTensorType>(input.getType());
|
||||||
Type inputElemTy = inputTy.getElementType();
|
Type inputElemTy = inputTy.getElementType();
|
||||||
int64_t inputRank = inputTy.getRank();
|
int64_t inputRank = inputTy.getRank();
|
||||||
RankedTensorType outTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
|
RankedTensorType outTy = cast<RankedTensorType>(
|
||||||
->convertType(op.getType())
|
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType()));
|
||||||
.template cast<RankedTensorType>();
|
|
||||||
auto outShape = outTy.getShape();
|
auto outShape = outTy.getShape();
|
||||||
|
|
||||||
if (inputRank <= Dim) {
|
if (inputRank <= Dim) {
|
||||||
|
|
|
@ -242,10 +242,10 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
||||||
auto *secondIdxArg = std::next(secondValArg);
|
auto *secondIdxArg = std::next(secondValArg);
|
||||||
|
|
||||||
stablehlo::ComparisonTypeAttr compareTypeAttr;
|
stablehlo::ComparisonTypeAttr compareTypeAttr;
|
||||||
if (inputTy.getElementType().isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(inputTy.getElementType())) {
|
||||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||||
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
|
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
|
||||||
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
|
} else if (isa<mlir::IntegerType>(inputTy.getElementType())) {
|
||||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||||
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
||||||
}
|
}
|
||||||
|
@ -535,12 +535,10 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
||||||
"AtenMaxDimOp to StableHLO");
|
"AtenMaxDimOp to StableHLO");
|
||||||
}
|
}
|
||||||
|
|
||||||
RankedTensorType valResultType = getTypeConverter()
|
RankedTensorType valResultType = cast<RankedTensorType>(
|
||||||
->convertType(op.getResult(0).getType())
|
getTypeConverter()->convertType(op.getResult(0).getType()));
|
||||||
.template cast<RankedTensorType>();
|
RankedTensorType idxResultType = cast<RankedTensorType>(
|
||||||
RankedTensorType idxResultType = getTypeConverter()
|
getTypeConverter()->convertType(op.getResult(1).getType()));
|
||||||
->convertType(op.getResult(1).getType())
|
|
||||||
.template cast<RankedTensorType>();
|
|
||||||
Type idxElementType = idxResultType.getElementType();
|
Type idxElementType = idxResultType.getElementType();
|
||||||
if (!isa<mlir::IntegerType>(idxElementType)) {
|
if (!isa<mlir::IntegerType>(idxElementType)) {
|
||||||
return op.emitError("Aten.max.dim needs integer-like result");
|
return op.emitError("Aten.max.dim needs integer-like result");
|
||||||
|
@ -636,9 +634,8 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||||
auto outTy = getTypeConverter()
|
auto outTy =
|
||||||
->convertType(op.getType())
|
dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
.template dyn_cast<RankedTensorType>();
|
|
||||||
if (!inputTy) {
|
if (!inputTy) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only Tensor types supported in StableHLO");
|
op, "only Tensor types supported in StableHLO");
|
||||||
|
|
|
@ -271,7 +271,7 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||||
|
|
||||||
auto getOptionalVal = [&](Value val) -> std::optional<Value> {
|
auto getOptionalVal = [&](Value val) -> std::optional<Value> {
|
||||||
if (val.getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(val.getType())) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
} else {
|
} else {
|
||||||
return val;
|
return val;
|
||||||
|
@ -451,7 +451,7 @@ template <>
|
||||||
LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite(
|
||||||
PrimsSplitDimOp op, OpAdaptor adaptor,
|
PrimsSplitDimOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getA().getType());
|
||||||
if (!selfType) {
|
if (!selfType) {
|
||||||
return op.emitError("only tensor types are currently supported");
|
return op.emitError("only tensor types are currently supported");
|
||||||
}
|
}
|
||||||
|
|
|
@ -292,7 +292,7 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc,
|
||||||
arith::CmpIPredicate predicate = isDescending ? ge : le;
|
arith::CmpIPredicate predicate = isDescending ? ge : le;
|
||||||
compareOp = rewriter.create<arith::CmpIOp>(
|
compareOp = rewriter.create<arith::CmpIOp>(
|
||||||
loc, predicate, block->getArgument(0), block->getArgument(1));
|
loc, predicate, block->getArgument(0), block->getArgument(1));
|
||||||
} else if (elementTypes[0].isa<mlir::FloatType>()) {
|
} else if (isa<mlir::FloatType>(elementTypes[0])) {
|
||||||
// Case for using arith::CmpFOp.
|
// Case for using arith::CmpFOp.
|
||||||
arith::CmpFPredicate predicate =
|
arith::CmpFPredicate predicate =
|
||||||
isDescending ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OLE;
|
isDescending ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OLE;
|
||||||
|
@ -349,8 +349,8 @@ public:
|
||||||
b.create<TMTensor::YieldOp>(loc, updatesElement);
|
b.create<TMTensor::YieldOp>(loc, updatesElement);
|
||||||
});
|
});
|
||||||
|
|
||||||
auto resultType = typeConverter->convertType(op->getResult(0).getType())
|
auto resultType = cast<RankedTensorType>(
|
||||||
.cast<RankedTensorType>();
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -381,7 +381,7 @@ public:
|
||||||
// Check whether the input is a 1-d tensor of integer type or not.
|
// Check whether the input is a 1-d tensor of integer type or not.
|
||||||
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||||
if (inputType.getRank() != 1 ||
|
if (inputType.getRank() != 1 ||
|
||||||
!inputType.getElementType().isa<mlir::IntegerType>())
|
!isa<mlir::IntegerType>(inputType.getElementType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op,
|
op,
|
||||||
"Input tensor has to be a one-dimensional tensor of integer type.");
|
"Input tensor has to be a one-dimensional tensor of integer type.");
|
||||||
|
@ -395,7 +395,7 @@ public:
|
||||||
"Unimplemented: Integer width not equal to 64 are not supported.");
|
"Unimplemented: Integer width not equal to 64 are not supported.");
|
||||||
|
|
||||||
// TODO: Incorporate the weight argument.
|
// TODO: Incorporate the weight argument.
|
||||||
if (!weights.getType().isa<mlir::torch::Torch::NoneType>())
|
if (!isa<mlir::torch::Torch::NoneType>(weights.getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Unimplemented: the weights operand is not incorporated.");
|
op, "Unimplemented: the weights operand is not incorporated.");
|
||||||
|
|
||||||
|
@ -439,8 +439,8 @@ public:
|
||||||
indices = typeConverter->materializeTargetConversion(
|
indices = typeConverter->materializeTargetConversion(
|
||||||
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
|
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
|
||||||
|
|
||||||
auto resultType = typeConverter->convertType(op->getResult(0).getType())
|
auto resultType = cast<RankedTensorType>(
|
||||||
.cast<RankedTensorType>();
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
Type resultElemType = resultType.getElementType();
|
Type resultElemType = resultType.getElementType();
|
||||||
|
|
||||||
SmallVector<Value, 1> inputSizeDynamic =
|
SmallVector<Value, 1> inputSizeDynamic =
|
||||||
|
@ -686,8 +686,8 @@ public:
|
||||||
auto valuesType = cast<ValueTensorType>(values.getType());
|
auto valuesType = cast<ValueTensorType>(values.getType());
|
||||||
int64_t inputRank = inputType.getSizes().size();
|
int64_t inputRank = inputType.getSizes().size();
|
||||||
auto valuesTensorType = cast<BaseTensorType>(op.getValues().getType());
|
auto valuesTensorType = cast<BaseTensorType>(op.getValues().getType());
|
||||||
auto resultType = typeConverter->convertType(op->getResult(0).getType())
|
auto resultType = cast<RankedTensorType>(
|
||||||
.cast<RankedTensorType>();
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
|
|
||||||
if (!valuesTensorType.hasSizes())
|
if (!valuesTensorType.hasSizes())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -823,10 +823,10 @@ public:
|
||||||
Value inputElement) {
|
Value inputElement) {
|
||||||
Value yieldValue = valuesElement;
|
Value yieldValue = valuesElement;
|
||||||
if (accumulate) {
|
if (accumulate) {
|
||||||
if (inputElement.getType().isa<mlir::IntegerType>()) {
|
if (isa<mlir::IntegerType>(inputElement.getType())) {
|
||||||
yieldValue =
|
yieldValue =
|
||||||
b.create<arith::AddIOp>(loc, inputElement, valuesElement);
|
b.create<arith::AddIOp>(loc, inputElement, valuesElement);
|
||||||
} else if (inputElement.getType().isa<mlir::FloatType>()) {
|
} else if (isa<mlir::FloatType>(inputElement.getType())) {
|
||||||
yieldValue =
|
yieldValue =
|
||||||
b.create<arith::AddFOp>(loc, inputElement, valuesElement);
|
b.create<arith::AddFOp>(loc, inputElement, valuesElement);
|
||||||
} else {
|
} else {
|
||||||
|
@ -1042,10 +1042,10 @@ public:
|
||||||
[&](OpBuilder &b, Location loc, Value valuesElement,
|
[&](OpBuilder &b, Location loc, Value valuesElement,
|
||||||
Value inputElement) {
|
Value inputElement) {
|
||||||
Value yieldValue = valuesElement;
|
Value yieldValue = valuesElement;
|
||||||
if (inputElement.getType().isa<mlir::IntegerType>()) {
|
if (isa<mlir::IntegerType>(inputElement.getType())) {
|
||||||
yieldValue =
|
yieldValue =
|
||||||
b.create<arith::AddIOp>(loc, inputElement, valuesElement);
|
b.create<arith::AddIOp>(loc, inputElement, valuesElement);
|
||||||
} else if (inputElement.getType().isa<mlir::FloatType>()) {
|
} else if (isa<mlir::FloatType>(inputElement.getType())) {
|
||||||
yieldValue =
|
yieldValue =
|
||||||
b.create<arith::AddFOp>(loc, inputElement, valuesElement);
|
b.create<arith::AddFOp>(loc, inputElement, valuesElement);
|
||||||
} else {
|
} else {
|
||||||
|
@ -1204,33 +1204,33 @@ public:
|
||||||
Value result;
|
Value result;
|
||||||
if (reduceEnum == torch_upstream::ReductionType::SUM ||
|
if (reduceEnum == torch_upstream::ReductionType::SUM ||
|
||||||
reduceEnum == torch_upstream::ReductionType::MEAN) {
|
reduceEnum == torch_upstream::ReductionType::MEAN) {
|
||||||
if (update.getType().isa<mlir::IntegerType>()) {
|
if (isa<mlir::IntegerType>(update.getType())) {
|
||||||
result = b.create<arith::AddIOp>(loc, update, current);
|
result = b.create<arith::AddIOp>(loc, update, current);
|
||||||
} else if (update.getType().isa<mlir::FloatType>()) {
|
} else if (isa<mlir::FloatType>(update.getType())) {
|
||||||
result = b.create<arith::AddFOp>(loc, update, current);
|
result = b.create<arith::AddFOp>(loc, update, current);
|
||||||
} else {
|
} else {
|
||||||
llvm_unreachable("Only integer/float types supported!");
|
llvm_unreachable("Only integer/float types supported!");
|
||||||
}
|
}
|
||||||
} else if (reduceEnum == torch_upstream::ReductionType::PROD) {
|
} else if (reduceEnum == torch_upstream::ReductionType::PROD) {
|
||||||
if (update.getType().isa<mlir::IntegerType>()) {
|
if (isa<mlir::IntegerType>(update.getType())) {
|
||||||
result = b.create<arith::MulIOp>(loc, update, current);
|
result = b.create<arith::MulIOp>(loc, update, current);
|
||||||
} else if (update.getType().isa<mlir::FloatType>()) {
|
} else if (isa<mlir::FloatType>(update.getType())) {
|
||||||
result = b.create<arith::MulFOp>(loc, update, current);
|
result = b.create<arith::MulFOp>(loc, update, current);
|
||||||
} else {
|
} else {
|
||||||
llvm_unreachable("Only integer/float types supported!");
|
llvm_unreachable("Only integer/float types supported!");
|
||||||
}
|
}
|
||||||
} else if (reduceEnum == torch_upstream::ReductionType::MAX) {
|
} else if (reduceEnum == torch_upstream::ReductionType::MAX) {
|
||||||
if (update.getType().isa<mlir::IntegerType>()) {
|
if (isa<mlir::IntegerType>(update.getType())) {
|
||||||
result = b.create<arith::MaxSIOp>(loc, update, current);
|
result = b.create<arith::MaxSIOp>(loc, update, current);
|
||||||
} else if (update.getType().isa<mlir::FloatType>()) {
|
} else if (isa<mlir::FloatType>(update.getType())) {
|
||||||
result = b.create<arith::MaximumFOp>(loc, update, current);
|
result = b.create<arith::MaximumFOp>(loc, update, current);
|
||||||
} else {
|
} else {
|
||||||
llvm_unreachable("Only integer/float types supported!");
|
llvm_unreachable("Only integer/float types supported!");
|
||||||
}
|
}
|
||||||
} else if (reduceEnum == torch_upstream::ReductionType::MIN) {
|
} else if (reduceEnum == torch_upstream::ReductionType::MIN) {
|
||||||
if (update.getType().isa<mlir::IntegerType>()) {
|
if (isa<mlir::IntegerType>(update.getType())) {
|
||||||
result = b.create<arith::MinSIOp>(loc, update, current);
|
result = b.create<arith::MinSIOp>(loc, update, current);
|
||||||
} else if (update.getType().isa<mlir::FloatType>()) {
|
} else if (isa<mlir::FloatType>(update.getType())) {
|
||||||
result = b.create<arith::MinimumFOp>(loc, update, current);
|
result = b.create<arith::MinimumFOp>(loc, update, current);
|
||||||
} else {
|
} else {
|
||||||
llvm_unreachable("Only integer/float types supported!");
|
llvm_unreachable("Only integer/float types supported!");
|
||||||
|
@ -1285,9 +1285,8 @@ public:
|
||||||
})
|
})
|
||||||
.getResult()[0];
|
.getResult()[0];
|
||||||
}
|
}
|
||||||
auto resultType = getTypeConverter()
|
auto resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -1392,9 +1391,8 @@ public:
|
||||||
|
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto resultType = getTypeConverter()
|
auto resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
Type elementType = resultType.getElementType();
|
Type elementType = resultType.getElementType();
|
||||||
Type inputElementType =
|
Type inputElementType =
|
||||||
cast<RankedTensorType>(input.getType()).getElementType();
|
cast<RankedTensorType>(input.getType()).getElementType();
|
||||||
|
@ -1414,7 +1412,7 @@ public:
|
||||||
|
|
||||||
int64_t inputRank = resultType.getRank();
|
int64_t inputRank = resultType.getRank();
|
||||||
Value dtype = op.getDtype();
|
Value dtype = op.getDtype();
|
||||||
if (!dtype.getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(dtype.getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unsupported: dtype argument not supported");
|
op, "unsupported: dtype argument not supported");
|
||||||
|
|
||||||
|
@ -1444,7 +1442,7 @@ public:
|
||||||
rewriter, loc, input, output, acc, dim, /*inclusive=*/true,
|
rewriter, loc, input, output, acc, dim, /*inclusive=*/true,
|
||||||
[](OpBuilder &b, Location loc, Value input, Value acc) {
|
[](OpBuilder &b, Location loc, Value input, Value acc) {
|
||||||
Value sum =
|
Value sum =
|
||||||
(input.getType().isa<mlir::FloatType>()
|
(isa<mlir::FloatType>(input.getType())
|
||||||
? b.create<arith::AddFOp>(loc, input, acc)->getResult(0)
|
? b.create<arith::AddFOp>(loc, input, acc)->getResult(0)
|
||||||
: b.create<arith::AddIOp>(loc, input, acc)->getResult(0));
|
: b.create<arith::AddIOp>(loc, input, acc)->getResult(0));
|
||||||
b.create<TMTensor::YieldOp>(loc, sum);
|
b.create<TMTensor::YieldOp>(loc, sum);
|
||||||
|
@ -1472,7 +1470,7 @@ public:
|
||||||
cast<ShapedType>(adaptor.getQuery().getType()).getElementType();
|
cast<ShapedType>(adaptor.getQuery().getType()).getElementType();
|
||||||
|
|
||||||
// Verify inputs (only support defaults)
|
// Verify inputs (only support defaults)
|
||||||
if (!mask.getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(mask.getType()))
|
||||||
return rewriter.notifyMatchFailure(op.getLoc(),
|
return rewriter.notifyMatchFailure(op.getLoc(),
|
||||||
"attention masking not supported");
|
"attention masking not supported");
|
||||||
double dropout;
|
double dropout;
|
||||||
|
@ -1483,7 +1481,7 @@ public:
|
||||||
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal)
|
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op.getLoc(), "causal attention masking not supported");
|
op.getLoc(), "causal attention masking not supported");
|
||||||
if (!scale.getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(scale.getType())) {
|
||||||
double scaleFloat;
|
double scaleFloat;
|
||||||
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
|
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
|
||||||
scaleFloat != 1.0)
|
scaleFloat != 1.0)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
//
|
////
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
@ -47,7 +47,7 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Only Tensor types supported in TOSA");
|
"Only Tensor types supported in TOSA");
|
||||||
|
|
||||||
if (selfTy.getElementType().isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(selfTy.getElementType())) {
|
||||||
rewriter.replaceOpWithNewOp<TosaOpT>(
|
rewriter.replaceOpWithNewOp<TosaOpT>(
|
||||||
op,
|
op,
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
|
@ -99,9 +99,9 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Only Tensor types supported in TOSA");
|
"Only Tensor types supported in TOSA");
|
||||||
|
|
||||||
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto outTy = cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template cast<TensorType>();
|
op.getType()));
|
||||||
|
|
||||||
auto binaryOp =
|
auto binaryOp =
|
||||||
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
|
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
|
||||||
|
@ -248,9 +248,9 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get output type: tensor<i32/i64/f32>
|
// Get output type: tensor<i32/i64/f32>
|
||||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto outType = cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template cast<TensorType>();
|
op.getType()));
|
||||||
|
|
||||||
Type outElemTy = outType.getElementType();
|
Type outElemTy = outType.getElementType();
|
||||||
if (!outElemTy.isIntOrFloat()) {
|
if (!outElemTy.isIntOrFloat()) {
|
||||||
|
@ -373,9 +373,9 @@ public:
|
||||||
std::is_same<AtenOpT, AtenLtScalarOp>());
|
std::is_same<AtenOpT, AtenLtScalarOp>());
|
||||||
|
|
||||||
// Promote lhs and rhs dtypes for bitwise operators.
|
// Promote lhs and rhs dtypes for bitwise operators.
|
||||||
TensorType resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
TensorType resultTy = cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template cast<TensorType>();
|
op.getType()));
|
||||||
if (isBitwiseOp) {
|
if (isBitwiseOp) {
|
||||||
lhs = tosa::promoteType(rewriter, lhs, resultTy);
|
lhs = tosa::promoteType(rewriter, lhs, resultTy);
|
||||||
rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy);
|
rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy);
|
||||||
|
@ -416,9 +416,9 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Only Tensor types supported in TOSA");
|
"Only Tensor types supported in TOSA");
|
||||||
|
|
||||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto outType = cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template cast<TensorType>();
|
op.getType()));
|
||||||
|
|
||||||
Type outElemTy = outType.getElementType();
|
Type outElemTy = outType.getElementType();
|
||||||
if (!outElemTy.isIntOrFloat())
|
if (!outElemTy.isIntOrFloat())
|
||||||
|
@ -444,9 +444,9 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<mlir::FloatType>(outElemTy) || isa<mlir::IntegerType>(outElemTy)) {
|
if (isa<mlir::FloatType>(outElemTy) || isa<mlir::IntegerType>(outElemTy)) {
|
||||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto outType = cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template cast<TensorType>();
|
op.getType()));
|
||||||
|
|
||||||
auto mulOp = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
|
auto mulOp = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
|
||||||
rhsTensor, /*shift=*/0);
|
rhsTensor, /*shift=*/0);
|
||||||
|
@ -492,9 +492,9 @@ public:
|
||||||
"conversion in TOSA operation");
|
"conversion in TOSA operation");
|
||||||
}
|
}
|
||||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto outType = cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template cast<TensorType>();
|
op.getType()));
|
||||||
|
|
||||||
// auto result;
|
// auto result;
|
||||||
Value result;
|
Value result;
|
||||||
|
@ -540,7 +540,7 @@ LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = cast<TensorType>(self.getType());
|
auto selfTy = cast<TensorType>(self.getType());
|
||||||
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
|
if (selfTy && isa<mlir::FloatType>(selfTy.getElementType())) {
|
||||||
rewriter.replaceOpWithNewOp<tosa::TanhOp>(
|
rewriter.replaceOpWithNewOp<tosa::TanhOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), self);
|
op, getTypeConverter()->convertType(op.getType()), self);
|
||||||
return success();
|
return success();
|
||||||
|
@ -557,7 +557,7 @@ LogicalResult ConvertAtenOp<AtenSigmoidOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = cast<TensorType>(self.getType());
|
auto selfTy = cast<TensorType>(self.getType());
|
||||||
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
|
if (selfTy && isa<mlir::FloatType>(selfTy.getElementType())) {
|
||||||
rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(
|
rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), self);
|
op, getTypeConverter()->convertType(op.getType()), self);
|
||||||
return success();
|
return success();
|
||||||
|
@ -584,7 +584,7 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rescale the clampIn for quantized types. TBD
|
// Rescale the clampIn for quantized types. TBD
|
||||||
if (!selfTy.getElementType().isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(selfTy.getElementType())) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only floating-point datatype legalization currently supported");
|
op, "Only floating-point datatype legalization currently supported");
|
||||||
}
|
}
|
||||||
|
@ -604,7 +604,7 @@ LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
|
||||||
|
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = cast<TensorType>(self.getType());
|
auto selfTy = cast<TensorType>(self.getType());
|
||||||
if (!selfTy.getElementType().isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(selfTy.getElementType())) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only floating-point datatype legalization currently supported");
|
op, "Only floating-point datatype legalization currently supported");
|
||||||
}
|
}
|
||||||
|
@ -667,9 +667,9 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Only Tensor types supported in TOSA");
|
"Only Tensor types supported in TOSA");
|
||||||
|
|
||||||
auto outputTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto outputTy = cast<RankedTensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template cast<RankedTensorType>();
|
op.getType()));
|
||||||
if (!outputTy)
|
if (!outputTy)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only ranked tensor type outputs permitted for reduce_mean");
|
op, "Only ranked tensor type outputs permitted for reduce_mean");
|
||||||
|
@ -828,9 +828,8 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "non-const keepdim parameter unsupported");
|
op, "non-const keepdim parameter unsupported");
|
||||||
|
|
||||||
auto resultTy = getTypeConverter()
|
auto resultTy = cast<RankedTensorType>(
|
||||||
->convertType(op.getResult().getType())
|
getTypeConverter()->convertType(op.getResult().getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
auto outputETy = resultTy.getElementType();
|
auto outputETy = resultTy.getElementType();
|
||||||
|
|
||||||
// Create a single instance of tosa.argmax.
|
// Create a single instance of tosa.argmax.
|
||||||
|
@ -927,9 +926,9 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Squeeze could not compute new shape");
|
"Squeeze could not compute new shape");
|
||||||
|
|
||||||
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto resultTy = cast<RankedTensorType>(
|
||||||
->convertType(op.getResult().getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template cast<RankedTensorType>();
|
op.getResult().getType()));
|
||||||
auto resultElemTy = resultTy.getElementType();
|
auto resultElemTy = resultTy.getElementType();
|
||||||
|
|
||||||
auto newOutputTy = RankedTensorType::get(
|
auto newOutputTy = RankedTensorType::get(
|
||||||
|
@ -1017,7 +1016,7 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only ranked tensor types supported in TOSA Pow");
|
op, "Only ranked tensor types supported in TOSA Pow");
|
||||||
|
|
||||||
if (!selfTy.getElementType().isa<mlir::FloatType>())
|
if (!isa<mlir::FloatType>(selfTy.getElementType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only floating-point datatype legalization supported");
|
op, "Only floating-point datatype legalization supported");
|
||||||
|
|
||||||
|
@ -1624,9 +1623,9 @@ public:
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||||
op,
|
op,
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()
|
cast<RankedTensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template cast<RankedTensorType>(),
|
op.getType())),
|
||||||
output);
|
output);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -1800,9 +1799,9 @@ public:
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||||
op,
|
op,
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()
|
cast<RankedTensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template cast<RankedTensorType>(),
|
op.getType())),
|
||||||
matmulPlusBias);
|
matmulPlusBias);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -1823,7 +1822,7 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only ranked tensor types supported in TOSA Rsub");
|
op, "Only ranked tensor types supported in TOSA Rsub");
|
||||||
|
|
||||||
if (!selfTy.getElementType().isa<mlir::FloatType>())
|
if (!isa<mlir::FloatType>(selfTy.getElementType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only floating-point datatype legalization supported");
|
op, "Only floating-point datatype legalization supported");
|
||||||
|
|
||||||
|
@ -1869,9 +1868,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
||||||
|
|
||||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
auto weightTy = cast<RankedTensorType>(weight.getType());
|
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||||
auto outputTy = getTypeConverter()
|
auto outputTy =
|
||||||
->convertType(op.getType())
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
.template cast<RankedTensorType>();
|
|
||||||
|
|
||||||
if (!inputTy || !weightTy || !outputTy)
|
if (!inputTy || !weightTy || !outputTy)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -2208,7 +2206,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
// Note: cudnn_enabled is not handled.
|
// Note: cudnn_enabled is not handled.
|
||||||
|
|
||||||
// FIXME: Handle training and momentum.
|
// FIXME: Handle training and momentum.
|
||||||
if (op.getMomentum().getType().isa<Torch::NoneType>())
|
if (isa<Torch::NoneType>(op.getMomentum().getType()))
|
||||||
return rewriter.notifyMatchFailure(op, "Unsupported None for momentum");
|
return rewriter.notifyMatchFailure(op, "Unsupported None for momentum");
|
||||||
|
|
||||||
auto meanType = dyn_cast<TensorType>(adaptor.getRunningMean().getType());
|
auto meanType = dyn_cast<TensorType>(adaptor.getRunningMean().getType());
|
||||||
|
@ -2312,9 +2310,9 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
// Note: cudnn_enabled is not handled.
|
// Note: cudnn_enabled is not handled.
|
||||||
|
|
||||||
// FIXME: Handle the None cases for the optional parameters.
|
// FIXME: Handle the None cases for the optional parameters.
|
||||||
if (adaptor.getWeight().getType().isa<Torch::NoneType>())
|
if (isa<Torch::NoneType>(adaptor.getWeight().getType()))
|
||||||
return rewriter.notifyMatchFailure(op, "Unsupported None for weight");
|
return rewriter.notifyMatchFailure(op, "Unsupported None for weight");
|
||||||
if (adaptor.getBias().getType().isa<Torch::NoneType>())
|
if (isa<Torch::NoneType>(adaptor.getBias().getType()))
|
||||||
return rewriter.notifyMatchFailure(op, "Unsupported None for bias");
|
return rewriter.notifyMatchFailure(op, "Unsupported None for bias");
|
||||||
|
|
||||||
auto weightType = cast<RankedTensorType>(adaptor.getWeight().getType());
|
auto weightType = cast<RankedTensorType>(adaptor.getWeight().getType());
|
||||||
|
@ -2453,9 +2451,8 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
||||||
ValueTensorLiteralOp op, OpAdaptor adaptor,
|
ValueTensorLiteralOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
auto outputTy = getTypeConverter()
|
auto outputTy =
|
||||||
->convertType(op.getType())
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
.template cast<RankedTensorType>();
|
|
||||||
|
|
||||||
// Tensors with integer types need to be converted to signless integer
|
// Tensors with integer types need to be converted to signless integer
|
||||||
// element type. All tensors with element types other than integer can reuse
|
// element type. All tensors with element types other than integer can reuse
|
||||||
|
@ -3122,7 +3119,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
||||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
|
|
||||||
auto indicesType = dyn_cast<RankedTensorType>(indices.getType());
|
auto indicesType = dyn_cast<RankedTensorType>(indices.getType());
|
||||||
if (!indicesType || !indicesType.getElementType().isa<IntegerType>())
|
if (!indicesType || !isa<IntegerType>(indicesType.getElementType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Indices must be of integer tensor type");
|
op, "Indices must be of integer tensor type");
|
||||||
|
|
||||||
|
@ -3632,11 +3629,11 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
||||||
auto indexTorch = tensorsTorchType[i];
|
auto indexTorch = tensorsTorchType[i];
|
||||||
// TODO add support for none index other than i==0, like (index0, None)
|
// TODO add support for none index other than i==0, like (index0, None)
|
||||||
// (None, index1)
|
// (None, index1)
|
||||||
if (i == 0 && indexTorch.getType().isa<Torch::NoneType>()) {
|
if (i == 0 && isa<Torch::NoneType>(indexTorch.getType())) {
|
||||||
// convert None to [0,0,0]
|
// convert None to [0,0,0]
|
||||||
auto indexNext = indexTensors[i + 1];
|
auto indexNext = indexTensors[i + 1];
|
||||||
auto indexNextTorch = tensorsTorchType[i + 1];
|
auto indexNextTorch = tensorsTorchType[i + 1];
|
||||||
if (indexNextTorch.getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(indexNextTorch.getType())) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Multiple None index is not support for now.");
|
op, "Multiple None index is not support for now.");
|
||||||
}
|
}
|
||||||
|
@ -3963,8 +3960,8 @@ LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
|
||||||
if (!selfType.hasStaticShape() || !otherType.hasStaticShape())
|
if (!selfType.hasStaticShape() || !otherType.hasStaticShape())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types with static shape are supported");
|
op, "Only tensor types with static shape are supported");
|
||||||
if (!selfType.getElementType().isa<mlir::FloatType>() ||
|
if (!isa<mlir::FloatType>(selfType.getElementType()) ||
|
||||||
!otherType.getElementType().isa<mlir::FloatType>()) {
|
!isa<mlir::FloatType>(otherType.getElementType())) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: only FP element type is supported");
|
op, "unimplemented: only FP element type is supported");
|
||||||
}
|
}
|
||||||
|
@ -4058,9 +4055,8 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
const TypeConverter *typeConverter = this->getTypeConverter();
|
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||||
RankedTensorType resultType =
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
typeConverter->convertType(op->getResult(0).getType())
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
|
|
||||||
// At this point all tensors should have value semantics, and hence the
|
// At this point all tensors should have value semantics, and hence the
|
||||||
// `layout` check can be ignored.
|
// `layout` check can be ignored.
|
||||||
|
@ -4068,7 +4064,7 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
||||||
// TODO: Add support for pin_memory features.
|
// TODO: Add support for pin_memory features.
|
||||||
// The pin_memory should be either `False` or `none`.
|
// The pin_memory should be either `False` or `none`.
|
||||||
bool pinMemory;
|
bool pinMemory;
|
||||||
if (!op.getPinMemory().getType().isa<Torch::NoneType>() &&
|
if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
|
||||||
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
||||||
pinMemory)) {
|
pinMemory)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -4162,10 +4158,10 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto isIntType =
|
const auto isIntType =
|
||||||
resultType.getElementType().dyn_cast_or_null<mlir::IntegerType>();
|
dyn_cast_or_null<mlir::IntegerType>(resultType.getElementType());
|
||||||
|
|
||||||
const auto isDoubleType =
|
const auto isDoubleType =
|
||||||
resultType.getElementType().dyn_cast_or_null<mlir::FloatType>();
|
dyn_cast_or_null<mlir::FloatType>(resultType.getElementType());
|
||||||
|
|
||||||
auto maybeResult = [&]() -> std::optional<Value> {
|
auto maybeResult = [&]() -> std::optional<Value> {
|
||||||
// Integer output type, and start / end / range are all integers.
|
// Integer output type, and start / end / range are all integers.
|
||||||
|
@ -4218,9 +4214,8 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
const TypeConverter *typeConverter = this->getTypeConverter();
|
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||||
RankedTensorType resultType =
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
typeConverter->convertType(op->getResult(0).getType())
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
|
|
||||||
// Only supports integer operand type, because for the floating point operand
|
// Only supports integer operand type, because for the floating point operand
|
||||||
// type result tensor has to be of type `f64` which is not supported in the
|
// type result tensor has to be of type `f64` which is not supported in the
|
||||||
|
@ -4323,7 +4318,7 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
||||||
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType())) {
|
||||||
int64_t memoryFormat;
|
int64_t memoryFormat;
|
||||||
if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)))
|
if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -4336,9 +4331,8 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
||||||
"memory_format is supported");
|
"memory_format is supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto resultTy = getTypeConverter()
|
auto resultTy = cast<RankedTensorType>(
|
||||||
->convertType(op.getResult().getType())
|
getTypeConverter()->convertType(op.getResult().getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
|
|
||||||
Value result;
|
Value result;
|
||||||
if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.getSelf(),
|
if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.getSelf(),
|
||||||
|
@ -4779,9 +4773,9 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto outType = dyn_cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template dyn_cast<TensorType>();
|
op.getType()));
|
||||||
|
|
||||||
if (!outType)
|
if (!outType)
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
@ -4841,9 +4835,9 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto outType = dyn_cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template dyn_cast<TensorType>();
|
op.getType()));
|
||||||
|
|
||||||
if (!outType || !outType.hasStaticShape())
|
if (!outType || !outType.hasStaticShape())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -4875,9 +4869,9 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto outType = dyn_cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template dyn_cast<TensorType>();
|
op.getType()));
|
||||||
|
|
||||||
if (!outType || !outType.hasStaticShape())
|
if (!outType || !outType.hasStaticShape())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -4947,9 +4941,9 @@ public:
|
||||||
"unimplemented: only contiguous and channels last memory "
|
"unimplemented: only contiguous and channels last memory "
|
||||||
"format is supported");
|
"format is supported");
|
||||||
}
|
}
|
||||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto outType = dyn_cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template dyn_cast<TensorType>();
|
op.getType()));
|
||||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, adaptor.getSelf());
|
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, adaptor.getSelf());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -5077,8 +5071,8 @@ LogicalResult ConvertAtenOp<AtenSqrtOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Only Tensor types supported in TOSA");
|
"Only Tensor types supported in TOSA");
|
||||||
|
|
||||||
auto resultType = typeConverter->convertType(op.getType())
|
auto resultType =
|
||||||
.template cast<RankedTensorType>();
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
auto elementType = resultType.getElementType();
|
auto elementType = resultType.getElementType();
|
||||||
|
|
||||||
if (isa<mlir::IntegerType>(selfTy.getElementType())) {
|
if (isa<mlir::IntegerType>(selfTy.getElementType())) {
|
||||||
|
|
|
@ -813,9 +813,9 @@ convertReduceProdOp(PatternRewriter &rewriter, Operation *op,
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
bool input_is_qtype =
|
bool input_is_qtype =
|
||||||
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
isa<mlir::quant::UniformQuantizedType>(input_type.getElementType());
|
||||||
bool output_is_qtype =
|
bool output_is_qtype =
|
||||||
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
isa<mlir::quant::UniformQuantizedType>(output_type.getElementType());
|
||||||
|
|
||||||
if (input_is_qtype || output_is_qtype) {
|
if (input_is_qtype || output_is_qtype) {
|
||||||
op->emitOpError("ConvertReduceProdOp: input/output tensor should "
|
op->emitOpError("ConvertReduceProdOp: input/output tensor should "
|
||||||
|
@ -839,9 +839,9 @@ convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
bool input_is_qtype =
|
bool input_is_qtype =
|
||||||
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
isa<mlir::quant::UniformQuantizedType>(input_type.getElementType());
|
||||||
bool output_is_qtype =
|
bool output_is_qtype =
|
||||||
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
isa<mlir::quant::UniformQuantizedType>(output_type.getElementType());
|
||||||
|
|
||||||
if (input_is_qtype != output_is_qtype) {
|
if (input_is_qtype != output_is_qtype) {
|
||||||
op->emitOpError("ConvertReduceSumOp: input/output tensor should "
|
op->emitOpError("ConvertReduceSumOp: input/output tensor should "
|
||||||
|
@ -894,9 +894,9 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
bool input_is_qtype =
|
bool input_is_qtype =
|
||||||
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
isa<mlir::quant::UniformQuantizedType>(input_type.getElementType());
|
||||||
bool output_is_qtype =
|
bool output_is_qtype =
|
||||||
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
|
isa<mlir::quant::UniformQuantizedType>(output_type.getElementType());
|
||||||
|
|
||||||
if (input_is_qtype != output_is_qtype) {
|
if (input_is_qtype != output_is_qtype) {
|
||||||
op->emitOpError("ConvertReduceSumOp: input/output tensor should "
|
op->emitOpError("ConvertReduceSumOp: input/output tensor should "
|
||||||
|
@ -905,7 +905,7 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only supports float type mean() if it's non-quantized
|
// Only supports float type mean() if it's non-quantized
|
||||||
if (!input_is_qtype && !output_type.getElementType().isa<mlir::FloatType>()) {
|
if (!input_is_qtype && !isa<mlir::FloatType>(output_type.getElementType())) {
|
||||||
op->emitWarning(
|
op->emitWarning(
|
||||||
"Failed convertReduceMean: input unquantized type but output element "
|
"Failed convertReduceMean: input unquantized type but output element "
|
||||||
"not FloatType!");
|
"not FloatType!");
|
||||||
|
|
|
@ -31,7 +31,7 @@ LogicalResult verifyLinalgCompatibleTypes(Operation *op,
|
||||||
return false;
|
return false;
|
||||||
auto tensor = dyn_cast<ValueTensorType>(type);
|
auto tensor = dyn_cast<ValueTensorType>(type);
|
||||||
return !tensor ||
|
return !tensor ||
|
||||||
tensor.toBuiltinTensor().dyn_cast_or_null<RankedTensorType>();
|
dyn_cast_or_null<RankedTensorType>(tensor.toBuiltinTensor());
|
||||||
};
|
};
|
||||||
|
|
||||||
bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) &&
|
bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) &&
|
||||||
|
@ -66,7 +66,7 @@ Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
|
||||||
|
|
||||||
// Generate IR: assert(dim >= 0 && dim < inputRank)
|
// Generate IR: assert(dim >= 0 && dim < inputRank)
|
||||||
void assertIsValidDim(OpBuilder &b, Location loc, Value dim, Value inputRank) {
|
void assertIsValidDim(OpBuilder &b, Location loc, Value dim, Value inputRank) {
|
||||||
assert(dim.getType().isa<IntegerType>() &&
|
assert(isa<IntegerType>(dim.getType()) &&
|
||||||
"dim arg of assertIsValidDim must be integer type");
|
"dim arg of assertIsValidDim must be integer type");
|
||||||
Value cst0 =
|
Value cst0 =
|
||||||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
|
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
|
||||||
|
@ -139,12 +139,12 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||||
}
|
}
|
||||||
|
|
||||||
Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
|
Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
|
||||||
assert(v.getType().isa<IntegerType>() && "must be called with integer type");
|
assert(isa<IntegerType>(v.getType()) && "must be called with integer type");
|
||||||
return b.create<arith::IndexCastOp>(loc, b.getIndexType(), v);
|
return b.create<arith::IndexCastOp>(loc, b.getIndexType(), v);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value castIndexToInt64(OpBuilder &b, Location loc, Value idx) {
|
Value castIndexToInt64(OpBuilder &b, Location loc, Value idx) {
|
||||||
assert(idx.getType().isa<IndexType>() && "must be called with integer type");
|
assert(isa<IndexType>(idx.getType()) && "must be called with integer type");
|
||||||
return b.create<arith::IndexCastOp>(loc, b.getI64Type(), idx);
|
return b.create<arith::IndexCastOp>(loc, b.getI64Type(), idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -375,7 +375,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||||
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
|
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
Value torchOptionalInt, Value builtinInt,
|
Value torchOptionalInt, Value builtinInt,
|
||||||
Value defaultValue, Value dimSize) {
|
Value defaultValue, Value dimSize) {
|
||||||
if (torchOptionalInt.getType().isa<Torch::NoneType>())
|
if (isa<Torch::NoneType>(torchOptionalInt.getType()))
|
||||||
return defaultValue;
|
return defaultValue;
|
||||||
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
|
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
|
||||||
Value positiveDim =
|
Value positiveDim =
|
||||||
|
|
|
@ -149,14 +149,12 @@ static Value getScalarIntValue(Value input, Location loc,
|
||||||
|
|
||||||
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
||||||
if (inputDtype.isInteger(64)) {
|
if (inputDtype.isInteger(64)) {
|
||||||
auto val = valueTensorLiteralOp.getValue()
|
auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue())
|
||||||
.cast<DenseIntElementsAttr>()
|
|
||||||
.getSplatValue<int64_t>();
|
.getSplatValue<int64_t>();
|
||||||
return rewriter.create<Torch::ConstantIntOp>(
|
return rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(val));
|
loc, rewriter.getI64IntegerAttr(val));
|
||||||
} else {
|
} else {
|
||||||
auto val = valueTensorLiteralOp.getValue()
|
auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue())
|
||||||
.cast<DenseIntElementsAttr>()
|
|
||||||
.getSplatValue<bool>();
|
.getSplatValue<bool>();
|
||||||
return rewriter.create<Torch::ConstantIntOp>(
|
return rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(val));
|
loc, rewriter.getI64IntegerAttr(val));
|
||||||
|
@ -191,8 +189,7 @@ static Value getScalarFloatValue(Value input, Location loc,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
||||||
auto val = valueTensorLiteralOp.getValue()
|
auto val = cast<DenseFPElementsAttr>(valueTensorLiteralOp.getValue())
|
||||||
.cast<DenseFPElementsAttr>()
|
|
||||||
.getSplatValue<FloatAttr>()
|
.getSplatValue<FloatAttr>()
|
||||||
.getValueAsDouble();
|
.getValueAsDouble();
|
||||||
return rewriter.create<Torch::ConstantFloatOp>(
|
return rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
@ -1946,7 +1943,7 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
|
||||||
OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) {
|
||||||
auto resultType = dyn_cast<ValueTensorType>(getType());
|
auto resultType = dyn_cast<ValueTensorType>(getType());
|
||||||
if (resultType && resultType.hasDtype() &&
|
if (resultType && resultType.hasDtype() &&
|
||||||
resultType.getDtype().isa<mlir::IntegerType>()) {
|
isa<mlir::IntegerType>(resultType.getDtype())) {
|
||||||
return getSelf();
|
return getSelf();
|
||||||
}
|
}
|
||||||
return {};
|
return {};
|
||||||
|
@ -2136,7 +2133,7 @@ traceKnownSizeTensorType(Value value, std::optional<int64_t> dim) {
|
||||||
// Limit the loop count to 6 to avoid indefinite compilation times from
|
// Limit the loop count to 6 to avoid indefinite compilation times from
|
||||||
// unbounded IR traversals.
|
// unbounded IR traversals.
|
||||||
for (auto idx = 0; idx < 6; ++idx) {
|
for (auto idx = 0; idx < 6; ++idx) {
|
||||||
if (!value || !value.getType().isa<BaseTensorType>())
|
if (!value || !isa<BaseTensorType>(value.getType()))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto tensorType = cast<BaseTensorType>(value.getType());
|
auto tensorType = cast<BaseTensorType>(value.getType());
|
||||||
|
@ -2518,7 +2515,7 @@ OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) {
|
||||||
|
|
||||||
OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
|
||||||
// Constant fold int -> float conversion.
|
// Constant fold int -> float conversion.
|
||||||
if (auto integerAttr = adaptor.getA().dyn_cast_or_null<IntegerAttr>()) {
|
if (auto integerAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getA())) {
|
||||||
return FloatAttr::get(
|
return FloatAttr::get(
|
||||||
mlir::Float64Type::get(getContext()),
|
mlir::Float64Type::get(getContext()),
|
||||||
static_cast<double>(integerAttr.getValue().getSExtValue()));
|
static_cast<double>(integerAttr.getValue().getSExtValue()));
|
||||||
|
@ -2535,7 +2532,7 @@ OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
|
||||||
|
|
||||||
OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) {
|
||||||
// Constant fold float -> int conversion.
|
// Constant fold float -> int conversion.
|
||||||
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
|
if (auto floatAttr = dyn_cast_or_null<FloatAttr>(adaptor.getA())) {
|
||||||
return IntegerAttr::get(
|
return IntegerAttr::get(
|
||||||
mlir::IntegerType::get(getContext(), 64),
|
mlir::IntegerType::get(getContext(), 64),
|
||||||
static_cast<int64_t>(floatAttr.getValue().convertToDouble()));
|
static_cast<int64_t>(floatAttr.getValue().convertToDouble()));
|
||||||
|
@ -2549,7 +2546,7 @@ OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) {
|
||||||
|
|
||||||
OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) {
|
||||||
// Constant fold float -> int conversion.
|
// Constant fold float -> int conversion.
|
||||||
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
|
if (auto floatAttr = dyn_cast_or_null<FloatAttr>(adaptor.getA())) {
|
||||||
return IntegerAttr::get(
|
return IntegerAttr::get(
|
||||||
mlir::IntegerType::get(getContext(), 64),
|
mlir::IntegerType::get(getContext(), 64),
|
||||||
static_cast<long>(floatAttr.getValue().convertToDouble()));
|
static_cast<long>(floatAttr.getValue().convertToDouble()));
|
||||||
|
@ -2695,9 +2692,8 @@ LogicalResult NonValueTensorLiteralOp::inferReturnTypes(
|
||||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
auto attr = properties.as<Properties *>()
|
auto attr =
|
||||||
->getValue()
|
dyn_cast_or_null<ElementsAttr>(properties.as<Properties *>()->getValue());
|
||||||
.dyn_cast_or_null<ElementsAttr>();
|
|
||||||
if (!attr)
|
if (!attr)
|
||||||
return failure();
|
return failure();
|
||||||
RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
|
RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
|
||||||
|
@ -2723,10 +2719,10 @@ static bool areSizesAndDtypesCompatible(BaseTensorType a, BaseTensorType b) {
|
||||||
|
|
||||||
bool NonValueTensorLiteralOp::isCompatibleReturnTypes(TypeRange inferred,
|
bool NonValueTensorLiteralOp::isCompatibleReturnTypes(TypeRange inferred,
|
||||||
TypeRange actual) {
|
TypeRange actual) {
|
||||||
if (!actual[0].isa<BaseTensorType>())
|
if (!isa<BaseTensorType>(actual[0]))
|
||||||
return false;
|
return false;
|
||||||
return areSizesAndDtypesCompatible(inferred[0].cast<BaseTensorType>(),
|
return areSizesAndDtypesCompatible(cast<BaseTensorType>(inferred[0]),
|
||||||
actual[0].cast<BaseTensorType>());
|
cast<BaseTensorType>(actual[0]));
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -2737,9 +2733,8 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
|
||||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
auto attr = properties.as<Properties *>()
|
auto attr =
|
||||||
->getValue()
|
dyn_cast_or_null<ElementsAttr>(properties.as<Properties *>()->getValue());
|
||||||
.dyn_cast_or_null<ElementsAttr>();
|
|
||||||
if (!attr)
|
if (!attr)
|
||||||
return failure();
|
return failure();
|
||||||
RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
|
RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
|
||||||
|
@ -2760,8 +2755,8 @@ OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) {
|
||||||
|
|
||||||
bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs,
|
bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs,
|
||||||
mlir::TypeRange outputs) {
|
mlir::TypeRange outputs) {
|
||||||
return areSizesAndDtypesCompatible(inputs[0].cast<BaseTensorType>(),
|
return areSizesAndDtypesCompatible(cast<BaseTensorType>(inputs[0]),
|
||||||
outputs[0].cast<BaseTensorType>());
|
cast<BaseTensorType>(outputs[0]));
|
||||||
}
|
}
|
||||||
|
|
||||||
void TensorStaticInfoCastOp::getCanonicalizationPatterns(
|
void TensorStaticInfoCastOp::getCanonicalizationPatterns(
|
||||||
|
@ -3072,7 +3067,7 @@ OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) {
|
||||||
if (!operandType)
|
if (!operandType)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
if (operandType.hasDtype()) {
|
if (operandType.hasDtype()) {
|
||||||
bool isFloatType = operandType.getDtype().isa<mlir::FloatType>();
|
bool isFloatType = isa<mlir::FloatType>(operandType.getDtype());
|
||||||
return IntegerAttr::get(IntegerType::get(getContext(), 1), isFloatType);
|
return IntegerAttr::get(IntegerType::get(getContext(), 1), isFloatType);
|
||||||
}
|
}
|
||||||
// doesn't has dtype
|
// doesn't has dtype
|
||||||
|
@ -3130,12 +3125,12 @@ void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
int64_t start;
|
int64_t start;
|
||||||
int64_t end;
|
int64_t end;
|
||||||
int64_t step;
|
int64_t step;
|
||||||
if (op.getStart().getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(op.getStart().getType())) {
|
||||||
start = 0;
|
start = 0;
|
||||||
} else if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) {
|
} else if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
if (op.getEnd().getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(op.getEnd().getType())) {
|
||||||
end = listElements.size();
|
end = listElements.size();
|
||||||
} else if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) {
|
} else if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) {
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -3228,7 +3223,7 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
// things.
|
// things.
|
||||||
Value replacement = tupleConstruct.getElements()[i];
|
Value replacement = tupleConstruct.getElements()[i];
|
||||||
if (replacement.getType() != op.getType()) {
|
if (replacement.getType() != op.getType()) {
|
||||||
if (op.getType().isa<BaseTensorType>()) {
|
if (isa<BaseTensorType>(op.getType())) {
|
||||||
replacement = rewriter.create<Torch::TensorStaticInfoCastOp>(
|
replacement = rewriter.create<Torch::TensorStaticInfoCastOp>(
|
||||||
op.getLoc(), op.getType(), replacement);
|
op.getLoc(), op.getType(), replacement);
|
||||||
} else {
|
} else {
|
||||||
|
@ -3384,8 +3379,8 @@ using BinaryIntOperatorFn = std::function<int64_t(int64_t, int64_t)>;
|
||||||
static OpFoldResult
|
static OpFoldResult
|
||||||
atenBinaryIntOperatorFoldHelper(ArrayRef<Attribute> operands,
|
atenBinaryIntOperatorFoldHelper(ArrayRef<Attribute> operands,
|
||||||
BinaryIntOperatorFn f) {
|
BinaryIntOperatorFn f) {
|
||||||
auto intLhs = operands[0].dyn_cast_or_null<IntegerAttr>();
|
auto intLhs = dyn_cast_or_null<IntegerAttr>(operands[0]);
|
||||||
auto intRhs = operands[1].dyn_cast_or_null<IntegerAttr>();
|
auto intRhs = dyn_cast_or_null<IntegerAttr>(operands[1]);
|
||||||
if (!intLhs || !intRhs) {
|
if (!intLhs || !intRhs) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -3711,7 +3706,7 @@ OpFoldResult AtenAddOp::fold(FoldAdaptor adaptor) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) {
|
if (isa<IntegerAttr>(adaptor.getA()) && isa<IntegerAttr>(adaptor.getB())) {
|
||||||
return atenBinaryIntOperatorFoldHelper(
|
return atenBinaryIntOperatorFoldHelper(
|
||||||
adaptor.getOperands(),
|
adaptor.getOperands(),
|
||||||
[](int64_t a, int64_t b) -> int64_t { return a + b; });
|
[](int64_t a, int64_t b) -> int64_t { return a + b; });
|
||||||
|
@ -3730,7 +3725,7 @@ OpFoldResult AtenMulOp::fold(FoldAdaptor adaptor) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) {
|
if (isa<IntegerAttr>(adaptor.getA()) && isa<IntegerAttr>(adaptor.getB())) {
|
||||||
return atenBinaryIntOperatorFoldHelper(
|
return atenBinaryIntOperatorFoldHelper(
|
||||||
adaptor.getOperands(),
|
adaptor.getOperands(),
|
||||||
[](int64_t a, int64_t b) -> int64_t { return a * b; });
|
[](int64_t a, int64_t b) -> int64_t { return a * b; });
|
||||||
|
@ -3749,7 +3744,7 @@ OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) {
|
if (isa<IntegerAttr>(adaptor.getA()) && isa<IntegerAttr>(adaptor.getB())) {
|
||||||
return atenBinaryIntOperatorFoldHelper(
|
return atenBinaryIntOperatorFoldHelper(
|
||||||
adaptor.getOperands(),
|
adaptor.getOperands(),
|
||||||
[](int64_t a, int64_t b) -> int64_t { return a - b; });
|
[](int64_t a, int64_t b) -> int64_t { return a - b; });
|
||||||
|
@ -3806,7 +3801,7 @@ OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) {
|
||||||
if (!adaptor.getA()) {
|
if (!adaptor.getA()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto floatValue = adaptor.getA().dyn_cast_or_null<FloatAttr>();
|
auto floatValue = dyn_cast_or_null<FloatAttr>(adaptor.getA());
|
||||||
if (!floatValue) {
|
if (!floatValue) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -3834,7 +3829,7 @@ OpFoldResult AtenNegFloatOp::fold(FoldAdaptor adaptor) {
|
||||||
if (!adaptor.getA()) {
|
if (!adaptor.getA()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto value = adaptor.getA().dyn_cast_or_null<FloatAttr>();
|
auto value = dyn_cast_or_null<FloatAttr>(adaptor.getA());
|
||||||
if (!value) {
|
if (!value) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -4487,8 +4482,8 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
|
||||||
if (getA() == getB())
|
if (getA() == getB())
|
||||||
return getA();
|
return getA();
|
||||||
|
|
||||||
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
|
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
|
||||||
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
|
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
|
||||||
if (!lhs || !rhs)
|
if (!lhs || !rhs)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
// Torch semantics are that !torch.int is 64-bit signed.
|
// Torch semantics are that !torch.int is 64-bit signed.
|
||||||
|
@ -4556,8 +4551,8 @@ OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) {
|
||||||
if (getA() == getB())
|
if (getA() == getB())
|
||||||
return getA();
|
return getA();
|
||||||
|
|
||||||
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
|
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
|
||||||
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
|
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
|
||||||
if (!lhs || !rhs)
|
if (!lhs || !rhs)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
// Torch semantics are that !torch.int is 64-bit signed.
|
// Torch semantics are that !torch.int is 64-bit signed.
|
||||||
|
@ -4644,8 +4639,8 @@ LogicalResult AtenNormScalarOp::verify() {
|
||||||
// Check if dtype is one of those supported by norm operation.
|
// Check if dtype is one of those supported by norm operation.
|
||||||
// ComplexType will match any torch complex types, but each float must be
|
// ComplexType will match any torch complex types, but each float must be
|
||||||
// checked individually.
|
// checked individually.
|
||||||
if (!inTensorDtype.isa<mlir::ComplexType, mlir::Float16Type,
|
if (!isa<mlir::ComplexType, mlir::Float16Type, mlir::Float32Type,
|
||||||
mlir::Float32Type, mlir::Float64Type>()) {
|
mlir::Float64Type>(inTensorDtype)) {
|
||||||
return emitOpError(
|
return emitOpError(
|
||||||
"expected a float or complex type for input tensor, but got ")
|
"expected a float or complex type for input tensor, but got ")
|
||||||
<< inTensorDtype;
|
<< inTensorDtype;
|
||||||
|
|
|
@ -190,8 +190,8 @@ static bool isValidTorchDtype(Type dtype) {
|
||||||
// Builtin floating point types.
|
// Builtin floating point types.
|
||||||
if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(dtype))
|
if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(dtype))
|
||||||
return true;
|
return true;
|
||||||
if (dtype.isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
|
if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
|
||||||
Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
|
Float8E4M3FNUZType, Float8E4M3B11FNUZType>(dtype))
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
if (isa<Torch::StringType>(dtype))
|
if (isa<Torch::StringType>(dtype))
|
||||||
|
@ -228,9 +228,9 @@ Type BaseTensorType::getWithSizesAndDtypeFrom(BaseTensorType other) const {
|
||||||
|
|
||||||
Type BaseTensorType::getWithSizesAndDtype(
|
Type BaseTensorType::getWithSizesAndDtype(
|
||||||
std::optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype) const {
|
std::optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype) const {
|
||||||
if (isa<NonValueTensorType>())
|
if (mlir::isa<NonValueTensorType>(*this))
|
||||||
return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype);
|
return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype);
|
||||||
if (isa<ValueTensorType>())
|
if (mlir::isa<ValueTensorType>(*this))
|
||||||
return ValueTensorType::get(getContext(), optionalSizes, optionalDtype);
|
return ValueTensorType::get(getContext(), optionalSizes, optionalDtype);
|
||||||
llvm_unreachable("not a BaseTensorType!");
|
llvm_unreachable("not a BaseTensorType!");
|
||||||
}
|
}
|
||||||
|
@ -248,9 +248,9 @@ Type BaseTensorType::getWithSizesAndDtypeAndSparsity(
|
||||||
}
|
}
|
||||||
|
|
||||||
ValueTensorType BaseTensorType::getWithValueSemantics() const {
|
ValueTensorType BaseTensorType::getWithValueSemantics() const {
|
||||||
if (auto tensor = dyn_cast<NonValueTensorType>())
|
if (auto tensor = mlir::dyn_cast<NonValueTensorType>(*this))
|
||||||
return tensor.getWithValueSemantics();
|
return tensor.getWithValueSemantics();
|
||||||
if (auto tensor = dyn_cast<ValueTensorType>())
|
if (auto tensor = mlir::dyn_cast<ValueTensorType>(*this))
|
||||||
return tensor;
|
return tensor;
|
||||||
llvm_unreachable("not a BaseTensorType!");
|
llvm_unreachable("not a BaseTensorType!");
|
||||||
}
|
}
|
||||||
|
|
|
@ -110,7 +110,7 @@ public:
|
||||||
continue;
|
continue;
|
||||||
auto it = typeBoundMap.find({call.getCallee(), operand.index()});
|
auto it = typeBoundMap.find({call.getCallee(), operand.index()});
|
||||||
if (it != typeBoundMap.end()) {
|
if (it != typeBoundMap.end()) {
|
||||||
if (auto valueTensorType = it->second.dyn_cast<ValueTensorType>()) {
|
if (auto valueTensorType = dyn_cast<ValueTensorType>(it->second)) {
|
||||||
newOperands.push_back(copyTensorToType(
|
newOperands.push_back(copyTensorToType(
|
||||||
rewriter, call->getLoc(), valueTensorType, operand.value()));
|
rewriter, call->getLoc(), valueTensorType, operand.value()));
|
||||||
continue;
|
continue;
|
||||||
|
@ -215,11 +215,11 @@ static LogicalResult adjustCallingConventions(func::FuncOp func,
|
||||||
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
|
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
|
||||||
if (func.getArgAttr(i, "torch.type_bound"))
|
if (func.getArgAttr(i, "torch.type_bound"))
|
||||||
return false;
|
return false;
|
||||||
if (func.getArgumentTypes()[i].isa<Torch::NoneType>())
|
if (isa<Torch::NoneType>(func.getArgumentTypes()[i]))
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
for (int i = 0, e = func.getNumResults(); i != e; i++) {
|
for (int i = 0, e = func.getNumResults(); i != e; i++) {
|
||||||
if (func.getFunctionType().getResults()[i].isa<Torch::NoneType>())
|
if (isa<Torch::NoneType>(func.getFunctionType().getResults()[i]))
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -38,7 +38,7 @@ static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
|
||||||
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
||||||
if (failed(resDtype))
|
if (failed(resDtype))
|
||||||
return false;
|
return false;
|
||||||
return resDtype->isa<mlir::FloatType>();
|
return isa<mlir::FloatType>(*resDtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to compute the return type of the reduction function.
|
// Helper function to compute the return type of the reduction function.
|
||||||
|
@ -99,19 +99,15 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
|
||||||
Operation *op, Value input, Value dim,
|
Operation *op, Value input, Value dim,
|
||||||
bool keepDim) {
|
bool keepDim) {
|
||||||
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
||||||
BaseTensorType valueType =
|
BaseTensorType valueType = cast<BaseTensorType>(computeReductionType(
|
||||||
computeReductionType(rewriter, op, cast<BaseTensorType>(input.getType()),
|
rewriter, op, cast<BaseTensorType>(input.getType()), dim, keepDim));
|
||||||
dim, keepDim)
|
|
||||||
.cast<BaseTensorType>();
|
|
||||||
if (!valueType)
|
if (!valueType)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
BaseTensorType indexType =
|
BaseTensorType indexType =
|
||||||
valueType
|
cast<BaseTensorType>(valueType.getWithSizesAndDtype(
|
||||||
.getWithSizesAndDtype(
|
|
||||||
!valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
|
!valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
|
||||||
: llvm::ArrayRef(valueType.getSizes()),
|
: llvm::ArrayRef(valueType.getSizes()),
|
||||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed))
|
IntegerType::get(op->getContext(), 64, IntegerType::Signed)));
|
||||||
.cast<BaseTensorType>();
|
|
||||||
return rewriter
|
return rewriter
|
||||||
.create<AtenMaxDimOp>(loc, valueType, indexType, input, dim, keepDimCst)
|
.create<AtenMaxDimOp>(loc, valueType, indexType, input, dim, keepDimCst)
|
||||||
.getValues();
|
.getValues();
|
||||||
|
@ -1059,7 +1055,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenEyeMOp op,
|
LogicalResult matchAndRewrite(AtenEyeMOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto outType = op.getType().dyn_cast<BaseTensorType>();
|
auto outType = dyn_cast<BaseTensorType>(op.getType());
|
||||||
if (!outType)
|
if (!outType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types input are currently supported");
|
op, "Only tensor types input are currently supported");
|
||||||
|
@ -1659,11 +1655,9 @@ public:
|
||||||
unsigned inputRank = *maybeInputRank;
|
unsigned inputRank = *maybeInputRank;
|
||||||
if (!indicesTensorType.hasSizes())
|
if (!indicesTensorType.hasSizes())
|
||||||
return failure();
|
return failure();
|
||||||
BaseTensorType valueTensorType =
|
BaseTensorType valueTensorType = cast<BaseTensorType>(
|
||||||
inputType
|
inputType.getWithSizesAndDtype(indicesTensorType.getOptionalSizes(),
|
||||||
.getWithSizesAndDtype(indicesTensorType.getOptionalSizes(),
|
inputType.getOptionalDtype()));
|
||||||
inputType.getOptionalDtype())
|
|
||||||
.cast<BaseTensorType>();
|
|
||||||
|
|
||||||
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
|
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
|
||||||
// `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so
|
// `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so
|
||||||
|
@ -1671,10 +1665,8 @@ public:
|
||||||
// happens on the 0th dimension.
|
// happens on the 0th dimension.
|
||||||
if (isa<Torch::NoneType>(dim.getType())) {
|
if (isa<Torch::NoneType>(dim.getType())) {
|
||||||
BaseTensorType flattenType =
|
BaseTensorType flattenType =
|
||||||
inputType
|
cast<BaseTensorType>(inputType.getWithSizesAndDtype(
|
||||||
.getWithSizesAndDtype({kUnknownSize},
|
{kUnknownSize}, inputType.getOptionalDtype()));
|
||||||
inputType.getOptionalDtype())
|
|
||||||
.cast<BaseTensorType>();
|
|
||||||
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||||
Value end = rewriter.create<ConstantIntOp>(
|
Value end = rewriter.create<ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(inputRank - 1));
|
loc, rewriter.getI64IntegerAttr(inputRank - 1));
|
||||||
|
@ -3003,7 +2995,7 @@ public:
|
||||||
bool dimIsNone = false;
|
bool dimIsNone = false;
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
Value dimValue = op.getDim();
|
Value dimValue = op.getDim();
|
||||||
if (dimValue.getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(dimValue.getType())) {
|
||||||
dimIsNone = true;
|
dimIsNone = true;
|
||||||
dim = inputRank - 1;
|
dim = inputRank - 1;
|
||||||
} else {
|
} else {
|
||||||
|
@ -3887,10 +3879,9 @@ public:
|
||||||
gradOutputViewSizesInt[0] = kUnknownSize;
|
gradOutputViewSizesInt[0] = kUnknownSize;
|
||||||
gradOutputViewSizesInt[1] = 1;
|
gradOutputViewSizesInt[1] = 1;
|
||||||
BaseTensorType gradOutputTypeForView =
|
BaseTensorType gradOutputTypeForView =
|
||||||
gradOutputTy
|
cast<BaseTensorType>(gradOutputTy.getWithSizesAndDtype(
|
||||||
.getWithSizesAndDtype(llvm::ArrayRef(gradOutputViewSizesInt),
|
llvm::ArrayRef(gradOutputViewSizesInt),
|
||||||
gradOutputTy.getOptionalDtype())
|
gradOutputTy.getOptionalDtype()));
|
||||||
.cast<BaseTensorType>();
|
|
||||||
Value gradOutputView = rewriter.create<Torch::AtenViewOp>(
|
Value gradOutputView = rewriter.create<Torch::AtenViewOp>(
|
||||||
loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList);
|
loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList);
|
||||||
|
|
||||||
|
@ -3918,10 +3909,9 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
BaseTensorType gradWeightTy =
|
BaseTensorType gradWeightTy =
|
||||||
inputTransposedTy
|
cast<BaseTensorType>(inputTransposedTy.getWithSizesAndDtype(
|
||||||
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt),
|
llvm::ArrayRef(gradWeightSizesInt),
|
||||||
inputTransposedTy.getOptionalDtype())
|
inputTransposedTy.getOptionalDtype()));
|
||||||
.cast<BaseTensorType>();
|
|
||||||
|
|
||||||
Value numGroup = rewriter.create<AtenSizeIntOp>(loc, input, cstZero);
|
Value numGroup = rewriter.create<AtenSizeIntOp>(loc, input, cstZero);
|
||||||
gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
|
gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
|
||||||
|
@ -3937,10 +3927,9 @@ public:
|
||||||
for (unsigned i = 0; i < gradWeightTy.getSizes().size() - 2; i++) {
|
for (unsigned i = 0; i < gradWeightTy.getSizes().size() - 2; i++) {
|
||||||
gradWeightSizesInt[i + 2] = weightSizes[i + 2];
|
gradWeightSizesInt[i + 2] = weightSizes[i + 2];
|
||||||
BaseTensorType gradWeightNarrowTy =
|
BaseTensorType gradWeightNarrowTy =
|
||||||
gradWeightTy
|
cast<BaseTensorType>(gradWeightTy.getWithSizesAndDtype(
|
||||||
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt),
|
llvm::ArrayRef(gradWeightSizesInt),
|
||||||
gradWeightTy.getOptionalDtype())
|
gradWeightTy.getOptionalDtype()));
|
||||||
.cast<BaseTensorType>();
|
|
||||||
|
|
||||||
Value dim = rewriter.create<ConstantIntOp>(
|
Value dim = rewriter.create<ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(i + 2));
|
loc, rewriter.getI64IntegerAttr(i + 2));
|
||||||
|
@ -3970,10 +3959,9 @@ public:
|
||||||
gradWeightViewShapeValue);
|
gradWeightViewShapeValue);
|
||||||
|
|
||||||
BaseTensorType gradWeightTypeForView =
|
BaseTensorType gradWeightTypeForView =
|
||||||
gradWeightTy
|
cast<BaseTensorType>(gradWeightTy.getWithSizesAndDtype(
|
||||||
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightViewShapeInt),
|
llvm::ArrayRef(gradWeightViewShapeInt),
|
||||||
gradWeightTy.getOptionalDtype())
|
gradWeightTy.getOptionalDtype()));
|
||||||
.cast<BaseTensorType>();
|
|
||||||
gradWeight = rewriter.create<Torch::AtenViewOp>(
|
gradWeight = rewriter.create<Torch::AtenViewOp>(
|
||||||
loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList);
|
loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList);
|
||||||
|
|
||||||
|
@ -3986,10 +3974,9 @@ public:
|
||||||
gradWeightViewShapeInt[gradWeightDimsOrder[i]]);
|
gradWeightViewShapeInt[gradWeightDimsOrder[i]]);
|
||||||
}
|
}
|
||||||
BaseTensorType gradWeightTypeForMoveDim =
|
BaseTensorType gradWeightTypeForMoveDim =
|
||||||
gradWeightTy
|
cast<BaseTensorType>(gradWeightTy.getWithSizesAndDtype(
|
||||||
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightMoveDimShape),
|
llvm::ArrayRef(gradWeightMoveDimShape),
|
||||||
gradWeightTy.getOptionalDtype())
|
gradWeightTy.getOptionalDtype()));
|
||||||
.cast<BaseTensorType>();
|
|
||||||
|
|
||||||
gradWeight = rewriter.create<AtenMovedimIntOp>(
|
gradWeight = rewriter.create<AtenMovedimIntOp>(
|
||||||
loc, gradWeightTypeForMoveDim, gradWeight, /*source=*/cstZero,
|
loc, gradWeightTypeForMoveDim, gradWeight, /*source=*/cstZero,
|
||||||
|
@ -4009,8 +3996,7 @@ public:
|
||||||
Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||||
loc, transposedType, gradOutput, cstZero, cstOne);
|
loc, transposedType, gradOutput, cstZero, cstOne);
|
||||||
// Convolve input with grad_output.
|
// Convolve input with grad_output.
|
||||||
if (failed(
|
if (failed(getTransposedType(cast<BaseTensorType>(op.getResultTypes()[1]),
|
||||||
getTransposedType(op.getResultTypes()[1].cast<BaseTensorType>(),
|
|
||||||
0, 1, transposedType)))
|
0, 1, transposedType)))
|
||||||
return failure();
|
return failure();
|
||||||
gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
|
gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
|
||||||
|
@ -4063,7 +4049,7 @@ public:
|
||||||
|
|
||||||
// TODO: Handle integer type operands.
|
// TODO: Handle integer type operands.
|
||||||
auto inputType = cast<BaseTensorType>(input.getType());
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype())) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: non-floating point dtype");
|
op, "unimplemented: non-floating point dtype");
|
||||||
}
|
}
|
||||||
|
@ -4125,7 +4111,7 @@ public:
|
||||||
MLIRContext *context = op.getContext();
|
MLIRContext *context = op.getContext();
|
||||||
|
|
||||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>() ||
|
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype()) ||
|
||||||
!isNoneOrFloatDtype(context, dtype)) {
|
!isNoneOrFloatDtype(context, dtype)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only floating-point type is supported");
|
op, "only floating-point type is supported");
|
||||||
|
@ -4133,7 +4119,7 @@ public:
|
||||||
|
|
||||||
SmallVector<Value> dimListElements;
|
SmallVector<Value> dimListElements;
|
||||||
if (!getListConstructElements(dimList, dimListElements) &&
|
if (!getListConstructElements(dimList, dimListElements) &&
|
||||||
!dimList.getType().isa<Torch::NoneType>()) {
|
!isa<Torch::NoneType>(dimList.getType())) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected `dim` to be `None` or constructed from list construct");
|
op, "expected `dim` to be `None` or constructed from list construct");
|
||||||
}
|
}
|
||||||
|
@ -4215,7 +4201,7 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>())
|
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only support floating type input for training mode");
|
op, "only support floating type input for training mode");
|
||||||
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
@ -4243,7 +4229,7 @@ public:
|
||||||
Value input = op.getInput();
|
Value input = op.getInput();
|
||||||
Value prob = op.getP();
|
Value prob = op.getP();
|
||||||
bool train = false;
|
bool train = false;
|
||||||
if (!op.getTrain().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getTrain().getType())) {
|
||||||
if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) {
|
if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "train must be a boolean constant or none");
|
op, "train must be a boolean constant or none");
|
||||||
|
@ -4263,7 +4249,7 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype())) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only support floating type input for training mode");
|
op, "only support floating type input for training mode");
|
||||||
}
|
}
|
||||||
|
@ -4332,7 +4318,7 @@ public:
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
BaseTensorType inputTensorTy = cast<BaseTensorType>(self.getType());
|
BaseTensorType inputTensorTy = cast<BaseTensorType>(self.getType());
|
||||||
if (!inputTensorTy.hasDtype() ||
|
if (!inputTensorTy.hasDtype() ||
|
||||||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
!isa<mlir::FloatType>(inputTensorTy.getDtype())) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Only aten.std support floating type");
|
"Only aten.std support floating type");
|
||||||
}
|
}
|
||||||
|
@ -4388,7 +4374,7 @@ public:
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType());
|
BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType());
|
||||||
if (!inputTensorType.hasDtype() ||
|
if (!inputTensorType.hasDtype() ||
|
||||||
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
|
!isa<mlir::FloatType>(inputTensorType.getDtype())) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "aten.std.dim expects input tensor of floating-point type");
|
op, "aten.std.dim expects input tensor of floating-point type");
|
||||||
}
|
}
|
||||||
|
@ -4413,7 +4399,7 @@ public:
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType());
|
BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType());
|
||||||
if (!inputTensorType.hasDtype() ||
|
if (!inputTensorType.hasDtype() ||
|
||||||
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
|
!isa<mlir::FloatType>(inputTensorType.getDtype())) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op,
|
op,
|
||||||
"aten.std.correction expects input tensor of floating-point type");
|
"aten.std.correction expects input tensor of floating-point type");
|
||||||
|
@ -4506,7 +4492,7 @@ public:
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
Type resultType = op.getType();
|
Type resultType = op.getType();
|
||||||
auto inputType = cast<BaseTensorType>(input.getType());
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype())) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only support floating-point type");
|
"only support floating-point type");
|
||||||
}
|
}
|
||||||
|
@ -4547,7 +4533,7 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,
|
||||||
op, "can't decompose bernoulli like ops without sizes or dtype");
|
op, "can't decompose bernoulli like ops without sizes or dtype");
|
||||||
}
|
}
|
||||||
// The `prob` is expected to be a float type tensor.
|
// The `prob` is expected to be a float type tensor.
|
||||||
if (!probType.getDtype().isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(probType.getDtype())) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "probabilities must be a float type tensor");
|
op, "probabilities must be a float type tensor");
|
||||||
}
|
}
|
||||||
|
@ -4582,7 +4568,7 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(op.getGenerator().getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "The generator has to be None because only global default "
|
op, "The generator has to be None because only global default "
|
||||||
"generator is supported");
|
"generator is supported");
|
||||||
|
@ -4640,7 +4626,7 @@ public:
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
Value prob = op.getP();
|
Value prob = op.getP();
|
||||||
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(op.getGenerator().getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "The generator has to be None because only global default "
|
op, "The generator has to be None because only global default "
|
||||||
"generator is supported");
|
"generator is supported");
|
||||||
|
@ -4665,7 +4651,7 @@ public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(AtenExponentialOp op,
|
LogicalResult matchAndRewrite(AtenExponentialOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(op.getGenerator().getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "The generator has to be None because only global default "
|
op, "The generator has to be None because only global default "
|
||||||
"generator is supported");
|
"generator is supported");
|
||||||
|
@ -4706,7 +4692,7 @@ public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(AtenNormalFunctionalOp op,
|
LogicalResult matchAndRewrite(AtenNormalFunctionalOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(op.getGenerator().getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "The generator has to be None because only global default "
|
op, "The generator has to be None because only global default "
|
||||||
"generator is supported");
|
"generator is supported");
|
||||||
|
@ -4984,10 +4970,10 @@ class DecomposeAtenNativeLayerNormOp
|
||||||
|
|
||||||
Value weight = op.getWeight();
|
Value weight = op.getWeight();
|
||||||
Value bias = op.getBias();
|
Value bias = op.getBias();
|
||||||
if (!weight.getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(weight.getType())) {
|
||||||
out = rewriter.create<AtenMulTensorOp>(loc, out.getType(), out, weight);
|
out = rewriter.create<AtenMulTensorOp>(loc, out.getType(), out, weight);
|
||||||
}
|
}
|
||||||
if (!bias.getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(bias.getType())) {
|
||||||
out =
|
out =
|
||||||
rewriter.create<AtenAddTensorOp>(loc, out.getType(), out, bias, one);
|
rewriter.create<AtenAddTensorOp>(loc, out.getType(), out, bias, one);
|
||||||
}
|
}
|
||||||
|
@ -5238,13 +5224,13 @@ class DecomposeAtenNativeGroupNormOp
|
||||||
loc, ListType::get(IntType::get(context)), viewShape);
|
loc, ListType::get(IntType::get(context)), viewShape);
|
||||||
|
|
||||||
Value groupNormOutput = reshapedOutput;
|
Value groupNormOutput = reshapedOutput;
|
||||||
if (!weight.getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(weight.getType())) {
|
||||||
auto weightReshaped = rewriter.create<AtenViewOp>(
|
auto weightReshaped = rewriter.create<AtenViewOp>(
|
||||||
loc, baseType, weight, /*shape=*/viewShapeSizeList);
|
loc, baseType, weight, /*shape=*/viewShapeSizeList);
|
||||||
groupNormOutput = rewriter.create<AtenMulTensorOp>(
|
groupNormOutput = rewriter.create<AtenMulTensorOp>(
|
||||||
loc, inputType, groupNormOutput, weightReshaped);
|
loc, inputType, groupNormOutput, weightReshaped);
|
||||||
}
|
}
|
||||||
if (!bias.getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(bias.getType())) {
|
||||||
auto biasReshaped = rewriter.create<AtenViewOp>(
|
auto biasReshaped = rewriter.create<AtenViewOp>(
|
||||||
loc, baseType, bias, /*shape=*/viewShapeSizeList);
|
loc, baseType, bias, /*shape=*/viewShapeSizeList);
|
||||||
groupNormOutput = rewriter.create<AtenAddTensorOp>(
|
groupNormOutput = rewriter.create<AtenAddTensorOp>(
|
||||||
|
@ -5297,8 +5283,8 @@ class DecomposeAtenNativeBatchNormOp
|
||||||
|
|
||||||
// In the inference mode, the `runningMean` and `runningVar` must not be
|
// In the inference mode, the `runningMean` and `runningVar` must not be
|
||||||
// None.
|
// None.
|
||||||
if (runningMean.getType().isa<Torch::NoneType>() ||
|
if (isa<Torch::NoneType>(runningMean.getType()) ||
|
||||||
runningVar.getType().isa<Torch::NoneType>())
|
isa<Torch::NoneType>(runningVar.getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "running stats must not be None in inference mode");
|
op, "running stats must not be None in inference mode");
|
||||||
|
|
||||||
|
@ -5354,7 +5340,7 @@ class DecomposeAtenNativeBatchNormOp
|
||||||
// 2. bias = bias.view(1, C, 1?, 1?, 1?)
|
// 2. bias = bias.view(1, C, 1?, 1?, 1?)
|
||||||
// 3. output = normalizedInput * weight + bias
|
// 3. output = normalizedInput * weight + bias
|
||||||
Value batchNormOutput = normalizedInput;
|
Value batchNormOutput = normalizedInput;
|
||||||
if (!weight.getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(weight.getType())) {
|
||||||
// Rank of `weight` must be exactly 1.
|
// Rank of `weight` must be exactly 1.
|
||||||
std::optional<unsigned> weightRank = getTensorRank(weight);
|
std::optional<unsigned> weightRank = getTensorRank(weight);
|
||||||
if (!weightRank || *weightRank != 1)
|
if (!weightRank || *weightRank != 1)
|
||||||
|
@ -5364,7 +5350,7 @@ class DecomposeAtenNativeBatchNormOp
|
||||||
batchNormOutput = rewriter.create<AtenMulTensorOp>(
|
batchNormOutput = rewriter.create<AtenMulTensorOp>(
|
||||||
loc, batchNormOutput.getType(), batchNormOutput, weight);
|
loc, batchNormOutput.getType(), batchNormOutput, weight);
|
||||||
}
|
}
|
||||||
if (!bias.getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(bias.getType())) {
|
||||||
// Rank of `bias` must be exactly 1.
|
// Rank of `bias` must be exactly 1.
|
||||||
std::optional<unsigned> biasRank = getTensorRank(bias);
|
std::optional<unsigned> biasRank = getTensorRank(bias);
|
||||||
if (!biasRank || *biasRank != 1)
|
if (!biasRank || *biasRank != 1)
|
||||||
|
@ -5444,7 +5430,7 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
|
||||||
LogicalResult matchAndRewrite(OpTy op,
|
LogicalResult matchAndRewrite(OpTy op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value dtype = op.getDtype();
|
Value dtype = op.getDtype();
|
||||||
if (dtype.getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(dtype.getType())) {
|
||||||
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
|
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
|
||||||
if (!tensorType.hasDtype()) {
|
if (!tensorType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -5518,7 +5504,7 @@ public:
|
||||||
return transposeWeight;
|
return transposeWeight;
|
||||||
};
|
};
|
||||||
|
|
||||||
if (bias.getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(bias.getType())) {
|
||||||
auto weightRank = weightType.getSizes().size();
|
auto weightRank = weightType.getSizes().size();
|
||||||
if (weightRank > 2 || weightRank <= 0)
|
if (weightRank > 2 || weightRank <= 0)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -5622,7 +5608,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenNewFullOp op,
|
LogicalResult matchAndRewrite(AtenNewFullOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value dtype = op.getDtype();
|
Value dtype = op.getDtype();
|
||||||
if (dtype.getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(dtype.getType())) {
|
||||||
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
|
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
|
||||||
if (!tensorType.hasDtype()) {
|
if (!tensorType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -5718,7 +5704,7 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||||
Value dtype = op.getDtype();
|
Value dtype = op.getDtype();
|
||||||
if (dtype.getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(dtype.getType())) {
|
||||||
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
|
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
|
||||||
if (!tensorType.hasDtype()) {
|
if (!tensorType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -5743,9 +5729,9 @@ class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
Value value = op.getValue();
|
Value value = op.getValue();
|
||||||
if (value.getType().isa<Torch::OptionalType>())
|
if (isa<Torch::OptionalType>(value.getType()))
|
||||||
return rewriter.notifyMatchFailure(op, "optional type not supported");
|
return rewriter.notifyMatchFailure(op, "optional type not supported");
|
||||||
if (value.getType().isa<Torch::NoneType>())
|
if (isa<Torch::NoneType>(value.getType()))
|
||||||
value = rewriter.create<Torch::ConstantFloatOp>(
|
value = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
op.getLoc(), rewriter.getF64FloatAttr(0));
|
op.getLoc(), rewriter.getF64FloatAttr(0));
|
||||||
|
|
||||||
|
@ -5765,7 +5751,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenToDtypeLayoutOp op,
|
LogicalResult matchAndRewrite(AtenToDtypeLayoutOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
// TODO: Add support for pinMemory arg equal to `True`.
|
// TODO: Add support for pinMemory arg equal to `True`.
|
||||||
if (!op.getPinMemory().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getPinMemory().getType())) {
|
||||||
bool pinMemory;
|
bool pinMemory;
|
||||||
if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)))
|
if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -5776,7 +5762,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Add support for device arg other than cpu.
|
// TODO: Add support for device arg other than cpu.
|
||||||
if (!op.getDevice().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getDevice().getType())) {
|
||||||
std::string device;
|
std::string device;
|
||||||
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
|
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -5788,7 +5774,7 @@ public:
|
||||||
|
|
||||||
// TODO: Add support for non-strided layout.
|
// TODO: Add support for non-strided layout.
|
||||||
// torch.layout is by default strided i.e. 0.
|
// torch.layout is by default strided i.e. 0.
|
||||||
if (!op.getLayout().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getLayout().getType())) {
|
||||||
int64_t tensorLayout;
|
int64_t tensorLayout;
|
||||||
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -6254,7 +6240,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
||||||
Type newOutputType = outputTensorType.getWithSizesAndDtype(
|
Type newOutputType = outputTensorType.getWithSizesAndDtype(
|
||||||
outputTensorType.getSizes(), rewriter.getF64Type());
|
outputTensorType.getSizes(), rewriter.getF64Type());
|
||||||
if (!inputTensorTy.hasDtype() ||
|
if (!inputTensorTy.hasDtype() ||
|
||||||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
!isa<mlir::FloatType>(inputTensorTy.getDtype())) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "support floating-point type input only");
|
op, "support floating-point type input only");
|
||||||
}
|
}
|
||||||
|
@ -6391,14 +6377,14 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
int64_t correctionValInt;
|
int64_t correctionValInt;
|
||||||
double correctionValFloat = 1.0;
|
double correctionValFloat = 1.0;
|
||||||
if (!op.getCorrection().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getCorrection().getType())) {
|
||||||
if (op.getCorrection().getType().isa<Torch::FloatType>()) {
|
if (isa<Torch::FloatType>(op.getCorrection().getType())) {
|
||||||
if (!matchPattern(op.getCorrection(),
|
if (!matchPattern(op.getCorrection(),
|
||||||
m_TorchConstantFloat(&correctionValFloat)))
|
m_TorchConstantFloat(&correctionValFloat)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only support constant int or float correction value for "
|
op, "Only support constant int or float correction value for "
|
||||||
"aten.var");
|
"aten.var");
|
||||||
} else if (op.getCorrection().getType().isa<Torch::IntType>()) {
|
} else if (isa<Torch::IntType>(op.getCorrection().getType())) {
|
||||||
if (!matchPattern(op.getCorrection(),
|
if (!matchPattern(op.getCorrection(),
|
||||||
m_TorchConstantInt(&correctionValInt)))
|
m_TorchConstantInt(&correctionValInt)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -6525,11 +6511,9 @@ public:
|
||||||
if (!inputType.hasSizes())
|
if (!inputType.hasSizes())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Expected the input tensor to have sizes");
|
op, "Expected the input tensor to have sizes");
|
||||||
BaseTensorType subType =
|
BaseTensorType subType = cast<BaseTensorType>(
|
||||||
inputType
|
inputType.getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()),
|
||||||
.getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()),
|
resultType.getOptionalDtype()));
|
||||||
resultType.getOptionalDtype())
|
|
||||||
.cast<BaseTensorType>();
|
|
||||||
|
|
||||||
Value sub =
|
Value sub =
|
||||||
createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget());
|
createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget());
|
||||||
|
@ -6566,7 +6550,7 @@ public:
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||||
Value ord = op.getP();
|
Value ord = op.getP();
|
||||||
if (ord.getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(ord.getType())) {
|
||||||
ord = rewriter.create<Torch::ConstantFloatOp>(
|
ord = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
loc, rewriter.getF64FloatAttr(2.0));
|
loc, rewriter.getF64FloatAttr(2.0));
|
||||||
}
|
}
|
||||||
|
@ -6609,10 +6593,8 @@ public:
|
||||||
loc, rewriter.getF64FloatAttr((double)cstHigh));
|
loc, rewriter.getF64FloatAttr((double)cstHigh));
|
||||||
|
|
||||||
BaseTensorType floatResultType =
|
BaseTensorType floatResultType =
|
||||||
resultTensorType
|
cast<BaseTensorType>(resultTensorType.getWithSizesAndDtype(
|
||||||
.getWithSizesAndDtype(resultTensorType.getSizes(),
|
resultTensorType.getSizes(), rewriter.getF32Type()));
|
||||||
rewriter.getF32Type())
|
|
||||||
.cast<BaseTensorType>();
|
|
||||||
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
||||||
loc, floatResultType, op.getSize(), /*dtype=*/none,
|
loc, floatResultType, op.getSize(), /*dtype=*/none,
|
||||||
/*layout=*/op.getLayout(),
|
/*layout=*/op.getLayout(),
|
||||||
|
@ -6704,7 +6686,7 @@ public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(PrimsVarOp op,
|
LogicalResult matchAndRewrite(PrimsVarOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
if (!op.getOutputDtype().getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(op.getOutputDtype().getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Unimplemented non-None dtype for prims::var op");
|
op, "Unimplemented non-None dtype for prims::var op");
|
||||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
||||||
|
@ -6816,7 +6798,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenRandnLikeOp op,
|
LogicalResult matchAndRewrite(AtenRandnLikeOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
||||||
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType())) {
|
||||||
int64_t memoryFormat;
|
int64_t memoryFormat;
|
||||||
if (!matchPattern(op.getMemoryFormat(),
|
if (!matchPattern(op.getMemoryFormat(),
|
||||||
m_TorchConstantInt(&memoryFormat)))
|
m_TorchConstantInt(&memoryFormat)))
|
||||||
|
@ -6913,8 +6895,8 @@ public:
|
||||||
op.getDevice(), op.getPinMemory());
|
op.getDevice(), op.getPinMemory());
|
||||||
// calculate (end - start) / (steps - 1)
|
// calculate (end - start) / (steps - 1)
|
||||||
Value sub;
|
Value sub;
|
||||||
if (op.getEnd().getType().isa<Torch::FloatType>() ||
|
if (isa<Torch::FloatType>(op.getEnd().getType()) ||
|
||||||
op.getStart().getType().isa<Torch::FloatType>()) {
|
isa<Torch::FloatType>(op.getStart().getType())) {
|
||||||
sub = rewriter.create<AtenSubOp>(loc, Torch::FloatType::get(context),
|
sub = rewriter.create<AtenSubOp>(loc, Torch::FloatType::get(context),
|
||||||
op.getEnd(), op.getStart());
|
op.getEnd(), op.getStart());
|
||||||
} else {
|
} else {
|
||||||
|
@ -6930,7 +6912,7 @@ public:
|
||||||
}
|
}
|
||||||
// to dtype
|
// to dtype
|
||||||
Value result;
|
Value result;
|
||||||
if (!op.getDtype().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getDtype().getType())) {
|
||||||
result = rewriter.create<AtenToDtypeOp>(
|
result = rewriter.create<AtenToDtypeOp>(
|
||||||
loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal,
|
loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal,
|
||||||
/*copy=*/falseVal, /*memory_format=*/none);
|
/*copy=*/falseVal, /*memory_format=*/none);
|
||||||
|
@ -7344,11 +7326,8 @@ public:
|
||||||
|
|
||||||
auto selfType = cast<BaseTensorType>(self.getType());
|
auto selfType = cast<BaseTensorType>(self.getType());
|
||||||
auto indexType = cast<BaseTensorType>(index.getType());
|
auto indexType = cast<BaseTensorType>(index.getType());
|
||||||
BaseTensorType srcType =
|
BaseTensorType srcType = cast<BaseTensorType>(selfType.getWithSizesAndDtype(
|
||||||
selfType
|
indexType.getOptionalSizes(), selfType.getOptionalDtype()));
|
||||||
.getWithSizesAndDtype(indexType.getOptionalSizes(),
|
|
||||||
selfType.getOptionalDtype())
|
|
||||||
.cast<BaseTensorType>();
|
|
||||||
Value src =
|
Value src =
|
||||||
createInitTensor(rewriter, loc, srcType, op.getValue(), sizeList);
|
createInitTensor(rewriter, loc, srcType, op.getValue(), sizeList);
|
||||||
rewriter.replaceOpWithNewOp<AtenScatterSrcOp>(op, op.getType(), self,
|
rewriter.replaceOpWithNewOp<AtenScatterSrcOp>(op, op.getType(), self,
|
||||||
|
@ -7372,7 +7351,7 @@ public:
|
||||||
"expected result type to have dtype");
|
"expected result type to have dtype");
|
||||||
}
|
}
|
||||||
// TODO: support complex type in future.
|
// TODO: support complex type in future.
|
||||||
if (outType.getDtype().isa<mlir::ComplexType>()) {
|
if (isa<mlir::ComplexType>(outType.getDtype())) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"doesn't support complex type now");
|
"doesn't support complex type now");
|
||||||
}
|
}
|
||||||
|
@ -7488,7 +7467,7 @@ static FailureOr<Value> createNewIndices(Operation *op,
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
MLIRContext *context = op->getContext();
|
MLIRContext *context = op->getContext();
|
||||||
|
|
||||||
auto inputType = input.getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasSizes()) {
|
if (!inputType.hasSizes()) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
@ -7497,7 +7476,7 @@ static FailureOr<Value> createNewIndices(Operation *op,
|
||||||
|
|
||||||
int64_t maxIndexRank = 0;
|
int64_t maxIndexRank = 0;
|
||||||
for (auto index : oldIndices) {
|
for (auto index : oldIndices) {
|
||||||
auto indexType = index.getType().dyn_cast<BaseTensorType>();
|
auto indexType = dyn_cast<BaseTensorType>(index.getType());
|
||||||
if (!indexType) // None index
|
if (!indexType) // None index
|
||||||
continue;
|
continue;
|
||||||
if (!indexType.hasSizes())
|
if (!indexType.hasSizes())
|
||||||
|
@ -7586,15 +7565,13 @@ public:
|
||||||
int64_t inputRank = inputSizes.size();
|
int64_t inputRank = inputSizes.size();
|
||||||
|
|
||||||
auto isTensor = [](Value v) {
|
auto isTensor = [](Value v) {
|
||||||
return v.getType().isa<Torch::BaseTensorType>();
|
return isa<Torch::BaseTensorType>(v.getType());
|
||||||
};
|
};
|
||||||
|
|
||||||
// directly replace aten.Index.Tensor with aten.index.Tensor_hacked_twin
|
// directly replace aten.Index.Tensor with aten.index.Tensor_hacked_twin
|
||||||
if (llvm::all_of(indices, isTensor)) {
|
if (llvm::all_of(indices, isTensor)) {
|
||||||
// By default, we regard the first index type as the list element type.
|
// By default, we regard the first index type as the list element type.
|
||||||
auto indexElemType = indices[0]
|
auto indexElemType = cast<BaseTensorType>(indices[0].getType())
|
||||||
.getType()
|
|
||||||
.template cast<BaseTensorType>()
|
|
||||||
.getWithSizesAndDtype(std::nullopt, nullptr);
|
.getWithSizesAndDtype(std::nullopt, nullptr);
|
||||||
auto newIndices = rewriter.create<PrimListConstructOp>(
|
auto newIndices = rewriter.create<PrimListConstructOp>(
|
||||||
loc, Torch::ListType::get(indexElemType), indices);
|
loc, Torch::ListType::get(indexElemType), indices);
|
||||||
|
@ -7684,7 +7661,7 @@ public:
|
||||||
"failed to get elements of `indices`");
|
"failed to get elements of `indices`");
|
||||||
|
|
||||||
auto input = op.getSelf();
|
auto input = op.getSelf();
|
||||||
auto inputType = input.getType().template cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasSizes()) {
|
if (!inputType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only input with shape information is supported");
|
op, "only input with shape information is supported");
|
||||||
|
@ -7693,15 +7670,13 @@ public:
|
||||||
int64_t inputRank = inputSizes.size();
|
int64_t inputRank = inputSizes.size();
|
||||||
|
|
||||||
auto isTensor = [](Value v) {
|
auto isTensor = [](Value v) {
|
||||||
return v.getType().isa<Torch::BaseTensorType>();
|
return isa<Torch::BaseTensorType>(v.getType());
|
||||||
};
|
};
|
||||||
|
|
||||||
// directly replace current op with aten.index_put.hacked_twin
|
// directly replace current op with aten.index_put.hacked_twin
|
||||||
if (llvm::all_of(indices, isTensor)) {
|
if (llvm::all_of(indices, isTensor)) {
|
||||||
// By default, we regard the first index type as the list element type.
|
// By default, we regard the first index type as the list element type.
|
||||||
auto indexElemType = indices[0]
|
auto indexElemType = cast<BaseTensorType>(indices[0].getType())
|
||||||
.getType()
|
|
||||||
.template cast<BaseTensorType>()
|
|
||||||
.getWithSizesAndDtype(std::nullopt, nullptr);
|
.getWithSizesAndDtype(std::nullopt, nullptr);
|
||||||
auto newIndex = rewriter.create<PrimListConstructOp>(
|
auto newIndex = rewriter.create<PrimListConstructOp>(
|
||||||
loc, Torch::ListType::get(indexElemType), indices);
|
loc, Torch::ListType::get(indexElemType), indices);
|
||||||
|
@ -7831,7 +7806,7 @@ public:
|
||||||
|
|
||||||
// default ord value is 2 for vector_norm
|
// default ord value is 2 for vector_norm
|
||||||
auto ord = op.getOrd();
|
auto ord = op.getOrd();
|
||||||
if (ord.getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(ord.getType())) {
|
||||||
ord = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
|
ord = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenLinalgVectorNormOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenLinalgVectorNormOp>(
|
||||||
|
|
|
@ -63,8 +63,8 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
static bool isTypeTriviallySafe(Type type) {
|
static bool isTypeTriviallySafe(Type type) {
|
||||||
return type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
|
return isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
|
||||||
Torch::StringType, Torch::NoneType, Torch::ValueTensorType>();
|
Torch::StringType, Torch::NoneType, Torch::ValueTensorType>(type);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isUseTreatedWithValueSemantics(OpOperand &use) {
|
static bool isUseTreatedWithValueSemantics(OpOperand &use) {
|
||||||
|
|
|
@ -36,8 +36,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
static LogicalResult checkType(Operation *op, Type type,
|
static LogicalResult checkType(Operation *op, Type type,
|
||||||
bool actuallyEmitDiagnostics) {
|
bool actuallyEmitDiagnostics) {
|
||||||
// Allow various scalar types that backends are expected to be able to handle.
|
// Allow various scalar types that backends are expected to be able to handle.
|
||||||
if (type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
|
if (isa<Torch::IntType, Torch::FloatType, Torch::BoolType, Torch::DeviceType>(
|
||||||
Torch::DeviceType>())
|
type))
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Backends are not expected to support dynamic computations on these types,
|
// Backends are not expected to support dynamic computations on these types,
|
||||||
|
|
|
@ -187,7 +187,7 @@ public:
|
||||||
auto it = originalReturnTypes.find(i);
|
auto it = originalReturnTypes.find(i);
|
||||||
if (it == originalReturnTypes.end())
|
if (it == originalReturnTypes.end())
|
||||||
continue;
|
continue;
|
||||||
auto originalType = it->second.cast<NonValueTensorType>();
|
auto originalType = cast<NonValueTensorType>(it->second);
|
||||||
rewriter.setInsertionPoint(returnOp);
|
rewriter.setInsertionPoint(returnOp);
|
||||||
Value newReturnValue = copyTensorToType(rewriter, returnOp->getLoc(),
|
Value newReturnValue = copyTensorToType(rewriter, returnOp->getLoc(),
|
||||||
originalType, operand.get());
|
originalType, operand.get());
|
||||||
|
@ -350,7 +350,7 @@ public:
|
||||||
auto it = originalTypes.find(operand.get());
|
auto it = originalTypes.find(operand.get());
|
||||||
if (it == originalTypes.end())
|
if (it == originalTypes.end())
|
||||||
continue;
|
continue;
|
||||||
auto originalType = it->second.cast<BaseTensorType>();
|
auto originalType = cast<BaseTensorType>(it->second);
|
||||||
rewriter.setInsertionPoint(op);
|
rewriter.setInsertionPoint(op);
|
||||||
Value newReturnValue = copyTensorToType(rewriter, op->getLoc(),
|
Value newReturnValue = copyTensorToType(rewriter, op->getLoc(),
|
||||||
originalType, operand.get());
|
originalType, operand.get());
|
||||||
|
|
|
@ -118,7 +118,7 @@ public:
|
||||||
if (auto optionalType =
|
if (auto optionalType =
|
||||||
dyn_cast<OptionalType>(listType.getContainedType())) {
|
dyn_cast<OptionalType>(listType.getContainedType())) {
|
||||||
if (!llvm::all_of(listConstruct.getElements(), [](Value val) {
|
if (!llvm::all_of(listConstruct.getElements(), [](Value val) {
|
||||||
return val.getType().isa<NonValueTensorType, Torch::NoneType>();
|
return isa<NonValueTensorType, Torch::NoneType>(val.getType());
|
||||||
})) {
|
})) {
|
||||||
rewriter.cancelOpModification(op);
|
rewriter.cancelOpModification(op);
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
|
|
@ -81,7 +81,7 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable(
|
||||||
if (name.starts_with("valsem."))
|
if (name.starts_with("valsem."))
|
||||||
name = name.drop_front(strlen("valsem."));
|
name = name.drop_front(strlen("valsem."));
|
||||||
if (isa<OperatorOp>(op))
|
if (isa<OperatorOp>(op))
|
||||||
name = cast<OperatorOp>(op)->getAttr("name").cast<StringAttr>().getValue();
|
name = cast<StringAttr>(cast<OperatorOp>(op)->getAttr("name")).getValue();
|
||||||
std::string libFuncName =
|
std::string libFuncName =
|
||||||
(getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str();
|
(getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str();
|
||||||
auto libFunc = library.lookupSymbol<func::FuncOp>(libFuncName);
|
auto libFunc = library.lookupSymbol<func::FuncOp>(libFuncName);
|
||||||
|
@ -191,8 +191,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
|
||||||
// to match the library function signature.
|
// to match the library function signature.
|
||||||
if (auto unionType = dyn_cast<Torch::UnionType>(desiredType)) {
|
if (auto unionType = dyn_cast<Torch::UnionType>(desiredType)) {
|
||||||
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
|
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
|
||||||
return containedType
|
return isa<Torch::IntType, Torch::FloatType, Torch::NoneType>(
|
||||||
.isa<Torch::IntType, Torch::FloatType, Torch::NoneType>();
|
containedType);
|
||||||
}))
|
}))
|
||||||
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
|
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
|
||||||
}
|
}
|
||||||
|
|
|
@ -179,11 +179,10 @@ public:
|
||||||
"should have concrete Scalar Type.");
|
"should have concrete Scalar Type.");
|
||||||
}
|
}
|
||||||
Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType());
|
Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType());
|
||||||
auto impliedTypeFromInputType =
|
auto impliedTypeFromInputType = cast<BaseTensorType>(
|
||||||
cast<BaseTensorType>(originalResultType)
|
cast<BaseTensorType>(originalResultType)
|
||||||
.getWithSizesAndDtype(originalResultType.getOptionalSizes(),
|
.getWithSizesAndDtype(originalResultType.getOptionalSizes(),
|
||||||
inputType)
|
inputType));
|
||||||
.cast<BaseTensorType>();
|
|
||||||
|
|
||||||
op.getResult().setType(impliedTypeFromInputType);
|
op.getResult().setType(impliedTypeFromInputType);
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -97,11 +97,10 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
|
||||||
}
|
}
|
||||||
|
|
||||||
auto originalResultType = cast<BaseTensorType>(result.getType());
|
auto originalResultType = cast<BaseTensorType>(result.getType());
|
||||||
auto impliedTypesFromShape =
|
auto impliedTypesFromShape = cast<BaseTensorType>(
|
||||||
cast<BaseTensorType>(originalResultType)
|
cast<BaseTensorType>(originalResultType)
|
||||||
.getWithSizesAndDtype(ArrayRef(sizes),
|
.getWithSizesAndDtype(ArrayRef(sizes),
|
||||||
originalResultType.getOptionalDtype())
|
originalResultType.getOptionalDtype()));
|
||||||
.cast<BaseTensorType>();
|
|
||||||
|
|
||||||
return updateCalculateOpResultTypes(op, resultNum, impliedTypesFromShape,
|
return updateCalculateOpResultTypes(op, resultNum, impliedTypesFromShape,
|
||||||
rewriter);
|
rewriter);
|
||||||
|
|
|
@ -74,7 +74,7 @@ LogicalResult FromBuiltinTensorOp::verify() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult FromI1Op::fold(FoldAdaptor adaptor) {
|
OpFoldResult FromI1Op::fold(FoldAdaptor adaptor) {
|
||||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::BoolAttr>();
|
auto attr = dyn_cast_or_null<mlir::BoolAttr>(adaptor.getOperand());
|
||||||
if (attr) {
|
if (attr) {
|
||||||
return attr;
|
return attr;
|
||||||
} else {
|
} else {
|
||||||
|
@ -87,7 +87,7 @@ OpFoldResult FromI1Op::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult ToI1Op::fold(FoldAdaptor adaptor) {
|
OpFoldResult ToI1Op::fold(FoldAdaptor adaptor) {
|
||||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::BoolAttr>();
|
auto attr = dyn_cast_or_null<mlir::BoolAttr>(adaptor.getOperand());
|
||||||
if (attr) {
|
if (attr) {
|
||||||
return attr;
|
return attr;
|
||||||
} else {
|
} else {
|
||||||
|
@ -100,7 +100,7 @@ OpFoldResult ToI1Op::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) {
|
OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) {
|
||||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>();
|
auto attr = dyn_cast_or_null<mlir::IntegerAttr>(adaptor.getOperand());
|
||||||
if (attr) {
|
if (attr) {
|
||||||
return attr;
|
return attr;
|
||||||
} else {
|
} else {
|
||||||
|
@ -113,7 +113,7 @@ OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) {
|
OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) {
|
||||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>();
|
auto attr = dyn_cast_or_null<mlir::IntegerAttr>(adaptor.getOperand());
|
||||||
if (attr) {
|
if (attr) {
|
||||||
return attr;
|
return attr;
|
||||||
} else {
|
} else {
|
||||||
|
@ -126,7 +126,7 @@ OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) {
|
OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) {
|
||||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>();
|
auto attr = dyn_cast_or_null<mlir::FloatAttr>(adaptor.getOperand());
|
||||||
if (attr) {
|
if (attr) {
|
||||||
return attr;
|
return attr;
|
||||||
} else {
|
} else {
|
||||||
|
@ -139,7 +139,7 @@ OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult FromF64Op::fold(FoldAdaptor adaptor) {
|
OpFoldResult FromF64Op::fold(FoldAdaptor adaptor) {
|
||||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>();
|
auto attr = dyn_cast_or_null<mlir::FloatAttr>(adaptor.getOperand());
|
||||||
if (attr) {
|
if (attr) {
|
||||||
return attr;
|
return attr;
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -91,7 +91,7 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target,
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
// Other input type to be converted to i64 are handled by other
|
// Other input type to be converted to i64 are handled by other
|
||||||
// materializers.
|
// materializers.
|
||||||
if (!inputs[0].getType().isa<Torch::IntType>())
|
if (!isa<Torch::IntType>(inputs[0].getType()))
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
return builder.create<ToI64Op>(loc, inputs[0]).getResult();
|
return builder.create<ToI64Op>(loc, inputs[0]).getResult();
|
||||||
|
@ -145,7 +145,7 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target,
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
// Other input type to be converted to i64 are handled by other
|
// Other input type to be converted to i64 are handled by other
|
||||||
// materializers.
|
// materializers.
|
||||||
if (!inputs[0].getType().isa<Torch::GeneratorType>())
|
if (!isa<Torch::GeneratorType>(inputs[0].getType()))
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
return builder.create<GeneratorToI64Op>(loc, inputs[0]).getResult();
|
return builder.create<GeneratorToI64Op>(loc, inputs[0]).getResult();
|
||||||
|
|
|
@ -56,7 +56,7 @@ void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); }
|
||||||
static bool isArgMemRefTypeValid(Type type) {
|
static bool isArgMemRefTypeValid(Type type) {
|
||||||
if (auto memRefType = dyn_cast<MemRefType>(type)) {
|
if (auto memRefType = dyn_cast<MemRefType>(type)) {
|
||||||
Type elemTy = memRefType.getElementType();
|
Type elemTy = memRefType.getElementType();
|
||||||
if (elemTy.isa<Float16Type, Float32Type, Float64Type>()) {
|
if (isa<Float16Type, Float32Type, Float64Type>(elemTy)) {
|
||||||
return true;
|
return true;
|
||||||
} else if (auto integerTy = dyn_cast<IntegerType>(elemTy)) {
|
} else if (auto integerTy = dyn_cast<IntegerType>(elemTy)) {
|
||||||
if (integerTy.isSignlessInteger(64))
|
if (integerTy.isSignlessInteger(64))
|
||||||
|
@ -70,7 +70,7 @@ static bool isArgMemRefTypeValid(Type type) {
|
||||||
if (integerTy.isSignlessInteger(1))
|
if (integerTy.isSignlessInteger(1))
|
||||||
return true;
|
return true;
|
||||||
} else if (auto complexTy = dyn_cast<ComplexType>(elemTy)) {
|
} else if (auto complexTy = dyn_cast<ComplexType>(elemTy)) {
|
||||||
return complexTy.getElementType().isa<Float32Type, Float64Type>();
|
return isa<Float32Type, Float64Type>(complexTy.getElementType());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|
Loading…
Reference in New Issue