make (array) of ArrayObject return the contents
[hiphop-php.git] / hphp / util / db-query.cpp
blob7acfb0206aed570f9850f28eb1aef68ed782c474
1 /*
2 +----------------------------------------------------------------------+
3 | HipHop for PHP |
4 +----------------------------------------------------------------------+
5 | Copyright (c) 2010-2013 Facebook, Inc. (http://www.facebook.com) |
6 +----------------------------------------------------------------------+
7 | This source file is subject to version 3.01 of the PHP license, |
8 | that is bundled with this package in the file LICENSE, and is |
9 | available through the world-wide-web at the following url: |
10 | http://www.php.net/license/3_01.txt |
11 | If you did not receive a copy of the PHP license and are unable to |
12 | obtain it through the world-wide-web, please send a note to |
13 | license@php.net so we can mail you a copy immediately. |
14 +----------------------------------------------------------------------+
17 #include "hphp/util/db-query.h"
18 #include "hphp/util/db-conn.h"
19 #include "hphp/util/db-dataset.h"
20 #include "util.h"
22 namespace HPHP {
23 ///////////////////////////////////////////////////////////////////////////////
25 DBQuery::DBQuery(DBConn *conn, const char *sql, ...)
26 : m_conn(conn), m_insert(false) {
27 assert(sql && *sql);
28 va_list ap;
29 va_start(ap, sql);
30 Util::string_vsnprintf(m_base, sql, ap);
31 va_end(ap);
34 ///////////////////////////////////////////////////////////////////////////////
36 void DBQuery::filterBy(const char *fmt, Op op /* = And */) {
37 assert(fmt && *fmt);
39 if (m_where.empty()) {
40 m_where = " where ";
41 } else {
42 switch (op) {
43 case And: m_where += " and "; break;
44 case Or: m_where += " or "; break;
45 default: break;
48 m_where += fmt;
51 void DBQuery::filterBy(const char *fmt, const char *value, Op op /* = And */) {
52 assert(m_conn);
54 string escaped;
55 m_conn->escapeString(value, escaped);
56 char *where = (char*)malloc(strlen(fmt) + escaped.size() - 1);
57 sprintf(where, fmt, escaped.c_str());
59 filterBy(where, op);
60 free(where);
63 void DBQuery::filterBy(const char *fmt, const std::string &value,
64 Op op /* = And */) {
65 filterBy(fmt, value.c_str(), op);
68 void DBQuery::filterBy(const char *fmt, int value, Op op /* = And */) {
69 char *where = (char*)malloc(strlen(fmt) + 16);
70 sprintf(where, fmt, value);
72 filterBy(where, op);
73 free(where);
76 void DBQuery::filterBy(const char *fmt, unsigned int value,
77 Op op /* = And */) {
78 filterBy(fmt, (int)value, op);
81 void DBQuery::filterBy(const char *fmt, DBQueryFilterPtr filter,
82 Op op /* = And */) {
83 assert(!filter->isEmpty());
84 assert(!m_filter);
86 m_filter = filter;
87 filterBy(fmt, op);
90 void DBQuery::orderBy(const char *field, bool ascending /* = true */) {
91 m_order += m_order.empty() ? " ORDER BY " : ",";
92 m_order += field;
93 if (!ascending) {
94 m_order += " DESC";
98 void DBQuery::limit(int count, int offset /* = 0 */) {
99 m_limit = " LIMIT ";
100 if (offset) {
101 m_limit += boost::lexical_cast<string>(offset) + ", ";
103 m_limit += boost::lexical_cast<string>(count);
106 void DBQuery::insert(const char *fmt, ...) {
107 va_list ap;
108 va_start(ap, fmt);
109 format(fmt, ap);
110 va_end(ap);
111 m_insert = true;
112 m_values.push_back(m_format);
115 void DBQuery::append(const char *extra) {
116 m_extra += (extra ? extra : "");
119 void DBQuery::setField(const char *fmt) {
120 m_values.push_back(fmt);
123 void DBQuery::setField(const char *fmt, const char *value) {
124 setField(fmt, value, strlen(value));
127 void DBQuery::setField(const char *fmt, const std::string &value) {
128 setField(fmt, value.data(), value.length());
131 void DBQuery::setField(const char *fmt, const char *binary, int len) {
132 assert(m_conn);
134 string escaped;
135 m_conn->escapeString(binary, len, escaped);
137 char *buffer = (char*)malloc(strlen(fmt) + escaped.size());
138 if (!buffer) {
139 throw std::bad_alloc();
141 sprintf(buffer, fmt, escaped.c_str());
142 setField(buffer);
143 free(buffer);
146 void DBQuery::setField(const char *fmt, int value) {
147 setField(fmt, boost::lexical_cast<string>(value).c_str());
150 void DBQuery::setField(const char *fmt, unsigned int value) {
151 setField(fmt, (int)value);
154 ///////////////////////////////////////////////////////////////////////////////
156 int DBQuery::execute() {
157 return execute(nullptr);
160 int DBQuery::execute(DBDataSet &ds) {
161 return execute(&ds);
164 int DBQuery::execute(DBDataSet *ds) {
165 assert(m_conn);
166 assert(m_conn->isOpened());
168 int affected = 0;
169 for (const char *sql = getFirstSql(); sql; sql = getNextSql()) {
170 affected += m_conn->execute(sql, ds);
172 return affected;
175 int DBQuery::execute(int &result) {
176 DBDataSet ds;
177 int affected = execute(ds);
178 result = 0;
179 for (ds.moveFirst(); ds.getRow(); ds.moveNext()) {
180 result += ds.getIntField(0);
182 return affected;
185 int DBQuery::execute(unsigned int &result) {
186 DBDataSet ds;
187 int affected = execute(ds);
188 result = 0;
189 for (ds.moveFirst(); ds.getRow(); ds.moveNext()) {
190 result += ds.getUIntField(0);
192 return affected;
195 ///////////////////////////////////////////////////////////////////////////////
197 const char *DBQuery::getFirstSql() {
198 if (m_filter) {
199 const char *where = m_filter->getFirst(m_where);
200 assert(where);
201 return getSql(where);
203 return getSql(m_where.c_str());
206 const char *DBQuery::getNextSql() {
207 if (m_filter) {
208 const char *where = m_filter->getNext(m_where);
209 if (where) return getSql(where);
211 return nullptr;
214 const char *DBQuery::getSql(const char *where) {
215 if (m_values.empty()) {
216 m_sql = m_base + where + m_order + m_limit + m_extra;
217 } else if (m_insert) {
219 int total = m_base.size() + 8 + m_extra.size();
220 for (unsigned int i = 0; i < m_values.size(); i++) {
221 total += m_values[i].size() + 4;
223 m_sql.reserve(total);
225 m_sql = m_base;
226 m_sql += " VALUES ";
227 for (unsigned int i = 0; i < m_values.size(); i++) {
228 if (i > 0) m_sql += ", ";
229 m_sql += "(";
230 m_sql += m_values[i];
231 m_sql += ")";
233 m_sql += m_extra;
234 } else {
235 m_sql = m_base;
236 m_sql += " SET ";
237 for (unsigned int i = 0; i < m_values.size(); i++) {
238 if (i > 0) m_sql += ",";
239 m_sql += m_values[i];
241 m_sql += where;
242 m_sql += m_limit;
243 m_sql += m_extra;
245 return m_sql.c_str();
248 ///////////////////////////////////////////////////////////////////////////////
250 const char *DBQuery::format(const char *fmt, ...) {
251 va_list ap;
252 va_start(ap, fmt);
253 format(fmt, ap);
254 va_end(ap);
255 return m_format.c_str();
258 const char *DBQuery::format(const char *fmt, va_list ap) {
259 m_format = fmt;
261 for (string::size_type pos = m_format.find('%');
262 pos != string::npos && pos < m_format.length() - 1;
263 pos = m_format.find('%', pos + 1)) {
264 switch (m_format[pos+1]) {
265 case 's':
267 assert(m_conn);
268 const char *value = va_arg(ap, const char *);
269 string escaped;
270 m_conn->escapeString(value, escaped);
271 m_format.replace(pos, 2, escaped);
272 pos += escaped.size();
274 break;
275 case 'd':
277 int value = va_arg(ap, int);
279 char buf[12];
280 sprintf(buf, "%d", value);
281 m_format.replace(pos, 2, buf);
282 pos += strlen(buf);
284 break;
285 case 'p':
287 long value = va_arg(ap, long);
289 char buf[20];
290 sprintf(buf, "%ld", value);
291 m_format.replace(pos, 2, buf);
292 pos += strlen(buf);
294 break;
295 case '%':
296 m_format.erase(pos, 1);
297 break;
298 default:
299 assert(false);
303 return m_format.c_str();
306 std::string DBQuery::escapeFieldName(const char *fieldNameList) {
307 assert(fieldNameList);
308 string ret = "`";
309 ret += fieldNameList;
310 ret += "`";
311 Util::replaceAll(ret, ",", "`,`");
312 return ret;
315 std::string DBQuery::escapeFieldName(const std::string &fieldNameList) {
316 return escapeFieldName(fieldNameList.c_str());
319 ///////////////////////////////////////////////////////////////////////////////