New ops support & enhancements (#1494)

* New ops support & enhancements

* Enabled xfail ltc tests
pull/1485/head
Gleb Kazantaev 2022-10-14 10:28:21 -04:00 committed by GitHub
parent 7df9179f85
commit bdb5083d33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 57 additions and 43 deletions

View File

@ -5,7 +5,6 @@ blacklist:
- index.Tensor # Error: TODO not sure if there are other valid types to handle here
- index_put # Error: TODO not sure if there are other valid types to handle here
- index_put_ # Error: TODO not sure if there are other valid types to handle here
- stack # Error: TODO not sure if there are other valid types to handle here
# Additional ops which autogen is supported for but don't compile yet
- _convolution
@ -20,6 +19,10 @@ blacklist:
- rsub
- slice.Tensor # Disabled in favour of slice_copy.Tensor
- zeros
- ones
- arange
- arange.start
- arange.start_step
# Disabled in favour of functionalized alternatives
- _reshape_alias
@ -33,7 +36,10 @@ blacklist:
- unsqueeze
- view
# whitelist:
whitelist:
# Enabled for consistency with TS backend
- arange.start_out
# List of ops to autogen even if not supported by Torch-MLIR explicitly
#- split_copy.Tensor
#- split_with_sizes_copy
@ -74,7 +80,6 @@ supported:
symint:
- empty.memory_format
- new_empty_strided
- expand
- expand_copy
- narrow_copy
- slice_backward
@ -98,6 +103,7 @@ non_native:
properties:
- ShapeCompute
- TreatScalarsAsConstants
- func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor
- func: cast(Tensor input, ScalarType dtype, ScalarType? stype) -> Tensor
opkind: ltc_cast
properties:

View File

@ -502,27 +502,7 @@ LTC_XFAIL_SET = {
"DropoutTrainModule_basic",
"ElementwiseAtenFloorDivideBroadcastModule_basic",
"ElementwiseAtenFloorDivideModule_basic",
"ElementwiseWhereScalarModule_basic",
"ElementwiseWhereScalarOtherModule_basic",
"ElementwiseWhereScalarSelfModule_basic",
"ElementwiseWhereSelfModule_basic",
"EmptyLikeMemoryFormatModule_basic",
"EmptyLikeModule_defaultDtype",
"EmptyLikeModule_falsePinMemory",
"EmptyLikeModule_float",
"EmptyLikeModule_int",
"EqIntModule_basic",
"Fill_TensorFloat64WithFloat32_basic",
"Fill_TensorFloat64WithFloat64_basic",
"Fill_TensorFloat64WithInt64_basic",
"FullLikeModuleDefaultDtype_basic",
"FullLikeModuleFalsePinMemory_basic",
"FullLikeModuleFloat2D_basic",
"FullLikeModuleFloat3DStatic_basic",
"FullLikeModuleFloat3D_basic",
"FullLikeModuleInt2DStatic_basic",
"FullLikeModuleInt2D_basic",
"FullLikeModuleInt3D_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
"GeIntModule_basic",
@ -582,25 +562,6 @@ LTC_XFAIL_SET = {
"DivIntModule_basic",
"NeFloatIntModule_basic",
"NeIntModule_basic",
"NewEmptyModuleDefaultDtype_basic",
"NewEmptyModuleFalsePinMemory_basic",
"NewEmptyModuleFloat2D_basic",
"NewEmptyModuleFloat3D_basic",
"NewEmptyModuleInt2D_basic",
"NewEmptyModuleInt3D_basic",
"NewEmptyModuleLayoutIntDtype_basic",
"NewEmptyModuleNonDefaultFloatDtype_basic",
"NewEmptyModuleNonDefaultIntDtype_basic",
"NewOnesModuleDefaultDtype_basic",
"NewOnesModuleFalsePinMemory_basic",
"NewOnesModuleFloat2D_basic",
"NewOnesModuleFloat3D_basic",
"NewOnesModuleInt2D_basic",
"NewOnesModuleInt3D_basic",
"OnesLikeModule_defaultDtype",
"OnesLikeModule_falsePinMemory",
"OnesLikeModule_float",
"OnesLikeModule_int",
"QuantizedMLP_basic",
"RandLikeDtypeModule_basic",
"RandLikeModule_basic",

View File

@ -6711,6 +6711,29 @@ def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [
}];
}
def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::FloatImplicit : (Tensor) -> (float)`";
let arguments = (ins
AnyTorchTensorType:$a
);
let results = (outs
Torch_FloatType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFloatImplicitOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenFloatImplicitOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -33,7 +33,7 @@ namespace lazy {
struct TorchMlirIrBuilder : IrBuilder {
NodePtr MakeDeviceData(const std::shared_ptr<BackendData>& data) const override { return MakeNode<DeviceData>(data); }
NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type) const override { return MakeNode<Scalar>(value, type); }
NodePtr MakeExpand(const Value& input0, const std::vector<int64_t>& size, const bool& is_scalar_expand) const override { UNIMPLEMENTED_FUNCTION_ERROR(); }
NodePtr MakeExpand(const Value& input0, const std::vector<int64_t>& size, const bool& is_scalar_expand) const override { return MakeNode<Expand>(input0, size, is_scalar_expand); }
NodePtr MakeView(const Value& input0, const std::vector<int64_t>& output_size) const override { UNIMPLEMENTED_FUNCTION_ERROR(); }
NodePtr MakeCast(const Value& input0, const at::ScalarType& dtype, const c10::optional<at::ScalarType>& stype = c10::nullopt) const override { return MakeNode<Cast>(input0, dtype, stype); }
NodePtr MakeTensorList(const OpList& inputs) const override { return MakeNode<TorchMlirTensorList>(inputs); }

View File

@ -284,5 +284,22 @@ TorchMlirOpVector Scalar::Lower(
return {loctx->graph()->insertConstant(at::scalar_tensor(value, options))};
}
TorchMlirOpVector Expand::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
arguments.emplace_back(size);
auto expand_out = LowerBuiltin(this, function, arguments);
if (is_scalar_expand) {
// The aten::expand operations sets all strides to 0 when the original is
// of rank 0. This leads to false positives when checking for internal
// memory overlap, because at::has_internal_overlap returns
// MemOverlap::YES when a stride is set to 0.
TORCH_CHECK_EQ(expand_out.size(), 1);
return {GenerateClone(expand_out.front(), function)};
}
return expand_out;
}
} // namespace lazy
} // namespace torch

View File

@ -42,6 +42,12 @@ std::vector<torch::lazy::Shape> compute_shape_hardtanh(
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_where(
const at::Tensor & condition,
const at::Tensor & self,
const at::Tensor & other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
} // namespace lazy
} // namespace torch

View File

@ -488,6 +488,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
emit("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)")
emit("aten::IntImplicit : (Tensor) -> (int)")
emit("aten::FloatImplicit : (Tensor) -> (float)")
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)