[TORCH] The torch definition related to aten.gelu has changed.

New str argument approximation is added.
pull/602/head
Prashant Kumar 2022-02-17 15:37:14 +00:00
parent ed9bd556b3
commit abbde7d439
4 changed files with 57 additions and 23 deletions

View File

@ -1342,14 +1342,15 @@ def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::gelu : (Tensor) -> (Tensor)`";
let summary = "Generated op for `aten::gelu : (Tensor, str) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
AnyTorchTensorType:$self,
Torch_StringType:$approximate
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
let assemblyFormat = "$self `,` $approximate attr-dict `:` qualified(type($self)) `,` qualified(type($approximate)) `->` qualified(type($result))";
}
def Torch_AtenPowTensorScalarOp : Torch_Op<"aten.pow.Tensor_Scalar", [
@ -1383,6 +1384,20 @@ def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [
let assemblyFormat = "$grad_output `,` $self `,` $threshold attr-dict `:` qualified(type($grad_output)) `,` qualified(type($self)) `,` qualified(type($threshold)) `->` qualified(type($result))";
}
def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::fill_.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $value attr-dict `:` qualified(type($self)) `,` qualified(type($value)) `->` qualified(type($result))";
}
def Torch_AtenUniform_Op : Torch_Op<"aten.uniform_", [
AllowsTypeRefinement
]> {
@ -2103,20 +2118,6 @@ def Torch_AtenSizeOp : Torch_Op<"aten.size", [
let hasCanonicalizer = 1;
}
def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::fill_.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $value attr-dict `:` qualified(type($self)) `,` qualified(type($value)) `->` qualified(type($result))";
}
def Torch_AtenBoolTensorOp : Torch_Op<"aten.Bool.Tensor", [
AllowsTypeRefinement,
HasValueSemantics
@ -3830,15 +3831,16 @@ def Torch_AtenGeluBackwardOp : Torch_Op<"aten.gelu_backward", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::gelu_backward : (Tensor, Tensor) -> (Tensor)`";
let summary = "Generated op for `aten::gelu_backward : (Tensor, Tensor, str) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad,
AnyTorchTensorType:$self
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self,
Torch_StringType:$approximate
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$grad `,` $self attr-dict `:` qualified(type($grad)) `,` qualified(type($self)) `->` qualified(type($result))";
let assemblyFormat = "$grad_output `,` $self `,` $approximate attr-dict `:` qualified(type($grad_output)) `,` qualified(type($self)) `,` qualified(type($approximate)) `->` qualified(type($result))";
}
def Torch_Aten_LogSoftmaxBackwardDataOp : Torch_Op<"aten._log_softmax_backward_data", [

View File

@ -60,6 +60,21 @@ struct torch_constant_float_op_binder {
return false;
}
};
struct torch_constant_str_op_binder {
std::string &bind_value;
/// Creates a matcher instance that binds the value to bv if match succeeds.
torch_constant_str_op_binder(std::string &bv) : bind_value(bv) {}
bool match(Operation *op) {
if (auto constantString = dyn_cast<Torch::ConstantStrOp>(op)) {
bind_value = constantString.value().str();
return true;
}
return false;
}
};
} // namespace detail
/// Matches the integer stored in a `torch.constant.bool`.
@ -74,6 +89,12 @@ m_TorchConstantFloat(double *bind_value) {
return detail::torch_constant_float_op_binder(bind_value);
}
/// Matches the string value stored in a `torch.constant.str`.
inline detail::torch_constant_str_op_binder
m_TorchConstantStr(std::string &bind_value) {
return detail::torch_constant_str_op_binder(bind_value);
}
namespace detail {
/// Matches the bool stored in a `torch.constant.bool`.
struct torch_constant_bool_op_binder {

View File

@ -1630,6 +1630,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
gelu.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
// TODO: Take approximation into account.
std::string approximate;
if (!matchPattern(gelu.approximate(), m_TorchConstantStr(approximate)) ||
approximate != "none")
return nullptr;
Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[0]);
return b.create<arith::MulFOp>(loc, payloadArgs[0], cdf);
}
@ -1641,6 +1646,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
geluBackward.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
// TODO: Take approximation into account.
std::string approximate;
if (!matchPattern(geluBackward.approximate(),
m_TorchConstantStr(approximate)) ||
approximate != "none")
return nullptr;
Type elementType = payloadArgs[1].getType();
Value cstAlpha0 = b.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 1.12837916709551257390));

View File

@ -497,7 +497,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::gelu : (Tensor) -> (Tensor)")
emit("aten::gelu : (Tensor, str) -> (Tensor)")
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
@ -693,7 +693,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
# backprop ops
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
emit("aten::tanh_backward : (Tensor, Tensor) -> (Tensor)")
emit("aten::gelu_backward : (Tensor, Tensor) -> (Tensor)")
emit("aten::gelu_backward : (Tensor, Tensor, str) -> (Tensor)")
emit("aten::_log_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")